zoukankan      html  css  js  c++  java
  • 多类 SVM 的损失函数及其梯度计算

    CS231n Convolutional Neural Networks for Visual Recognition —— optimization

    1. 多类 SVM 的损失函数(Multiclass SVM loss)

    在给出类别预测前的输出结果是实数值, 也即根据 score function 得到的 score(s=f(xi,W)),

    Li=jyimax(0,sjsyi+Δ),Δ=1

    • yi 表示真实的类别,syi 在真实类别上的得分;
    • sj,jyi 在其他非真实类别上的得分,也即预测错误时的得分;

    则在全体训练样本上的平均损失为:

    L=1Ni=1NLi

    delta = 1
    scores = np.dot(W, X)
    correct_scores = scores[y, np.arange(num_samples)]
    
    diff = score - correct_scores + delta
    diff[y, np.arange(num_samples)] = 0
    
    thresh = np.maximum(0, diff)
    loss = np.sum(thresh)
    loss /= num_samples

    2. 优化(optimization):梯度计算

    首先来看损失函数的定义,如下为第 i 个样本的损失值(Wc×dXd×Nd 特征向量的维度,c:输出类别的个数):

    Li==jyimax(0,sjsyi+1)jyi[max(0,wTjxiwTyixi+1)]

    • 遍历 j,就是遍历 W 每一列的每一个元素, wTjxij=1,,c;i=1,,N
    • wTj 表示 W 的每一行,共 c 行;

    下面的额关键是如何求得损失函数关于参数 wj,wyi 的梯度:

    wyiLi=jyi1(wTjxiwTyixi+Δ>0)xiwjLi=1(wTjxiwTyixi+Δ>0)xijyi

    binary = thresh 
    binary[thresh > 0] = 1          # 实现 indicator 函数
    
    col_sum = np.sum(binary, axis=0)
    binary[y, np.arange(num_samples)] = -col_sum
    
    dW = np.dot(binary, X.T)        # binary 维度信息:c*N, X 维度信息:d*N
    dW /= N
    
    dW += reg * W
    
  • 相关阅读:
    WCF三种通信方式
    Linux发布WebApi
    Supervisor Linux程序进程管理
    Centos安装Mongodb
    本地网址连不上远程mysql问题
    .Net之垃圾回收算法
    .Net之托管堆资源分配
    Centos7+ASP.Net Core 运行
    ASP .Net Core 使用 Dapper 轻型ORM框架
    转载 Jquery中AJAX参数详细介绍
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421594.html
Copyright © 2011-2022 走看看