zoukankan      html  css  js  c++  java
  • 多类 SVM 的损失函数及其梯度计算

    CS231n Convolutional Neural Networks for Visual Recognition —— optimization

    1. 多类 SVM 的损失函数(Multiclass SVM loss)

    在给出类别预测前的输出结果是实数值, 也即根据 score function 得到的 score(s=f(xi,W)),

    Li=jyimax(0,sjsyi+Δ),Δ=1

    • yi 表示真实的类别,syi 在真实类别上的得分;
    • sj,jyi 在其他非真实类别上的得分,也即预测错误时的得分;

    则在全体训练样本上的平均损失为:

    L=1Ni=1NLi

    delta = 1
    scores = np.dot(W, X)
    correct_scores = scores[y, np.arange(num_samples)]
    
    diff = score - correct_scores + delta
    diff[y, np.arange(num_samples)] = 0
    
    thresh = np.maximum(0, diff)
    loss = np.sum(thresh)
    loss /= num_samples

    2. 优化(optimization):梯度计算

    首先来看损失函数的定义,如下为第 i 个样本的损失值(Wc×dXd×Nd 特征向量的维度,c:输出类别的个数):

    Li==jyimax(0,sjsyi+1)jyi[max(0,wTjxiwTyixi+1)]

    • 遍历 j,就是遍历 W 每一列的每一个元素, wTjxij=1,,c;i=1,,N
    • wTj 表示 W 的每一行,共 c 行;

    下面的额关键是如何求得损失函数关于参数 wj,wyi 的梯度:

    wyiLi=jyi1(wTjxiwTyixi+Δ>0)xiwjLi=1(wTjxiwTyixi+Δ>0)xijyi

    binary = thresh 
    binary[thresh > 0] = 1          # 实现 indicator 函数
    
    col_sum = np.sum(binary, axis=0)
    binary[y, np.arange(num_samples)] = -col_sum
    
    dW = np.dot(binary, X.T)        # binary 维度信息:c*N, X 维度信息:d*N
    dW /= N
    
    dW += reg * W
    
  • 相关阅读:
    85. Maximal Rectangle
    120. Triangle
    72. Edit Distance
    39. Combination Sum
    44. Wildcard Matching
    138. Copy List with Random Pointer
    91. Decode Ways
    142. Linked List Cycle II
    异或的性质及应用
    64. Minimum Path Sum
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421595.html
Copyright © 2011-2022 走看看