zoukankan      html  css  js  c++  java
  • CenterLoss

    参考资料:https://blog.csdn.net/fxwfxw7037681/article/details/114440117

    中心损失

    ​ centerloss是关注于类间距离,即我们认为同一类的对象距离其类别中心应该尽可能小。这种假设在聚类中是一个基本假设,注意这个假设与类间距无关!

    ​ 令self.centers标识所有类别的中心点,shape为(num_class, feat_dim),num_class为类别数,feat_dim是节点的坐标(特征向量)。令BiLSTM的输出为:out.shape = (seq_len, batch, feat_dim),我们将其转换为x.shape = (seq_len * batch, feat_dim),由如图所示的过程可以得到dist_map:

    dist_map = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + 
               torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
    

    算距离为什么要两个“距离”相加?实际上真正的“距离”表征为:

    dist_map.addbmm_(1, -2, x, self.centers.t())
    

    那么它又为什么可以标识两个点间的距离呢?

    解析:

    ​ 上图中的c1, c2, ……, c8分别标识8个类别到原点的距离平方;b1, b2, b3, b4分别标识序列中每个节点对应的特征向量到原点的距离平方。上图最右边的map是两者之和,基于这些假设我们来说明上面的代码为什么是样本与所有中心点的距离:

    • 首先dist_map.addbmm_(1, -2, x, self.centers.t())的计算公式为:

      [dist\_map = eta * dist\_map quad + quad alpha left ( X imes self.centers^{T} ight ) ]

      这里的(eta)为1,(alpha)为-2。也即:

      [dist\_map = dist\_map quad -2 left ( X imes self.centers^{T} ight ) ]

    • 假设有两个点(Alpha,Beta),那么他们之间安定距离平方为:

      [left( Alpha - Beta ight)^{2} = Alpha^{2} + Beta^{2} -2AlphaBeta ]

      由公式(3)即可证明dist_map是距离的标识,只是我们这里是距离的平方。

    代码

    
    import torch
    import torch.nn as nn
    
    
    class CenterLoss(nn.Module):
        """
        Reference:
            Wen et al. A Discriminative Feature Learning Approach
                for Deep Face Recognition. ECCV 2016.
            https://blog.csdn.net/fxwfxw7037681/article/details/114440117
        Attribute::
            num_class: [int], 类别数量;
            feat_dim: [int], 特征向量的维度;
        """
    
        def __init__(self, num_class=10, feat_dim=2, use_gpu=True):
            super(CenterLoss, self).__init__()
            self.num_class = num_class
            self.feat_dim = feat_dim
            self.use_gpu = use_gpu
            if self.use_gpu:
                self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim).cuda())
            else:
                self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim))
            nn.init.normal_(self.centers, mean=0, std=1)
    
        def forward(self, x, labels):
            """
            :param x: 特征图,shape为 (batch_size, feat_dim)
            :param labels: GT label, shape 为 (batch_size)
            :return:
            """
            batch_size = x.size(0)
            dist_map = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + 
                       torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
            dist_map.addbmm_(1, -2, x, self.centers.t())
            classes = torch.arange(self.num_class).long()
            if self.use_gpu:
                classes = classes.cuda()
            labels = labels.unsqueeze(1).expand(batch_size, self.num_class)
            mask = labels.eq(classes.expand(batch_size, self.num_class))
            dist = dist_map * mask.float()
            loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
            return loss
    
    清澈的爱,只为中国
  • 相关阅读:
    HDU 4358 莫队算法+dfs序+离散化
    HDU 5692 线段树+dfs序
    Codeforces Round #377 (Div. 2) A B C D 水/贪心/贪心/二分
    LVS负载均衡的三种模式和八种算法总结
    hdfs 常用命令
    Linux 系统监控
    CentOS 7 时区设置
    kubernetes 留言版DEMO
    CentOS7 PostgreSQL 主从配置( 三)
    Postgres数据库在Linux中优化
  • 原文地址:https://www.cnblogs.com/dan-baishucaizi/p/14637011.html
Copyright © 2011-2022 走看看