zoukankan      html  css  js  c++  java
  • 龙良曲pytorch学习笔记_迁移学习

      1 import torch
      2 from torch import optim,nn
      3 import visdom
      4 import torchvision
      5 from torch.utils.data import DataLoader
      6 
      7 from pokemon import Pokemon
      8 
      9 # from resnet import ResNet18
     10 # 可以加载直接加载好的状态
     11 from torchvision.models import resnet18
     12 
     13 from utils import Flatten
     14 
     15 batchsz = 32
     16 lr = 1e-3
     17 epochs = 10
     18 
     19 device = torch.device('cuda')
     20 # 设置随机种子保证能够复现出来
     21 torch.manual_seed(1234)
     22 
     23 train_db = Pokemon('pokemon',224,mode = 'train')
     24 val_db = Pokemon('pokemon',224,mode = 'val')
     25 test_db = Pokemon('pokemon',224,mode = 'test')
     26 
     27 train_loader = DataLoader(train_db,batch_size = batchsz,shuffle = True,num_workers = 4)
     28 val_loader = DataLoader(val_db,batch_size = batchsz,num_workers = 2)
     29 test_loader = DataLoader(test_db,batch_size = batchsz,num_workers = 2)
     30 
     31 # visdom
     32 viz = visdom.Visdom()
     33 
     34 def evalute(model,loader):
     35     
     36     correct = 0
     37     total = len(loader.dataset)
     38     
     39     for x,y in loader:
     40         x,y = x.to(device),y.to(device)
     41         with torch.no_grad():
     42             logits = model(x)
     43             pred = logits.argmax(dim = 1)
     44             correct += torch.eq(pred,y).sum().float().item()
     45         
     46     return correct / total
     47 
     48 def main():
     49     
     50     # model = ResNet18(5).to(device)
     51     trained_model = resnet18(pretrained = True)
     52     # 取出前17层,加*打散数据
     53     model = nn.Sequential(*list(train_model.children())[:-1], # [b,512,1,1]
     54                           Flatten(), # [b,512,1,1] --> [b,512]
     55                           nn.Linear(512,5)
     56                           ).to(device)
     57     
     58     optimizer = optim.Adam(model.parameters().lr = lr)
     59     criteon = nn.CrossEntropyLoss
     60     
     61     best_acc,best_epoch = 0,0
     62     global_step = 0
     63     # visdom
     64     viz.line([0],[-1],win = 'loss',opts = dict(title = 'loss'))
     65     viz.line([0],[-1],win = 'val_acc',opts = dict(title = 'val_acc'))
     66     
     67     for epoch in range(epochs):
     68         
     69         for step,(x,y) in enumerate(train_loader):
     70             
     71             # x: [b,3,224,224] ,y : [b]
     72             x,y = x.to(device),y.to(device)
     73             
     74             # logits是没经过loss的
     75             logits = model(x)
     76             # CrossEntropyLoss会在内部进行onehot,所以不需要自己写
     77             loss = criteon(logits,y).item()
     78             
     79             optimizer.zero_grad()
     80             loss.backward()
     81             optimizer.step()
     82             
     83             # visdom
     84             viz.line([loss.item()],[global_step],win = 'loss',update = 'append')
     85             global_step += 1
     86             
     87         if epoch % 2 == 0:
     88         
     89             val_acc = evalute(model,val_loader)
     90             
     91             if val_acc > best_acc:
     92                 best_epoch = epoch
     93                 best_acc = val_acc
     94                 
     95                 torch.save(model.state_dict(),'best.mdl')
     96                 # visdom
     97                 viz.line([val_acc],[global_step],win = 'val_acc',update = 'append')
     98                 
     99     print('best_acc:',best_acc,'best_epoch',best_epoch)
    100     
    101     model.load_state_dict(torch.load('best.mdl'))
    102     print('loaded from skpt!')
    103     
    104     test_acc = evalute(model,test_loader)
    105     print('test_acc',test_acc)
    106 
    107 
    108 if __name__ == '__main__'
    109     main()
  • 相关阅读:
    python深浅copy探究
    构建squid代理服务器
    python列表和元组操作
    python字符串操作
    Apache虚拟主机
    Apache访问控制
    部署AWStats分析系统
    LAMP平台部署
    二分查找
    设计模式六大原则
  • 原文地址:https://www.cnblogs.com/fxw-learning/p/12331530.html
Copyright © 2011-2022 走看看