zoukankan      html  css  js  c++  java
  • (原)CosFace/AM-Softmax及其mxnet代码

    转载请注明出处:

    http://www.cnblogs.com/darkknightzh/p/8525241.html

    论文:

    CosFace: Large Margin Cosine Loss for Deep Face Recognition

    https://arxiv.org/abs/1801.09414

    Additive Margin Softmax for Face Verification

    https://arxiv.org/abs/1801.05599

    第一篇论文目前无代码

    第二篇论文官方代码:

    https://github.com/happynear/AMSoftmax

    这两篇论文第三方mxnet代码:

    https://github.com/deepinsight/insightface

    说明:没用过mxnet,下面的代码注释只是纯粹从代码的角度来分析并进行注释,如有错误之处,敬请谅解,并欢迎指出。

    先查看sphereface,查看$psi ( heta )$的介绍:http://www.cnblogs.com/darkknightzh/p/8524937.html

    论文AM中定义$psi ( heta )$为:

    $psi ( heta )=cos ( heta )-m$

    sphereface中只对w进行归一化,AM中对w及x均进行了归一化,不过为了使得训练能收敛,增加了一个参数s=30,最终AM如下:

    ${{L}_{AMS}}=-frac{1}{n}sumlimits_{i=1}^{n}{log frac{{{e}^{scenterdot (cos {{ heta }_{yi}}-m)}}}{{{e}^{scenterdot (cos {{ heta }_{yi}}-m)}}+sum olimits_{j=1,j e yi}^{c}{{{e}^{scenterdot cos {{ heta }_{j}}}}}}}=-frac{1}{n}sumlimits_{i=1}^{n}{log frac{{{e}^{scenterdot (W_{yi}^{T}{{f}_{i}}-m)}}}{{{e}^{scenterdot (W_{yi}^{T}{{f}_{i}}-m)}}+sum olimits_{j=1,j e yi}^{c}{{{e}^{sW_{j}^{T}{{f}_{i}}}}}}}$

    程序中计算时,$scenterdot (cos {{ heta }_{yi}}-m)=scenterdot cos {{ heta }_{yi}}-sm$,分别计算$scenterdot cos {{ heta }_{yi}}$,sm。而后将yi处的减去sm,之后通过log softmax,得到概率,在计算损失。

    具体的代码如下(完整代码请见参考网址中mxnet的代码):

     1     s = args.margin_s  # 参数s
     2     m = args.margin_m  # 参数m
     3     _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) # (C,F)
     4     _weight = mx.symbol.L2Normalization(_weight, mode='instance')  # 对w进行归一化
     5     
     6     nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s # 对x进行归一化,并得到s*x,(B,F)
     7     fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') # Y=XW'+b,(B,F)*(C,F)'=(B,C), '为转置
     8        
     9     s_m = s*m  # 计算s*m
    10     gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) # 得到one-hot矩阵,每行对应i处值为s_m
    11     fc7 = fc7-gt_one_hot  # 将对应i处的减去s_m
  • 相关阅读:
    根据类生成数据库连接
    C# 获取动态类中所有的字段
    mysql 基础配置经验
    CSS小笔记
    jquery知识location.search
    Eclipse 启动tomcat 访问主页报错404
    window下安裝redis服務
    用maven创建web工程
    @WebListener 注解方式实现监听
    Dubbo-admin管理平台的安装
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/8525241.html
Copyright © 2011-2022 走看看