zoukankan      html  css  js  c++  java
  • pytorch对模型参数初始化

    1.使用apply()

    举例说明:

    • Encoder :设计的编码其模型
    • weights_init(): 用来初始化模型
    • model.apply():实现初始化
    # coding:utf-8
    from torch import nn
    
    def weights_init(mod):
        """设计初始化函数"""
        classname=mod.__class__.__name__
        # 返回传入的module类型
        print(classname)
        if classname.find('Conv')!= -1:    #这里的Conv和BatchNnorm是torc.nn里的形式
            mod.weight.data.normal_(0.0,0.02)
        elif classname.find('BatchNorm')!= -1:
            mod.weight.data.normal_(1.0,0.02) #bn层里初始化γ,服从(10.02)的正态分布
            mod.bias.data.fill_(0)  #bn层里初始化β,默认为0
    
    class Encoder(nn.Module):
        def __init__(self, input_size, input_channels, base_channnes, z_channels):
    
            super(Encoder, self).__init__()
            # input_size必须为16的倍数
            assert input_size % 16 == 0, "input_size has to be a multiple of 16"
    
            models = nn.Sequential()
            models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False))
            models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True))
            # 此时图片大小已经下降一倍
            temp_size = input_size/2
    
            # 直到特征图高宽为4
            # 目的是保证无论输入什么大小的图片,经过这几层后特征图大小为4*4
            while temp_size > 4 :
                models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False))
                models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2))
                models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True))
                base_channnes *= 2
                temp_size /= 2
    
            # 特征图高宽为4后面则添加上最后一层
            # 让输出为1*1
            models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False))
            self.models = models
    
        def forward(self, x):
            x = self.models(x)
            return x
    
    if __name__ == '__main__':
        e = Encoder(256, 3, 64, 100)
        # 对e模型中的每个module和其本身都会调用一次weights_init函数,mod参数的值即这些module
        e.apply(weights_init)
        # 根据名字来查看参数
        for name, param in e.named_parameters():
            print(name)
            # 举个例子看看是否按照设计进行初始化
            # 可见BatchNorm2d的weight是正态分布形的参数,bias参数都是0
            if name == 'models.BatchNorm2d_128.weight' or name == 'models.BatchNorm2d_128.bias':
                print(param)

    返回:

    # 返回的是依次传入初始化函数的module
    Conv2d
    LeakyReLU
    Conv2d
    BatchNorm2d
    LeakyReLU
    Conv2d
    BatchNorm2d
    LeakyReLU
    Conv2d
    BatchNorm2d
    LeakyReLU
    Conv2d
    BatchNorm2d
    LeakyReLU
    Conv2d
    BatchNorm2d
    LeakyReLU
    Conv2d
    Sequential
    Encoder
    
    # 输出name的格式,并根据条件打印出BatchNorm2d-128的两个参数
    models.Conv2_3_64.weight
    models.Conv2_64_128.weight
    models.BatchNorm2d_128.weight
    Parameter containing:
    tensor([1.0074, 0.9865, 1.0188, 1.0015, 0.9757, 1.0393, 0.9813, 1.0135, 1.0227,
            0.9903, 1.0490, 1.0102, 0.9920, 0.9878, 1.0060, 0.9944, 0.9993, 1.0139,
            0.9987, 0.9888, 0.9816, 0.9951, 1.0017, 0.9818, 0.9922, 0.9627, 0.9883,
            0.9985, 0.9759, 0.9962, 1.0183, 1.0199, 1.0033, 1.0475, 0.9586, 0.9916,
            1.0354, 0.9956, 0.9998, 1.0022, 1.0307, 1.0141, 1.0062, 1.0082, 1.0111,
            0.9683, 1.0372, 0.9967, 1.0157, 1.0299, 1.0352, 0.9961, 0.9901, 1.0274,
            0.9727, 1.0042, 1.0278, 1.0134, 0.9648, 0.9887, 1.0225, 1.0175, 1.0002,
            0.9988, 0.9839, 1.0023, 0.9913, 0.9657, 1.0404, 1.0197, 1.0221, 0.9925,
            0.9962, 0.9910, 0.9865, 1.0342, 1.0156, 0.9688, 1.0015, 1.0055, 0.9751,
            1.0304, 1.0132, 0.9778, 0.9900, 1.0092, 0.9745, 1.0067, 1.0077, 1.0057,
            1.0117, 0.9850, 1.0309, 0.9918, 0.9945, 0.9935, 0.9746, 1.0366, 0.9913,
            0.9564, 1.0071, 1.0370, 0.9774, 1.0126, 1.0040, 0.9946, 1.0080, 1.0126,
            0.9761, 0.9811, 0.9974, 0.9992, 1.0338, 1.0104, 0.9931, 1.0204, 1.0230,
            1.0255, 0.9969, 1.0079, 1.0127, 0.9816, 1.0132, 0.9884, 0.9691, 0.9922,
            1.0166, 0.9980], requires_grad=True)
    models.BatchNorm2d_128.bias
    Parameter containing:
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
    models.Conv2_128_256.weight
    models.BatchNorm2d_256.weight
    models.BatchNorm2d_256.bias
    models.Conv2_256_512.weight
    models.BatchNorm2d_512.weight
    models.BatchNorm2d_512.bias
    models.Conv2_512_1024.weight
    models.BatchNorm2d_1024.weight
    models.BatchNorm2d_1024.bias models.Conv2_1024_2048.weight models.BatchNorm2d_2048.weight models.BatchNorm2d_2048.bias models.Conv2_2048_100.weight

    2.直接在定义网络时定义

    import torch.nn as nn
    import torch.nn.init as init
    import torch.nn.functional as F
    
    class Discriminator(nn.Module):
        """ 
            6层全连接层
        """
        def __init__(self, z_dim):
            super(Discriminator, self).__init__()
            self.z_dim = z_dim
            self.net = nn.Sequential(
                nn.Linear(z_dim, 1000),
                nn.LeakyReLU(0.2, True),
                nn.Linear(1000, 1000),
                nn.LeakyReLU(0.2, True),
                nn.Linear(1000, 1000),
                nn.LeakyReLU(0.2, True),
                nn.Linear(1000, 1000),
                nn.LeakyReLU(0.2, True),
                nn.Linear(1000, 1000),
                nn.LeakyReLU(0.2, True),
                nn.Linear(1000, 2),
            )
            self.weight_init()
    
        # 参数初始化
        def weight_init(self, mode='normal'):
            if mode == 'kaiming':
                initializer = kaiming_init
            elif mode == 'normal':
                initializer = normal_init
    
            for block in self._modules:
                for m in self._modules[block]:
                    initializer(m)
    
        def forward(self, z):
            return self.net(z).squeeze()
    
    def kaiming_init(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            init.kaiming_normal_(m.weight)
            if m.bias is not None:
                m.bias.data.fill_(0)
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            m.weight.data.fill_(1)
            if m.bias is not None:
                m.bias.data.fill_(0)
    
    def normal_init(m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                m.bias.data.fill_(0)
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            m.weight.data.fill_(1)
            if m.bias is not None:
                m.bias.data.fill_(0)

    然后调用即可

  • 相关阅读:
    java 网络编程
    JAVA 中for-each循环使用方法
    JAVA 常用集合接口List、Set、Map总结
    android学习计划
    ExtJs
    jQuery easyui
    MVC
    简易servlet计算器
    使用servlet实现用户注册功能
    用JavaBean实现数据库的连接和关闭,在jsp页面输出数据库中student表中学生的信息
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/11385160.html
Copyright © 2011-2022 走看看