zoukankan      html  css  js  c++  java
  • 【pytorch】使用迁移学习(resnet18)训练mnist数据集

    预备知识

    • 自己搭建cnn模型训练mnist(不使用迁移学习)

    https://blog.csdn.net/qq_42951560/article/details/109565625

    • pytorch官方的迁移学习教程(蚂蚁、蜜蜂分类)

    https://blog.csdn.net/qq_42951560/article/details/109950786

    学习目标

    今天我们尝试在pytorch中使用迁移学习来训练mnist数据集。

    如何迁移

    预训练模型

    迁移学习需要选择一个预训练模型,我们这个任务也不是特别大,选择resnet18就行了。

    数据预处理

    resnet18输入的CHW(3, 224, 224)

    mnist数据集中单张图片CHW(1, 28, 28)

    所以我们需要对mnist数据集做一下预处理:
    在这里插入图片描述

    # 预处理
    my_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081)),
        ])
    # 训练集
    train_file = datasets.MNIST(
        root='./dataset/',
        train=True,
        transform=my_transform
    )
    # 测试集
    test_file = datasets.MNIST(
        root='./dataset/',
        train=False,
        transform=my_transform
    )
    

    pytorch中数据增强和图像处理的教程(torchvision.transforms)可以看我的这篇文章

    改全连接层

    resnet18是在imagenet上训练的,输出特征数是1000;而对于mnist来说,需要分10类,因此要改一下全连接层的输出。

    model = models.resnet18(pretrained=True)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, 10)
    

    调整学习率

    之前设置的Adam的学习率是1e-3,现在使用了迁移学习,所以学习率调小一点,改为1e-4

    训练结果

    resnet18相较于普通的一两层卷积网络来说已经比较深了,并且mnsit数据集还是挺大的,总共有7万张图片。为了节省时间,我们使用7张GeForce GTX 1080 Ti来训练:

    • 数据并行(DataParallel)
    EPOCH: 01/10 STEP: 67/67 LOSS: 0.0266 ACC: 0.9940 VAL-LOSS: 0.0246 VAL-ACC: 0.9938 TOTAL-TIME: 102
    EPOCH: 02/10 STEP: 67/67 LOSS: 0.0141 ACC: 0.9973 VAL-LOSS: 0.0177 VAL-ACC: 0.9948 TOTAL-TIME: 80
    EPOCH: 03/10 STEP: 67/67 LOSS: 0.0067 ACC: 0.9990 VAL-LOSS: 0.0147 VAL-ACC: 0.9958 TOTAL-TIME: 80
    EPOCH: 04/10 STEP: 67/67 LOSS: 0.0042 ACC: 0.9995 VAL-LOSS: 0.0151 VAL-ACC: 0.9948 TOTAL-TIME: 80
    EPOCH: 05/10 STEP: 67/67 LOSS: 0.0029 ACC: 0.9997 VAL-LOSS: 0.0143 VAL-ACC: 0.9955 TOTAL-TIME: 80
    EPOCH: 06/10 STEP: 67/67 LOSS: 0.0019 ACC: 0.9999 VAL-LOSS: 0.0133 VAL-ACC: 0.9962 TOTAL-TIME: 80
    EPOCH: 07/10 STEP: 67/67 LOSS: 0.0013 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963 TOTAL-TIME: 80
    EPOCH: 08/10 STEP: 67/67 LOSS: 0.0008 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963 TOTAL-TIME: 79
    EPOCH: 09/10 STEP: 67/67 LOSS: 0.0006 ACC: 1.0000 VAL-LOSS: 0.0122 VAL-ACC: 0.9962 TOTAL-TIME: 79
    EPOCH: 10/10 STEP: 67/67 LOSS: 0.0005 ACC: 1.0000 VAL-LOSS: 0.0131 VAL-ACC: 0.9959 TOTAL-TIME: 79
    | BEST-MODEL | EPOCH: 07/10 STEP: 67/67 LOSS: 0.0013 ACC: 1.0000 VAL-LOSS: 0.0132 VAL-ACC: 0.9963
    

    训练10轮,最佳的模型出现在第7轮,最大准确率是0.9963。在这篇文章中,我们自己搭了两层的卷积,也训练了10轮,最大准确率是0.9923。准确率提高了0.0040,我们要知道测试集共有1万张图片,也就是多预测对了40张图片,已经提升很高。当然,因为网络变深了,所以训练花费的时间也就增加了。

    引用参考

    https://blog.csdn.net/qq_42951560/article/details/109950786

  • 相关阅读:
    cf1100 F. Ivan and Burgers
    cf 1033 D. Divisors
    LeetCode 17. 电话号码的字母组合
    LeetCode 491. 递增的子序列
    LeetCode 459.重复的子字符串
    LeetCode 504. 七进制数
    LeetCode 3.无重复字符的最长子串
    LeetCode 16.06. 最小差
    LeetCode 77. 组合
    LeetCode 611. 有效三角形个数
  • 原文地址:https://www.cnblogs.com/ghgxj/p/14219077.html
Copyright © 2011-2022 走看看