zoukankan      html  css  js  c++  java
  • 推荐算法——非负矩阵分解(NMF)

    一、矩阵分解回想

    在博文推荐算法——基于矩阵分解的推荐算法中,提到了将用户-商品矩阵进行分解。从而实现对未打分项进行打分。

    矩阵分解是指将一个矩阵分解成两个或者多个矩阵的乘积。对于上述的用户-商品矩阵(评分矩阵),记为Vm×n。能够将其分解成两个或者多个矩阵的乘积,如果分解成两个矩阵Wm×kHk×n。我们要使得矩阵Wm×kHk×n的乘积能够还原原始的矩阵Vm×n

    Vm×nWm×k×Hk×n=V^m×n

    当中,矩阵Wm×k表示的是m个用户与k个主题之间的关系,而矩阵Hk×n表示的是k个主题与n个商品之间的关系。

    通常在用户对商品进行打分的过程中。打分是非负的,这就要求:

    Wm×k0

    Hk×n0

    这便是非负矩阵分解(Non-negtive Matrix Factorization, NMF)的来源。

    二、非负矩阵分解

    2.1、非负矩阵分解的形式化定义

    上面简介了非负矩阵分解的基本含义。简单来讲,非负矩阵分解是在矩阵分解的基础上对分解完毕的矩阵加上非负的限制条件。即对于用户-商品矩阵Vm×n,找到两个矩阵Wm×kHk×n,使得:

    Vm×nWm×k×Hk×n=V^m×n

    同一时候要求:

    Wm×k0

    Hk×n0

    2.2、损失函数

    为了能够定量的比較矩阵Vm×n和矩阵V^m×n的近似程度。在參考文献1中作者提出了两种损失函数的定义方式:

    • 平方距离

    AB2=i,j(Ai,jBi,j)2

    • KL散度

    D(AB)=i,j(Ai,jlogAi,jBi,jAi,j+Bi,j)

    在KL散度的定义中,D(AB)0。当且仅当A=B时取得等号。

    当定义好损失函数后,须要求解的问题就变成了例如以下的形式,相应于不同的损失函数:

    求解例如以下的最小化问题:

    • minimizeVWH2s.t.W0,H0

    • minimizeD(VWH)s.t.W0,H0

    2.3、优化问题的求解

    在參考文献1中,作者提出了乘法更新规则(multiplicative update rules),详细的操作例如以下:

    对于平方距离的损失函数:

    Wi,k=Wi,k(VHT)i,k(WHHT)i,k

    Hk,j=Hk,j(WTV)k,j(WTWH)k,j

    对于KL散度的损失函数:

    Wi,k=Wi,kuHk,uVi,u/(WH)i,uvHk,v

    Hk,j=Hk,juWu,kVu,j/(WH)u,j)vWv,k

    上述的乘法规则主要是为了在计算的过程中保证非负,而基于梯度下降的方法中,加减运算无法保证非负。事实上上述的乘法更新规则与基于梯度下降的算法是等价的。以下以平方距离为损失函数说明上述过程的等价性:

    平方损失函数能够写成:

    l=i=1mj=1n[Vi,j(k=1rWi,kHk,j)]2

    使用损失函数对Hk,j求偏导数:

    lHk,j=i=1mj=1n[2(Vi,j(k=1rWi,kHk,j))(Wi,k)]=2[(WTV)k,j(WTWH)k,j]

    则依照梯度下降法的思路:

    Hk,j=Hk,jηk,jlHk,j

    即为:

    Hk,j=Hk,j+ηk,j[(WTV)k,j(WTWH)k,j]

    ηk,j=Hk,j(WTWH)k,j,即能够得到上述的乘法更新规则的形式。

    2.4、非负矩阵分解的实现

    对于例如以下的矩阵:

    这里写图片描写叙述

    通过非负矩阵分解。得到例如以下的两个矩阵:

    这里写图片描写叙述

    这里写图片描写叙述

    对原始矩阵的还原为:
    这里写图片描写叙述

    实现的代码

    #!/bin/python
    
    from numpy import * 
    
    def load_data(file_path):
        f = open(file_path)
        V = []
        for line in f.readlines():
            lines = line.strip().split("	")
            data = []
            for x in lines:
                data.append(float(x))
            V.append(data)
        return mat(V)
    
    def train(V, r, k, e):
        m, n = shape(V)
        W = mat(random.random((m, r)))
        H = mat(random.random((r, n)))
    
        for x in xrange(k):
            #error 
            V_pre = W * H
            E = V - V_pre
            #print E
            err = 0.0
            for i in xrange(m):
                for j in xrange(n):
                    err += E[i,j] * E[i,j]
            print err
    
            if err < e:
                break
    
            a = W.T * V
            b = W.T * W * H
            #c = V * H.T
            #d = W * H * H.T
            for i_1 in xrange(r):
                for j_1 in xrange(n):
                    if b[i_1,j_1] != 0:
                        H[i_1,j_1] = H[i_1,j_1] * a[i_1,j_1] / b[i_1,j_1]
    
            c = V * H.T
            d = W * H * H.T
            for i_2 in xrange(m):
                for j_2 in xrange(r):
                    if d[i_2, j_2] != 0:
                        W[i_2,j_2] = W[i_2,j_2] * c[i_2,j_2] / d[i_2, j_2]
    
        return W,H 
    
    
    if __name__ == "__main__":
        #file_path = "./data_nmf"
        file_path = "./data1"
    
        V = load_data(file_path)
        W, H = train(V, 2, 100, 1e-5 )
    
        print V
        print W
        print H
        print W * H
    

    收敛曲线例如以下图所看到的:

    这里写图片描写叙述

    '''
    Date:20160411
    @author: zhaozhiyong
    '''
    
    from pylab import *
    from numpy import *
    
    data = []
    
    f = open("result_nmf")
    for line in f.readlines():
        lines = line.strip()
        data.append(lines)
    
    n = len(data)
    x = range(n)
    plot(x, data, color='r',linewidth=3)
    plt.title('Convergence curve')
    plt.xlabel('generation')
    plt.ylabel('loss')
    show()

    參考文献

  • 相关阅读:
    Python NLP入门教程
    一个月入门Python爬虫,轻松爬取大规模数据
    Python爬虫实战案例:爬取爱奇艺VIP视频
    探索Python F-strings是如何工作
    Ruby 和 Python 分析器是如何工作的?
    超级干货,python常用函数大总结
    Python 开发者的 6 个必备库,你都了解吗?
    神经网络中 BP 算法的原理与 Python 实现源码解析
    新手程序员必学的代码编程技巧
    零基础小白怎么用Python做表格?
  • 原文地址:https://www.cnblogs.com/gavanwanggw/p/7337227.html
Copyright © 2011-2022 走看看