概
[q_1 = frac{exp(z_i/T)}{sum_j exp(z_j/T)}.
]
主要内容
这篇文章或许重点是在迁移学习上, 一个重点就是其认为soft labels (即概率向量)比hard target (one-hot向量)含有更多的信息. 比如, 数字模型判别数字(2)为(3)和(7)的概率分别是0.1, 0.01, 这说明这个数字(2)很有可能和(3)长的比较像, 这是one-hot无法带来的信息.
于是乎, 现在的情况是:
-
以及有一个训练好的且往往效果比较好但是计量大的模型(t);
-
我们打算用一个小的模型(s)去近似这个已有的模型;
-
策略是每个样本(x), 先根据(t(x))获得soft logits (z in mathbb{R}^K), 其中(K)是类别数, 且(z)未经softmax.
-
最后我们希望根据下面的损失函数来训练(s):
[mathcal{L(x, y)} = T^2 cdot mathcal{L}_{soft}(x, y) + lambda cdotmathcal{L}_{hard}(x, y) ]
其中
[mathcal{L}_{soft}(x, y) = -sum_{i=1}^K p_i(x) log q_i (x) = -sum_{i=1}^K
frac{exp(v_i(x)/T)}{sum_j exp(v_j(x)/T)}
log frac{exp(z_i(x)/T)}{sum_j exp(z_j(x)/T)}
]
[mathcal{L}_{hard}(x, y) = -log
frac{exp(z_y(x))}{sum_j exp(z_j(x))}
]
至于(T^2)是怎么来的, 这是为了配平梯度的magnitude.
[egin{array}{ll}
frac{partial mathcal{L}_{soft}}{partial z_k}
&= -sum_{i=1}^K frac{p_i}{q_i} frac{partial q_i}{partial z_k}
= -frac{1}{T}p_k -sum_{i=1}^K frac{p_i}{q_i} cdot (-frac{1}{T}q_i q_k) \
&= -frac{1}{T} (p_k -sum_{i=1}^K p_iq_k) = frac{1}{T}(q_k-p_k) \
&= frac{1}{T} (frac{e^{z_i/T}}{sum_j e^{z_j/T}} - frac{e^{v_i/T}}{sum_j e^{v_j/T}}) .
end{array}
]
当(T)足够大的时候, 并假设(sum_j z_j=0 = sum_j v_j =0), 有
[frac{partial mathcal{L}_{soft}}{partial z_k} approx frac{1}{KT^2} (z_k - v_k).
]
故需要加个(T^2)取抵消这部分的影响.
代码
其实一直很好奇的一点是这部分代码在pytorch里是怎么实现的, 毕竟pytorch里的交叉熵是
[-log p_y(x)
]
另外很恶心的一点是, 我看大家都用的是 KLDivLOSS, 但是其实现居然是:
[mathcal{L}(x, y) = y cdot log y - y cdot x,
]
注: 这里的(cdot)是逐项的.
def kl_div(x, y):
return y * (torch.log(y) - x)
x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1
loss1 = F.kl_div(x, y, reduction="none")
loss2 = kl_div(x, y)
这时, 出来的结果长这样
tensor([[-1.5965, 2.2040, -0.8753],
[ 3.9795, 0.0910, 1.0761]])
tensor([[-1.5965, 2.2040, -0.8753],
[ 3.9795, 0.0910, 1.0761]])
又或者:
def kl_div(x, y):
return (y * (torch.log(y) - x)).sum(dim=1).mean()
torch.manual_seed(10086)
x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1
loss1 = F.kl_div(x, y, reduction="batchmean")
loss2 = kl_div(x, y)
print(loss1)
print(loss2)
tensor(2.4394)
tensor(2.4394)
所以如果真要弄, 应该要
def soft_loss(z, v, T=10.):
# z: logits
# v: targets
z = F.log_softmax(z / T, dim=1)
v = F.softmax(v / T, dim=1)
return F.kl_div(z, v, reduction="batchmean")