zoukankan      html  css  js  c++  java
  • 知识蒸馏

    知识蒸馏

    一. Distilling the Knowledge in a Neural Network

    知识蒸馏的开端之作,简单叙述蒸馏过程:

    • 先训练一个大网络,比如Resnet50用于分类任务
    • 搭建一个小网络训练结构,比如mobilenetV2
    • 训练小网络的同时推理大网络,大网络的结果去指导小网络(KDLoss用于估计分布的相似性)

    类似的代码:链接地址

    类似的文章:链接地址

    比较简单的过程:

    # 教师输出和学生输出得到loss1,学生输出和label得到loss2,按一定比例结合进行反向传播
    def loss_fn_kd(outputs, labels, teacher_outputs, params):
        """
        Compute the knowledge-distillation (KD) loss given outputs, labels.
        "Hyperparameters": temperature and alpha
        NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
        and student expects the input tensor to be log probabilities! See Issue #2
        """
        alpha = params.alpha
        T = params.temperature
        KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                                 F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + 
                  F.cross_entropy(outputs, labels) * (1. - alpha)
    
        return KD_loss
    
    
    

    二. Fast Human Pose Estimation Pytorch

    论文:链接地址

    代码:链接地址

    论文没有实质的创新,KDLoss直接使用MSE对Heatmap进行分布相似估计,正常Loss也使用MSE,按一定比例结核即可

    注释:看见有人说蒸馏必须网络结构类似,不然效果反而会下降(待尝试)

    # 关键点不可见的情况下只进行KDLoss,可见的情况下进行KDLoss和正常训练Loss
    for j in range(0, len(output)):
      	_output = output[j]
      	for i in range(gtmask.shape[0]):
              if gtmask[i] < 0.1:
              # unlabeled data, gtmask=0.0, kdloss only
              # need to dividen train_batch to keep number equal
              kdloss_unlabeled += criterion(_output[i,:,:,:], toutput[i, :,:,:])/train_batch
          	else:
              # labeled data: kdloss + gtloss
              gtloss += criterion(_output[i,:,:,:], target_var[i, :,:,:])/train_batch
              kdloss += criterion(_output[i,:,:,:], toutput[i,:,:,:])/train_batch
    
    loss_labeled = kdloss_alpha * (kdloss) + (1 - kdloss_alpha)*gtloss
    total_loss   = loss_labeled + unkdloss_alpha * kdloss_unlabeled
    
  • 相关阅读:
    C# httpclient获取cookies实现模拟web登录
    C# httpclient获取cookies实现模拟web登录
    长连接与短连接的区别以及使用场景
    长连接与短连接的区别以及使用场景
    vuejs项目性能优化总结
    vuejs项目性能优化总结
    C# 发送HTTP请求(可加入Cookies)
    C# 发送HTTP请求(可加入Cookies)
    集合框架系列教材 (十六)- 其他
    集合框架系列教材 (十五)- 关系与区别
  • 原文地址:https://www.cnblogs.com/wjy-lulu/p/14146240.html
Copyright © 2011-2022 走看看