zoukankan      html  css  js  c++  java
  • [Machine Learning]朴素贝叶斯(NaiveBayes)

    C++ 描述:

     1 #include <iostream>
     2 #include <string>
     3 #include <fstream>
     4 #include <sstream>
     5 #include <vector>
     6 #include <map>
     7 #include <set>
     8 
     9 using namespace std;
    10 
    11 class NaiveBayes {
    12 public:
    13     void load_data(string path);
    14     void train_model();
    15     int predict(const vector<int> &item);
    16 private:
    17     vector<vector<int>> data;
    18     map<pair<int, int>, double> c_p; //conditional prob
    19     map<int, double> p_p; // prior prob
    20 };
    21 
    22 void NaiveBayes::load_data(string path) {
    23     ifstream fin(path.c_str());
    24     if (!fin) {
    25         cerr << "open file error" << endl;
    26         exit(1);
    27     }
    28 
    29     string line;
    30     while (getline(fin, line)) {
    31         if (line.size() > 1) {
    32             stringstream sin(line);
    33             int elem;
    34             vector<int> tmp;
    35             while (sin >> elem) {
    36                 tmp.push_back(elem);
    37             }
    38             data.push_back(tmp);
    39         }
    40     }
    41     fin.close();
    42 }
    43 
    44 void NaiveBayes::train_model() {
    45     for (auto &d : data) {
    46         int len = d.size();
    47         p_p[d[len - 1]] += (1.0 / data.size());
    48     }
    49 
    50     for (auto &p : p_p) {
    51         int label = p.first;
    52         double prior = p.second;
    53         for (auto &d : data) {
    54             for (int i = 0; i < d.size(); ++i) {
    55                 c_p[make_pair(d[i], label)] += (1.0 / (prior * data.size()));
    56             }
    57         }
    58     }
    59 }
    60 
    61 int NaiveBayes::predict(const vector<int> &item) { 
    62     int result;
    63     double max_prob = 0.0;
    64     for (auto &p : p_p) {
    65         int label = p.first;
    66         double prior = p.second;
    67         double prob = prior;
    68         for (int i = 0; i < item.size() - 1; ++i) {
    69             prob *= c_p[make_pair(item[i], label)];
    70         }
    71 
    72         if (prob > max_prob) {
    73             max_prob = prob;
    74             result = label;
    75         }
    76     }
    77 
    78     return result;
    79 }
    80 
    81 int main() {
    82     NaiveBayes naive_bayes;
    83     naive_bayes.load_data(string("result.txt"));
    84     naive_bayes.train_model();
    85 
    86     vector<int> item{2, 4};
    87     cout << naive_bayes.predict(item);
    88     return 0;
    89 }

    数据集:

    1 4 -1
    1 5 -1
    1 5 1
    1 4 1
    1 4 -1
    2 4 -1
    2 5 -1
    2 5 1
    2 6 1
    2 6 1
    3 6 1
    3 5 1
    3 5 1
    3 6 1
    3 6 -1
  • 相关阅读:
    C语言I博客作业04
    C语言II博客作业04
    C语言II作业03
    C语言II博客作业02
    C语言II博客作业01
    第一周C语言作业
    C语言I博客作业02
    C语言I博客作业08
    C语言I博客作业07
    C语言I博客作业06
  • 原文地址:https://www.cnblogs.com/skycore/p/5127725.html
Copyright © 2011-2022 走看看