zoukankan      html  css  js  c++  java
  • pytorch迁移学习alexnet

    话不多说我直接上代码,我为了验证state_dict的使用方法,全连接的时候写的有点不一样,之后我会试试其他模型的迁移学习,看看有没有什么更好的办法,字典实在是用的太不习惯了,python我唯一能忍受的就是列表了,别的都好难用。

     1 import torch
     2 import torch.nn as nn
     3 from torchvision.models import alexnet
     4 
     5 alex=alexnet(pretrained=True)
     6 # print(alex)
     7 # print(alex.state_dict().keys())
     8 pretrained_dict=alex.state_dict()
     9 weight_0=pretrained_dict['features.3.weight']
    10 bias_0=pretrained_dict['features.3.bias']
    11 print(weight_0.shape)
    12 print(bias_0.shape)
    13 class alex_net(nn.Module):
    14     def __init__(self,num_classes):
    15         super(alex_net, self).__init__()
    16         self.features=nn.Sequential(
    17             nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
    18             nn.ReLU(inplace=True),
    19             nn.MaxPool2d(kernel_size=3, stride=2),
    20             nn.Conv2d(64, 192, kernel_size=5, padding=2),
    21             nn.ReLU(inplace=True),
    22             nn.MaxPool2d(kernel_size=3, stride=2),
    23             nn.Conv2d(192, 384, kernel_size=3, padding=1),
    24             nn.ReLU(inplace=True),
    25             nn.Conv2d(384, 256, kernel_size=3, padding=1),
    26             nn.ReLU(inplace=True),
    27             nn.Conv2d(256, 256, kernel_size=3, padding=1),
    28             nn.ReLU(inplace=True),
    29             nn.MaxPool2d(kernel_size=3, stride=2),
    30         )
    31         self.avgpool=nn.AdaptiveAvgPool2d((6,6))
    32         self.classifier=nn.Sequential(
    33             nn.Dropout(0.5),
    34             nn.Linear(256 * 6 * 6, 4096),
    35             nn.ReLU(inplace=True),
    36             nn.Dropout(),
    37             nn.Linear(4096, 4096),
    38             nn.ReLU(inplace=True),
    39             # nn.Linear(4096,num_classes)
    40         )
    41         self.gategory=nn.Linear(4096, num_classes)
    42     def forward(self,input):
    43         out=self.features(input)
    44         out=self.avgpool(out)
    45         out=torch.flatten(out,1)
    46         out=self.classifier(out)
    47         out=self.gategory(out)
    48         return out
    49 
    50 model=alex_net(num_classes=5)
    51 print(model.state_dict().keys())
    52 # print(model)
    import torch
    from torch import optim,nn
    import visdom
    from torchvision.models import alexnet
    from torch.utils.data import DataLoader
    
    from transfer_learning.poke import Pokemonn
    from transfer_learning.model import alex_net
    batch_size=16
    learning_rate=1e-3
    # device=torch.device('cuda')
    epoches=10
    # 设置随机种子,用于生成随机数
    torch.manual_seed(1234)
    vis = visdom.Visdom()
    train_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='train')
    validation_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='validation')
    test_db=Pokemonn('/Users/wenyu/Desktop/TorchProject/Pokemann/pokeman',227,mode='test')
    
    train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True,num_workers=4)
    validation_loader=DataLoader(validation_db,batch_size=batch_size,num_workers=2)
    test_loader=DataLoader(test_db,batch_size=batch_size,num_workers=2)
    def evaluate(model,loader):
        correct=0
        total_num=len(loader.dataset)
        for x,y in loader:
            # x,y=x.to(device),y.to(device)
            with torch.no_grad():
                logits=model(x)
                pred=logits.argmax(dim=1)
            correct+=torch.eq(pred,y).sum().float().item()
        return correct/total_num
    def main():
        # model = ResNet18(5).to(device)
        model=alex_net(5)
        model_dict=model.state_dict()
        pretrained_model=alexnet(pretrained=True)
        pretrained_dict=pretrained_model.state_dict()
        pretrained_dict={k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        optimizer=optim.SGD(model.parameters(),lr=learning_rate)
        fun_loss=nn.CrossEntropyLoss()
        vis.line([0.], [-1], win='train_loss', opts=dict(title='train_loss'))
        vis.line([0.], [-1], win='validation_acc', opts=dict(title='validation_acc'))
        global_step=0
        best_epoch,best_acc=0,0
        for epoch in range(epoches):
            for step,(x,y) in enumerate(train_loader):
                # x,y=x.to(device),y.to(device)
                logits=model(x)
                loss=fun_loss(logits,y)
                # pred=logits.argmax(dim=1)
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                vis.line([loss.item()],[global_step],win='train_loss',update='append')
                global_step += 1
    
            if epoch % 1==0:
                val_acc=evaluate(model, validation_loader)
                if val_acc>best_acc:
                    best_acc=val_acc
                    best_epoch=epoch
                    torch.save(model.state_dict(),'best.mdl')
                    vis.line([val_acc],[global_step],win='validation_acc',update='append')
    
        print('best acc',best_acc,'best epoch',best_epoch)
        model.load_state_dict(torch.load('best.mdl'))
        print('load from ckpt')
    
        test_acc=evaluate(model,test_loader)
        print(test_acc)
    
    
    if __name__ == '__main__':
        main()

    训练这部分大部分是跟龙龙老师写的,数据集也是它的,我就想简单的验证一下迁移学习怎么用的,之后会做mobelnet,龙龙老师的pytorch讲的真的非常浅显易懂,但是迁移学习这块不是很全面,想学的话还需要再看看。

  • 相关阅读:
    mac下编写命令脚本
    mac环境mongodb安装小坑
    JS
    设计模式:装饰器
    proxy 数据帧听
    react hook 简单实现
    报错:java.lang.NumberFormatException: null
    git回滚到指定版本
    1109. 航班预订统计 力扣(中等) 差分数组 不会但神奇
    528. 按权重随机选择 力扣(中等) 前缀和rand()
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/12762467.html
Copyright © 2011-2022 走看看