zoukankan      html  css  js  c++  java
  • Softmax回归

    Reference:

    http://ufldl.stanford.edu/wiki/index.php/Softmax_regression

    http://deeplearning.net/tutorial/logreg.html

    起源:Logistic的二类分类

    Softmax回归是Logistic回归的泛化版本,用于解决线性多类(K类)的分类问题。

    Logistic回归可以看作是Softmax回归在K=2时的特例。Softmax函数即是K分类版的Logistc函数。

    裸Softmax回归的效果很差,因为没有隐层结构,归根还是是线性回归。所以在深度学习里,Softmax则通常作为MLP的输出层。

    即,将BP网络和Softmax结合起来,取BP网络的隐层映射机制、Softmax的多分类机制,加以组合形成新的MLP架构。

    这么做的原因就是,传统BP网络的输出层是个多神经元的自行设计接口层,比如常见的log2(K)方法,转多分类需要麻烦的编码。

    但实际上,隐层(可看作是input)到输出层的映射原理等效于Softmax,既然Softmax拥有概率取分类的方法,何必再用低效的编码方法?

    Part I  如何从2类转化为K类?

    解决方案是引入K组(W、b)参数,即有K个分隔超平面,选择$max P(Y=j|x^{i}, heta,b)$作为最终分类即可。

    由于存在K组参数,原来的$h( heta)=sigmoid(Inner)$将从单个值,变成一个大小为K的向量。

     

    Part II  变化的目标函数

    Logistic的目标函数: $J( heta)=sum_{i=1}^{m}(1-y^{(i)})log(1-h_{ heta}(x^{i})+y^{i}log(h_{ heta}(x^{(i)}))$

    在Softmax里,由于$h_{ heta}(x^{(i)}$已经变成了向量,所以不能再使用。

    实际上,在Logistic的推导里,$h_{ heta}(x^{(i)})$只是偶然而已,$P(y=0|x; heta)=h( heta)$。

    即$P(y|x; heta))$才是真正的概率分布函数,上述情况只是二项分布的特例。由于y的取值变成的K类,所以新的K项分布概率密度分布表示如下:

    $P(y^{(i)}=j|x; heta)=frac{e^{W_{j}X^{i}}}{sum_{l=1}^{k}e^{W_{l}X^{i}}}$

    且定义$1{y_{i}=j}=(y_{i}==j)?1:0$

    则  $J( heta)=sum_{i=1}^{m}sum_{j=0}^{l}1{y_{i}=j}logfrac{e^{W_{j}X^{i}}}{sum_{l=1}^{k}e^{W_{l}X^{i}}}$

    仔细观察,其实就是$h_{ heta}(x^{(i)})$这个向量根据$y^{(i)}$情况抽取的单个值而已,这就是Logistic函数的修改版本——Softmax函数

    梯度变成:$frac{partial J( heta_{j})}{partial heta_{j}}=sum_{i=1}^{m}x^{(i)}(1{y_{i}=j}-P(y^{(i)}=j|x; heta_{j})),j=1,2....k$

    可以使用梯度上升算法了(下降算法也可,即取均值加上负号,变成负对数似然函数):

    $ heta_{j}^{new}= heta_{j}^{new}+alphafrac{partial J( heta_{j})}{partial heta_{j}},j=1,2....k$

    Part III  C++代码与实现

    #include "cstdio"
    #include "iostream"
    #include "fstream"
    #include "vector"
    #include "sstream"
    #include "string"
    #include "math.h"
    using namespace std;
    #define N 500
    #define delta 0.0001
    #define alpha 0.1
    #define cin fin
    #define K 2
    #define Dim dataSet[0].feature.size()
    struct Data
    {
        vector<double> feature;
        int y;
        Data(vector<double> feature,int y):feature(feature),y(y) {}
    };
    struct Parament
    {
        vector<double> w;
        double b;
        Parament() {}
        Parament(vector<double> w,double b):w(w),b(b) {}
    };
    vector<Data> dataSet;
    vector<Parament> parament;
    void read()
    {
        ifstream fin("fullTrain.txt");
        double fea;int cls;
        string line;
        while(getline(cin,line))
        {
            stringstream sin(line);
            vector<double> feature;
            while(sin>>fea) feature.push_back(fea);
            cls=feature.back();feature.pop_back();
            dataSet.push_back(Data(feature,cls));
        }
        for(int i=0;i<K;i++) parament.push_back(Parament(vector<double>(Dim,0.0),0.0));
    }
    double calcInner(Parament param,Data data)
    {
        double ret=0.0;
        for(int i=0;i<data.feature.size();i++) ret+=(param.w[i]*data.feature[i]);
        return ret+param.b;
    }
    double calcProb(int j,Data data)
    {
        double ret=0.0,spec=0.0;
        for(int l=1;l<=K;l++)
        {
            double tmp=exp(calcInner(parament[l-1],data));
            if(l==j) spec=tmp;
            ret+=tmp;
        }
        return spec/ret;
    }
    double calcLW()
    {
        double ret=0.0;
        for(int i=0;i<dataSet.size();i++)
        {
            double prob=calcProb(dataSet[i].y,dataSet[i]);
            ret+=log(prob);
        }
        return ret;
    }
    void gradient(int iter)
    {
        /*batch (logistic)
        for(int i=0;i<param.w.size();i++)
        {
            double ret=0.0;
            for(int j=0;j<dataSet.size();j++)
            {
                double ALPHA=(double)0.1/(iter+j+1)+0.1;
                ret+=ALPHA*(dataSet[j].y-sigmoid(param,dataSet[j]))*dataSet[j].feature[i];
            }
            param.w[i]+=ret;
        }
        for(int i=0;i<dataSet.size();i++) ret+=alpha*(dataSet[i].y-sigmoid(param,dataSet[i]));
        */
        //random
        for(int j=0;j<dataSet.size();j++)
        {
            double ret=0.0,prob=0.0;
            double ALPHA=(double)0.1/(iter+j+1)+0.1;
            for(int k=1;k<=K;k++)
            {
                prob=((dataSet[j].y==k?1:0)-calcProb(k,dataSet[j]));
                for(int i=0;i<Dim;i++) parament[k-1].w[i]+=ALPHA*prob*dataSet[j].feature[i];
                parament[k-1].b+=ALPHA*prob;
            }
        }
    }
    void classify()
    {
        ifstream fin("fullTest.txt");
        double fea;int cls,no=0;
        string line;
        while(getline(cin,line))
        {
            stringstream sin(line);
            vector<double> feature;
            while(sin>>fea) feature.push_back(fea);
            cls=feature.back();feature.pop_back();
            int bestClass=-1;double bestP=-1;
            for(int i=1;i<=K;i++)
            {
                double p=calcProb(i,Data(feature,cls));
                if(p>bestP) {bestP=p;bestClass=i;}
            }
            cout<<"Test:"<<++no<<"  origin:"<<cls<<" classify:"<<bestClass<<endl;
        }
    }
    void mainProcess()
    {
        double objLW=calcLW(),newLW;
        int iter=0;
        gradient(iter);
        newLW=calcLW();
        while(fabs(newLW-objLW)>delta)
        {
            objLW=newLW;
            gradient(iter);
            newLW=calcLW();
            iter++;
            //if(iter%5==0) cout<<"iter: "<<iter<<"  target value: "<<newLW<<endl;
        }
        cout<<endl<<endl;
    }
    int main()
    {
        read();
        mainProcess();
        classify();
    }
    Softmax

    Part IV  测试

    使用Iris鸢尾花数据集:http://archive.ics.uci.edu/ml/datasets/Iris,是三类分类问题

    该数据集的第三组数据是非线性的,若K=3训练,则因为非线性数据扰乱,错误率很大。

    若K=2,则代码等效于Logistic回归,错误率相近。

  • 相关阅读:
    转:彻底搞清楚javascript中的require、import和export
    转:博客园新随笔 添加锚点
    转:深入浅出空间索引:为什么需要空间索引
    转:常见的空间索引方法
    可视化&地图__公司收集
    js json转xml(可自定义属性,区分大小写)
    Python3.6之给指定用户发送微信消息
    微信服务号发送模板消息
    log4j封装方法,输出集合
    Java封装servlet发送请求(二)
  • 原文地址:https://www.cnblogs.com/neopenx/p/4316611.html
Copyright © 2011-2022 走看看