zoukankan      html  css  js  c++  java
  • pytorch迁移学习mobilenet1

    上个博客讲了怎么制作参数字典,这次讲怎么迁移,怎么按照层迁移。代码还有待寻优,现在先看看吧,

    import torch
    import torch.nn as nn
    from torch import optim
    import visdom
    from torch.utils.data import DataLoader
    from MobileNet.mobilenet_v1 import MobileNet
    from MobileNet.iris_csv import Iris
    
    batch_size=16
    base_learning_rate=1e-4
    
    epoches=10
    torch.manual_seed(1234)
    vis=visdom.Visdom()
    train_db=Iris('/root/demo',64,128,'train')
    validation_db=Iris('/root/demo',64,128,'validation')
    test_db=Iris('root/demo',64,128,'test')
    
    train_loader=DataLoader(train_db,batch_size=batch_size,shuffle=True,num_workers=4)
    validation_loader=DataLoader(validation_db,batch_size=batch_size,num_workers=2)
    test_loader=DataLoader(test_db,batch_size=batch_size,num_workers=2)
    def evaluate(model,loader):
        correct=0
        total_num=len(loader.dataset)
        for x,y in loader:
            # x,y=x.to(device),y.to(device)
            with torch.no_grad():
                logits=model(x)
                pred=logits.argmax(dim=1)
            correct+=torch.eq(pred,y).sum().float().item()
        return correct/total_num
    def adapt_weights(pthfile,module):
        module_dict=module.state_dict()
        pretrained_dict=torch.load(pthfile)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in module_dict}
        module_dict.update(pretrained_dict)
        module.load_state_dict(module_dict)
    
    def main():
        mod=MobileNet(35)
        mod_dict = mod.state_dict()
        nn.init.kaiming_normal_(mod.upchannel.weight, nonlinearity='relu')
        nn.init.constant_(mod.upchannel.bias,0.1)
        pretrained_dict = torch.load('/root/tf_to_torch.pth')
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in mod_dict}
        mod_dict.update(pretrained_dict)
        mod.load_state_dict(mod_dict)
        freeze_list=list(mod.state_dict().keys())[0:-2]
        # print(freeze_list)
        for name,param in mod.named_parameters():
             if name in freeze_list:
                 param.requires_grad=False
             if param.requires_grad:
                 print(name)
        optimizer=optim.SGD(filter(lambda p: p.requires_grad, mod.parameters()),lr=base_learning_rate)
        fun_loss = nn.CrossEntropyLoss()
        vis.line([0.], [-1], win='train_loss', opts=dict(title='train_loss'))
        vis.line([0.], [-1], win='validation_acc', opts=dict(title='validation_acc'))
        global_step = 0
        best_epoch, best_acc = 0, 0
        for epoch in range(10):
            for step, (x, y) in enumerate(train_loader):
                logits = mod(x)
                # print(logits.shape)
                loss = fun_loss(logits, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                vis.line([loss.item()], [global_step], win='train_loss', update='append')
                global_step += 1
    
    
            if epoch%1==0:
                val_acc = evaluate(mod, validation_loader)
                if  val_acc > best_acc:
                    best_acc = val_acc
                    best_epoch = epoch
                    torch.save(mod.state_dict(), 'best.pth')
                    vis.line([val_acc], [global_step], win='validation_acc', update='append')
    
        print('best acc', best_acc, 'best epoch', best_epoch)
    
    if __name__ == '__main__':
        main()

    root的地方就是电脑的路径,根据自己的工程来就行。freeze_list就是不更新的层的key的名称,你不想哪一层的参数更新你就把哪一层的参数名写进去,然后用

    for name,param in mod.named_parameters()

    这一行得到参数字典里所有的参数名和参数本身,如果name在freeze_list当中,那你需要将它冻结,不然参数更新,只把它作为特征提取器使用。

  • 相关阅读:
    基于HttpListener的web服务器
    基于TcpListener的web服务器
    一个简单的web服务器
    c# 6.0新特性(二)
    c# 6.0新特性(一)
    c#之Redis实践list,hashtable
    html5摇一摇[转]
    在Microsoft-IIS/10.0上面部署mvc站点的时候,出现404的错误
    [实战]MVC5+EF6+MySql企业网盘实战(28)——其他列表
    让DELPHI自带的richedit控件显示图片
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/12845585.html
Copyright © 2011-2022 走看看