zoukankan      html  css  js  c++  java
  • 迁移学习resnet

      1 import torch
      2 import numpy as np
      3 import torchvision
      4 import torch.nn as nn
      5 
      6 from torchvision import datasets,transforms,models
      7 import matplotlib.pyplot as plt
      8 import time
      9 import os
     10 import copy
     11 print("Torchvision Version:",torchvision.__version__)
     12 
     13 data_dir="./hymenoptera_data"
     14 batch_size=32
     15 input_size=224
     16 model_name="resnet"
     17 num_classes=2
     18 num_epochs=15
     19 feature_extract=True
     20 data_transforms={
     21     "train":transforms.Compose([
     22         transforms.RandomResizedCrop(input_size),
     23         transforms.RandomHorizontalFlip(),
     24         transforms.ToTensor(),
     25         transforms.Normalize([0.482,0.456,0.406],[0.229,0.224,0.225])
     26     ]),
     27     "val":transforms.Compose([
     28 
     29     transforms.RandomResizedCrop(input_size),
     30     transforms.RandomHorizontalFlip(),
     31     transforms.ToTensor(),
     32     transforms.Normalize([0.482, 0.456, 0.406], [0.229, 0.224, 0.225])
     33 ]),
     34 }
     35 image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])
     36                 for x in ["train",'val']}
     37 dataloader_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,
     38                 shuffle=True)for x in ['train','val']}
     39 device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
     40 inputs,labels=next(iter(dataloader_dict["train"]))
     41 #print(inputs.shape)#一个batch
     42 #print(labels)
     43 
     44 
     45 #加载resent模型并修改全连接层
     46 def set_parameter_requires_grad(model,feature_extract):
     47     if feature_extract:
     48         for param in model.parameters():
     49             param.requires_grad=False
     50 
     51 def initialize_model(model_name,num_classes,feature_extract,use_pretrained=True):
     52     if model_name=="resnet":
     53         model_ft=models.resnet18(pretrained=use_pretrained)
     54         set_parameter_requires_grad(model_ft,feature_extract)
     55         num_ftrs=model_ft.fc.in_features
     56         model_ft.fc=nn.Linear(num_ftrs,num_classes)
     57         input_size=224
     58     else:
     59         print("model not implemented")
     60         return None,None
     61 
     62     return model_ft,input_size
     63 model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True)
     64 #print(model_ft)
     65 print('-'*200)
     66 
     67 
     68 def train_model(model,dataloaders,loss_fn,optimizer,num_epochs):
     69     best_model_wts=copy.deepcopy(model.state_dict)
     70     best_acc=0.
     71     val_acc_history=[]
     72     for epoch in range(num_epochs):
     73         for phase in ["train","val"]:
     74             running_loss=0.
     75             running_corrects=0.
     76             if phase=="train":
     77                 model.train()
     78             else:
     79                 model.eval()
     80 
     81             for inputs,labels in dataloaders[phase]:
     82                 inputs,labels=inputs.to(device),labels.to(device)
     83 
     84                 with torch.autograd.set_grad_enabled(phase=="train"):
     85                     outputs=model(inputs)
     86                     loss=loss_fn(outputs,labels)
     87                 preds=outputs.argmax(dim=1)
     88                 if phase=="train":
     89                     optimizer.zero_grad()
     90                     loss.backward()
     91                     optimizer.step()
     92                 running_loss+=loss.item()*inputs.size(0)
     93                 running_corrects+=torch.sum(preds.view(-1)==labels.view(-1)).item()
     94 
     95             epoch_loss=running_loss/len(dataloaders[phase].dataset)
     96             epoch_acc=running_corrects/len(dataloaders[phase].dataset)
     97 
     98             print("Phase{}  loss:{},   acc:{}".format(phase,epoch_loss,epoch_acc))
     99 
    100             if phase=="val" and epoch_acc>best_acc:
    101                 best_acc=epoch_acc
    102                 best_model_wts=copy.deepcopy(model.state_dict())
    103             if phase=="val":
    104                 val_acc_history.append(epoch_acc)
    105     model.load_state_dict(best_model_wts)
    106     return  model,val_acc_history
    107 
    108 model_ft=model_ft.to(device)
    109 optimizer=torch.optim.SGD(filter(lambda  p: p.requires_grad,model_ft.parameters()),
    110                           lr=0.001,momentum=0.9)
    111 loss_fn=nn.CrossEntropyLoss()
    112 print("feature extraction: 我们不再改变训练模型的参数,而是只更新我们改变过的部分模型参数。"
    113       "我们之所以叫它feature extraction是因为我们把预训练的CNN模型当做一个特征提取模型,利用提取出来的特征做来完成我们的训练任务。")
    114 _,ohist=train_model(model_ft,dataloader_dict,loss_fn,optimizer,num_epochs=num_epochs)
    115 
    116 print("-"*200)
    117 
    118 
    119 model_scratch,_=initialize_model(model_name,num_classes,feature_extract=False,use_pretrained=False)
    120 model_scratch=model_ft.to(device)
    121 optimizer=torch.optim.SGD(filter(lambda  p: p.requires_grad,model_ft.parameters()),
    122                           lr=0.001,momentum=0.9)
    123 loss_fn=nn.CrossEntropyLoss()
    124 print("fine tuning: 从一个预训练模型开始,我们改变一些模型的架构,然后继续训练整个模型的参数。")
    125 _,scratch_ohist=train_model(model_ft,dataloader_dict,loss_fn,optimizer,num_epochs=num_epochs)
    126 
    127 plt.title("Accuracy vs. Training Epoch")
    128 plt.xlabel("Training Epoch")
    129 plt.ylabel("Accuracy")
    130 plt.plot(range(1,num_epochs+1),ohist,label="Pretrained")
    131 plt.plot(range(1,num_epochs+1),scratch_ohist,label="No_pretrained")
    132 plt.ylim((0,1.))
    133 plt.xticks(np.arange(1,num_epochs+1,1.0))
    134 plt.legend()
    135 plt.show()
  • 相关阅读:
    python中文编码
    Python习题纠错1
    Python中的变量
    Python之注释
    python初步学习
    java输入数据并排序
    五月最后一天
    @component注解
    多线程回顾
    赖床分子想改变--
  • 原文地址:https://www.cnblogs.com/-xuewuzhijing-/p/12987581.html
Copyright © 2011-2022 走看看