多分类问题的交叉熵
在多分类问题中,损失函数(loss function)为交叉熵(cross entropy)损失函数。对于样本点(x,y)来说,y是真实的标签,在多分类问题中,其取值只可能为标签集合labels. 我们假设有K个标签值,且第i个样本预测为第k个标签值的概率为(p_{i,k}), 即(p_{i,k} = operatorname{Pr}(t_{i,k} = 1)), 一共有N个样本,则该数据集的损失函数为
[L_{log}(Y, P) = -log operatorname{Pr}(Y|P) = - frac{1}{N} sum_{i=0}^{N-1} sum_{k=0}^{K-1} y_{i,k} log p_{i,k}
]
一个例子
在Python的sklearn模块中,提供了一个函数log_loss()来计算多分类问题的交叉熵。再根据我们在博客Sklearn中二分类问题的交叉熵计算对log_loss()函数的源代码的分析,我们不难利用上面的计算公式用自己的方法来实现交叉熵的求值。
我们给出的例子如下:
y_true = ['1', '4', '5'] # 样本的真实标签
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],
[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],
[0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]
] # 样本的预测概率
labels = ['0','1','2','3','4','5','6','7','8','9'] # 所有标签
在这个例子中,一个有3个样本,标签为1,4,5,一共是10个标签,y_pred是对每个样本的所有标签的预测值。
接下来我们将会用log_loss()函数和自己的方法分别来实现这个例子的交叉熵的计算,完整的Python代码如下:
from sklearn.metrics import log_loss
from sklearn.preprocessing import LabelBinarizer
from math import log
y_true = ['1', '4', '5'] # 样本的真实标签
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],
[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],
[0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]
] # 样本的预测概率
labels = ['0','1','2','3','4','5','6','7','8','9'] # 所有标签
# 利用sklearn中的log_loss()函数计算交叉熵
sk_log_loss = log_loss(y_true, y_pred, labels=labels)
print("Loss by sklearn is:%s." %sk_log_loss)
# 利用公式实现交叉熵
# 交叉熵的计算公式网址为:
# http://scikit-learn.org/stable/modules/model_evaluation.html#log-loss
# 对样本的真实标签进行标签二值化
lb = LabelBinarizer()
lb.fit(labels)
transformed_labels = lb.transform(y_true)
# print(transformed_labels)
N = len(y_true) # 样本个数
K = len(labels) # 标签个数
eps = 1e-15 # 预测概率的控制值
Loss = 0 # 损失值初始化
for i in range(N):
for k in range(K):
# 控制预测概率在[eps, 1-eps]内,避免求对数时出现问题
if y_pred[i][k] < eps:
y_pred[i][k] = eps
if y_pred[i][k] > 1-eps:
y_pred[i][k] = 1-eps
# 多分类问题的交叉熵计算公式
Loss -= transformed_labels[i][k]*log(y_pred[i][k])
Loss /= N
print("Loss by equation is:%s." % Loss)
输出的结果如下:
Loss by sklearn is:1.16885263244.
Loss by equation is:1.16885263244.
这说明我们能够用公式来自己实现交叉熵的计算了,是不是很神奇呢?
多分类问题的交叉熵计算是建立在二分类问题的交叉熵计算的基础上,有了我们对log_loss()函数的源代码的研究,那就用自己的方法来实现多(二)分类问题的交叉熵计算就不是问题了~~
本次分享到此结束,欢迎大家交流~~
注意:本人现已开通两个微信公众号: 因为Python(微信号为:python_math)以及轻松学会Python爬虫(微信号为:easy_web_scrape), 欢迎大家关注哦~~