zoukankan      html  css  js  c++  java
  • (原)InsightFace及其mxnet代码

    转载请注明出处:

    http://www.cnblogs.com/darkknightzh/p/8525287.html

    论文

    InsightFace : Additive Angular Margin Loss for Deep Face Recognition

    https://arxiv.org/abs/1801.07698

    官方mxnet代码:

    https://github.com/deepinsight/insightface

    说明:没用过mxnet,下面的代码注释只是纯粹从代码的角度来分析并进行注释,如有错误之处,敬请谅解,并欢迎指出。

    先查看sphereface,查看$psi ( heta )$的介绍:http://www.cnblogs.com/darkknightzh/p/8524937.html

    论文arcface中,定义$psi ( heta )$为:

    $psi ( heta )=cos ({{ heta }_{yi}}+m)$

    同时对w及x均进行了归一化,为了使得训练能收敛,增加了一个参数s=64,最终loss如下:

    $L=-frac{1}{m}sumlimits_{i=1}^{m}{log frac{{{e}^{s(cos ({{ heta }_{yi}}+m))}}}{{{e}^{s(cos ({{ heta }_{yi}}+m))}}+sum olimits_{j=n,j e yi}^{n}{{{e}^{scos {{ heta }_{j}}}}}}}$

    其中,

    ${{W}_{j}}=frac{{{W}_{j}}}{left| {{W}_{j}} ight|}$,${{x}_{i}}=frac{{{x}_{i}}}{left| {{x}_{i}} ight|}$,$cos {{ heta }_{j}}=W_{j}^{T}{{x}_{i}}$

    程序中先对w及x归一化,然后通过全连接层得到cosθ,再扩大s倍,得到scosθ。

    对于yi处,由于

    $cos ( heta +m)=cos heta cos m-sin heta sin m$

    以及

    $sin heta =sqrt{1-{{cos }^{2}} heta }$

    得到sinθ。

    由于$cos ( heta +m)$非单调,设置了easy_margin标志,当其为真时,使用0作为阈值,当特征和权重的cos值小于0,直接截断;当其为假时,使用cos(pi-m)=-cos(m)作为阈值。该阈值小于0。

    之后判断时,当easy_margin为真时,若s*cos(θ+m)小于0,直接使用s*cos(θ);当easy_margin为假时,若s*cos(θ+m)小于0,使用s*cos(θ)-s*m*sin(m)。

    具体的代码如下(完整代码见参考网址):

     1     s = args.margin_s  # 参数s
     2     m = args.margin_m  # 参数m
     3 
     4     _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) # (C,F)
     5     _weight = mx.symbol.L2Normalization(_weight, mode='instance')   # 对w进行归一化
     6     nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s # 对x进行归一化,并得到s*x,(B,F)
     7     fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') # Y=XW'+b,(B,F)*(C,F)'=(B,C),'为转置,此处得到scos(theta)
     8     
     9     zy = mx.sym.pick(fc7, gt_label, axis=1)  # 得到fc7中gt_label位置的值。(B,1)或者(B),即当前batch中yi处的scos(theta)
    10     cos_t = zy/s  # 由于fc7及zy均为cos的s倍,此处除以s,得到实际的cos值。(B,1)或者(B)
    11     
    12     cos_m = math.cos(m)
    13     sin_m = math.sin(m)
    14     mm = math.sin(math.pi-m)*m # sin(pi-m)*m = sin(m)*m
    15     threshold = math.cos(math.pi-m)  # 阈值,避免theta + m >= pi,实际上threshold < 0
    16     if args.easy_margin:
    17       cond = mx.symbol.Activation(data=cos_t, act_type='relu') #easy_margin=True,直接使用0作为阈值,得到超过阈值的索引
    18     else:
    19       cond_v = cos_t - threshold #easy_margin=False,使用threshold(负数)作为阈值。
    20       cond = mx.symbol.Activation(data=cond_v, act_type='relu') # 得到超过阈值的索引
    21     body = cos_t*cos_t  # 通过cos*cos + sin * sin = 1, 来得到sin_theta
    22     body = 1.0-body
    23     sin_t = mx.sym.sqrt(body)  # sin_theta
    24     new_zy = cos_t*cos_m # cos(theta+m)=cos(theta)*cos(m)-sin(theta)*sin(m),此处为cos(theta)*cos(m)
    25     b = sin_t*sin_m # 此处为sin(theta)*sin(m)
    26     new_zy = new_zy - b # 此处为cos(theta)*cos(m)-sin(theta)*sin(m)=cos(theta+m)
    27     new_zy = new_zy*s # 此处为s*cos(theta+m),扩充了s倍
    28     if args.easy_margin:
    29       zy_keep = zy   # zy_keep为zy,即s*cos(theta)
    30     else:
    31       zy_keep = zy - s*mm  # zy_keep为zy-s*sin(m)*m=s*cos(theta)-s*m*sin(m)
    32     new_zy = mx.sym.where(cond, new_zy, zy_keep) # cond中>0的保持new_zy=s*cos(theta+m)不变,<0的裁剪为zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)
    33 
    34     diff = new_zy - zy # 
    35     diff = mx.sym.expand_dims(diff, 1)
    36     gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
    37     body = mx.sym.broadcast_mul(gt_one_hot, diff) # 对应yi处为new_zy - zy
    38     fc7 = fc7+body # 对应yi处,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的为s*cos(theta+m),<0的裁剪为s*cos(theta) or s*cos(theta)-s*m*sin(m)
  • 相关阅读:
    golang 数据结构 优先队列(堆)
    leetcode刷题笔记5210题 球会落何处
    leetcode刷题笔记5638题 吃苹果的最大数目
    leetcode刷题笔记5637题 判断字符串的两半是否相似
    剑指 Offer 28. 对称的二叉树
    剑指 Offer 27. 二叉树的镜像
    剑指 Offer 26. 树的子结构
    剑指 Offer 25. 合并两个排序的链表
    剑指 Offer 24. 反转链表
    剑指 Offer 22. 链表中倒数第k个节点
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/8525287.html
Copyright © 2011-2022 走看看