zoukankan      html  css  js  c++  java
  • matlab-2

      1 function varargout = gmm(X, K_or_centroids)
      2 % ============================================================
      3 % Expectation-Maximization iteration implementation of
      4 % Gaussian Mixture Model.
      5 %
      6 % PX = GMM(X, K_OR_CENTROIDS)
      7 % [PX MODEL] = GMM(X, K_OR_CENTROIDS)
      8 %
      9 %  - X: N-by-D data matrix.
     10 %  - K_OR_CENTROIDS: either K indicating the number of
     11 %       components or a K-by-D matrix indicating the
     12 %       choosing of the initial K centroids.
     13 %
     14 %  - PX: N-by-K matrix indicating the probability of each
     15 %       component generating each point.
     16 %  - MODEL: a structure containing the parameters for a GMM:
     17 %       MODEL.Miu: a K-by-D matrix.
     18 %       MODEL.Sigma: a D-by-D-by-K matrix.
     19 %       MODEL.Pi: a 1-by-K vector.
     20 % ============================================================
     21  
     22     threshold = 1e-15;
     23     [N, D] = size(X);
     24  
     25     if isscalar(K_or_centroids)
     26         K = K_or_centroids;
     27         % randomly pick centroids
     28         rndp = randperm(N);
     29         centroids = X(rndp(1:K), :);
     30     else
     31         K = size(K_or_centroids, 1);
     32         centroids = K_or_centroids;
     33     end
     34  
     35     % initial values
     36     [pMiu pPi pSigma] = init_params();
     37  
     38     Lprev = -inf;
     39     while true
     40         Px = calc_prob();
     41  
     42         % new value for pGamma
     43         pGamma = Px .* repmat(pPi, N, 1);
     44         pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K);
     45  
     46         % new value for parameters of each Component
     47         Nk = sum(pGamma, 1);
     48         pMiu = diag(1./Nk) * pGamma' * X;
     49         pPi = Nk/N;
     50         for kk = 1:K
     51             Xshift = X-repmat(pMiu(kk, :), N, 1);
     52             pSigma(:, :, kk) = (Xshift' * ...
     53                 (diag(pGamma(:, kk)) * Xshift)) / Nk(kk);
     54         end
     55  
     56         % check for convergence
     57         L = sum(log(Px*pPi'));
     58         if L-Lprev < threshold
     59             break;
     60         end
     61         Lprev = L;
     62     end
     63  
     64     if nargout == 1
     65         varargout = {Px};
     66     else
     67         model = [];
     68         model.Miu = pMiu;
     69         model.Sigma = pSigma;
     70         model.Pi = pPi;
     71         varargout = {Px, model};
     72     end
     73  
     74     function [pMiu pPi pSigma] = init_params()
     75         pMiu = centroids;
     76         pPi = zeros(1, K);
     77         pSigma = zeros(D, D, K);
     78  
     79         % hard assign x to each centroids
     80         distmat = repmat(sum(X.*X, 2), 1, K) + ...
     81             repmat(sum(pMiu.*pMiu, 2)', N, 1) - ...
     82             2*X*pMiu';
     83         [dummy labels] = min(distmat, [], 2);
     84  
     85         for k=1:K
     86             Xk = X(labels == k, :);
     87             pPi(k) = size(Xk, 1)/N;
     88             pSigma(:, :, k) = cov(Xk);
     89         end
     90     end
     91  
     92     function Px = calc_prob()
     93         Px = zeros(N, K);
     94         for k = 1:K
     95             Xshift = X-repmat(pMiu(k, :), N, 1);
     96             inv_pSigma = inv(pSigma(:, :, k));
     97             tmp = sum((Xshift*inv_pSigma) .* Xshift, 2);
     98             coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma));
     99             Px(:, k) = coef * exp(-0.5*tmp);
    100         end
    101     end
    102 end

    (注:此段代码为GMM的EM算法实现)

    1、isscalar

    该函数用于判断输入参数是否是一个标量。在matlab中所谓标量,即1行1列的矩阵。
    语法格式:
    TF = isscalar(A)
    如果矩阵A是一行一列的,则返回逻辑1(true),否则返回逻辑0(false)。
    相关函数:isa、isvector
     
    2、随机函数

    a)rand函数

    rand(n):生成0到1之间的n阶随机数方阵

    rand(m,n):生成0到1之间的m×n的随机数矩阵

    b)randint函数

    randint(m,n,[1 N]):生成m×n的在1到N之间的随机整数矩阵,其效果与randint(m,n,N+1)相同。

    >> randint(3,4,[1 10])

    ans =

         5     7     4    10
         5     1     2     7
         8     7     8     6
    >> randint(3,4,11)

    ans =

        10     9     6     9
         5    10     8     9
        10     0     2     6

    c)randperm函数

    randperm(n):产生一个1到n的随机顺序。
    >> randperm(10)

    ans =

         6     4     8     9     3     5     7   10     2     1

    3、

    xn是一个向量,也就是一维数组,xn(k:-1:k-M+1)的意义:假设k=10,M=5,则该式变为xn(10:-1:6),则x = xn(10:-1:6)的意思就算把xn(10)至xn(6)共五个数按从10到6的顺序赋给x(1)到x(5),即x(1)=xn(10),x(2)=xn(9)....,如果是正向的就不用加-1,例如xn(6:10),默认间隔为1.

    4、inf、nan

    Matlab中的Inf和-Inf分别代表正无穷和负无穷;

    NaN表示非数值的值;

    无穷一般是由于0 做了分母或者运算溢出,产生了超出双精度浮点数数值范围的结果;

    非数值量则是因为0/0,或者Inf/Inf型的非正常运算。

    5、zeros函数和ones函数

    zeros函数——生成零矩阵

    ones函数——生成全1阵

    【zeros的使用方法】

    B=zeros(n):生成n×n全零阵。

    B=zeros(m,n):生成m×n全零阵。

    B=zeros([m n]):生成m×n全零阵。

    B=zeros(d1,d2,d3……):生成d1×d2×d3×……全零阵或数组。

    B=zeros([d1 d2 d3……]):生成d1×d2×d3×……全零阵或数组。

    B=zeros(size(A)):生成与矩阵A相同大小的全零阵。

    【ones的使用方法】

    ones的使用方法与zeros的使用方法类似。

    6、repmat函数

    repmat 即 Replicate Matrix ,复制和平铺矩阵

    a)B = repmat(A,m,n)
    将矩阵 A 复制 m×n 块,即把 A 作为 B 的元素,B 由 m×n 个 A 平铺而成。B 的维数是 [size(A,1)*m, size(A,2)*n] 。
    >> A = [1,2;3,4]
    A =
    1 2
    3 4
    >> B = repmat(A,2,3)
    B =
    1 2 1 2 1 2
    3 4 3 4 3 4
    1 2 1 2 1 2
    3 4 3 4 3 4
    b)B = repmat(A,[m n])
    与 B = repmat(A,m,n) 用法一致。
     
    7、matlab代码中省略号代表改行没结束,进行续行。
     
    8、max函数

    a)MAX函数的几种形式 
    (1)max(a)

     (2)max(a,b)

     (3)max(a,[],dim) 

    (4)[C,I]=max(a)

     (5)[C,I]=max(a,[],dim) 
    b)举例说明函数意思 
    (1)max(a) 
    如果a是一个矩阵,比如a=[1,2,3;4,5,6],max(a)的意思就是找出矩阵每列的最大值, 

    本例中:max(a)=[4,5,6] 
    (2)max(a,b) 
    如果a和b都是大于1维的矩阵,那么要求a和b的行列的维数都要相等,函数的结果是比较a和b中每个元素的大小,

    比如: 
    a=[1,2,3;4,5,6]      b=[4,5,6;7,8,3] 

    max(a,b)=[4,5,6;7,8,6] 

    另外,如果a和b中至少有一个是常数,也是可以的。

    比如:a=[1,2,3;4,5,6]      b=3     c=5 
    max(a,b)=[3,3,3;4,5,6]       

    max(b,c)=5 
    (3)max(a,[],dim) 
    这个函数的意思是针对于2维矩阵的,dim是英文字母dimension的缩写,意思是维数。 

    当dim=1时,比较的a矩阵的行,也就是和max(a)的效果是一样的;

    当dim2时,比较的是a矩阵的列。

    下面举个例子: 
    a=[1,2,3;4,5,6]      

     max(a)=max(a,[],1)=[4,5,6]    比较的第一行和第二行的值  

    max(a,[],2)=[3,6]

    (4)[C,I]=max(a) 
    C表示的是矩阵a每列的最大值,I表示的是每个最大值对应的下标:

     下面举例说明: 
    还是刚才那个例子:a=[1,2,3;4,5,6]          [C,I]=max(a) 
    结果显示的是C=[4,5,6]       I=[2,2,2]   返回的是最大值对应的行号。 
    (5)[C,I]=max(a,[],dim) 
    同理:如果dim=1时,其结果和[c,i]=max(a)是一样的。 

    当dim=2时,同样上面的矩阵a,我们运行一下: 
    [c,i]=max(a,[],2)     结果是:c=[3,6]   i=[3,3]    i返回的是矩阵a的列号。

    9、sum函数

    sum(x,2)表示矩阵x的横向相加,求每行的和,结果是列向量。
    而缺省的sum(x)就是竖向相加,求每列的和,结果是行向量。

    A>0的结果是得到一个逻辑矩阵,大小跟原来的A一致,
    A中大于零的元素的位置置为1,小于等于零的位置置为0。

    所以横向求和以后,就是求A中每行大于零的元素个数。

    例如
    >> A=randn(5)


    A =

       -0.4326    1.1909   -0.1867    0.1139    0.2944
       -1.6656    1.1892    0.7258    1.0668   -1.3362
        0.1253   -0.0376   -0.5883    0.0593    0.7143
        0.2877    0.3273    2.1832   -0.0956    1.6236
       -1.1465    0.1746   -0.1364   -0.8323   -0.6918

    >> sum(A)

    ans =

       -2.8316    2.8444    1.9976    0.3120    0.6043

    >> sum(A>0)

    ans =

         2     4     2     3     3

    >> sum(A<0)

    ans =

         3     1     3     2     2

    >> sum(A,2)

    ans =

        0.9800
       -0.0200
        0.2730
        4.3261
       -2.6324

    >> sum(A>0,2)

    ans =

         3
         3
         3
         4
         1

    sum(A<0,2)

    ans =

         2
         2
         2
         1
         4

     10、x.*x的含义:

    它表示两个矩阵的相对应元素之间直接进行乘积运算。例如,A=[1 2 ;3 4 ],B=[5 6;7 8] .C=A.*B=[1*5 2*6;3*7 4*8]=[5 12;21 28].

    11、size函数

    size(A)函数是用来求矩阵的大小的,你必须首先弄清楚A到底是什么,大小是多少。

    比如说一个A是一个3×4的二维矩阵:

          a)、size(A) %直接显示出A大小

           输出:ans=

                              3    4

           b)、s=size(A)%返回一个行向量s,s的第一个元素是矩阵的行数,第二个元素是矩阵的列数

           输出:s=

                              3    4

           c)、[r,c]=size(A)%将矩阵A的行数返回到第一个输出变量r,将矩阵的列数返回到第二个输出变量c

           输出:r=

                              3

                    c=

                              4

           d)、[r,c,m]=size(A)

           输出:r=

                              3

                    c=

                              4

                    m=

                              1

    也就说它把二维矩阵当作第三维为1的三维矩阵,这也如同我们把n维列向量当作n×1的矩阵一样

           e)、当a是一个n维行向量时,size(A)把其当成一个1×n的矩阵,因此size(a)的结果是

           ans

                      1   n

    而不是a的元素个数n

           f)、size(A,n)

           如果在size函数的输入参数中再添加一项n,并用1或2为n赋值,则 size将返回矩阵的行数或列数。其中r=size(A,1)该语句返回的是矩阵A的行数, c=size(A,2) 该语句返回的是矩阵A的列数。

    12、diag函数

    diag函数功能:矩阵对角元素的提取和创建对角阵

    设以下X为方阵,v为向量

    a、X = diag(v,k)当v是一个含有n个元素的向量时,返回一个n+abs(k)阶方阵X,向量v在矩阵X中的第k个对角线上,k=0表示主对角线,k>0表示在主对角线上方,k<0表示在主对角线下方。例1:

    v=[1 2 3];
    diag(v, 3)

    ans =

         0     0     0     1     0     0
         0     0     0     0     2     0
         0     0     0     0     0     3
         0     0     0     0     0     0
         0     0     0     0     0     0
         0     0     0     0     0     0

    注:从主对角矩阵上方的第三个位置开始按对角线方向产生数据的

    例2:

    v=[1 2 3];
    diag(v, -1)
    ans =
          0 0 0 0
          1 0 0 0
          0 2 0 0
          0 0 3 0

    注:从主对角矩阵下方的第一个位置开始按对角线方向产生数据的

    b、X = diag(v)

    向量v在方阵X的主对角线上,类似于diag(v,k),k=0的情况。

    例3:

    v=[1 2 3];
    diag(v)

    ans =

    1 0 0
    0 2 0
    0 0 3

    注:写成了对角矩阵的形式

     cv = diag(X,k)

    返回列向量v,v由矩阵X的第k个对角线上的元素形成

    例4:

     v=[1 0 3;2 3 1;4 5 3];
    diag(v,1)

    ans =

         0
         1

    注:把主对角线上方的第一个数据作为起始数据,按对角线顺序取出写成列向量形式

    d、v = diag(X)返回矩阵X的主对角线上的元素,类似于diag(X,k),k=0的情况例5:

    v=[1 0 0;0 3 0;0 0 3];
    diag(v)

    ans =

    1
    3
    3

    或改为:

    v=[1 0 3;2 3 1;4 5 3];
    diag(v)

    ans =

    1
    3
    3

    注:把主对角线的数据取出写成列向量形式

    e、diag(diag(X))

    取出X矩阵的对角元,然后构建一个以X对角元为对角的对角矩阵。
    例6:

     X=[1 2;3 4]       
     diag(diag(X))

    X =

         1     2
         3     4

    ans =

         1     0
         0     4

  • 相关阅读:
    java数组的相关方法
    spring boot 文件目录
    mysql 数据库安装,datagrip安装,datagrip连接数据库
    linux maven 的安装与配置
    java String字符串常量常用方法
    java 命名规范
    deepin 安装open jdk
    jetbrains(idea,webstorm,pycharm,datagrip)修改背景,主题,添加特效,汉化
    JVM学习(九)volatile应用
    JVM学习(八)指令重排序
  • 原文地址:https://www.cnblogs.com/xiaojingang/p/4448454.html
Copyright © 2011-2022 走看看