zoukankan      html  css  js  c++  java
  • RBM代码注释c++

    这代码各种看不懂,各种给跪,当工具用吧。。

    主函数:

    main.cpp

      1 #include "rbmpredictdata.h"
      2 #include "rbmdata.h"
      3 #include "rbm.h"
      4 #include "rbmparallel.h"
      5 #include <fstream>
      6 #include <string>
      7 #include <cstdlib>
      8 #include <deque>
      9 
     10 int main(int argc, char* argv[]) {
     11     parseArgs(argc, argv);   //第1步、这一步取得三个值 Hidden层数、trainfilename、testfilename
     12     srand(seed);
     13     printConfig();                  //第2步、打印初始化的数据
     14 
     15     RbmPredictData predictData;     //构造出对象
     16     safeLoad(predictData, testFilename);    //第3步、预测数据化成矩阵形式
     17     cout << "Done loading test data" << flush;
     18 
     19     RbmData data;
     20     safeLoad(data, trainFilename);            //步骤和第3步完全一样,这是训练集
     21     cout << "\rDone loading data.     " << endl;
     22 
     23     Rbm* r = NULL; 
     24     if (parallel) 
     25         r = new RbmParallel(nThreads, data, nHidden);   
     26     else
     27         r = new Rbm(data, nHidden);          //第4步、构造rbm,初始化各种W,hb,vb。查看rbm.cpp
     28     r->momentum = initialMomentum;
     29     r->hBiasLearnRate = hLearn;
     30     r->vBiasLearnRate = vLearn;
     31     r->WlearnRate = wLearn;
     32     r->weightDecay = wCost;
     33 
     34     for (int i = 1; i <= nEpochs; i++) {
     35         int increment = extractIncrements(i);    //浮云
     36         if (increment) {
     37             r->T += increment;
     38             cout << "\tT = " << r->T << endl;
     39         }
     40         if (i == finalMomentumStart)
     41             r->momentum = finalMomentum;
     42 
     43         r->performEpoch(data);        //第5步、最关键的步骤,查看rbm.cpp
     44 
     45         if (!epochsToSaveAt.empty() && epochsToSaveAt.front() == i) {
     46             epochsToSaveAt.pop_front();
     47             stringstream t;
     48             t << savePrefix << i;
     49             double d = r->predict(data, predictData, t.str());
     50             cout << i << ": saving to " << t.str() << "(" << d << ")" << endl;
     51         } else if (predictAlways) {
     52             double d = r->predict(data, predictData);            //第6步、预测,查看rbm.cpp
     53             cout << i << ": " << d << endl;
     54         } else {
     55             cout << i << ": omitting prediction" << endl;
     56         }
     57     }
     58 
     59     return 0;
     60 }
     61 
     62 //第1步、取参数
     63 void parseArgs(int argc, char* argv[]) {
     64     int current = 1;
     65     // TODO: make more user friendly in terms of error handling.
     66     //       and coding style leaves something to be desired...
     67     while (current < argc) {
     68         if (strcmp("-h", argv[current]) == 0) {
     69             printf(helpString, argv[0]);
     70             exit(0);
     71         } else if (strcmp("-d", argv[current]) == 0) {
     72             cout << "Defaults: " << endl;
     73             printConfig();
     74             exit(0);
     75         } else if (strcmp("-v", argv[current]) == 0) {
     76             vLearn = atof(argv[current + 1]);
     77             current += 1;
     78         } else if (strcmp("-H", argv[current]) == 0) {
     79             hLearn = atof(argv[current + 1]);
     80             current += 1;
     81         } else if (strcmp("-w", argv[current]) == 0) {
     82             wLearn = atof(argv[current + 1]);
     83             current += 1;
     84         } else if (strcmp("-n", argv[current]) == 0) {
     85             nHidden = atoi(argv[current + 1]);
     86             current += 1;
     87         } else if (strcmp("-i", argv[current]) == 0) {
     88             initialMomentum = atof(argv[current + 1]);
     89             current += 1;
     90         } else if (strcmp("-m", argv[current]) == 0) {
     91             finalMomentum = atof(argv[current + 1]);
     92             finalMomentumStart = atoi(argv[current + 2]);
     93             current += 2;
     94         } else if (strcmp("-e", argv[current]) == 0) {
     95             nEpochs = atoi(argv[current + 1]);
     96             current += 1;
     97         } else if (strcmp("-c", argv[current]) == 0) {
     98             wCost = atof(argv[current + 1]);
     99             current += 1;
    100         } else if (strcmp("-t", argv[current]) == 0) {
    101             do {
    102                 current += 1;
    103                 tIncrements.push_back(atoi(argv[current]));
    104             } while (current + 1 < argc && 
    105                     '0' <= argv[current + 1][0] && argv[current + 1][0] <= '9');
    106         } else if (strcmp("-s", argv[current]) == 0) {
    107             current += 1;
    108             seed = atoi(argv[current]);
    109         } else if (strcmp("--save", argv[current]) == 0) {
    110             current += 1;
    111             savePrefix = argv[current];
    112             do {
    113                 current += 1;
    114                 epochsToSaveAt.push_back(atoi(argv[current]));
    115             } while (current + 1 < argc && 
    116                     '0' <= argv[current + 1][0] && argv[current + 1][0] <= '9');
    117         } else if (strcmp("--never", argv[current]) == 0) {
    118             predictAlways = false;
    119         } else if (argc - current == 2) {
    120             trainFilename = argv[current];
    121         } else if (argc - current == 1) {
    122             testFilename = argv[current];
    123         } else {
    124             cerr << "ERROR: Unknown option: " << argv[current] << endl
    125                  << "Exiting now." << endl;
    126             exit(1);
    127         }
    128         current += 1;
    129     }
    130 }
    131 
    132 //第2步、打印参数
    133 void printConfig() {
    134     cout << "Learning rates:" << endl
    135          << "  visible        " << vLearn << endl
    136          << "  hidden         " << hLearn << endl
    137          << "  weights        " << wLearn << endl
    138          << "  cost           " << wCost << endl;
    139     cout << "Hidden nodes:    " << nHidden << endl;
    140     if (initialMomentum != 0.0 && finalMomentum != 0.0) {
    141         cout << "Momentum:        " << endl
    142              << "  initial        " << initialMomentum << endl
    143              << "  final          " << finalMomentum << endl
    144              << "  start          " << finalMomentumStart << endl;
    145     }
    146     if (tIncrements.size() > 0) {
    147         cout << "T-increment on:  ";
    148         for (unsigned i = 0; i < tIncrements.size(); i++) {
    149             cout << tIncrements[i] << ' ';
    150         }
    151         cout << endl;
    152     }
    153     if (epochsToSaveAt.size() > 0) {
    154         cout << "Saving at epochs:";
    155         for (unsigned i = 0; i < epochsToSaveAt.size(); i++) {
    156             cout << epochsToSaveAt[i] << ' ';
    157         }
    158         cout << endl;
    159         cout << "Save prefix:     " << savePrefix << endl;
    160     }
    161     cout << "Datasets:      " << endl
    162          << "  train          " << trainFilename << endl
    163          << "  test           " << testFilename << endl;
    164     cout << "Epochs:          " << nEpochs << endl;
    165     cout << "Random seed:     " << seed << endl;
    166     cout << "Predict always:  " << (predictAlways? "yes" : "no") << endl;
    167     cout << endl;
    168 }
    169 
    170 //第3步、这会调用rbmpredictdata.cpp的safeLoad函数,步骤3.1。
    171 void safeLoad(RbmData& p, const string& fname) {
    172     ifstream f(fname.c_str());
    173     if (!f) {
    174         cerr << "ERROR: " << fname << " could not be opened." << endl;
    175         cerr << "Exiting now." << endl;
    176         exit(1);
    177     }
    178     p.loadTsv(f);
    179     f.close();
    180 }
    181 
    182 char helpString[] =
    183     "Usage: %s [arguments] <trainfile> <testfile>\n"
    184     "\n"
    185     "Arguments:\n"
    186     "  -h                    Print Help (this message) and exit\n"
    187     "  -v <float>            Set visible bias learning rate to <float>\n"
    188     "  -H <float>            Set hidden bias learning rate to <float>\n"
    189     "  -w <float>            Set weight learning rate to <float>\n"
    190     "  -c <float>            Set weight-cost coefficient to <float>\n"
    191     "  -n <uint>             Use <uint> hidden nodes\n"
    192     "  -i <float>            Set initial momentum to <float>\n"
    193     "  -m <float> <uint>     Set final momentum to <float> at epoch <uint>\n"
    194     "  -e <uint>             Perform <uint> epochs\n"
    195     "  -t <uint> [<uint>...] Increase T by one at epoch <uint>, <uint> ...\n"
    196     "  -d                    Print the defaults and exit\n"
    197     "  -s <uint>             Set the random seed to <uint>\n"
    198     "  --save <string>       Save the predictions of the model after epoch\n"
    199     "         <uint> [...]        <uint> to file: <string>+<uint>.dat\n"
    200     "  --never               Do not predict, unless specified to save.\n"
    201     "\n"
    202     "Example usage:\n"
    203     "  ./rbm -t 5 5 7 train.dat test.dat\n"
    204     "     Set T to 3 at epoch 5, and to 4 at epoch 7.\n"
    205     "  ./rbm --never --save mypredfile 10 20 train.dat test.dat\n"
    206     "     Generates two files: mypredfile10.dat and mypredfile20.dat\n";
    207 
    208 string trainFilename, testFilename, savePrefix;
    209 float vLearn = 0.005, hLearn = 0.005, wLearn = 0.005, wCost = 0.005;
    210 int nHidden = 100, nEpochs = 40;
    211 float finalMomentum = 0.0, initialMomentum = 0.0;
    212 int finalMomentumStart = 5;
    213 deque<int> tIncrements;
    214 deque<int> epochsToSaveAt;
    215 int seed = 1;
    216 bool predictAlways = true;
    217 bool parallel = false;
    218 int nThreads = 32;
    219 
    220 int extractIncrements(int i) {
    221     int increment = 0;
    222     while (tIncrements.size() != 0 && tIncrements.front() == i) {
    223         increment += 1;
    224         tIncrements.pop_front();
    225     }
    226     return increment;
    227 } 

    大头来了

    rbm.h

     1 #ifndef RBM_H
     2 #define RBM_H
     3 
     4 #include "rbmdata.h"
     5 #include "rbmpredictdata.h"
     6 #include <Eigen/Dense>
     7 #include <vector>
     8 #include <thread>
     9 
    10 using namespace Eigen;
    11 using namespace std;
    12 
    13 typedef Matrix<bool, 1, Dynamic> RowVectorXb;
    14 
    15 class Rbm {
    16 public:
    17     Rbm(const RbmData& data, int nHidden);
    18 
    19     virtual void performEpoch(const RbmData& data);
    20 
    21     virtual double predict(
    22             const RbmData& data, 
    23             const RbmPredictData& predictData, 
    24             const string& filename);
    25 
    26     virtual double predict(
    27             const RbmData& data,
    28             const RbmPredictData& predictData);
    29 
    30     virtual double predict(
    31             const RbmData& data,
    32             const RbmPredictData& predictData,
    33             ostream& predictStream);
    34 
    35     static float hBiasLearnRate;
    36     static float vBiasLearnRate;
    37     static float WlearnRate;
    38     static float weightDecay;
    39     static float momentum;
    40     static int T;
    41 private:
    42     Rbm();
    43 
    44     void negActivation(const RowVectorXf& h0states);
    45 
    46     void gibbsSample();
    47 
    48     void normalizedNegActivation(const RowVectorXf& h0states);
    49     void softmax(const RowVectorXf& h0states);
    50 
    51     void initVisibleBias(const RbmData& data);
    52 
    53     void applyMomentum(const RbmData& data, int user);
    54     void selectWeights(const RbmData& data, int rangeStart, int rangeEnd);
    55 
    56 public: // These are public so RbmParallel can use them
    57     MatrixXf W, Wsel, Wmomentum;
    58     RowVectorXf vBias, vBiasSel, vBiasMomentum, hBias, hBiasMomentum;
    59 
    60     void performEpoch(const RbmData& data, int userStart, int userEnd);
    61 
    62 private:
    63     RowVectorXf h0probs, hTProbs;
    64     RowVectorXb h0states;
    65 
    66     RowVectorXf negData;
    67 
    68     MatrixXf posProds, negProds;
    69 
    70     int nHidden;
    71     int nClasses;
    72 };
    73 
    74 #endif

    rbm.cpp

      1 #include "rbm.h"
      2 #include <iomanip>
      3 #include <numeric>
      4 #include <thread>
      5 #include <cmath>
      6 #include <fstream>
      7 
      8 float Rbm::hBiasLearnRate = 0.001;
      9 float Rbm::vBiasLearnRate = 0.008;
     10 float Rbm::WlearnRate = 0.0006;
     11 float Rbm::weightDecay = 0.0001;
     12 float Rbm::momentum = 0.5;
     13 int Rbm::T = 1;
     14 
     15 
     16 //第4步、初始化w,hb,vb
     17 //W = N(movie) * K(评分类数) * M(隐层节点数)
     18 //vb = N(movie) * K(评分类数)
     19 //hb = M(隐层节点数)
     20 Rbm::Rbm(const RbmData& data, int nHidden) 
     21         : nHidden(nHidden), nClasses(data.nClasses)
     22 {
     23     W.setRandom(data.nMovies * nClasses, nHidden);
     24     W.array() *= 0.01;
     25     vBias.setZero(data.nMovies * nClasses);
     26     initVisibleBias(data);                    //可见层的bias初始化。
     27     hBias.setZero(nHidden);
     28 }
     29 
     30 
     31 //4.1所建立的movies和ratings根据用户id进行匹配
     32 
     33 void Rbm::initVisibleBias(const RbmData& data) {
     34     MatrixXi totals = MatrixXi::Zero(1, data.nMovies);
     35     int nUsers = data.range.size() - 1;
     36     for (int i = 0; i < nUsers; i++) {
     37         
     38 //segment根据data.range(i), data.range(i + 1)也就是每行的userid数建立矩阵        
     39         const auto& movies = data.movies.segment(
     40                 data.range(i), data.range(i + 1) - data.range(i));
     41         const auto& ratings = data.ratings.segment(
     42                 data.range(i), data.range(i + 1) - data.range(i));
     43                 
     44 //表示userid=0有多少个 1有多少。。。且和每行的movies、ratings相等                
     45         int amountOfRatings = data.range(i + 1) - data.range(i);     
     46         for (int r = 0; r < amountOfRatings; r++) {
     47             vBias(movies(r)) += ratings(r);
     48             int exactMovie = movies(r) / nClasses;
     49             totals(exactMovie) += ratings(r);
     50         }
     51     }
     52     for (int m = 0; m < totals.size(); m++) {
     53         for (int c = 0; c < nClasses; c++) {
     54             if (vBias(m * nClasses + c) != 0) {
     55                 vBias(m * nClasses + c) /= totals(m);
     56                 vBias(m * nClasses + c) = 
     57                     log(vBias(m * nClasses + c));
     58             }
     59         }
     60     }
     61 }
     62 
     63 
     64 //第5步、坑爹的又调用下面的
     65 void Rbm::performEpoch(const RbmData& data) {
     66     performEpoch(data, 0, data.range.size() - 1);
     67 }
     68 
     69 //第5.1步、开刀
     70 void Rbm::performEpoch(const RbmData& data, int userStart, int userEnd) {
     71     vector<int> randomizedIds(userEnd - userStart, 0);       //定义了End-Start个0元素
     72     for (int i = userStart; i < userEnd; i++) 
     73         randomizedIds[i - userStart] = i;                                            //根据传进来的值知道这个相当是userid个数
     74     random_shuffle(randomizedIds.begin(), randomizedIds.end());    //打乱顺序
     75 
     76     for (unsigned i = 0; i < randomizedIds.size(); i++) {
     77         int rangeStart = data.range(randomizedIds[i]);
     78         int rangeEnd = data.range(randomizedIds[i] + 1); // end is exclusive
     79         int rangeLength = rangeEnd - rangeStart;                //用户id为i的那一行有几个电影评分了*5;
     80         
     81         selectWeights(data, rangeStart, rangeEnd);            //寻找和userid匹配的W
     82         const auto& visData = data.ratings.segment(rangeStart, rangeLength);   //这个userid矩阵化
     83 
     84         h0probs = 1 / (1 + (-visData*Wsel - hBias).array().exp());        //h0
     85         h0states = h0probs.array() > 
     86             (h0probs.Random(h0probs.size()).array() + 1) / 2;           //
     87         negActivation(h0states.cast<float>());                                                    //
     88 
     89         hTProbs = 1 / (1 + (-negData*Wsel - hBias).array().exp());      //h1
     90         for (int t = 1; t < T; t++)
     91             gibbsSample();                                                                                            //gibbs
     92 
     93         posProds.noalias() = visData.transpose() * h0probs;
     94         negProds.noalias() = negData.transpose() * hTProbs; 
     95 
     96         if (i > 0)
     97             applyMomentum(data, randomizedIds[i - 1]);                        //userid>0
     98 
     99 
    100 //下面各种更新参数
    101         hBiasMomentum.noalias() = hBiasLearnRate * (h0probs - hTProbs);
    102         hBias.noalias() += hBiasMomentum;
    103         vBiasMomentum.noalias() = vBiasLearnRate * (visData - negData);
    104         for (int r = rangeStart; r < rangeEnd; r++)
    105             vBias(data.movies(r)) += vBiasMomentum(r - rangeStart);
    106 
    107         Wmomentum.noalias() = posProds - negProds;
    108         for (int r = 0; r < rangeLength; r++)
    109             W.row(data.movies(r + rangeStart)).noalias() += 
    110                 WlearnRate * (Wmomentum.row(r) - weightDecay*Wsel.row(r));
    111     }
    112 }
    113 
    114 void Rbm::negActivation(const RowVectorXf& h0states) {
    115     softmax(h0states);
    116 }
    117 
    118 void Rbm::gibbsSample() {
    119     h0states = 
    120         hTProbs.array() > (hTProbs.Random(hTProbs.size()).array() + 1) / 2;
    121     negActivation(h0states.cast<float>());
    122     hTProbs = 1 / (1 + (-negData*Wsel - hBias).array().exp());
    123 }
    124 
    125 void Rbm::normalizedNegActivation(const RowVectorXf& h0states) {
    126     softmax(h0states);
    127 }
    128 
    129 void Rbm::softmax(const RowVectorXf& h0states) {
    130     negData = (h0states*Wsel.transpose() + vBiasSel);
    131     for (int m = 0; m < negData.size(); m += nClasses) {
    132         negData.segment(m, nClasses).array() -=
    133             negData.segment(m, nClasses).maxCoeff();           //userid=i那行减去其最大值
    134     }
    135     negData.array() = negData.array().exp();
    136     for (int m = 0; m < negData.size(); m += nClasses) {
    137         negData.segment(m, nClasses).array() /=
    138             negData.segment(m, nClasses).sum();                     //userid=i那行减去其最大值,不知何用意?
    139     }
    140 }
    141 
    142 void Rbm::selectWeights(const RbmData& data, int rangeStart, int rangeEnd) {
    143     int rangeLength = rangeEnd - rangeStart;
    144     Wsel.resize(rangeLength, nHidden);
    145     vBiasSel.resize(rangeLength);
    146     for (int r = rangeStart; r < rangeEnd; r++) {
    147         Wsel.row(r - rangeStart).noalias() = W.row(data.movies(r));
    148         vBiasSel(r - rangeStart) = vBias(data.movies(r));
    149     }
    150 }
    151 
    152 void Rbm::applyMomentum(const RbmData& data, int user) {
    153     if (momentum == 0.0) return;
    154     int rangeStart = data.range(user);
    155     int rangeEnd = data.range(user + 1); // end is exclusive
    156     int rangeLength = rangeEnd - rangeStart;
    157 
    158     hBias.noalias() += momentum * hBiasMomentum;
    159     for (int r = rangeStart; r < rangeEnd; r++)
    160         vBias(data.movies(r)) += momentum * vBiasMomentum(r - rangeStart);
    161 
    162     for (int r = 0; r < rangeLength; r++) {
    163         W.row(data.movies(r + rangeStart)).noalias() += 
    164             momentum * Wmomentum.row(r);
    165     }
    166 }
    167 
    168 double Rbm::predict(const RbmData& data, const RbmPredictData& predictData,
    169         const string& fname) {
    170     ofstream out(fname.c_str());
    171     double d = predict(data, predictData, out);
    172     out.close();
    173     return d;
    174 }
    175 
    176 //第6步、又特么调用下面的
    177 double Rbm::predict(const RbmData& data, const RbmPredictData& predictData) {
    178     stringstream dontcare;        //输入流,传入参数和目标对象类型自动推导
    179     return predict(data, predictData, dontcare);
    180 }
    181 
    182 //第6.1步。主要思想如下
    183 //由之前训练重构的数据和预测数据相减算rmse
    184 double Rbm::predict(const RbmData& data, const RbmPredictData& predictData,
    185         ostream& predictStream) {
    186     double rmse = 0.0;
    187     int predictCount = 0;
    188     for (unsigned i = 0; i < predictData.userIds.size(); i++) {
    189         int userId = predictData.userIds[i];
    190         int rangeStart = data.range(userId);
    191         int rangeEnd = data.range(userId + 1); // end is exclusive
    192         selectWeights(data, rangeStart, rangeEnd);
    193         const auto& visData = 
    194             data.ratings.segment(rangeStart, rangeEnd - rangeStart);
    195         h0probs = 1 / (1 + (-visData*Wsel - hBias).array().exp());       //这里又不知道在干嘛。
    196 
    197         rangeStart = predictData.range(i);
    198         rangeEnd = predictData.range(i + 1);
    199         selectWeights(predictData, rangeStart, rangeEnd);
    200         normalizedNegActivation(h0probs);
    201         const auto& actualData = 
    202             predictData.ratings.segment(rangeStart, rangeEnd - rangeStart);    //预测集ratings矩阵化
    203 
    204         for (int r = 0; r < actualData.size(); r += nClasses) {
    205             float actual = 0;
    206             float predicted = 0;
    207             for (int c = 0; c < nClasses; c++) {
    208                 actual += actualData(r + c) * (c + 1);          //把5维的0 1 转化成评分数1~k
    209                 predicted += negData(r + c) * (c + 1);                
    210             }
    211             predictStream << predicted << endl;
    212             float t = actual - predicted;
    213             rmse += t * t;
    214         }
    215         predictCount += actualData.size() / nClasses;
    216     }
    217     return sqrt(rmse / predictCount);
    218 }

    rbmparallel.h

     1 #ifndef RBM_PARALLEL_H
     2 #define RBM_PARALLEL_H
     3 #include "rbm.h"
     4 #include "rbmdata.h"
     5 
     6 class RbmParallel : public Rbm {
     7 public:
     8     RbmParallel(int nThreads, const RbmData& data, int nHidden);
     9 
    10     void performEpoch(const RbmData& data);
    11 
    12 private:
    13     RbmParallel();
    14 
    15     void startEpochs(const RbmData& data, int batchStart, int batchSize);
    16     void joinExecution();
    17     void updateWeights();
    18     void synchronizeWeights();
    19 
    20     static void performEpochInThread(
    21             Rbm& r, const RbmData& data, int userStart, int userEnd);
    22 
    23     int nThreads;
    24     int nUsers;
    25     vector<Rbm> subRbms;
    26     vector<thread> threads;
    27 };
    28 
    29 #endif

    rbmparallel.cpp

     1 #include "rbmparallel.h"
     2 
     3 RbmParallel::RbmParallel(int nThreads, const RbmData& data, int nHidden) : 
     4     Rbm(data, nHidden), 
     5     nThreads(nThreads), 
     6     nUsers(data.range.size() - 1),
     7     subRbms(nThreads - 1, Rbm(data, nHidden))
     8 {
     9     synchronizeWeights();
    10 }
    11 
    12 void RbmParallel::performEpoch(const RbmData& data) {
    13     int batchSize = nUsers;
    14     for (int batchStart = 0; batchStart < nUsers; batchStart += batchSize) {
    15         startEpochs(data, batchStart, batchSize);
    16         joinExecution();
    17         updateWeights();
    18     }
    19 }
    20 
    21 void RbmParallel::startEpochs(
    22         const RbmData& data, int batchStart, int batchSize) {
    23     threads.clear();
    24     int usersPerStep = batchSize / nThreads;
    25     int userStart = batchStart;
    26     for (int i = 0; i < nThreads - 1; i++) {
    27         threads.push_back(
    28                 thread(&RbmParallel::performEpochInThread, 
    29                     ref(subRbms[i]), ref(data), 
    30                     userStart, userStart + usersPerStep));
    31         userStart += usersPerStep;
    32     }
    33     Rbm::performEpoch(data, userStart, nUsers);
    34 }
    35 
    36 void RbmParallel::performEpochInThread(
    37         Rbm& r, const RbmData& data, int userStart, int userEnd) {
    38     r.performEpoch(data, userStart, userEnd);
    39 }
    40 
    41 void RbmParallel::joinExecution() {
    42     for (auto it = threads.begin(); it != threads.end(); it++) {
    43         it->join();
    44     }
    45 }
    46 
    47 void RbmParallel::updateWeights() {
    48     float updateFactor = 1.0 / nThreads;
    49     W *= updateFactor;
    50     vBias *= updateFactor;
    51     hBias *= updateFactor;
    52     for (auto it = subRbms.begin(); it != subRbms.end(); it++) {
    53         W += it->W * updateFactor;
    54         vBias += it->vBias * updateFactor;
    55         hBias += it->hBias * updateFactor;
    56     }
    57     synchronizeWeights();
    58 }
    59 
    60 void RbmParallel::synchronizeWeights() {
    61     for (auto it = subRbms.begin(); it != subRbms.end(); it++) {
    62         it->W = W;
    63         it->hBias = hBias;
    64         it->vBias = vBias;
    65     }
    66 }

    基础太差,看了一个星期,还是看不懂,伤不起。

  • 相关阅读:
    处理溢出
    电梯调度之需求分析
    求二维矩阵和最大的子矩阵
    四则运算改进,结果判断
    结对开发
    四则运算题测试阶段
    阶段二站立会议(2)
    阶段二站立会议(1)
    课程改进意见
    场景调研
  • 原文地址:https://www.cnblogs.com/wn19910213/p/3581024.html
Copyright © 2011-2022 走看看