zoukankan      html  css  js  c++  java
  • GMM 的EM 实现

    算法逻辑在这里:

    http://www.cnblogs.com/Azhu/p/4131733.html

        贴之前先说下,本来呢是打算自己写一个的,在matlab 上,不过,实在是写不出来那么高效和健壮的,网上有很多实现的代码,例如上面参考里面的,那个代码明显有问题阿,然后因为那里面的代码与逻辑分析是一致的,那在其基础上修改看看,结果发现代码健壮性实在太差了,我的数据集是 70-by-2000 的矩阵,70个样本2000维,结果协方差的逆根本算不出来,全部是inf,那去前50维,还是算不出来,这个虽然逻辑是对的,但是这软件的局限阿。

        那只能用其他方法了,有一个写的很好的,下面会贴出来,不过都是矩阵运算,看是能看懂的,不过数学计算实在写不出来,按这么来的也只是跟着其敲了一遍,敲之前还看了半天才懂其的数学计算,matlab 的内置函数也不算熟,这里就顺便写下来好了。

    主函数:

    • 12-26 行是初始化类标号和其他参数,12行调用了初始标号的参数,实际上初始化的是R。
    • R 是一个n-by-k 矩阵,每项表示一个i-th 样本在 j-th GM 中的概率值,就是p(x|k)。
    • 因为是初始化,所以14行获取了当前类标号label。
    • 27 - 40 是迭代部分,通过判断是否收敛和迭代次数循环
    • 29 是m-step, 这跟我写的算法逻辑有点不同,不过不影响。
    • 29 m-step是假设知道了标号,训练GMM 模型参数,获得的是model。
    • 30 是 e-step,假设训练好了GMM ,计算样本的分配情况,其中loglikehood 是在e-step 中计算了。
    • 剩下的是收敛判断
     1 function [label, model, llh] = emgm(X, init)
     2 % Perform EM algorithm for fitting the Gaussian mixture model.
     3 %   X: d x n data matrix
     4 %   init: k (1 x 1) or label (1 x n, 1<=label(i)<=k) or center (d x k)
     5 % Written by Michael Chen (sth4nth@gmail.com).
     6 %% initialization
     7 % fprintf('EM for Gaussian mixture: running ... 
    ');
     8 % load('final_initlize');
     9 % X = dataset(1).x';
    10 % init = dataset(1).y';
    11 % R n-by-k 矩阵,表示i-th 样本属于j-th 个类的概率,初始化时候为1、0,迭代后变是权重化了。
    12 R = initialization(X,init);
    13 % label 表示n 个样本的类标号。
    14 [~,label(1,:)] = max(R,[],2);
    15 % 这句是为了处理类标号不连续的情况
    16 R = R(:,unique(label));
    17 
    18 %pect = zeros(size(label));
    19 % tol 是阀值控制
    20 tol = 1e-10;
    21 maxiter = 500;
    22 % loglikehood
    23 llh = -inf(1,maxiter);
    24 converged = false;
    25 % 当前迭代的标号
    26 t = 1;
    27 while ~converged && t < maxiter
    28     t = t+1;
    29     model = maximization(X,R);
    30     [R, llh(t)] = expectation(X,model);
    31    
    32     [~,label(:)] = max(R,[],2);
    33     u = unique(label);   % non-empty components
    34     if size(R,2) ~= size(u,2)
    35         R = R(:,u);   % remove empty components
    36     else
    37         converged = llh(t)-llh(t-1) < tol*abs(llh(t));
    38     end
    39 
    40 end
    41 llh = llh(2:t);
    42 % if converged
    43 %     fprintf('Converged in %d steps.
    ',t-1);
    44 %     llh  = t-1;  
    45 % else
    46 %     fprintf('Not converged in %d steps.
    ',maxiter);
    47 %     llh = maxiter;   
    48 % end

    初始化函数:

        这个函数很简单,没什么好解释的。

     1 %% init
     2 function R = initialization(X, init)
     3 % 初始化一共用4中方式,一种是给定GMM 模型的参数初始值,一种是给定k 的个数,一种是给各sample 的标号,一种是给出类的中心点
     4 [d,n] = size(X);
     5 if isstruct(init)  % initialize with a model
     6     R  = expectation(X,init);
     7 elseif length(init) == 1  % random initialization
     8     k = init;
     9     idx = randsample(n,k);
    10     m = X(:,idx);
    11     [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1);
    12     [u,~,label] = unique(label);
    13     while k ~= length(u)
    14         idx = randsample(n,k);
    15         m = X(:,idx);
    16         [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1);
    17         [u,~,label] = unique(label);
    18     end
    19     R = full(sparse(1:n,label,1,n,k,n));
    20 elseif size(init,1) == 1 && size(init,2) == n  % initialize with labels
    21     label = init;
    22     k = max(label);
    23     R = full(sparse(1:n,label,1,n,k,n));
    24 elseif size(init,1) == d  %initialize with only centers
    25     k = size(init,2);
    26     m = init;
    27     [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1);
    28     R = full(sparse(1:n,label,1,n,k,n));
    29 else
    30     error('ERROR: init is not valid.');
    31 end
    View Code

    m-step函数:

    • 输入参数 R解释参考上面。
    • 7 计算各类的sample 个数和,一个1-by-k matrix。
    • 8 7中的值除以样本总数就是 GM 的权重,同样是1-by-k matrix。
    • 9 计算GM 样本均值,mu 是个d-by-k matrix,每列表示 k-th GM 的样本均值。
    • 19 计算sqrtR 是为了 15-17行的计算中 结果刚好是R。
    • 15-17 sigma
    • 18行 应该是为了避免sigma 不能逆。
     1 %% m-step
     2 function model = maximization(X, R)
     3 [d,n] = size(X);
     4 % k 为类个数
     5 k = size(R,2);
     6 % 各类的sample个数
     7 nk = sum(R,1);
     8 w = nk/n;
     9 mu = bsxfun(@times, X*R, 1./nk);
    10 
    11 Sigma = zeros(d,d,k);
    12 % 这个值是为了下面计算时候得到R,
    13 sqrtR = sqrt(R);
    14 for i = 1:k
    15     Xo = bsxfun(@minus,X,mu(:,i));
    16     Xo = bsxfun(@times,Xo,sqrtR(:,i)');
    17     Sigma(:,:,i) = Xo*Xo'/nk(i);
    18     Sigma(:,:,i) = Sigma(:,:,i)+eye(d)*(1e-6); % add a prior for numerical stability
    19 end
    20 
    21 model.mu = mu;
    22 model.Sigma = Sigma;
    23 model.weight = w;

    e-step:

    • e step 需要解释很多阿。
    • 9 logRho,首先我们知道R 是每项表示一个i-th 样本在 j-th GM 中的概率值,计算公式如下,公式中x是d-by-1 的sample,也就是gamma 中的N()
    • %Gaussian posterior probability
      %N(x|pMiu,pSigma) = 1/((2pi)^(D/2))*(1/(abs(sigma))^0.5)*exp(-1/2*(x-pMiu)'pSigma^(-1)*(x-pMiu))

    • 问题是上面公式不一定能按步求出来阿,例如 sigma^-1,不一定解得出来阿,所以对上面得N()log 一下,后计算,同时避开计算sigma^-1,这个矩阵就是logRho
    • 20-31 便是12 行的函数调用,其中涉及了一堆矩阵转换,验证过没有错,计算的就是log 后的 N()
    • 14 上面公式 gamma 的分子部分。
    • 15-16 是计算当前的loglikehood。
    • 17 计算R 矩阵的log 形式。
    • 18 反计算R。
     1 %% e-step
     2 function [R, llh] = expectation(X, model)
     3 mu = model.mu;
     4 Sigma = model.Sigma;
     5 w = model.weight;
     6 
     7 n = size(X,2);
     8 k = size(mu,2);
     9 logRho = zeros(n,k);
    10 
    11 for i = 1:k
    12     logRho(:,i) = loggausspdf(X,mu(:,i),Sigma(:,:,i));
    13 end
    14 logRho = bsxfun(@plus,logRho,log(w));
    15 T = logsumexp(logRho,2);
    16 llh = sum(T)/n; % loglikelihood
    17 logR = bsxfun(@minus,logRho,T);
    18 R = exp(logR);
    19 %% log pdf
    20 function y = loggausspdf(X, mu, Sigma)
    21 
    22 d = size(X,1);
    23 X = bsxfun(@minus,X,mu);
    24 [U,p]= chol(Sigma);
    25 if p ~= 0
    26     error('ERROR: Sigma is not PD.');
    27 end
    28 Q = U'X;
    29 q = dot(Q,Q,1);  % quadratic term (M distance)
    30 c = d*log(2*pi)+2*sum(log(diag(U)));   % normalization constant
    31 y = -(c+q)/2;
  • 相关阅读:
    ECMAScript2017之async function
    ES3之closure ( 闭包 )
    RxJS之AsyncSubject
    RxJS之BehaviorSubject
    RxJS之Subject主题 ( Angular环境 )
    RxJS之工具操作符 ( Angular环境 )
    RxJS之转化操作符 ( Angular环境 )
    RxJS之过滤操作符 ( Angular环境 )
    RxJS之组合操作符 ( Angular环境 )
    关于Qt的StyleSheet作用范围
  • 原文地址:https://www.cnblogs.com/Azhu/p/4133454.html
Copyright © 2011-2022 走看看