zoukankan      html  css  js  c++  java
  • 基于ALS最小二乘法的推荐系统 spark实现 sparkMllib

    ALS 是交替最小二乘 (alternating least squares)的简称。在机器学习的上下文中,ALS 特指使用交替最小二乘求解的一个协同推荐算法。它通过观察到的所有用户给产品的打分,来推断每个用户的喜好并向用户推荐适合的产品。

    用户打分矩阵(行表示商品,列表示用户,每行表示用户对多个商品的评分)

    其中,A(i,j)表示用户user i对物品item j的打分。但是,用户不会对所以物品打分,图中?表示用户没有打分的情况,所以这个矩阵A很多元素都是空的,我们称其为“缺失值(missing value)”。在推荐系统中,我们希望得到用户对所有物品的打分情况,如果用户没有对一个物品打分,那么就需要预测用户是否会对该物品打分,以及会打多少分。这就是所谓的“矩阵补全(填空)”。

    推荐也就是一种对用户为打分的商品做预测,得到的评分做排序取TOpN列表

    ALS 的核心就是下面这个假设:打分矩阵是近似低秩的。换句话说,一个 的打分矩阵 A 可以用两个小矩阵的乘积来近似:。这样我们就把整个系统的自由度从一下降到了

    把人们的喜好和电影的特征都投到这个低维空间,一个人的喜好映射到了一个低维向量,一个电影的特征变成了纬度相同的向量,那么这个人和这个电影的相似度就可以表述成这两个向量之间的内积。 我们把打分理解成相似度,那么打分矩阵A就可以由用户喜好矩阵和产品特征矩阵的乘积来近似了。

    ALS的目标函数

    这里 R 指观察到的 (用户,产品)集

    我们把一个协同推荐的问题通过低秩假设成功转变成了一个优化问题。下面要讨论的内容很显然:这个优化问题怎么解?其实答案已经在 ALS 的名字里给出——交替最小二乘。ALS 的目标函数不是凸的,而且变量互相耦合在一起,所以它并不算好解。但如果我们把用户特征矩阵U和产品特征矩阵V固定其一,这个问题立刻变成了一个凸的而且可拆分的问题。比如我们固定U,那么目标函数就可以写成。其中关于每个产品特征的部分是独立的,也就是说固定U求我们只需要最小化就好了,这个问题就是经典的最小二乘问题。所谓“交替”,就是指我们先随机生成然后固定它求解,再固定求解,这样交替进行下去。因为每步迭代都会降低重构误差,并且误差是有下界的,所以 ALS 一定会收敛。但由于问题是非凸的,ALS 并不保证会收敛到全局最优解。但在实际应用中,ALS 对初始点不是很敏感,是不是全局最优解造成的影响并不大。

    参考:http://www.csdn.net/article/2015-05-07/2824641

    ALS在spark中的实现

    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.{SparkContext, SparkConf}
    
    import org.apache.spark.mllib.recommendation.ALS
    import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
    import org.apache.spark.mllib.recommendation.Rating
    
    object ALSTest {
        def main(args: Array[String]) {
            val conf = new SparkConf().setAppName("test").setMaster("local")
            val sc = new SparkContext(conf)
            val sql  = new SQLContext(sc);
            val data: RDD[String] = sc.textFile("data.txt")
            val ratings = data.map(_.split(',') match { case Array(user, item, rate) =>
                Rating(user.toInt, item.toInt, rate.toDouble)
            })
    
            // Build the recommendation model using ALS
            val rank = 10
            val numIterations = 10
    
            val model = ALS.train(ratings, rank, numIterations, 0.01)
    
            // Evaluate the model on rating data
            val usersProducts: RDD[(Int, Int)] = ratings.map { case Rating(user, product, rate) =>
                (user, product)
            }
            //推荐id为3
            val uuid =3
            val filter = usersProducts.filter(_._1 == uuid).map(_._2).collect().toSeq//评价过的电影
            val mov: RDD[Int] = usersProducts.map(_._2).distinct().filter(!filter.contains(_))
            mov.foreach(println)
    
            val recommendations: RDD[((Int, Int), Double)] = model.predict(mov.map((uuid,_))).map{case Rating(user, product, rate) =>
                ((user,product),rate)
            }.map(kv => (kv._2,kv._1)).sortByKey(false).map(kv => (kv._2,kv._1))
    
            recommendations.take(10).foreach{
                value =>
                    println("电影id:"+value._1._2+"   评分"+value._2+"")
            }
    //        val test: RDD[Rating] = model.predict(mov.map((uuid,_)))
    //        test.foreach(println)
    
            val predictions =
                model.predict(usersProducts).map { case Rating(user, product, rate) =>
                    ((user, product), rate)}
            //((用户id,商品id),(评分,预测评分))
            val ratesAndPreds: RDD[((Int, Int), (Double, Double))] = ratings.map { case Rating(user, product, rate) =>
                ((user, product), rate)
            }.join(predictions)
    //        ratesAndPreds.foreach(println)
            val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) =>
                val err = (r1 - r2)
                err * err
            }.mean()
            println("Mean Squared Error = " + MSE)
    
    
        }
    
    }
    

      部分样本数据

    1,1,5.0
    1,2,1.0
    1,3,5.0
    1,4,1.0
    2,1,5.0
    2,2,1.0
    2,3,5.0
    2,4,1.0
    3,1,1.0
    3,4,5.0
    4,1,1.0
    4,2,5.0
    4,3,1.0
    4,4,5.0
  • 相关阅读:
    python的with语句
    flask如何实现https以及自定义证书的制作
    flask及扩展源码解读
    加密的那些事
    SQLALchemy如何查询mysql某个区间内的数据
    集群设备之间的资源共享
    pycryptodom的源码安装
    github创建项目,并提交本地文件
    响应头里的"Last-Modified"值是怎么来的?
    SQL2005 数据库——查看索引
  • 原文地址:https://www.cnblogs.com/xiaoma0529/p/6934090.html
Copyright © 2011-2022 走看看