Matlab 代码:
1 % GMM code 2 3 function varargout = gmm(X, K_or_centroids) 4 5 % input X:N-by-D data matrix 6 % input K_or_centroids: K-by-D centroids 7 8 % 阈值 9 threshold = 1e-15; 10 % 读取数据维度 11 [N, D] = size(X); 12 % 判断输入质心是否为标量 13 if isscalar(K_or_centroids) 14 % 是标量,随机选取K个质心 15 K = K_or_centroids; 16 rnpm = randperm(N); % 打乱的N个序列 17 centroids = X(rnpm(1:K), :); 18 else % 矩阵,给出每一类的初始化 19 K = size(K_or_centroids, 1); 20 centroids = K_or_centroids; 21 end 22 23 % 定义模型初值 24 [pMiu pPi pSigma] = init_params(); 25 26 Lprev = -inf; 27 while true 28 % E-step,估算出概率值 29 % Px: N-by-K 30 Px = calc_prob(); 31 32 % pGamma新的值,样本点所占的权重 33 % pPi:1-by-K pGamma:N-by-K 34 pGamma = Px ./ repmat(pPi, N, 1); 35 % 对pGamma的每一行进行求和,sum(x,2):每一行求和 36 pGamma = pGamma ./ repmat(sum(pGamma, 2) ,1 , K); 37 38 % M-step 39 % 每一个组件给予新的值 40 Nk = sum(pGamma,1); 41 pMiu = diag(1./Nk)*pGamma'*X; 42 pPi = Nk/N; 43 for kk = 1:K 44 Xshift = X - repmat(pMiu(kk, :) ,N, 1); 45 pSigma(:,:,kk) = (Xshift'*(diag(pGamma(:,kk))*Xshift)) / Nk(kk); 46 end 47 48 % 观察收敛,convergence 49 L = sum(log(Px*pPi')); 50 if L-Lprev < threshold 51 break; 52 end 53 Lprev = L; 54 55 end 56 57 % 输出参数判定 58 if nargout == 1 59 varargout = {Px}; 60 else 61 model = []; 62 model.Miu = pMiu; 63 model.Sigma = pSigma; 64 model.Pi = pPi; 65 varargout = {Px, model}; 66 end 67 68 function [pMiu pPi pSigma] = init_params() 69 pMiu = centroids; % 均值,K类的中心 70 pPi = zeros(1, K); % 概率 71 pSigma = zeros(D, D, K); % 协方差,每一个都是D-by-D 72 73 % (X - pMiu)^2 = X^2 + pMiu^2 - 2*X*pMiu 74 distmat = repmat(sum(X.*X, 2), 1, K) + repmat(sum(pMiu.*pMiu, 2)', N, 1) - 2*X*pMiu'; 75 [dummy labels] = min(distmat, [], 2); % 找出每一行的最小值,并标出列的位置 76 77 for k=1:K %初始化参数 78 Xk = X(labels == k, :); 79 pPi(k) = size(Xk, 1)/N; 80 pSigma(:, :, k) = cov(Xk); 81 end 82 end 83 84 % 计算概率值 85 function Px = calc_prob() 86 Px = zeros(N,K); 87 for k=1:K 88 Xshift = X - repmat(pMiu(k,:),N,1); 89 inv_pSigma = inv(pSigma(:,:,k)+diag(repmat(threshold, 1, size(pSigma(:,:,k),1)))); 90 tmp = sum((Xshift*inv_pSigma).*Xshift, 2); 91 coef = (2*pi)^(-D/2)*sqrt(det(inv_pSigma)); 92 Px(:,k) = coef * exp(-1/2*tmp); 93 end 94 end 95 96 97 end
测试主程序:
1 % 测试代码 2 clear all 3 clc 4 5 data = load('testSet.txt'); 6 [PX, Model] = gmm(data, 4); 7 [~,index] = max(PX'); % 每一列的最大值 8 9 cent = Model.Miu; 10 figure 11 I = find(index == 1); 12 scatter(data(I,1), data(I,2)) 13 hold on 14 scatter(cent(1,1), cent(1,2) ,150, 'filled'); 15 hold on 16 I = find(index == 2); 17 scatter(data(I,1),data(I,2)) 18 hold on 19 scatter(cent(2,1),cent(2,2),150,'filled') 20 hold on 21 I = find(index == 3); 22 scatter(data(I,1),data(I,2)) 23 hold on 24 scatter(cent(3,1),cent(3,2),150,'filled') 25 hold on 26 I = find(index == 4); 27 scatter(data(I,1),data(I,2)) 28 hold on 29 scatter(cent(4,1),cent(4,2),150,'filled')
示意图:
参考自:http://www.voidcn.com/blog/llp1992/article/p-2308490.html