zoukankan      html  css  js  c++  java
  • 损失函数Center Loss 代码解析

    center loss来自ECCV2016的一篇论文:A Discriminative Feature Learning Approach for Deep Face Recognition。 
    论文链接:http://ydwen.github.io/papers/WenECCV16.pdf 
    代码链接:https://github.com/davidsandberg/facenet

    理论解析请参看 https://blog.csdn.net/u014380165/article/details/76946339

    下面给出centerloss的计算公式以及更新公式

    下面的代码是facenet作者利用tensorflow实现的centerloss代码

    def center_loss(features, label, alfa, nrof_classes):
        """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
           (http://ydwen.github.io/papers/WenECCV16.pdf)
           https://blog.csdn.net/u014380165/article/details/76946339
        """
        nrof_features = features.get_shape()[1]
      #训练过程中,需要保存当前所有类中心的全连接预测特征centers, 每个batch的计算都要先读取已经保存的centers centers
    = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) label = tf.reshape(label, [-1]) centers_batch = tf.gather(centers, label)#获取当前batch对应的类中心特征 diff = (1 - alfa) * (centers_batch - features)#计算当前的类中心与特征的差异,用于Cj的的梯度更新,这里facenet的作者做了一个 1-alfa操作,比较奇怪,和原论文不同 centers = tf.scatter_sub(centers, label, diff)#更新梯度Cj,对于上图中步骤6,tensorflow会将该变量centers保留下来,用于计算下一个batch的centerloss loss = tf.reduce_mean(tf.square(features - centers_batch))#计算当前的centerloss 对应于Lc return loss, centers
  • 相关阅读:
    docker安装
    win8换win7的操作方法
    java数组实现队列
    springMVC源码学习之获取参数名
    SpringMVC源码学习之request处理流程
    LeetCode 231. Power of Two
    LeetCode 202. Happy Number
    LeetCode 171. Excel Sheet Column Number
    Eclipse 保存代码时,不自动换行设置
    LeetCode 141. Linked List Cycle
  • 原文地址:https://www.cnblogs.com/adong7639/p/9090421.html
Copyright © 2011-2022 走看看