zoukankan      html  css  js  c++  java
  • GMM-实现聚类的代码示例

    Matlab 代码:

     1 % GMM code
     2 
     3 function varargout = gmm(X, K_or_centroids)
     4 
     5     % input X:N-by-D data matrix
     6     % input K_or_centroids: K-by-D centroids
     7     
     8     % 阈值
     9     threshold = 1e-15;
    10     % 读取数据维度
    11     [N, D] = size(X);
    12     % 判断输入质心是否为标量
    13     if isscalar(K_or_centroids)
    14         % 是标量,随机选取K个质心
    15         K = K_or_centroids;
    16         rnpm = randperm(N); % 打乱的N个序列
    17         centroids = X(rnpm(1:K), :);
    18     else   % 矩阵,给出每一类的初始化
    19         K = size(K_or_centroids, 1);
    20         centroids = K_or_centroids;
    21     end
    22     
    23     % 定义模型初值
    24     [pMiu pPi pSigma] = init_params();
    25     
    26     Lprev = -inf;
    27     while true
    28         % E-step,估算出概率值
    29         % Px: N-by-K 
    30         Px = calc_prob();
    31         
    32         % pGamma新的值,样本点所占的权重
    33         % pPi:1-by-K     pGamma:N-by-K
    34         pGamma = Px ./ repmat(pPi, N, 1);
    35         % 对pGamma的每一行进行求和,sum(x,2):每一行求和
    36         pGamma = pGamma ./ repmat(sum(pGamma, 2) ,1 , K);
    37         
    38         % M-step
    39         % 每一个组件给予新的值
    40         Nk = sum(pGamma,1);
    41         pMiu = diag(1./Nk)*pGamma'*X;
    42         pPi = Nk/N;
    43         for kk = 1:K
    44            Xshift = X - repmat(pMiu(kk, :) ,N, 1);
    45            pSigma(:,:,kk) = (Xshift'*(diag(pGamma(:,kk))*Xshift)) / Nk(kk);
    46         end
    47         
    48         % 观察收敛,convergence
    49         L = sum(log(Px*pPi'));
    50         if L-Lprev < threshold
    51             break;
    52         end
    53         Lprev = L;
    54         
    55     end
    56     
    57     % 输出参数判定
    58     if nargout == 1
    59         varargout = {Px};
    60     else
    61         model = [];
    62         model.Miu = pMiu;
    63         model.Sigma = pSigma;
    64         model.Pi = pPi;
    65         varargout = {Px, model};
    66     end
    67     
    68     function [pMiu pPi pSigma] = init_params()
    69        pMiu = centroids; % 均值,K类的中心
    70        pPi = zeros(1, K); % 概率
    71        pSigma = zeros(D, D, K); % 协方差,每一个都是D-by-D
    72        
    73        % (X - pMiu)^2 = X^2 + pMiu^2 - 2*X*pMiu
    74        distmat = repmat(sum(X.*X, 2), 1, K) + repmat(sum(pMiu.*pMiu, 2)', N, 1) - 2*X*pMiu';
    75        [dummy labels] = min(distmat, [], 2); % 找出每一行的最小值,并标出列的位置
    76        
    77        for k=1:K   %初始化参数
    78            Xk = X(labels == k, :);
    79            pPi(k) = size(Xk, 1)/N;
    80            pSigma(:, :, k) = cov(Xk);
    81        end             
    82     end
    83 
    84     % 计算概率值
    85     function Px = calc_prob()
    86         Px = zeros(N,K);
    87         for k=1:K
    88             Xshift = X - repmat(pMiu(k,:),N,1);
    89             inv_pSigma = inv(pSigma(:,:,k)+diag(repmat(threshold, 1, size(pSigma(:,:,k),1))));
    90             tmp = sum((Xshift*inv_pSigma).*Xshift, 2);
    91             coef = (2*pi)^(-D/2)*sqrt(det(inv_pSigma));
    92             Px(:,k) = coef * exp(-1/2*tmp);
    93         end
    94     end
    95       
    96 
    97 end

    测试主程序:

     1 % 测试代码
     2 clear all
     3 clc
     4 
     5 data = load('testSet.txt');
     6 [PX, Model] = gmm(data, 4);
     7 [~,index] = max(PX'); % 每一列的最大值
     8 
     9 cent = Model.Miu;
    10 figure
    11 I = find(index == 1);
    12 scatter(data(I,1), data(I,2))
    13 hold on 
    14 scatter(cent(1,1), cent(1,2) ,150, 'filled');
    15 hold on
    16 I = find(index == 2);
    17 scatter(data(I,1),data(I,2))
    18 hold on
    19 scatter(cent(2,1),cent(2,2),150,'filled')
    20 hold on
    21 I = find(index == 3);
    22 scatter(data(I,1),data(I,2))
    23 hold on
    24 scatter(cent(3,1),cent(3,2),150,'filled')
    25 hold on
    26 I = find(index == 4);
    27 scatter(data(I,1),data(I,2))
    28 hold on
    29 scatter(cent(4,1),cent(4,2),150,'filled')

    示意图:

     参考自:http://www.voidcn.com/blog/llp1992/article/p-2308490.html

  • 相关阅读:
    c++ 利用new动态的定义二维数组
    golang在linux后台执行的方法
    Linux安装配置go运行环境
    SpringCloud 笔记
    你真的了解 Unicode 和 UTF-8 吗?
    Elasticsearch 系列文章汇总(持续更新...)
    Maven 的依赖范围
    在 centos 上安装 virutalbox
    Java 异常总结
    使用 RabbitMQ 实现异步调用
  • 原文地址:https://www.cnblogs.com/demo-deng/p/7127979.html
Copyright © 2011-2022 走看看