zoukankan      html  css  js  c++  java
  • 推荐系统偏好SVD C++

    最近想整整推荐系统,比较经典的算法就是SVD了。具体理论不多讲了。直接上代码。

    先贴张效果图吧。userNum 6040 itemNum 3900

    本文链接:http://www.cnblogs.com/wn19910213/p/3617781.html

    上代码咯:

    SVD.h

     1 #ifndef SVD_H_INCLUDED
     2 #define SVD_H_INCLUDED
     3 
     4 #include <vector>
     5 #include <cstring>
     6 
     7 using namespace std;
     8 
     9 class SVD{
    10     public:
    11         SVD(double*,double*,int,double**,double**);
    12         ~SVD();
    13 
    14         void loadTrainFile(string);
    15         double predictScore(int,double,double,double*,double*);
    16         double Validate(string,double,double*,double*,double**,double**);
    17 //    private:
    18         double* Bi;
    19         double* Bu;
    20         int factor;
    21         double** Qi;
    22         double** Pu;
    23 };
    24 
    25 
    26 #endif // SVD_H_INCLUDED

    SVD.cpp

      1 #include <cmath>
      2 #include <iostream>
      3 #include <cstring>
      4 #include <cstdlib>
      5 #include <fstream>
      6 #include "SVD.h"
      7 
      8 
      9 int userNum = 6040;
     10 int itemNum = 3900;
     11 double AVG = 3.579231;
     12 double lr = 0.01;
     13 double theta = 0.05;
     14 double preRmse = 1000000.0;
     15 
     16 int main()
     17 {
     18     string trainFile = "/home/ja/CADATA/SVD/ml_data/training.txt";
     19     string testFile = "/home/ja/CADATA/SVD/ml_data/test.txt";
     20     srand(0);
     21     SVD svd(NULL,NULL,0,NULL,NULL);
     22 
     23     for(size_t i=0;i<100;i++){
     24         svd.loadTrainFile(trainFile);
     25         //lr *= 0.9;
     26         double curRmse = svd.Validate(testFile,AVG,svd.Bu,svd.Bi,svd.Pu,svd.Qi);
     27         cout << "test_Rmse in step " << i << ": " << curRmse << endl;
     28         if(curRmse >= preRmse){
     29             break;
     30         }
     31         else{
     32             preRmse = curRmse;
     33         }
     34     }
     35     return 0;
     36 }
     37 
     38 double SVD::Validate(string testfile,double avg,double* bu,double* bi,double** pu,double** qi){
     39     ifstream fin(testfile.c_str());
     40     if(!fin){
     41         cout << "error" << endl;
     42     }
     43     int userId,itemId,rating,t;
     44     int n = 0;
     45     double rmse;
     46     while(fin >> userId >> itemId >> rating >> t){
     47         n++;
     48         double pScore = predictScore(avg,bu[userId-1],bi[itemId-1],pu[userId-1],qi[itemId-1]);
     49         rmse += (rating - pScore) * (rating - pScore);
     50     }
     51     fin.close();
     52     return sqrt(rmse/n);
     53 }
     54 
     55 double SVD::predictScore(int avg,double bu,double bi,double* pu,double* qi){
     56     double tmp = 0.0;
     57     for(size_t i=0;i<factor;i++){
     58         tmp += pu[i] * qi[i];
     59     }
     60 
     61     double score = avg + bu + bi + tmp;
     62     if(score > 5){
     63         score = 5;
     64     }
     65     if(score < 1){
     66         score = 1;
     67     }
     68     return score;
     69 }
     70 
     71 void SVD::loadTrainFile(string file){
     72     ifstream fin(file.c_str());
     73     if(!fin){
     74         cout << "error" << endl;
     75     }
     76 
     77     int userId,itemId,rating,t;
     78     while(fin >> userId >> itemId >> rating >> t){
     79         double predict = predictScore(AVG,Bu[userId-1],Bi[itemId-1],Pu[userId-1],Qi[itemId-1]);
     80         double error = rating - predict;
     81         Bu[userId-1] += lr * (error - theta * Bu[userId-1]);
     82         Bi[itemId-1] += lr * (error - theta * Bi[itemId-1]);
     83 
     84         for(size_t i=0;i<factor;i++){
     85             double Tmp = Pu[userId-1][i];
     86             Pu[userId-1][i] += lr * (error * Qi[itemId-1][i] - theta * Pu[userId-1][i]);
     87             Qi[itemId-1][i] += lr * (error * Tmp - theta * Qi[itemId-1][i]);
     88         }
     89     }
     90     fin.close();
     91 }
     92 
     93 SVD::SVD(double* bi,double* bu,int k,double** qi,double** pu){
     94 
     95     if(bi == NULL){
     96         Bi = new double[itemNum];
     97         for(size_t i=0;i<itemNum;i++){
     98             Bi[i] = 0.0;
     99         }
    100     }
    101     else{
    102         Bi = bi;
    103     }
    104 
    105     if(bu == NULL){
    106         Bu = new double[userNum];
    107         for(size_t i=0;i<userNum;i++){
    108             Bu[i] = 0.0;
    109         }
    110     }
    111     else{
    112         Bu = bu;
    113     }
    114 
    115     factor = 10;
    116 
    117     if(qi == NULL){
    118         Qi = new double* [itemNum];
    119         for(size_t i=0;i<itemNum;i++){
    120             Qi[i] = new double[factor];
    121         }
    122 
    123         for(size_t i=0;i<itemNum;i++){
    124             for(size_t j=0;j<factor;j++){
    125                 Qi[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor);
    126             }
    127         }
    128     }
    129     else{
    130         Qi = qi;
    131     }
    132 
    133     if(pu == NULL){
    134         Pu = new double* [userNum];
    135         for(size_t i=0;i<userNum;i++){
    136             Pu[i] = new double[factor];
    137         }
    138 
    139         for(size_t i=0;i<userNum;i++){
    140             for(size_t j=0;j<factor;j++){
    141                 Pu[i][j] = 0.1 * (rand() / (RAND_MAX + 1.0)) / sqrt(factor);
    142             }
    143         }
    144     }
    145     else{
    146         Pu = pu;
    147     }
    148 }
    149 
    150 SVD::~SVD(){
    151     delete[] Bi;
    152     delete[] Bu;
    153     for(size_t i=0;i<userNum;i++){
    154         delete[] Pu[i];
    155     }
    156     for(size_t i=0;i<itemNum;i++){
    157         delete[] Qi[i];
    158     }
    159     delete[] Pu;
    160     delete[] Qi;
    161 }
  • 相关阅读:
    使用MySQL存储过程连续插入多条记录
    为什么编程语言以及数据库要从1970年1月1日开始计算时
    关于shtml页面include问题解决方案
    简单实用的FTP操作类
    js实现完美身份证号有效性验证
    .htaccess文件的建立和rewrite_module的启用
    php冒泡排序
    php实现汉诺塔问题(递归)
    简单的mysql数据库备份程序
    选择排序的php实现
  • 原文地址:https://www.cnblogs.com/wn19910213/p/3617781.html
Copyright © 2011-2022 走看看