zoukankan      html  css  js  c++  java
  • pytorch-第五章使用迁移学习实现分类任务-torchvison

    torchvision 主要是由三大模块组成, model, transforms, datasets 

     transforms 主要可以进行数据增强 

    datasets 主要下载一些常用的数据集如mnist数据集

    model 主要是将原来的模型进行下载 

    第一部分: 数据集的准备工作 
        第一步: 使用transforms进行数据的增强操作, 使用torch.utils.data.DataLoader()构造批量数据集

        第二步: 将数据集重新转换为原来的样子, 即转换为numpy格式,变化颜色通道, 将均值和标准差弥补上,使用image.clip(0, 1) 将数据限制在0和1之间,最后进行图像的显示 

    第二部分:  数据集的训练工作 
                   第一步: 使用 initiallize_model() 初始化网络 

         第二步: 对网络进行训练, 将效果最好的结果保存在路径下,返回最好的模型的参数结果 

    第三部分: 数据集的测试工作 

        第一步: 对于输入的当张图片进行测试, 这里需要对输入的图片做valid操作 

        第二步: 对一个batch的valid进行测试,最后对结果进行显示 

    import os
    import numpy as np
    import torch
    from torch import nn
    from torch import optim
    from torchvision import transforms, datasets, models
    import matplotlib.pyplot as plt
    import time
    import copy
    from PIL import Image
    
    
    # 第一部分数据的准备:
    
    
    # 数据读取与预处理操作
    data_dir = './flower_data'
    train_dir = '/train'
    test_dir = '/test'
    
    
    # 第一步: 数据的制作
    data_transform = {
        "train": transforms.Compose([
                                    transforms.Resize(256),
                                    transforms.RandomRotation(45),
                                    transforms.CenterCrop(224),
                                    transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.RandomVerticalFlip(p=0.5),
                                    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
                                    transforms.RandomGrayscale(p=0.025),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ]),
        'valid': transforms.Compose([
                                    transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    
    
    def im_convert(tensor):
        image = tensor.to('cpu').clone().detach() # clone() 修改image不会修改tensor
        image = image.numpy().squeeze() # 去除尺寸为1的维度
        image = image.transpose(1, 2, 0)
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
        image = image.clip(0, 1)
    
        return image
    
    
    
    
    batch_size = 8
    image_dataset = {x:datasets.ImageFolder(os.path.join(data_dir + train_dir), data_transform[x]) for x in ['train', 'valid']}
    dataloaders = {x:torch.utils.data.DataLoader(image_dataset[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    
    
    # 第二步:获得一个batch的验证集,进行图像的显示
    dataiter = iter(dataloaders['valid'])
    inputs, labels = dataiter.next()
    
    
    fig = plt.figure(figsize=(20, 12))
    row = 2
    columns = 4
    
    for idx in range(row * columns):
        ax = fig.add_subplot(row, columns, idx+1, xticks=[], yticks=[])
        plt.imshow(im_convert(inputs[idx]))
    plt.show()
    
    
    
    # 第二部分: 进行模型的训练操作
    # 进行模型的下载
    
    model_name = 'resnet'
    
    # 是否使用人家训练好的模型参数
    feature_extract = True
    
    train_on_gpu = torch.cuda.is_available()
    if not train_on_gpu:
        print('CUDA is not availabel.Training on GPU...')
    else:
        print('CUDA is not availabel! Training on CPU')
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    
    model_ft = models.resnet152()
    print(model_ft)
    
    
    
    def set_parameter_requires_grad(model, feature_extract):
        if feature_extract:
            for param in model.parameters(): # 将反应的层数进行冻结
                param.requires_grad = False
    
    # 初始化模型
    # 第一步: 进行模型的初始化操作
    def initiallize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    
        model_ft = None
        input_size = 0
    
        if model_name == 'resnet':
            """
            Resnet 152 
            """
            model_ft = models.resnet152(pretrained=use_pretrained)
            set_parameter_requires_grad(model_ft, feature_extract)
            num_ftrs = model_ft.fc.in_features # 最后一层全连接的个数
            model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes),
                                        nn.LogSoftmax(dim=1))
    
            input_size = 224
    
        return model_ft, input_size
    
    
    model_ft, input_size = initiallize_model(model_name, 17, feature_extract, use_pretrained=True)
    
    # 进行GPU训练
    model_ft = model_ft.to(device)
    
    
    # 模型保存
    filename = 'checkpoint.pth'
    
    param_to_update = model_ft.parameters()
    print(param_to_update)
    # # 进行所有层的训练
    # for param in model_ft.parameters():
    #     param.requires_grad = True
    
    
    if feature_extract:
        param_to_update = []
        for name, param in model_ft.named_parameters():
            if param.requires_grad == True: # 如果param.requires_grad是否进行训练 
                param_to_update.append(param)
                print('	', name)
    else:
        for name, param in model_ft.named_parameters():
            if param.requires_grad == True:
                print('	', name)
    
    print(model_ft)
    
    # 优化器设置
    optmizer_ft = optim.Adam(param_to_update, lr= 1e-2)  # 输入为需要优化的参数 
    scheduler = optim.lr_scheduler.StepLR(optmizer_ft, step_size=7, gamma=0.1) # 对于学习率每7个epoch进行一次衰减
    criterion = nn.NLLLoss() #输入的是一个对数概率和标签值 
    
    
    # 第二步: 进行模型的训练模块
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, filename=filename):
        since = time.time()
        best_acc = 0
    
        if os.path.exists(filename):
            checkpoint = torch.load(filename)
            best_acc = checkpoint('best_acc')
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
    
        model.to(device)
    
        val_acc_history = []
        train_acc_history = []
        train_losses = []
        valid_losses = []
        LRs = [optimizer.param_groups[0]['lr']]
    
        best_model_wts = copy.deepcopy(model.state_dict())
    
    
        for epoch in range(num_epochs):
            print("Epoch {} / {}".format(epoch, num_epochs - 1))
            print('-' *10)
    
            # 训练和验证
            for phase in ['train', 'valid']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
    
                running_loss = 0.0
                running_corrects = 0
    
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    print('runing')
                    # 清零操作
                    optimizer.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        if is_inception and phase == 'train':
                            outputs, aux_outputs = model(inputs)
                            loss1 = criterion(outputs, labels)
                            loss2 = criterion(aux_outputs, labels)
                            loss = loss1 + 0.4 * loss2
                        else:
                            outputs = model(inputs)
                            loss = criterion(outputs, labels)
    
                        pred = torch.argmax(outputs, 1)
    
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
    
                    # 计算损失值
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(pred == labels.data)
    
                epoch_loss = running_loss / len(dataloaders[phase].dataset)
                epoch_acc = running_corrects / len(dataloaders[phase].dataset)
    
                time_elapased = time.time() - since
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapased // 60, time_elapased % 60))
                print("{} loss:{:.4f} Acc:{.4f}".format(phase, epoch_loss, epoch_acc))
    
                # 将效果最好的一次模型进行保存
                if phase == 'valid' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    state = {
                        'static_dict':model.state_dict(),
                        'best_acc':best_acc,
                        'optimizer':optimizer.state_dict(),
                    }
                    torch.save(state, filename)
    
                if phase == 'valid':
                    val_acc_history.append(epoch_acc)
                    valid_losses.append(epoch_loss)
                    scheduler.step()
                if phase == 'train':
                    train_acc_history.append(epoch_acc)
                    train_losses.append(epoch_loss)
    
        time_elapased = time.time() - since
        print('Training complete in {:0.f}m {:0.f}s'.format(time_elapased // 60, time_elapased % 60))
        print('Best val Acc{:4f}'.format(best_acc))
    
        # 训练结束以后使用最好的一次模型当做模型保存的结果
        model.load_state_dict(best_model_wts)
    
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    
    
    # 第三部分:进行模型的测试
    def predict(model_name, num_classes, feature_extract, image_path):
        # 获得初始化的模型 
        model_ft, inputs_size = initiallize_model(model_name, num_classes, feature_extract)
    
        model_ft.to(device)
    
        # 加载训练好的网络结构 
        filename = 'checkpoint.pth'
    
        checkpoint = torch.load(filename)
        best_acc = checkpoint['best_acc']
        model_ft.load_state_dict(checkpoint['state_dict'])
        # 将输入的图片进行处理,使得可以用于进行网络的训练 
        def process_image(image_path):
            # 读取测试图片
            img = Image.open(image_path)
    
            if img.size[0] > img.size[1]:
                img.thumbnail((10000, 256))
            else:
                img.thumbnail((256, 10000))
            # Crop操作
            left_margin = (img.width - 224) / 2
            right_margin = left_margin + 224
            bottom_margin = (img.height - 224) / 2
            top_margin = bottom_margin + 224
            img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    
            img = np.array(img) / 255
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = (img - mean) / std
    
            img = img.transpose([2, 0, 1])
    
            return img
    
        img = process_image(image_path)
        outputs = model_ft(torch.tensor([img]))  # 进行一张图片的测试
    
        # 第二步: 获得一个batch的测试数据进行测试
        dataiter = iter(dataloaders['valid'])
    
        images, labels = dataiter.next()
    
        model_ft.eval()
    
        if train_on_gpu:
            outputs = model_ft(images.cuda())
        else:
            outputs = model_ft(images)
    
        _, preds_tensor = torch.max(outputs, 1)
    
        preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
    
        fig = plt.figure(figsize=(20, 20))
        columns = 4
        rows = 2
        for idx in range(columns * rows):
            ax = fig.add_subplot(row, columns, idx + 1, xticks=[], yticks=[])
            plt.imshow(im_convert(images[idx]))
            ax.set_title("{}?{}".format(preds[idx], labels.data[idx]),
                         color='green' if preds[idx] == labels.data[idx] else 'red')
        plt.show()
    
    
    if __name__=='__main__':
    
        # train_model(model_ft, dataloaders, criterion, optmizer_ft)
        image_path = r'C:Usersqq302Desktoppytorch学习第四章卷积神经网络实战flower_data	rain4image_0242.jpg'
        predict(model_name, 17, feature_extract, image_path)

      

        

  • 相关阅读:
    (转)JavaScript html js 地区二级联动,省市二级联动,省市县js+xml三级联动
    (转)PHP分页程序源码
    html select的事件 方法 属性
    mysql 插入中文乱码解决方案 转
    (转)jquery遍历之parent()与parents()的区别 及 parentsUntil() 方法
    作业二 总结
    第一次实验总结
    自我介绍
    linux环境下快速定位位置的一个小hack
    Script of modifying ether card MAC address under linux
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/12713683.html
Copyright © 2011-2022 走看看