zoukankan      html  css  js  c++  java
  • 学习笔记13:微调模型

    resnet预训练模型

    resnet模型与之前笔记中的vgg模型不同,需要我们直接覆盖掉最后的全连接层
    先看一下resnet模型的结构:

    我们需要先将所有的参数都设置成requires_grad = False
    然后再重新定义fc层,并覆盖掉原来的。
    重新定义的fc层的requires_grad默认为True

    for p in model.parameters():
        p.requries_grad = False
    
    in_f = model.fc.in_features
    model.fc = nn.Linear(in_f, 4)
    

    当定义optimizer的时候,需要注意,传进去的参数是fc层的参数,而不是所有层的参数

    optimizer = torch.optim.Adam(model.fc.parameters(), lr = 0.001)
    

    微调

    微调的一般步骤是:

    • 重新定义全连接层
    • 训练重新定义的全连接层
    • 解冻部分其他层
    • 训练整个模型
      注意:微调是在训练完新的全连接层后,才能进行的。也就相当于整个模型训练了两次。
      optimizer这时的参数就是整个模型的参数了。
      代码:
    for param in model.parameters():
        param.requires_grad = True
    
    extend_epoch = 30
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
    

    全部代码

    import torch
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import torch.nn as nn
    import torch.nn.functional as F
    import torchvision
    from torchvision import datasets, transforms, models
    import os
    import shutil
    %matplotlib inline
    
    train_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(192),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(0.2),
        transforms.ColorJitter(brightness = 0.5),
        transforms.ColorJitter(contrast = 0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((192, 192)),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
    ])
    train_ds = datasets.ImageFolder(
        "E:/datasets2/29-42/29-42/dataset2/4weather/train",
        transform = train_transform
    )
    test_ds = datasets.ImageFolder(
        "E:/datasets2/29-42/29-42/dataset2/4weather/test",
        transform = test_transform
    )
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size = 8, shuffle = True)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size = 8)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model = models.resnet101(pretrained = True)
    for p in model.parameters():
        p.requries_grad = False
    in_f = model.fc.in_features
    model.fc = nn.Linear(in_f, 4)
    
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.fc.parameters(), lr = 0.001)
    epochs = 30
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.1)
    
    def fit(epoch, model, trainloader, testloader):
        correct = 0
        total = 0
        running_loss = 0
        
        model.train()
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                y_pred = torch.argmax(y_pred, dim = 1)
                correct += (y_pred == y).sum().item()
                total += y.size(0)
                running_loss += loss.item()
    
        exp_lr_scheduler.step()
        
        epoch_acc = correct / total
        epoch_loss = running_loss / len(trainloader.dataset)
        
        test_correct = 0
        test_total = 0
        test_running_loss = 0
        
        model.eval()
        with torch.no_grad():
            for x, y in testloader:
                x, y = x.to(device), y.to(device)
                y_pred = model(x)
                loss = loss_func(y_pred, y)
                y_pred = torch.argmax(y_pred, dim = 1)
                test_correct += (y_pred == y).sum().item()
                test_total += y.size(0)
                test_running_loss += loss.item()
        epoch_test_acc = test_correct / test_total
        epoch_test_loss = test_running_loss / len(testloader.dataset)
        
        print('epoch: ', epoch, 
              'loss: ', round(epoch_loss, 3),
              'accuracy: ', round(epoch_acc, 3),
              'test_loss: ', round(epoch_test_loss, 3),
              'test_accuracy: ', round(epoch_test_acc, 3))
        
        return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
    
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    for epoch in range(epochs):
        epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)
    
    for param in model.parameters():
        param.requires_grad = True
    extend_epoch = 30
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    for epoch in range(extend_epoch):
        epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)
    
  • 相关阅读:
    计算机的几种命令行
    oracle体系结构
    数字档案馆建设指南及档案业务系统归档接口规范
    ERP系统归档
    oracle ITL(事务槽)的理解
    oracle表属性
    docker+httpd的安装
    访问GitLab的PostgreSQL数据库,查询、修改、替换等操作
    docker+rabbitmq的安装
    docker+elasticsearch的安装
  • 原文地址:https://www.cnblogs.com/miraclepbc/p/14360807.html
Copyright © 2011-2022 走看看