zoukankan      html  css  js  c++  java
  • 【EM】C++代码实现

    看了原理和比人的代码后,终于自己写了一个EM的实现。

    我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。

    实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)

                                     79%  (初始数据 男生165 女生150 方差都是10)

    正确率与初始值有关。

    /*
    试图用EM算法来根据输入的身高来区分性别
    */
    
    #include<iostream>
    #include<fstream>
    #include<algorithm>
    #include<vector>
    using namespace std;
    
    #define PI 3.14159
    #define max(x,y) (x > y ? x : y)
    
    typedef struct FLOAT2
    {
        float f1;
        float f2;
    }FLOAT2;
    typedef struct Gaussian
    {
        float mean;
        float var;
    }Gaussian;
    
    typedef struct EMData
    {
        char sex;
        float fHeight;
    }EMData;
    
    //获取身高性别数据
    int getdata(vector<EMData> &Data)
    {
        ifstream fin;
        fin.open("data.txt");
        if(!fin)
        {
            cout<<"error: can't open the file."<<endl;
            return -1;
        }
    
        while(!fin.eof())
        {
            char c[10];
            float height;
            fin >> c >> height;
            EMData data;
            data.sex = c[0];
            data.fHeight = height;
            Data.push_back(data);
        }
    
        return 0;
    }
    
    //根据身高数据区分性别, 返回正确率
    float predict(vector<EMData> Data)
    {
        //设符合正态分布
        Gaussian sex[2];
        float a[2]; //男女生所占百分比
        float t = 1;
        float tlimit = 0.000001; //收敛条件
    
        //赋初值 下标0表示男生 1表示女生
        sex[0].mean = 180.0;
        sex[0].var = 10.0;
        sex[1].mean = 150.0;
        sex[1].var = 10.0;
        a[0] = 0.5;
        a[1] = 0.5;
    
        while(t > tlimit)
        {
            Gaussian sex_old[2];
            float a_old[2];
            sex_old[0] = sex[0];
            sex_old[1] = sex[1];
            a_old[0] = a[0];
            a_old[1] = a[1];
    
            //计算每个样本分别被两个模型抽中的概率
            vector<FLOAT2> px;
        
            vector<EMData>::iterator it;
            for(it = Data.begin(); it < Data.end(); it++)
            {
                FLOAT2 p;
                p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var));
                p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var));
                px.push_back(p);
            }
    
            //E步
            //计算每个样本属于男生或女生的概率
            vector<FLOAT2>::iterator it2;
            for(it2 = px.begin(); it2 < px.end(); it2++)
            {
                float sum = 0.0;
                (*it2).f1 *= a[0];
                sum += (*it2).f1;
                (*it2).f2 *= a[1];
                sum += (*it2).f2;
    
                (*it2).f1 = (*it2).f1/sum;
                (*it2).f2 = (*it2).f2/sum;
            }
    
            //M步
            float sum_male = 0, sum_female = 0;
            float sum_mean_male = 0, sum_mean_female = 0;
            for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
            {
                sum_male += (*it2).f1;
                sum_female += (*it2).f2;
                sum_mean_male += (*it2).f1 * (it->fHeight);
                sum_mean_female += (*it2).f2 * (it->fHeight);
            }
            //更新a
            a[0] = sum_male/(sum_male + sum_female);
            a[1] = sum_female/(sum_male + sum_female);
    
            //更新均值
            sex[0].mean = sum_mean_male/ sum_male;
            sex[1].mean = sum_mean_female/ sum_female;
    
            //更新方差
            float sum_var_male = 0, sum_var_female = 0;
            for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
            {
                sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean);
                sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean);
            }
            sex[0].var = sum_var_male / sum_male;
            sex[1].var = sum_var_female / sum_female;
    
            //计算变化率
            t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]);
            t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean);
            t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean);
            t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var);
            t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var);
        }
    
        //计算正确率
        int correct_num = 0;
        float correct_rate = 0;
        vector<EMData>::iterator it;
        for(it = Data.begin(); it < Data.end(); it++)
        {
            float p[2];
            char csex;
            for(int i = 0; i < 2; i++)
            {
                p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var));
            }
    
            csex = (p[0] > p[1]) ? 'm' : 'f';
            if(csex == it->sex)
                correct_num++;
        }
    
        correct_rate = (float)correct_num / Data.size();
        return correct_rate;
    }
    
    int main()
    {
        vector<EMData> Data;
        getdata(Data);
        float correct_rate = predict(Data);
        cout << "correct rate = "<< correct_rate << endl;
        return 0;
    }

    数据:data.txt内容

    male    164
    female    156
    male    168
    female    160
    female    162
    male    187
    female    162
    male    167
    female    160.5
    female    160
    female    158
    female    164
    female    165
    male    174
    female    166
    female    158
    male     162
    male    175
    male    170
    female    161
    female    169
    female    161
    female    160
    female    167
    male    176
    male    169
    male    178
    male    165
    female    155
    male    183
    male    171
    male    179
    female    154
    male    172
    female    172
    male    173
    male    172
    male    175
    male    160
    male    160
    male    160
    male    175
    male    163
    male    181
    male    172
    male    175
    male    175
    male    167
    male    172
    male    169
    male    172
    male    175
    male    172
    male    170
    male    158
    male    167
    male    164
    male    176
    male    182
    male    173
    male    176
    male    163
    male    166
    male    162
    male    169
    male    163
    male    163
    male    176
    male    169
    male    173
    male    163
    male    167
    male    176
    male    168
    male    167
    male    170
    female    155
    female    157
    female    165
    female    156
    female    155
    female    156
    female    160
    female    158
    female    162
    female    162
    female    155
    female    163
    female    160
    female    162
    female    165
    female    159
    female    147
    female    163
    female    157
    female    160
    female    162
    female    158
    female    155
    female    165
    female    161
    female    159
    female    163
    female    158
    female    155
    female    162
    female    157
    female    159
    female    152
    female    156
    female    165
    female    154
    female    156
    female    162
  • 相关阅读:
    9月22日 又上锁妖塔
    1396. 【2014年鄞州区】挖掘机(d.pas/c/cpp)
    栓奶牛——二分解法
    P6188 [NOI Online 入门组]文具订购 题解
    HDC.Cloud | 基于IoT Studio自助生成10万行代码的奥秘
    华为云API Explorer开发者生态平台正式上线
    【华为云技术分享】揭秘华为云DLI背后的核心计算引擎
    【华为云技术分享】ARM体系结构基础(2)
    【华为云技术分享】HDC.Cloud | 以数字资产模型为核心驱动的一站式IoT数据分析实践
    【华为云技术分享】数据赋能,如何精细化保障企业大数据安全
  • 原文地址:https://www.cnblogs.com/dplearning/p/3981578.html
Copyright © 2011-2022 走看看