zoukankan      html  css  js  c++  java
  • [loss]Triphard loss优雅的写法

    之前一直自己手写各种triphard,triplet损失函数, 写的比较暴力,然后今天一个学长给我在github上看了一个别人的triphard的写法,一开始没看懂,用的pytorch函数没怎么见过,看懂了之后, 被惊艳到了。。因此在此记录一下,以及详细注释一下

    class TripletLoss(nn.Module):
        def __init__(self, margin=0.3):
            super(TripletLoss, self).__init__()
            self.margin = margin
            self.ranking_loss = nn.MarginRankingLoss(margin=margin)  # 获得一个简单的距离triplet函数
    
        def forward(self, inputs, labels):
    
            n = inputs.size(0)  # 获取batch_size
            # Compute pairwise distance, replace by the official when merged
            dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)  # 每个数平方后, 进行加和(通过keepdim保持2维),再扩展成nxn维
            dist = dist + dist.t()  # 这样每个dis[i][j]代表的是第i个特征与第j个特征的平方的和
            dist.addmm_(1, -2, inputs, inputs.t())  # 然后减去2倍的 第i个特征*第j个特征 从而通过完全平方式得到 (a-b)^2
            dist = dist.clamp(min=1e-12).sqrt()  # 然后开方
    
            # For each anchor, find the hardest positive and negative
            mask = labels.expand(n, n).eq(labels.expand(n, n).t())  # 这里dist[i][j] = 1代表i和j的label相同, =0代表i和j的label不相同
            dist_ap, dist_an = [], []
            for i in range(n):
                dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))  # 在i与所有有相同label的j的距离中找一个最大的
                dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))  # 在i与所有不同label的j的距离找一个最小的
            dist_ap = torch.cat(dist_ap)  # 将list里的tensor拼接成新的tensor
            dist_an = torch.cat(dist_an)
    
            # Compute ranking hinge loss
            y = torch.ones_like(dist_an)  # 声明一个与dist_an相同shape的全1tensor
            loss = self.ranking_loss(dist_an, dist_ap, y)
            return loss
    
    
  • 相关阅读:
    由于版本依赖造成的YUM段错误
    CodeDom系列事件(event)定义和反射调用
    CodeSmith模板引擎系列二文件目录树
    F#初试打印目录文件树
    在IIS上SSL的部署和启动SSL安全
    CodeDom系列二程序基本结构符号三角形问题
    CodeDom系列目录
    CodeDom系列四Code生成
    CodeDom六实体类生成示例
    CodeDom系列五动态编译
  • 原文地址:https://www.cnblogs.com/kk17/p/10252440.html
Copyright © 2011-2022 走看看