关于多分类
我们常见的逻辑回归、SVM等常用于解决二分类问题,对于多分类问题,比如识别手写数字,它就需要10个分类,同样也可以用逻辑回归或SVM,只是需要多个二分类来组成多分类,但这里讨论另外一种方式来解决多分类——softmax。
关于softmax
softmax的函数为
P(i)=exp(θTix)∑Kk=1exp(θTkx)P(i)=exp(θiTx)∑k=1Kexp(θkTx)
可以看到它有多个值,所有值加起来刚好等于1,每个输出都映射到了0到1区间,可以看成是概率问题。
θTixθiTx为多个输入,训练其实就是为了逼近最佳的θTθT。
如何多分类
从下图看,神经网络中包含了输入层,然后通过两个特征层处理,最后通过softmax分析器就能得到不同条件下的概率,这里需要分成三个类别,最终会得到y=0、y=1、y=2的概率值。
继续看下面的图,三个输入通过softmax后得到一个数组[0.05 , 0.10 , 0.85],这就是soft的功能。
计算过程直接看下图,其中zLiziL即为θTixθiTx,三个输入的值分别为3、1、-3,ezez的值为20、2.7、0.05,再分别除以累加和得到最终的概率值,0.88、0.12、0。
代价函数
对于训练集{(x(1),y(1)),...,(x(m),y(m))}{(x(1),y(1)),...,(x(m),y(m))},有y(i)∈{1,2,3...,k}y(i)∈{1,2,3...,k},总共有k个分类。对于每个输入x都会有对应每个类的概率,即p(y=j|x)p(y=j|x),从向量角度来看,有,
hθ(x(i))=⎡⎣⎢⎢⎢⎢⎢p(y(i)=1|x(i);θ)p(y(i)=2|x(i);θ)⋮p(y(i)=k|x(i);θ)⎤⎦⎥⎥⎥⎥⎥=1∑kj=1eθTj⋅x(i)⎡⎣⎢⎢⎢⎢⎢eθT1⋅x(i)eθT2⋅x(i)⋮eθTk⋅x(i)⎤⎦⎥⎥⎥⎥⎥hθ(x(i))=[p(y(i)=1|x(i);θ)p(y(i)=2|x(i);θ)⋮p(y(i)=k|x(i);θ)]=1∑j=1keθjT⋅x(i)[eθ1T⋅x(i)eθ2T⋅x(i)⋮eθkT⋅x(i)]
softmax的代价函数定为如下,其中包含了示性函数1{j=y(i)}1{j=y(i)},表示如果第i个样本的类别为j则yij=1yij=1。代价函数可看成是最大化似然函数,也即是最小化负对数似然函数。
J(θ)=−1m[∑mi=1∑kj=11{y(i)=j}⋅log(p(y(i)=j|x(i);θ))]J(θ)=−1m[∑i=1m∑j=1k1{y(i)=j}⋅log(p(y(i)=j|x(i);θ))]
其中,p(y(i)=j|x(i);θ)=exp(θTix)∑Kk=1exp(θTkx)p(y(i)=j|x(i);θ)=exp(θiTx)∑k=1Kexp(θkTx)则,
J(θ)=−1m[∑mi=1∑kj=11{y(i)=j}⋅(θTjx(i)−log(∑kl=1eθTl⋅x(i)))]J(θ)=−1m[∑i=1m∑j=1k1{y(i)=j}⋅(θjTx(i)−log(∑l=1keθlT⋅x(i)))]
一般使用梯度下降优化算法来最小化代价函数,而其中会涉及到偏导数,即θj:=θj−αδθjJ(θ)θj:=θj−αδθjJ(θ),则J(θ)J(θ)对θjθj求偏导,得到,
∇J(θ)∇θj=−1m∑mi=1[∇∑kj=11{y(i)=j}θTjx(i)∇θj−∇∑kj=11{y(i)=j}log(∑kl=1eθTl⋅x(i)))∇θj]∇J(θ)∇θj=−1m∑i=1m[∇∑j=1k1{y(i)=j}θjTx(i)∇θj−∇∑j=1k1{y(i)=j}log(∑l=1keθlT⋅x(i)))∇θj]
=−1m∑mi=1[1{y(i)=j}x(i)−∇∑kj=11{y(i)=j}∑kl=1eθTl⋅x(i)∑kl=1eθTl⋅x(i)∇θj]=−1m∑i=1m[1{y(i)=j}x(i)−∇∑j=1k1{y(i)=j}∑l=1keθlT⋅x(i)∑l=1keθlT⋅x(i)∇θj]
=−1m∑mi=1[1{y(i)=j}x(i)−x(i)eθTj⋅x(i)∑kl=1eθTl⋅x(i)]=−1m∑i=1m[1{y(i)=j}x(i)−x(i)eθjT⋅x(i)∑l=1keθlT⋅x(i)]
=−1m∑mi=1x(i)[1{y(i)=j}−p(y(i)=j|x(i);θ)]=−1m∑i=1mx(i)[1{y(i)=j}−p(y(i)=j|x(i);θ)]
得到代价函数对参数权重的梯度就可以优化了。
使用场景
在多分类场景中可以用softmax也可以用多个二分类器组合成多分类,比如多个逻辑分类器或SVM分类器等等。该使用softmax还是组合分类器,主要看分类的类别是否互斥,如果互斥则用softmax,如果不是互斥的则使用组合分类器。