zoukankan      html  css  js  c++  java
  • 6. EM算法-高斯混合模型GMM+Lasso详细代码实现

    1. 前言

    我们之前有介绍过4. EM算法-高斯混合模型GMM详细代码实现,在那片博文里面把GMM说涉及到的过程,可能会遇到的问题,基本讲了。今天我们升级下,主要一起解析下EM算法中GMM(搞事混合模型)带惩罚项的详细代码实现。

    2. 原理

    由于我们的极大似然公式加上了惩罚项,所以整个推算的过程在几个地方需要修改下。

    在带penality的GMM中,我们假设协方差是一个对角矩阵,这样的话,我们计算高斯密度函数的时候,只需要把样本各个维度与对应的(mu_k)(sigma_k)计算一维高斯分布,再相加即可。不需要通过多维高斯进行计算,也不需要协方差矩阵是半正定的要求。

    我们给上面的(1)式加入一个惩罚项,

    [lambdasum_{k=1}^Ksum_{j=1}^Pfrac{|mu_k-ar{x}_j|}{s_j} ]

    其中的(P)是样本的维度。(ar{x}_j)表示每个维度的平均值,(s_j)表示每个维度的标准差。这个penality是一个L1范式,对(mu_k)进行约束。

    加入penality后(1)变为

    [L( heta, heta^{(j)})=sum_{k=1}^Kn_k[logpi_k-frac{1}{2}(log(oldsymbol{Sigma_k})+frac{{(x_i-oldsymbol{mu}_k})^2}{oldsymbol{Sigma}_k})] - lambdasum_{k=1}^Ksum_{j=1}^Pfrac{|mu_k-ar{x}_j|}{s_j} ]

    这里需要注意的一点是,因为penality有一个绝对值,所以在对(mu_k)求导的时候,需要分情况。于是(2)变成了

    [mu_k=frac{1}{n_k}sum_{i=1}^Ngamma_{ik}x_i ]

    [mu_k= left {egin{array}{cc} frac{1}{n_k}(sum_{i=1}^Ngamma_{ik}x_i - frac{lambdasigma^2}{s_j}), & mu_k >= ar{x}_j\ frac{1}{n_k}(sum_{i=1}^Ngamma_{ik}x_i + frac{lambdasigma^2}{s_j}), & mu_k < ar{x}_j end{array} ight. ]

    3. 算法实现

    • 和不带惩罚项的GMM不同的是,我们GMM+LASSO的计算高斯密度函数有所变化。
    #计算高斯密度概率函数,样本的高斯概率密度函数,其实就是每个一维mu,sigma的高斯的和
    def log_prob(self, X, mu, sigma):
        N, D = X.shape
        logRes = np.zeros(N)
        for i in range(N):
            a = norm.logpdf(X[i,:], loc=mu, scale=sigma)
            logRes[i] = np.sum(a)
        return logRes
    
    • 在m-step中计算(mu_{k+1})的公式需要变化,先通过比较(mu_{kj})(means_{kj})的大小,来确定绝对值shift的符号。
    def m_step(self, step):
        gammaNorm = np.array(np.sum(self.gamma, axis=0)).reshape(self.K, 1)
        self.alpha = gammaNorm / np.sum(gammaNorm)
        for k in range(self.K):
            Nk = gammaNorm[k]
            if Nk == 0:
                continue
            for j in range(self.D):
                if step >= self.beginPenaltyTime:
                    # 算出penality的偏移量shift,通过当前维度的mu和样本均值比较,确定shift的符号,相当于把lasso的绝对值拆开了
                    shift = np.square(self.sigma[k, j]) * self.penalty / (self.std[j] * Nk)
                    if self.mu[k, j] >= self.means[j]:
                        shift = shift
                    else:
                        shift = -shift
                else:
                    shift = 0
                self.mu[k, j] = np.dot(self.gamma[:, k].T, self.X[:, j]) / Nk - shift
                self.sigma[k, j] = np.sqrt(np.sum(np.multiply(self.gamma[:, k], np.square(self.X[:, j] - self.mu[k, j]))) / Nk)
    
    • 最后需要修改loglikelihood的计算公式
    def GMM_EM(self):
        self.init_paras()
        for i in range(self.times):
            #m step
            self.m_step(i)
            # e step
            logGammaNorm, self.gamma= self.e_step(self.X)
            #loglikelihood
            loglike = self.logLikelihood(logGammaNorm)
            #penalty
            pen = 0
            if i >= self.beginPenaltyTime:
                for j in range(self.D):
                    pen += self.penalty * np.sum(abs(self.mu[:,j] - self.means[j])) / self.std[j]
    
            # print("step = %s, alpha = %s, loglike = %s"%(i, [round(p[0], 5) for p in self.alpha.tolist()], round(loglike - pen, 5)))
            # if abs(self.loglike - loglike) < self.tol:
            #     break
            # else:
    
            self.loglike = loglike - pen
    

    4. GMM算法实现结果

    用我实现的GMM+LASSO算法,对多个penality进行计算,选出loglikelihood最大的k和penality,与sklearn的结果比较。

    fileName = amix1-est.dat, k = 2, penalty = 0 alpha = [0.52838, 0.47162], loglike = -693.34677
    fileName = amix1-est.dat, k = 2, penalty = 0 alpha = [0.52838, 0.47162], loglike = -693.34677
    fileName = amix1-est.dat, k = 2, penalty = 1 alpha = [0.52789, 0.47211], loglike = -695.26835
    fileName = amix1-est.dat, k = 2, penalty = 1 alpha = [0.52789, 0.47211], loglike = -695.26835
    fileName = amix1-est.dat, k = 2, penalty = 2 alpha = [0.52736, 0.47264], loglike = -697.17009
    fileName = amix1-est.dat, k = 2, penalty = 2 alpha = [0.52736, 0.47264], loglike = -697.17009
    myself GMM alpha = [0.52838, 0.47162], loglikelihood = -693.34677, bestP = 0
    sklearn GMM alpha = [0.53372, 0.46628], loglikelihood = -176.73112
    succ = 299/300
    succ = 0.9966666666666667
    [0 1 0 0 1 1 0 1 1 1 0 0 1 0 0 1 0 0 0 1]
    [0 1 0 0 1 0 0 1 1 1 0 0 1 0 0 1 0 0 0 1]
    fileName = amix1-tst.dat, loglike = -2389.1852339407087
    fileName = amix1-val.dat, loglike = -358.1157431278091
    fileName = amix2-est.dat, k = 2, penalty = 0 alpha = [0.56, 0.44], loglike = 53804.54265
    fileName = amix2-est.dat, k = 2, penalty = 0 alpha = [0.82, 0.18], loglike = 24902.5522
    fileName = amix2-est.dat, k = 2, penalty = 1 alpha = [0.82, 0.18], loglike = 23902.65183
    fileName = amix2-est.dat, k = 2, penalty = 1 alpha = [0.56, 0.44], loglike = 52929.96459
    fileName = amix2-est.dat, k = 2, penalty = 2 alpha = [0.82, 0.18], loglike = 22907.40397
    fileName = amix2-est.dat, k = 2, penalty = 2 alpha = [0.82, 0.18], loglike = 22907.40397
    myself GMM alpha = [0.56, 0.44], loglikelihood = 53804.54265, bestP = 0
    sklearn GMM alpha = [0.56217, 0.43783], loglikelihood = 11738677.90164
    succ = 200/200
    succ = 1.0
    [0 1 0 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1]
    [0 1 0 0 1 0 1 1 0 0 0 1 1 0 1 0 1 1 0 1]
    fileName = amix2-tst.dat, loglike = 51502.878096147084
    fileName = amix2-val.dat, loglike = 6071.217012747491
    fileName = golub-est.dat, k = 2, penalty = 0 alpha = [0.575, 0.425], loglike = -24790.19895
    fileName = golub-est.dat, k = 2, penalty = 0 alpha = [0.525, 0.475], loglike = -24440.82743
    fileName = golub-est.dat, k = 2, penalty = 1 alpha = [0.55, 0.45], loglike = -25582.27485
    fileName = golub-est.dat, k = 2, penalty = 1 alpha = [0.6, 0.4], loglike = -26137.97508
    fileName = golub-est.dat, k = 2, penalty = 2 alpha = [0.55, 0.45], loglike = -26686.02411
    fileName = golub-est.dat, k = 2, penalty = 2 alpha = [0.55, 0.45], loglike = -26941.68964
    myself GMM alpha = [0.525, 0.475], loglikelihood = -24440.82743, bestP = 0
    sklearn GMM alpha = [0.5119, 0.4881], loglikelihood = 13627728.10766
    succ = 29/40
    succ = 0.725
    [0 1 0 1 0 1 0 1 0 1 0 0 0 0 0 1 1 1 0 1]
    [0 1 0 1 1 1 0 0 1 1 0 1 0 1 0 0 1 1 0 0]
    fileName = golub-tst.dat, loglike = -12949.606698037718
    fileName = golub-val.dat, loglike = -11131.35137056415
    

    5. 总结

    通过一番改造,实现了GMM+LASSO的代码,如果读者有什么好的改进方法,或者我有什么错误的地方,希望多多指教。

  • 相关阅读:
    MS SQL执行大脚本文件时,提示“未能完成操作,存储空间不足,无法处理此命令”的解决办法
    一、Flux 是什么?
    for...in for..of
    循环总结
    javascript中几种this指向问题
    redux
    布局方式
    js获取前几个月的具体日期
    动态引入js文件
    获取页面url信息
  • 原文地址:https://www.cnblogs.com/huangyc/p/10279240.html
Copyright © 2011-2022 走看看