zoukankan      html  css  js  c++  java
  • GMM demo

    # GMM model
    # 2015/7/22
    library(mvtnorm)
    
    set.seed(1)
    n1 = 1000
    n2 = 1000
    mu1 = c(0,1)
    mu2 = c(-5,-6)
    sigma1 = matrix(c(1,.5,.5,2),nrow=2)
    sigma2 = matrix(c(2,.5,.5,1),nrow=2)
    y1 = rep(1,n1)
    y2 = rep(2,n2)
    x1 = rmvnorm(n1, mean=mu1, sigma=sigma1)
    x2 = rmvnorm(n2, mean=mu2, sigma=sigma2)
    
    x = rbind(x1,x2)
    y = rbind(y1,y2)
    
    ns = 2
    ngrid = 100
    
    mv.gauss = function(x,y,mu,sigma)
    {
      nx = length(x)
      ny = length(y)
      z = matrix(0,nrow=ny, ncol=nx)
      sigma_inv = solve(sigma)
      det_sigma = det(sigma)
      for (i in 1:nx){
        for (j in 1:ny){
          z[i,j] = 1/(2*pi*sqrt(det_sigma)) * exp(-t(c(x[i], y[j]) - mu) %*% sigma_inv %*% t(t(c(x[i], y[j]) - mu)))
        }
      }
      return(z) 
    }
    gauss_density = function(x,mu,sigma)
    {
      nx = length(x)
      ny = length(y)
      z = matrix(0,nrow=ny, ncol=nx)
      sigma_inv = solve(sigma)
      det_sigma = det(sigma)
      value = 1/(2*pi*sqrt(det_sigma)) * exp(-1/2* t(x - mu) %*% sigma_inv %*% t(t(x - mu)))
      
      return(value) 
    }
    plot_contour = function(ngrid, ns, mv.gauss, mu1,sigma1,mu2,sigma2){
      x.range1 = seq(mu1[1]-ns*sigma1[1],mu1[1]+ns*sigma1[1],length.out=ngrid)
      y.range1 = seq(mu1[2]-ns*sigma1[4],mu1[2]+ns*sigma1[4],length.out=ngrid)
      
      x.range2 = seq(mu2[1]-ns*sigma2[1],mu2[1]+ns*sigma2[1],length.out=ngrid)
      y.range2 = seq(mu2[2]-ns*sigma2[4],mu1[2]+ns*sigma2[4],length.out=ngrid)
      
      z1 = mv.gauss(x.range1, y.range1, mu1, sigma1)
      z2 = mv.gauss(x.range2, y.range2, mu2, sigma2)
      contour(x.range1, y.range1, z1, add=TRUE,col="red", lwd = 3)
      contour(x.range2, y.range2, z2, add=TRUE,col="blue", lwd = 3)
    }
    plot_iter = function(ngrid,x1,x2,mv.gauss, mu1,sigma1,mu2,sigma2, iter=1){
      x = rbind(x1,x2)
      plot(x[,1], x[,2], type='p', 
           main=sprintf("Iter %d: mu1=(%.2f, %.2f)/(-5,-6) mu2=(%.2f, %.2f)/(0,1)", 
                        iter, mu1[1],mu1[2], mu2[1], mu2[2]))
      points(mu1[1], mu1[2], col='red', pch=15)
      points(mu2[1], mu2[2], col='blue', pch=15)
      plot_contour(ngrid,4,mv.gauss, mu1,sigma1,mu2,sigma2)
    }
    obj_value = function(x,phi, mu1, sigma1, mu2, sigma2){
      n = dim(x)[1]
      res = 0
      for (i in i:n){
        res = res + log(phi[1]*gauss_density(x[i,], mu1, sigma1)+phi[2]*gauss_density(x[i,], mu2, sigma2))
      }
      return(res)
    }
    plot_iter(ngrid,x1,x2,mv.gauss, mu1,sigma1,mu2,sigma2)
    
    mu1_i = c(0,0)
    mu2_i = c(1,0)
    sigma1_i = matrix(c(1,0,0,1),nrow=2)
    sigma2_i = matrix(c(1,0,0,1),nrow=2)
    plot_iter(ngrid,x1,x2,mv.gauss, mu1_i,sigma1_i,mu2_i,sigma2_i,0)
    
    n = n1+n2
    w = array(0,dim=c(n,2))
    phi1 = 0.5
    phi2 = 0.5
    num_iter = 12
    obj_val = rep(0,num_iter)
    for (ii in 1:num_iter){
      # E-step
      for (i in 1:n){
        w[i,1] = phi1 * gauss_density(x[i,], mu1_i, sigma1_i)
        w[i,2] = phi2 * gauss_density(x[i,], mu2_i, sigma2_i)
        tmp = sum(w[i,])
        w[i,1] = w[i,1] / tmp
        w[i,2] = w[i,2] / tmp
      }
      
      # M-step
      phi1 = mean(w[,1])
      phi2 = mean(w[,2])
      mu1_i = colSums(w[,1]*x) / sum(w[,1])
      mu2_i = colSums(w[,2]*x) / sum(w[,2])
      tmp = matrix(0,nrow=2,,ncol=2)
      mu = mu1_i
      for (i in 1:n){
        tmp = tmp + w[i,1] * (t(t(x[i,] - mu)) %*% t(x[i,] - mu))
      }
      sigma1_i = tmp / sum(w[,1])
      tmp = matrix(0,nrow=2,ncol=2)
      mu = mu2_i
      for (i in 1:n){
        tmp = tmp + w[i,2] * (t(t(x[i,] - mu)) %*% t(x[i,] - mu))
      }
      sigma2_i = tmp / sum(w[,2])
      plot_iter(ngrid,x1,x2,mv.gauss, mu1_i,sigma1_i,mu2_i,sigma2_i, ii)
      obj_val[ii] = obj_value(x,c(phi1,phi2), mu1_i,sigma1_i, mu2_i, sigma2_i)
    }
    plot(obj_val,type="l",main="Objective function: log likelihood",xlab="#Iteration")
    print(c(phi1, phi2))
    print(sigma1_i)
    print(sigma2_i)
  • 相关阅读:
    fiddler 抓包工具(新猿旺学习总结)
    Monkey之常用ADB命令(新猿旺学习总结)
    APP压力测试 monkey(新猿旺学习总结)
    linux 系统shell运行程序不退出
    c++字节对齐编译器指令#pragma
    vmware 14 新安装centos7 没法联网
    windows dll的def文件
    c编译器字节对齐指令
    centos 7 进入图形界面
    cent os 7 与cent os 6区别
  • 原文地址:https://www.cnblogs.com/shalijiang/p/4684208.html
Copyright © 2011-2022 走看看