zoukankan      html  css  js  c++  java
  • 李宏毅机器学习课程笔记-13.6模型压缩代码实战

    本文为作者学习李宏毅机器学习课程时参照样例完成homework7的记录。

    全部课程PPT、数据和代码下载链接:

    链接:https://pan.baidu.com/s/1n_N7aoaNxxwqO03EmV5Bjg 提取码:tpmc

    代码仓库:https://github.com/chouxianyu/LHY_ML2020_Codes

    任务描述

    通过Architecture Design、Knowledge Distillation、Network Pruning和Weight Quantization这4种模型压缩策略,用一个非常小的model完成homework3中食物图片分类的任务。

    1.Architecture Design

    MobileNet提出了Depthwise & Pointwise Convolution。我们在这里实现MobileNet v1这个比较小的network,后续使用Knowledge Distillation策略训练它,然后对它进行剪枝和量化。

    2.Knowledge Distillation

    将ResNet18作为Teacher Net(使用torchvision中的ResNet18,仅将num_classes改成11,加载助教训练好的Accuracy约为88.4%的参数),将上一步(1.Architecture Design)设计的小model作为Student Net,使用Knowledge_Distillation策略训练Student Net。

    Loss计算方法为(Loss = alpha T^2 imes KL(frac{ ext{Teacher's Logits}}{T} || frac{ ext{Student's Logits}}{T}) + (1-alpha)( ext{Original Loss})),关于为什么要对student进行logsoftmax可见https://github.com/peterliht/knowledge-distillation-pytorch/issues/2

    论文《Distilling the Knowledge in a Neural Network》:https://arxiv.org/abs/1503.02531

    3.Network Pruning

    对上一步(2.Knowledge_Distillation)训练好的Student Net做剪枝。

    根据论文《Learning Efficient Convolutional Networks through Network Slimming》,论文链接:https://arxiv.org/abs/1708.06519
    BatchNorm层中的gamma值和一些特定卷积核(或者全连接层的一个神经元)相关联,因此可以使用BatchNorm层中的gamma值判断相关通道的重要性。

    Student Net中CNN部分有几个结构相同的Sequential,其结构、权重名称、实现代码、权重形状如下表所示。

    # name meaning code weight shape
    0 cnn.{i}.0 Depthwise Convolution nn.Conv2d(x, x, 3, 1, 1, group=x) (x, 1, 3, 3)
    1 cnn.{i}.1 Batch Normalization nn.BatchNorm2d(x) (x)
    2 ReLU6 nn.ReLU6
    3 cnn.{i}.3 Pointwise Convolution nn.Conv2d(x, y, 1), (y, x, 1, 1)
    4 MaxPooling nn.MaxPool2d(2, 2, 0)

    独立剪枝prune_count次,每次剪枝的剪枝率按prune_rate逐渐增大,剪枝后微调finetune_epochs个epoch。

    4.Weight Quantization

    对第二步(2.Knowledge_Distillation)训练好的Student Net做量化(用更少的bit表示一个value)。

    torch预设的FloatTensor是32bit,而FloatTensor最低可以是16bit。

    如何将32bit转成8bit的int呢?对每个weight进行min-max normalization,然后乘以(2^8-1)再四舍五入成整数,这样就可以转成uint8了。

    数据集描述

    数据集为homework3中食物图片分类数据集。

    11个图片类别,训练集中有9866张图片,验证集中有3430张图片,测试集中有3347张图片。

    训练集和验证集中图片命名格式为类别_编号.jpg,编号不重要。

    代码

    https://github.com/chouxianyu/LHY_ML2020_Codes/tree/master/hw7_NetworkCompression


    Github(github.com):@chouxianyu

    Github Pages(github.io):@臭咸鱼

    知乎(zhihu.com):@臭咸鱼

    博客园(cnblogs.com):@臭咸鱼

    B站(bilibili.com):@绝版臭咸鱼

    微信公众号:@臭咸鱼

    转载请注明出处,欢迎讨论和交流!


  • 相关阅读:
    1-命令格式
    android default_workspace.xml
    多维算法思考(二):关于八皇后问题解法的探讨
    C#对excel的操作
    c#读取xml操作
    一个用 C# 实现操作 XML 文件的公共类代码
    如何使用 LINQ 执行插入、修改和删除操作
    C#--编程
    LINQ常用操作
    用线性组合表示两个数的最大公约数
  • 原文地址:https://www.cnblogs.com/chouxianyu/p/14743443.html
Copyright © 2011-2022 走看看