zoukankan      html  css  js  c++  java
  • 使用LFM(Latent factor model)隐语义模型进行Top-N推荐

    最近在拜读项亮博士的《推荐系统实践》,系统的学习一下推荐系统的相关知识。今天学习了其中的隐语义模型在Top-N推荐中的应用,在此做一个总结。

    隐语义模型LFM和LSI,LDA,Topic Model其实都属于隐含语义分析技术,是一类概念,他们在本质上是相通的,都是找出潜在的主题或分类。这些技术一开始都是在文本挖掘领域中提出来的,近些年它们也被不断应用到其他领域中,并得到了不错的应用效果。比如,在推荐系统中它能够基于用户的行为对item进行自动聚类,也就是把item划分到不同类别/主题,这些主题/类别可以理解为用户的兴趣。

    对于一个用户来说,他们可能有不同的兴趣。就以作者举的豆瓣书单的例子来说,用户A会关注数学,历史,计算机方面的书,用户B喜欢机器学习,编程语言,离散数学方面的书, 用户C喜欢大师Knuth, Jiawei Han等人的著作。那我们在推荐的时候,肯定是向用户推荐他感兴趣的类别下的图书。那么前提是我们要对所有item(图书)进行分类。那如何分呢?大家注意到没有,分类标准这个东西是因人而异的,每个用户的想法都不一样。拿B用户来说,他喜欢的三个类别其实都可以算作是计算机方面的书籍,也就是说B的分类粒度要比A小;拿离散数学来讲,他既可以算作数学,也可当做计算机方面的类别,也就是说有些item不能简单的将其划归到确定的单一类别;拿C用户来说,他倾向的是书的作者,只看某几个特定作者的书,那么跟A,B相比它的分类角度就完全不同了。

    显然我们不能靠由单个人(编辑)或team的主观想法建立起来的分类标准对整个平台用户喜好进行标准化。

    此外我们还需要注意的两个问题:

    1. 我们在可见的用户书单中归结出3个类别,不等于该用户就只喜欢这3类,对其他类别的书就一点兴趣也没有。也就是说,我们需要了解用户对于所有类别的兴趣度。
    2. 对于一个给定的类来说,我们需要确定这个类中每本书属于该类别的权重。权重有助于我们确定该推荐哪些书给用户。

    下面我们就来看看LFM是如何解决上面的问题的?对于一个给定的用户行为数据集(数据集包含的是所有的user, 所有的item,以及每个user有过行为的item列表),使用LFM对其建模后,我们可以得到如下图所示的模型:(假设数据集中有3个user, 4个item, LFM建模的分类数为4)

     

    R矩阵是user-item矩阵,矩阵值Rij表示的是user i 对item j的兴趣度,这正是我们要求的值。对于一个user来说,当计算出他对所有item的兴趣度后,就可以进行排序并作出推荐。LFM算法从数据集中抽取出若干主题,作为user和item之间连接的桥梁,将R矩阵表示为P矩阵和Q矩阵相乘。其中P矩阵是user-class矩阵,矩阵值Pij表示的是user i对class j的兴趣度;Q矩阵式class-item矩阵,矩阵值Qij表示的是item j在class i中的权重,权重越高越能作为该类的代表。所以LFM根据如下公式来计算用户U对物品I的兴趣度

    我们发现使用LFM后, 

    1. 我们不需要关心分类的角度,结果都是基于用户行为统计自动聚类的,全凭数据自己说了算。
    2. 不需要关心分类粒度的问题,通过设置LFM的最终分类数就可控制粒度,分类数越大,粒度约细。
    3. 对于一个item,并不是明确的划分到某一类,而是计算其属于每一类的概率,是一种标准的软分类。
    4. 对于一个user,我们可以得到他对于每一类的兴趣度,而不是只关心可见列表中的那几个类。
    5. 对于每一个class,我们可以得到类中每个item的权重,越能代表这个类的item,权重越高。

    那么,接下去的问题就是如何计算矩阵P和矩阵Q中参数值。一般做法就是最优化损失函数来求参数。在定义损失函数之前,我们需要准备一下数据集并对兴趣度的取值做一说明。


    数据集应该包含所有的user和他们有过行为的(也就是喜欢)的item。所有的这些item构成了一个item全集。对于每个user来说,我们把他有过行为的item称为正样本,规定兴趣度RUI=1,此外我们还需要从item全集中随机抽样,选取与正样本数量相当的样本作为负样本,规定兴趣度为RUI=0。因此,兴趣的取值范围为[0,1]。


    采样之后原有的数据集得到扩充,得到一个新的user-item集K={(U,I)},其中如果(U,I)是正样本,则RUI=1,否则RUI=0。损失函数如下所示:

    上式中的是用来防止过拟合的正则化项,λ需要根据具体应用场景反复实验得到。损失函数的优化使用随机梯度下降算法:

    1. 通过求参数PUK和QKI的偏导确定最快的下降方向;

    1. 迭代计算不断优化参数(迭代次数事先人为设置),直到参数收敛。



    其中,α是学习速率,α越大,迭代下降的越快。α和λ一样,也需要根据实际的应用场景反复实验得到。本书中,作者在MovieLens数据集上进行实验,他取分类数F=100,α=0.02,λ=0.01。
                   【注意】:书中在上面四个式子中都缺少了


    综上所述,执行LFM需要:

    1. 根据数据集初始化P和Q矩阵(这是我暂时没有弄懂的地方,这个初始化过程到底是怎么样进行的,还恳请各位童鞋予以赐教。)
    2. 确定4个参数:分类数F,迭代次数N,学习速率α,正则化参数λ。

    LFM的伪代码可以表示如下:

    [python] view plaincopy
     
    1. def LFM(user_items, F, N, alpha, lambda):  
    2.     #初始化P,Q矩阵  
    3.     [P, Q] = InitModel(user_items, F)  
    4.     #开始迭代  
    5.     For step in range(0, N):  
    6.         #从数据集中依次取出user以及该user喜欢的iterms集  
    7.         for user, items in user_item.iterms():  
    8.             #随机抽样,为user抽取与items数量相当的负样本,并将正负样本合并,用于优化计算  
    9.             samples = RandSelectNegativeSamples(items)  
    10.             #依次获取item和user对该item的兴趣度  
    11.             for item, rui in samples.items():  
    12.                 #根据当前参数计算误差  PS:转载的博客中rui写成了eui
    13.                 eui = rui - Predict(user, item)  
    14.                 #优化参数  
    15.                 for f in range(0, F):  
    16.                     P[user][f] += alpha * (eui * Q[f][item] - lambda * P[user][f])  
    17.                     Q[f][item] += alpha * (eui * P[user][f] - lambda * Q[f][item])  
    18.         #每次迭代完后,都要降低学习速率。一开始的时候由于离最优值相差甚远,因此快速下降;  
    19.         #当优化到一定程度后,就需要放慢学习速率,慢慢的接近最优值。  
    20.         alpha *= 0.9  


    本人对书中的伪代码追加了注释,有不对的地方还请指正。


    当估算出P和Q矩阵后,我们就可以使用(*)式计算用户U对各个item的兴趣度值,并将兴趣度值最高的N个iterm(即TOP N)推荐给用户。

    总结来说,LFM具有成熟的理论基础,它是一个纯种的学习算法,通过最优化理论来优化指定的参数,建立最优的模型.

    ========================我是分割线我自豪=============================

    我不懂Python, 所以按照书里的步骤用java实现了, 期间走了好多弯路

    在这里说一下, 初始值的选择, 在迭代的时候alpha的初始值切记不要选太大了, 我之前一直用0.1, 然后每次都没收敛, 晕死我了, 还一直改代码+调试

    以为是其它原因, 浪费了大半天时间

    最后不小心把alpha改为0.5后发现就正常了

    改时间把代码扔上来

    准备下班了........................

    -------------------------------------------------我是不需要理由的分割线----------------------------------------------------

    时隔好久,为了准备面试,重要把之前的代码复习了一遍,顺便整理好放到Github上

    现在就扔到博客上来吧

    LFM的核心代码模块

      1 package org.juefan.alg;
      2 
      3 import java.text.SimpleDateFormat;
      4 import java.util.ArrayList;
      5 import java.util.Collections;
      6 import java.util.Comparator;
      7 import java.util.Date;
      8 import java.util.HashMap;
      9 import java.util.HashSet;
     10 import java.util.List;
     11 import java.util.Map;
     12 import java.util.Set;
     13 
     14 public class LFM {
     15     
     16     public static final int latent = 100;
     17     public static double alpha = 0.03;
     18     public static double lambda = 0.01;
     19     public static final int iteration = 1;
     20     public static final   int resys = 10;
     21     
     22     public static Map<Integer, List<Float>> UserMap = new HashMap<Integer, List<Float>>();
     23     public static Map<Integer, List<Float>> ItemMap = new HashMap<Integer, List<Float>>();
     24         
     25     public static compares compare = new compares();
     26     
     27     public  class State{
     28         public int TemID;
     29         public Set<Integer> set = new HashSet<Integer>();
     30         public float sim;
     31 
     32         /**用户集排序*/
     33         public State(Set<Integer> s, float s2){
     34             set.addAll(s);
     35             sim = s2;
     36         }
     37 
     38         /**Item排序*/
     39         public State(Integer i, float s){
     40             TemID = i;
     41             sim = s;
     42         }
     43     }
     44 
     45     public static class compares implements Comparator<Object>{
     46         @Override
     47         public int compare(Object o1, Object o2) {            
     48             State s1 = (State)o1;
     49             State s2 = (State)o2;
     50             return s1.sim < s2.sim ? 1:-1;
     51         }        
     52     }
     53     
     54     public  String toString(){
     55         return "LFM";
     56     }
     57     /**
     58      * 加载用户与项目的集合并初始化隐含矩阵
     59      * 注意隐含层的数值不能太大,建议在0.05左右
     60      * @param user
     61      * @param item
     62      */
     63     public LFM(Set<Integer> user, Set<Integer> item){
     64         for(Integer u:user){
     65             List<Float> tList = new ArrayList<Float>();
     66             for(int i = 0; i < latent; i++)
     67                 tList.add((float) ((float) Math.random() * 0.1));
     68             UserMap.put(u, tList);
     69         }
     70         for(Integer u:item){
     71             List<Float> tList = new ArrayList<Float>();
     72             for(int i = 0; i < latent; i++)
     73                 tList.add((float) ((float) Math.random() *0.1));
     74             ItemMap.put(u, tList);
     75         }
     76     }
     77     public LFM() {
     78         // TODO Auto-generated constructor stub
     79     }
     80     /**
     81      *  计算用户对某个物品的兴趣
     82      * @param uLV    用户与隐含类的关系
     83      * @param iLV    隐含类与物品的关系
     84      * @return 返回用户对某个物品的兴趣
     85      */
     86     public static float getPreference(List<Float> uLV, List<Float> iLV){
     87         float p = 0;
     88         for(int i = 0; i < latent; i++){
     89             p = p + uLV.get(i) * iLV.get(i);
     90         }
     91         return p;
     92     }
     93     
     94     /**预测评分差*/
     95     public static float Predict(float i1, float i2){
     96         return i1 - i2;
     97     }
     98     
     99     /**
    100      * 迭代求解隐含层
    101      * @param UserItem    
    102      */
    103     public static void LatentFactorModel(Map<Integer, Map<Integer, Float>> UserItem){
    104         SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd_HH-mm");//设置日期格式
    105         for(int i = 0; i < iteration; i++){    
    106             System.out.println( df.format(new Date()) + "	第 " + (i + 1) + " 次迭代");
    107             for(int user: UserItem.keySet()){
    108                 for(int item: UserItem.get(user).keySet()){
    109                     float error = Predict(UserItem.get(user).get(item), 
    110                             getPreference(UserMap.get(user), ItemMap.get(item)));
    111                     for(int i1 = 0; i1 < latent; i1++){
    112                         UserMap.get(user).set(i1, (float) (UserMap.get(user).get(i1) + alpha * 
    113                                 (ItemMap.get(item).get(i1) * error - lambda * UserMap.get(user).get(i1))));
    114                         if (Float.isNaN(UserMap.get(user).get(i1) )) {
    115                             System.err.println("矩阵初始化或者参数有问题导致矩阵出现数值溢出");
    116                         }
    117                         ItemMap.get(item).set(i1, (float) (ItemMap.get(item).get(i1) + alpha * 
    118                                 (UserMap.get(user).get(i1) * error - lambda * ItemMap.get(item).get(i1))));
    119                     }
    120                 }
    121             }
    122             alpha = (float) (alpha * 0.9);
    123         }    
    124     }
    125     
    126     /**
    127      * 获取用户的最终推荐列表
    128      * @param map 项目的得分值表
    129      * @return
    130      */
    131     public Set<Integer> getResysK(Map<Integer, Float> map){
    132         List<State> tList = new ArrayList<State>();
    133         Set<Integer> set = new HashSet<Integer>();
    134         for(Integer key: map.keySet()){
    135             tList.add(new State(key,  map.get(key)));        
    136         }
    137         Collections.sort(tList, compare);
    138         for(int i = 0; i < tList.size() && i < resys; i++){
    139             set.add(tList.get(i).TemID);    
    140         }
    141         return set;
    142     }
    143     
    144     /**
    145      *  计算用户的推荐列表
    146      * @param user    用户的ID
    147      * @param item    用户的训练集
    148      * @return    用户的推荐列表
    149      */
    150     public Set<Integer> getResysList(int user, Map<Integer, Float> item){
    151         Map<Integer, Float> map = new HashMap<Integer, Float>();
    152         for(int i: ItemMap.keySet()){
    153             if(!item.containsKey(i))
    154                 map.put(i, getPreference(UserMap.get(user), ItemMap.get(i)));
    155         }
    156         return getResysK(map);
    157     }
    158 }

    接下来是具体的训练操作,比较重点的是训练数据和测试数据的选择,还有负例的生成

      1 package org.juefan.alg.test;
      2 
      3 import java.text.SimpleDateFormat;
      4 import java.util.ArrayList;
      5 import java.util.Date;
      6 import java.util.HashMap;
      7 import java.util.HashSet;
      8 import java.util.List;
      9 import java.util.Map;
     10 import java.util.Set;
     11 
     12 import org.juefan.IO.FileIO;
     13 import org.juefan.alg.LFM;
     14 import org.juefan.data.RatingData;
     15 import org.juefan.eva.Evaluation;
     16 
     17 public class TestLFM {
     18 
     19     public static Set<Integer> user = new HashSet<Integer>();
     20     public static Set<Integer> item = new HashSet<Integer>();
     21     public static List<Integer> itemList = new ArrayList<Integer>();
     22     public static Map<Integer, Integer> map = new HashMap<Integer, Integer>();
     23     public static Map<Integer, Integer> randMap = new HashMap<Integer, Integer>();    //倾向选择热门且用户未评价的为负例
     24 
     25     /**用户项目训练数据*/
     26     public static Map<Integer, Map<Integer, Float>>  UserItemTrain = new HashMap<Integer, Map<Integer, Float>> ();
     27     /**用户项目测试数据*/
     28     public static Map<Integer, Map<Integer, Float>>  UserItemTest = new HashMap<Integer, Map<Integer, Float>> ();
     29     
     30     public static LFM lfm = new LFM();
     31         
     32     public static Map<Integer, Float> getFu(Map<Integer, Float> item){
     33         Map<Integer, Float> map = new HashMap<Integer, Float>();    
     34         while(map.size() < item.size()*4 && item.size() + map.size() < TestLFM.item.size() * 0.8){
     35             /**抑制热门方式*/
     36             /*int rand = (int) (Math.random() * randMap.size());
     37             if(!item.containsKey(randMap.get(rand))){
     38                 map.put(randMap.get(rand), (float) 0);
     39             }*/
     40             /**同等对待方式*/
     41             int rand = (int) (Math.random() *  TestLFM.itemList.size());
     42             if(!item.containsKey( TestLFM.itemList.get(rand))){
     43                 map.put( TestLFM.itemList.get(rand), (float) 0);
     44             }
     45         }
     46         return map;
     47     }
     48 
     49     /**将Map的key加载进Set类*/
     50     public static Set<Integer> MapToSet(Map<Integer, Float> item){
     51         Set<Integer> tSet = new HashSet<Integer>();
     52         for(int k: item.keySet())
     53             tSet.add(k);
     54         return tSet;
     55     }
     56 
     57     /**
     58      * 测试入口
     59      */
     60     public static void main(String[] args) {
     61         System.setProperty("java.util.Arrays.useLegacyMergeSort", "true");  
     62         FileIO fileIO = new FileIO();
     63         fileIO.SetfileName(System.getProperty("user.dir") + "\data\input\ml-1m\ratings.dat");
     64         fileIO.FileRead();
     65         List<String> list = fileIO.cloneList();
     66         int num = 0;
     67         for(String s:list){
     68             RatingData data = new RatingData(s);
     69             float rand = (float) Math.random();
     70             if(rand >= (float)1/8){    //将数据随机分成训练数据和测试数据
     71                 if(UserItemTrain.containsKey(data.userID)){
     72                     UserItemTrain.get(data.userID).put(data.movieID, (float) 1);
     73                 }else {
     74                     Map<Integer, Float> tMap = new HashMap<Integer, Float>();
     75                     tMap.put(data.movieID, (float) 1);
     76                     UserItemTrain.put(data.userID, tMap);
     77                 }
     78                 //计算每个项目的热度
     79                 if(map.containsKey(data.movieID)){
     80                     map.put(data.movieID, map.get(data.movieID) + 1);
     81                 }else {
     82                     map.put(data.movieID, 1);
     83                 }
     84                 //构造项目分布映射
     85                 randMap.put(num++, data.movieID);
     86                 //收集用户列表和项目列表
     87                 user.add(data.userID);
     88                 item.add(data.movieID);
     89             }else {
     90                 if(UserItemTest.containsKey(data.userID)){
     91                     UserItemTest.get(data.userID).put(data.movieID, (float) 1);
     92                 }else {
     93                     Map<Integer, Float> tMap = new HashMap<Integer, Float>();
     94                     tMap.put(data.movieID, (float) 1);
     95                     UserItemTest.put(data.userID, tMap);
     96                 }
     97             }        
     98         }
     99         
    100 
    101         System.out.println("正在构造罗盘赌");
    102         for(Integer item: TestLFM.item){
    103             itemList.add(item);
    104         }
    105         int Fu = 0;
    106         for(int user: UserItemTrain.keySet()){
    107             UserItemTrain.get(user).putAll(getFu(UserItemTrain.get(user)));
    108             if(++Fu % 1000 == 0)
    109                 System.out.println("已构造 " + Fu +" 个负样本用户数据");
    110         }
    111         System.out.println("负样本生成完毕");
    112 
    113         SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd_HH-mm");//设置日期格式
    114         String dataString = "\data\output\Result\" + df.format(new Date()) + "_result.txt";
    115         LFM lfm = new LFM(user, item);
    116         for(int trac = 0; trac <= 20; trac++){
    117             LFM.LatentFactorModel(UserItemTrain);
    118             for(int user:UserItemTrain.keySet()){
    119                 if(UserItemTest.containsKey(user)){
    120                     Evaluation.setEvaluation(MapToSet(UserItemTest.get(user)), lfm.getResysList(user, UserItemTrain.get(user)));
    121                 }
    122             }
    123             System.out.println("准确率 = " + Evaluation.getPrecision() * 100 + "%		召回率 = " + Evaluation.getRecall() * 100 + "%		覆盖率 = " + Evaluation.getCoverage()/item.size() * 100 + "%");
    124             FileIO.FileWrite(System.getProperty("user.dir") + dataString, "===================使用算法 : " + lfm.toString()
    125                     + "=====================
    具体参数: "            
    126                     + "
    latent = " + LFM.latent
    127                     +"
    alpha = " + LFM.alpha
    128                     +"
    lambda = " + LFM.lambda
    129                     + "
    准确率 = " + Evaluation.getPrecision() * 100 + "%		召回率 = " + Evaluation.getRecall() * 100 + "%		覆盖率 = " + Evaluation.getCoverage()/item.size() * 100 + "%
    ", true);
    130         }
    131     }
    132 
    133 }

    好了,基本上就这样了

    如果要看完整的代码欢迎到本人的Github上查看,里面还有相应的数据,还有一个UserCF代码模块

    Github地址:https://github.com/JueFan/RecommendSystem

    参考文章: http://blog.csdn.net/harryhuang1990/article/details/9924377#reply

  • 相关阅读:
    Python列表(即数组)
    Python中的关键字和内置函数
    python的变量和数据类型
    将数据写入本地txt
    Notepad++配置Python开发环境
    java中方法复写的作用进一步理解
    this表示当前对象的例子
    数组冒泡算法
    java实现星号三角形
    求1到1000之间同时能被3、5、7整除的数
  • 原文地址:https://www.cnblogs.com/juefan/p/3459799.html
Copyright © 2011-2022 走看看