zoukankan      html  css  js  c++  java
  • 【PyTorch】state_dict详解

    这篇博客来自csdn,完全用于学习。

    Introduce

    在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453

    torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。

    因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

    Sample

    通过一个简单的案例来输出state_dict字典对象中存放的变量

    #encoding:utf-8
     
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import numpy as mp
    import matplotlib.pyplot as plt
    import torch.nn.functional as F
     
    #define model
    class TheModelClass(nn.Module):
        def __init__(self):
            super(TheModelClass,self).__init__()
            self.conv1=nn.Conv2d(3,6,5)
            self.pool=nn.MaxPool2d(2,2)
            self.conv2=nn.Conv2d(6,16,5)
            self.fc1=nn.Linear(16*5*5,120)
            self.fc2=nn.Linear(120,84)
            self.fc3=nn.Linear(84,10)
     
        def forward(self,x):
            x=self.pool(F.relu(self.conv1(x)))
            x=self.pool(F.relu(self.conv2(x)))
            x=x.view(-1,16*5*5)
            x=F.relu(self.fc1(x))
            x=F.relu(self.fc2(x))
            x=self.fc3(x)
            return x
     
    def main():
        # Initialize model
        model = TheModelClass()
     
        #Initialize optimizer
        optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
     
        #print model's state_dict
        print('Model.state_dict:')
        for param_tensor in model.state_dict():
            #打印 key value字典
            print(param_tensor,'	',model.state_dict()[param_tensor].size())
     
        #print optimizer's state_dict
        print('Optimizer,s state_dict:')
        for var_name in optimizer.state_dict():
            print(var_name,'	',optimizer.state_dict()[var_name])
     
     
     
    if __name__=='__main__':
        main()
     

    output:

    Model.state_dict:
    conv1.weight      torch.Size([6, 3, 5, 5])
    conv1.bias      torch.Size([6])
    conv2.weight      torch.Size([16, 6, 5, 5])
    conv2.bias      torch.Size([16])
    fc1.weight      torch.Size([120, 400])
    fc1.bias      torch.Size([120])
    fc2.weight      torch.Size([84, 120])
    fc2.bias      torch.Size([84])
    fc3.weight      torch.Size([10, 84])
    fc3.bias      torch.Size([10])
    Optimizer,s state_dict:
    state      {}
    param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
    csdn
    CSDN
    csdn
    CSDN
  • 相关阅读:
    LVM(逻辑卷管理器)部署、扩容、缩小
    部署磁盘阵列
    docker安装
    Linux基础命令
    awk补充
    awk
    shell脚本--grep与正则表达式
    文本处理工具 -wc、cut、sort、uniq的用法及参数
    Shell脚本编程原理
    重定向与管道符
  • 原文地址:https://www.cnblogs.com/peixu/p/13456971.html
Copyright © 2011-2022 走看看