zoukankan      html  css  js  c++  java
  • pytorch Containers的Module部分

     参考:https://pytorch.org/docs/stable/nn.html

    Containers

    Module

    CLASS torch.nn.Module

    所有神经网络模块的基类

    你定义的模型必须是该类的子类,即继承与该类

    模块也能包含其他模块,允许它们在树状结构中筑巢。您可以将子模块指定为常规属性:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
           x = F.relu(self.conv1(x))
           return F.relu(self.conv2(x))

    在这个例子中,nn.Conv2d(20, 20, 5)其实就是一个子模块

    以这种方式赋值的子模块将会被登记,当你调用to()等函数时,它们的参数也将被转换。

    方法:

    1)

    cpu()

    将所有模型的参数和缓冲都移到CPU上

    2)

    cuda(device=None)

    将所有模型的参数和缓冲都移到GPU上。因为其可能将关联的参数和缓冲变为不同的对象,所以如果优化时模块依赖于GPU,那么必须要在构造优化器之前调用该方法

    参数:

    device (int, optional) – 如果指定,则所有参数都将被复制到该设备上

    3)

    double()

    强制转换浮点参数和缓冲区为double数据类型

    float()

    强制转换浮点参数和缓冲区为float数据类型

    half()

    强制转换浮点参数和缓冲区为half数据类型

    举例:

    import torch
    from torch import nn
    linear = nn.Linear(2, 2)
    print(linear.weight)
    
    linear.double()
    print(linear.weight)

    返回:

    Parameter containing:
    tensor([[-0.0890,  0.2313],
            [-0.4812, -0.3514]], requires_grad=True)
    Parameter containing:
    tensor([[-0.0890,  0.2313],
            [-0.4812, -0.3514]], dtype=torch.float64, requires_grad=True)

    4)

    type(dst_type)

    强制转换所有参数和缓冲为给定的dst_type类型

    参数:

    dst_type (type or string) – 期望转成类型

    举例:

    import torch
    input = torch.FloatTensor([-0.8728,  0.3632, -0.0547])
    print(input)
    print(input.type(torch.double))

    返回:

    tensor([-0.8728,  0.3632, -0.0547])
    tensor([-0.8728,  0.3632, -0.0547], dtype=torch.float64)

    5)

    to(*args, **kwargs)

    移动和/或强制转换参数和缓冲区

    可调用的形式有三种:

    • to(device=None, dtype=None, non_blocking=False)
    • to(dtype, non_blocking=False)
    • to(tensor, non_blocking=False)

    它的签名类似于torch.Tensor.to(),但是只接受浮点所需的dtype,即dtype仅能设置为floatdoublehalf等浮点类型。如果给定了,该方法将仅强制转换浮点参数和缓冲区为指定的dtype。整数参数和缓冲区将移到给定的device中,dtypes类型不变。当设置non_blocking时,如果可能,它会尝试相对于主机异步地转换/移动,例如,将带有固定内存的CPU张量移动到CUDA设备。

    参数:

    • device (torch.device) – 该模块的参数和缓冲区期望使用的设备

    • dtype (torch.dtype) – 该模块的浮点参数和缓冲区期望转换为的浮点参数和缓冲区

    • tensor (torch.Tensor) – 该该模块的所有参数和缓冲区都转换为该张量的dtypedevice

    举例:

    import torch
    from torch import nn
    linear = nn.Linear(2, 2)
    print(linear.weight)
    
    linear.to(torch.double)
    print(linear.weight)
    
    gpu1 = torch.device("cuda:0")
    linear.to(gpu1, dtype=torch.half, non_blocking=True)
    print(linear.weight)
    
    cpu = torch.device("cpu")
    linear.to(cpu)
    print(linear.weight)

    返回:

    Parameter containing:
    tensor([[0.4604, 0.5215],
            [0.5981, 0.5912]], requires_grad=True)
    Parameter containing:
    tensor([[0.4604, 0.5215],
            [0.5981, 0.5912]], dtype=torch.float64, requires_grad=True)
    Parameter containing:
    tensor([[0.4604, 0.5215],
            [0.5981, 0.5913]], device='cuda:0', dtype=torch.float16,
           requires_grad=True)
    Parameter containing:
    tensor([[0.4604, 0.5215],
            [0.5981, 0.5913]], dtype=torch.float16, requires_grad=True)

    6)

    type(dst_type)

    强制转换参数和缓冲区为dst_type类型

    参数:

    dst_type (type or string) –期望类型

    举例:

    import torch
    from torch import nn
    linear = nn.Linear(2, 2)
    print(linear.weight)
    
    linear.type(torch.double)
    print(linear.weight)

    返回:

    Parameter containing:
    tensor([[ 0.4370, -0.6806],
            [-0.4628, -0.4366]], requires_grad=True)
    Parameter containing:
    tensor([[ 0.4370, -0.6806],
            [-0.4628, -0.4366]], dtype=torch.float64, requires_grad=True)

    7)

    forward(*input)

    定义每次调用时执行的计算。
    应该被所有子类覆盖。

    ⚠️虽然需要在这个函数中定义前向传播的配方,但是应该在之后调用模块实例,而不是这个来调用;因为前者负责运行已注册的钩子,而后者则默默忽略它们。

    这个函数就是我们在定义一个模块时定义的那个函数:

    def forward(self, x):

    当你调用模型时,该函数就会被调用:

    import torchvision.models as models
    
    alexnet = models.alexnet()
    output  = alexnet(input_data) #此时就会调用该forward()函数

    8)

    apply(fn)

    递归地将函数fn应用到每个子模块(调用.children()方法返回的模块即子模块)和它自己上。典型地就是在初始化模块的参数时使用(在torch-nn-init中可见)

    参数:

    • fn (Module -> None):应用到每个子模块上的函数

    返回:

    • self

    返回类型:

    • Module

    例子:可见pytorch对模型参数初始化

    9)

    named_parameters(prefix='', recurse=True)

    返回一个模型参数的迭代器,返回值包含参数的名字和参数本身

    参数:

    • prefix (str) – 添加到所有参数名字前面的前缀.

    • recurse (bool) – 如果设置为真,则递归获取该模块及其子模块参数;如果为False,则仅得到本模块参数

    上面的例子就有使用,从返回结果可知我们能直接使用名字来获得参数值:

    e.models.Conv2_3_64.weight.data

    返回:

    tensor([[[[ 1.8686e-02, -1.1276e-02,  1.0743e-02, -3.7258e-03],
              [ 1.7356e-02, -4.6002e-03, -1.5800e-02,  1.4272e-03],
              [-8.9406e-03,  2.8417e-02,  7.3844e-03, -2.0131e-02],
              [ 2.7378e-02, -1.3940e-02, -9.2417e-03, -1.3656e-02]],
    
             [[-2.6638e-02,  2.6307e-02, -2.9532e-02,  2.6932e-02],
              [-7.9886e-03,  3.4983e-03, -5.5121e-02,  1.8271e-02],
              [-4.3825e-02,  4.7733e-02, -3.5117e-02, -1.0677e-02],
              [-2.6437e-02, -4.5605e-03,  1.1901e-02, -1.9924e-02]],
    
             [[ 1.2108e-02, -2.0034e-02, -4.3065e-02, -4.4073e-03],
              [ 2.4294e-02,  2.0997e-04,  2.0511e-02,  4.0354e-02],
              [-7.4128e-03,  1.2180e-02,  2.1586e-02, -3.2092e-02],
              [-1.0036e-02, -1.3512e-02,  2.8016e-03,  1.7150e-02]]],
    
    
            [[[ 1.3010e-02, -7.7286e-03, -1.8568e-02,  2.6519e-03],
              [ 1.7086e-02, -3.7209e-03,  1.2222e-02, -9.8183e-03],
              [-1.2987e-02, -1.5011e-02,  1.0018e-02, -1.8424e-02],
              [-9.8759e-03,  3.1524e-03,  1.8473e-04,  3.0876e-02]],
    
             [[ 1.1653e-02, -3.5415e-02, -3.7799e-02,  1.5948e-02],
              [ 1.5886e-02, -2.0727e-02,  9.9321e-03, -2.6632e-02],
              [-1.3989e-02, -2.2149e-02, -1.6303e-02, -6.1840e-03],
              [-3.0577e-02, -8.2477e-03,  3.2550e-02,  3.0350e-02]],
    
             [[ 4.9647e-05,  2.5028e-02,  5.4636e-03, -2.2217e-02],
              [-1.7287e-02, -9.8452e-03, -2.1045e-02,  5.6478e-03],
              [ 9.7147e-03,  2.0614e-02, -1.5295e-02,  3.4130e-02],
              [ 4.1918e-02, -3.1760e-02,  7.8219e-03,  5.0951e-03]]],
    
    
            [[[-1.5743e-02,  3.2101e-02, -5.7166e-03,  3.7152e-02],
              [-8.6509e-03, -2.9025e-02,  1.2311e-02,  4.1298e-02],
              [ 1.3912e-02, -2.6538e-02,  1.2670e-02, -2.8338e-02],
              [ 1.7593e-04,  5.0950e-03, -3.0340e-02,  2.1955e-03]],
    
             [[ 4.7826e-03,  1.9481e-02,  5.3423e-03, -1.2969e-02],
              [ 5.1746e-03, -3.3188e-03, -2.3011e-02,  3.4073e-02],
              [ 1.5636e-02, -5.5335e-02,  1.1528e-03, -1.3905e-02],
              [ 9.9208e-03, -8.0908e-03, -9.8275e-03, -2.1614e-02]],
    
             [[ 9.2276e-03, -7.6164e-03,  8.6449e-03, -5.7667e-03],
              [ 2.2497e-02, -2.6568e-02,  2.9182e-02,  1.0791e-02],
              [ 2.8791e-02, -3.9055e-02,  4.0457e-04, -2.1397e-03],
              [-4.0300e-03, -2.0704e-03, -1.7246e-02,  3.2432e-02]]],
    
    
            ...,
    
    
            [[[ 1.7486e-02,  1.1616e-02, -1.2516e-02, -9.7095e-03],
              [-1.2367e-02,  3.0512e-02,  5.0169e-02,  1.1539e-02],
              [ 1.6477e-04,  2.5155e-03, -3.5218e-02, -1.3211e-02],
              [-1.3205e-02,  1.0017e-02,  4.2839e-02, -6.9317e-03]],
    
             [[-1.2817e-02,  3.1915e-02,  7.9632e-03, -6.4066e-03],
              [ 3.8245e-02,  1.1355e-02,  1.5460e-02, -1.1245e-03],
              [ 2.1138e-02, -2.4878e-03,  3.1970e-03,  4.2895e-02],
              [-2.4187e-02, -4.8445e-04, -2.5516e-02,  4.0083e-02]],
    
             [[ 2.0978e-02, -1.5094e-02,  3.0770e-02,  2.5550e-02],
              [ 8.2029e-03,  1.4726e-03,  1.2099e-02, -2.1542e-02],
              [ 6.7198e-03, -1.7803e-02, -4.8138e-03, -1.2432e-02],
              [-3.7668e-03, -1.9681e-02, -2.0834e-03,  8.3174e-04]]],
    
    
            [[[ 3.1066e-03, -1.3706e-02,  9.3733e-03,  1.2344e-02],
              [ 1.6753e-02,  1.4869e-03, -2.0681e-03, -8.8953e-03],
              [-3.0745e-02,  1.1374e-02,  2.1523e-02, -2.4726e-02],
              [ 1.0182e-02,  2.0394e-02,  5.5662e-04,  2.0951e-02]],
    
             [[ 2.1782e-02,  6.3107e-04,  1.6017e-02,  2.7767e-03],
              [ 7.6418e-03, -8.8861e-03, -2.2702e-02, -1.9778e-02],
              [ 2.2941e-02,  4.4974e-03, -2.7368e-02, -9.5090e-05],
              [ 3.2708e-02, -3.3382e-03,  1.5445e-02, -1.7446e-02]],
    
             [[ 1.5597e-02, -3.0816e-02,  1.4011e-02, -2.7484e-02],
              [ 2.3591e-03,  4.3519e-02, -1.3367e-02,  1.3066e-02],
              [-7.6286e-03, -4.7996e-03,  5.1619e-03, -1.1260e-02],
              [-1.5147e-02,  1.2956e-02, -2.5945e-02,  2.2437e-02]]],
    
    
            [[[ 2.1797e-02,  2.7596e-03, -2.0974e-02, -4.3435e-03],
              [ 4.6751e-03, -4.2520e-02, -1.0819e-02,  7.4361e-03],
              [ 4.7468e-02, -2.4098e-02,  7.5113e-04, -2.3566e-02],
              [ 1.6562e-03,  1.5573e-02,  1.5934e-02,  1.9551e-02]],
    
             [[ 1.7714e-02,  1.6497e-02,  1.9895e-02, -1.3463e-02],
              [ 1.6372e-02, -1.3358e-02,  2.0040e-02, -4.1047e-02],
              [-3.9821e-03,  1.3126e-02, -1.4217e-02,  5.7594e-03],
              [-2.2151e-02, -1.7522e-02,  2.9157e-03,  2.4983e-02]],
    
             [[-2.5523e-02,  1.2045e-02,  2.9011e-03, -1.2715e-02],
              [ 2.8795e-02, -2.6586e-02,  1.8300e-02,  3.7996e-02],
              [ 1.2800e-02, -1.6446e-02, -5.4592e-03, -1.6855e-02],
              [-4.6871e-02,  3.9172e-02,  2.6660e-02, -3.2577e-02]]]])
    View Code

    10)

    parameters(recurse=True)

    返回模块参数的迭代器,直接返回所有的参数

    参数:

    • recurse (bool) – 如果设置为真,则递归获取该模块及其子模块参数;如果为False,则仅得到本模块参数

    举例:

    for param in e.parameters():
        print(type(param.data), param.size())

    返回:

    <class 'torch.Tensor'> torch.Size([64, 3, 4, 4])
    <class 'torch.Tensor'> torch.Size([128, 64, 4, 4])
    <class 'torch.Tensor'> torch.Size([128])
    <class 'torch.Tensor'> torch.Size([128])
    <class 'torch.Tensor'> torch.Size([256, 128, 4, 4])
    <class 'torch.Tensor'> torch.Size([256])
    <class 'torch.Tensor'> torch.Size([256])
    <class 'torch.Tensor'> torch.Size([512, 256, 4, 4])
    <class 'torch.Tensor'> torch.Size([512])
    <class 'torch.Tensor'> torch.Size([512])
    <class 'torch.Tensor'> torch.Size([1024, 512, 4, 4])
    <class 'torch.Tensor'> torch.Size([1024])
    <class 'torch.Tensor'> torch.Size([1024])
    <class 'torch.Tensor'> torch.Size([2048, 1024, 4, 4])
    <class 'torch.Tensor'> torch.Size([2048])
    <class 'torch.Tensor'> torch.Size([2048])
    <class 'torch.Tensor'> torch.Size([100, 2048, 4, 4])

    11)

    register_parameter(name, param)

    添加一个参数到模块中

    该参数可以作为一个属性来使用给定的name来访问

    参数:

    • name (string) – 参数的名字。该参数可以使用这个给定的名字从模块中访问

    • param (Parameter) – 被添加到模块中的参数

     举例:

    import torch as t
    from torch import nn
    from torch.autograd import Variable as V
     
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))
            #等价于self.param1 = nn.Parameter(t.rand(3, 3))
            self.submodel1 = nn.Linear(3, 4)
        def forward(self, input):
            print('input : ', input.data)
            x = self.param1.mm(input) #param1参数与input相乘等到结果x
            print(x.size())
            print()
            print('middle x :', x)
            x = self.submodel1(x)
            return x
        
    net = Net()
    x = Variable(torch.randn(3,3))
    output = net(x)
    print()
    print('output : ', output)
    print()
    for name, param in net.named_parameters():
        print(name)
        print(param.size())
        print(param)
    View Code

    返回:

    input :  tensor([[-0.6774, -0.1080, -2.9368],
            [-0.7825,  1.4518, -1.5265],
            [-1.3426,  0.2754,  0.6105]])
    torch.Size([3, 3])
    
    middle x : tensor([[ 0.5576, -0.9339, -2.0338],
            [ 2.2566, -1.7668, -4.6034],
            [-0.0908, -0.6854, -0.2914]], grad_fn=<MmBackward>)
    
    output :  tensor([[-1.1309, -1.0884, -0.3657, -1.6447],
            [-2.3293, -1.8145, -1.4426, -2.9277],
            [-0.3567, -0.7607,  0.2292, -0.7849]], grad_fn=<AddmmBackward>)
    
    param1
    torch.Size([3, 3])
    Parameter containing:
    tensor([[ 0.8252, -0.4768, -0.5539],
            [ 1.5196, -0.7191, -2.0285],
            [ 0.3769, -0.4731,  0.1532]], requires_grad=True)
    submodel1.weight
    torch.Size([4, 3])
    Parameter containing:
    tensor([[ 0.0304,  0.1698,  0.4314],
            [-0.1409,  0.2963,  0.0934],
            [-0.4779, -0.3330,  0.2111],
            [ 0.2737,  0.4682,  0.5285]], requires_grad=True)
    submodel1.bias
    torch.Size([4])
    Parameter containing:
    tensor([-0.1118, -0.5433,  0.0191, -0.2851], requires_grad=True)
    View Code

    12)

    children()

    返回一个immediate子模块的迭代器

    举例:

    # coding:utf-8
    from torch import nn
    
    class Encoder(nn.Module):
        def __init__(self, input_size, input_channels, base_channnes, z_channels):
    
            super(Encoder, self).__init__()
            # input_size必须为16的倍数
            assert input_size % 16 == 0, "input_size has to be a multiple of 16"
    
            models = nn.Sequential()
            models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False))
            models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True))
            # 此时图片大小已经下降一倍
            temp_size = input_size/2
    
            # 直到特征图高宽为4
            # 目的是保证无论输入什么大小的图片,经过这几层后特征图大小为4*4
            while temp_size > 4 :
                models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False))
                models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2))
                models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True))
                base_channnes *= 2
                temp_size /= 2
    
            # 特征图高宽为4后面则添加上最后一层
            # 让输出为1*1
            models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False))
            self.models = models
    
        def forward(self, x):
            x = self.models(x)
            return x
    
    e = Encoder(256, 3, 64, 100)
    for child in e.children():
        print(child)
    View Code

    返回:

    Sequential(
      (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
    )
    View Code

    可以得到定义的模型的所有module信息

    13)

    named_children()

    返回一个immediate子模块的迭代器,返回一个包含模块名字和模块本身的元组(string, Module)

    接上面复杂的例子:

    for name, child in e.named_children():
        print(name)
        print(child)

    返回:

    models
    Sequential(
      (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
    )
    View Code

    该整体模块的名字为models,想得到该模块信息也可以直接调用:

    e.models

    简单例子:

    l = nn.Linear(2, 2)
    model = nn.Sequential(nn.Linear(2,2),
                 nn.ReLU(inplace=True),
                 nn.Sequential(l,l)
            )
    for name, module in model.named_children():
        print(name)
        print(module)

    返回:

    0
    Linear(in_features=2, out_features=2, bias=True)
    1
    ReLU(inplace)
    2
    Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): Linear(in_features=2, out_features=2, bias=True)
    )

    14)

    modules()

    返回网络的所有模块的迭代器,会一层层地返回,直到最后的一层,并且相同的module只会返回一个

    接着上面的例子:

    for module in e.modules():
        print(module)

    返回:

    Encoder(
      (models): Sequential(
        (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace)
        (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace)
        (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace)
        (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace)
        (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace)
        (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace)
        (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
      )
    )
    Sequential(
      (Conv2_3_64): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (LeakyReLU_64): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_64_128): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_128): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_128): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_128_256): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_256): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_256): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_256_512): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_512): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_512): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_512_1024): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_1024): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_1024): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_1024_2048): Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (BatchNorm2d_2048): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyReLU_2048): LeakyReLU(negative_slope=0.2, inplace)
      (Conv2_2048_100): Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
    )
    Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    LeakyReLU(negative_slope=0.2, inplace)
    Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    LeakyReLU(negative_slope=0.2, inplace)
    Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    LeakyReLU(negative_slope=0.2, inplace)
    Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    LeakyReLU(negative_slope=0.2, inplace)
    Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    LeakyReLU(negative_slope=0.2, inplace)
    Conv2d(1024, 2048, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    LeakyReLU(negative_slope=0.2, inplace)
    Conv2d(2048, 100, kernel_size=(4, 4), stride=(1, 1), bias=False)
    View Code

    简单点的例子:

    model = nn.Sequential(nn.Linear(2,2),
                 nn.ReLU(inplace=True),
                 nn.Sequential(nn.Linear(2,2),
                              nn.Linear(2,2)
                    )
            )
    for idx, m in enumerate(model.modules()):
        print(idx, '->', m)

    返回:

    0 -> Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU(inplace)
      (2): Sequential(
        (0): Linear(in_features=2, out_features=2, bias=True)
        (1): Linear(in_features=2, out_features=2, bias=True)
      )
    )
    1 -> Linear(in_features=2, out_features=2, bias=True)
    2 -> ReLU(inplace)
    3 -> Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): Linear(in_features=2, out_features=2, bias=True)
    )
    4 -> Linear(in_features=2, out_features=2, bias=True)
    5 -> Linear(in_features=2, out_features=2, bias=True)

    可见这里还是返回了两次Linear,这是因为相同的module的定义不是这样的,下面的例子才是相同的定义:

    l = nn.Linear(2, 2)
    model = nn.Sequential(nn.Linear(2,2),
                 nn.ReLU(inplace=True),
                 nn.Sequential(l,l)
            )
    for idx, m in enumerate(model.modules()):
        print(idx, '->', m)

    返回:

    0 -> Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): ReLU(inplace)
      (2): Sequential(
        (0): Linear(in_features=2, out_features=2, bias=True)
        (1): Linear(in_features=2, out_features=2, bias=True)
      )
    )
    1 -> Linear(in_features=2, out_features=2, bias=True)
    2 -> ReLU(inplace)
    3 -> Sequential(
      (0): Linear(in_features=2, out_features=2, bias=True)
      (1): Linear(in_features=2, out_features=2, bias=True)
    )
    4 -> Linear(in_features=2, out_features=2, bias=True)

    可见这个就只返回一次

    15)

    named_modules(memo=None, prefix='')

    返回网络的所有模块的迭代器,返回是一个有模块名和模块本身的元组

    for name, module in e.named_modules():
        print(name)
        print(module)

    返回:

    View Code

     所以能够使用该名字来调用模块:

    e.models.LeakyReLU_256

    返回:

    LeakyReLU(negative_slope=0.2, inplace)

    16)

    add_module(name, module) 

    添加子模块到当前模块中

    该添加子模块能够使用给定的名字name来访问

    参数:

    • name (string):子模块的名字。该添加子模块能够使用给定的名字name来从该模块中被访问
    • module (Module) :添加到该模块中的子模块

     例子:

    上面的模型可以写成

    import torch.nn as nn
    
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.model = nn.Sequential()
            self.model.add_module('conv1', nn.Conv2d(1, 20, 5))
            self.model.add_module('relu1', nn.ReLU(inplace=True))
            self.model.add_module('conv2', nn.Conv2d(20, 20, 5))
            self.model.add_module('relu2', nn.ReLU(inplace=True))
    
        def forward(self, x):
           x = self.model(x)
           return x

    17)

    buffers(recurse=True)

    返回一个模块缓冲区的迭代器,其保存的是模型中每次前向传播需用到上一次前向传播的结果,作为持久状态的值,如BatchNorm2d()中使用的均值和方差值,其随着BatchNorm2d()中参数的变化而变化

    参数:

    • recurse (bool) – 如果设置为真,则递归获取该模块及其子模块参数;如果为False,则仅得到本模块参数

      # coding:utf-8
      from torch import nn
      
      
      class Encoder(nn.Module):
          def __init__(self, input_size, input_channels, base_channnes, z_channels):
      
              super(Encoder, self).__init__()
              # input_size必须为16的倍数
              assert input_size % 16 == 0, "input_size has to be a multiple of 16"
      
              models = nn.Sequential()
              models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False))
              models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True))
              # 此时图片大小已经下降一倍
              temp_size = input_size/2
      
              # 直到特征图高宽为4
              # 目的是保证无论输入什么大小的图片,经过这几层后特征图大小为4*4
              while temp_size > 4 :
                  models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False))
                  models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2))
                  models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True))
                  base_channnes *= 2
                  temp_size /= 2
      
              # 特征图高宽为4后面则添加上最后一层
              # 让输出为1*1
              models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False))
              self.models = models
      
          def forward(self, x):
              x = self.models(x)
              return x
      
      e = Encoder(256, 3, 64, 100)
      
      for buffer in e.buffers():
          print(buffer)
          print(buffer.size())
      View Code

    举例:

    # coding:utf-8
    import torch
    from torch import nn
    from torch.autograd import Variable
    
    
    def weights_init(mod):
        """设计初始化函数"""
        classname=mod.__class__.__name__
        if classname.find('Conv')!= -1:    #这里的Conv和BatchNnorm是torc.nn里的形式
            mod.weight.data.normal_(0.0,0.02)
        elif classname.find('BatchNorm')!= -1:
            mod.weight.data.normal_(1.0,0.02) #bn层里初始化γ,服从(10.02)的正态分布
            mod.bias.data.fill_(0)  #bn层里初始化β,默认为0
    
    class Encoder(nn.Module):
        def __init__(self, input_size, input_channels, base_channnes, z_channels):
    
            super(Encoder, self).__init__()
            # input_size必须为16的倍数
            assert input_size % 16 == 0, "input_size has to be a multiple of 16"
    
            models = nn.Sequential()
            models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False))
            models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True))
            # 此时图片大小已经下降一倍
            temp_size = input_size/2
    
            # 直到特征图高宽为4
            # 目的是保证无论输入什么大小的图片,经过这几层后特征图大小为4*4
            while temp_size > 4 :
                models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False))
                models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2))
                models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True))
                base_channnes *= 2
                temp_size /= 2
    
            # 特征图高宽为4后面则添加上最后一层
            # 让输出为1*1
            models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False))
            self.models = models
    
        def forward(self, x):
            x = self.models(x)
            return x
    
    e = Encoder(256, 3, 64, 100)
    e.apply(weights_init)
    print('before running :')
    for buffer in e.buffers():
        print(buffer)
        print(buffer.size())
        
    x = Variable(torch.randn(2,3,256,256))
    output = e(x)
    
    print('after running :')
    for buffer in e.buffers():
        print(buffer)
        print(buffer.size())
    View Code

    返回:

    before running :
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0.])
    torch.Size([128])
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1.])
    torch.Size([128])
    tensor(0)
    torch.Size([])
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
    torch.Size([256])
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1.])
    torch.Size([256])
    tensor(0)
    torch.Size([])
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0.])
    torch.Size([512])
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
            1., 1., 1., 1., 1., 1., 1., 1.])
    torch.Size([512])
    tensor(0)
    torch.Size([])
    tensor([0., 0., 0.,  ..., 0., 0., 0.])
    torch.Size([1024])
    tensor([1., 1., 1.,  ..., 1., 1., 1.])
    torch.Size([1024])
    tensor(0)
    torch.Size([])
    tensor([0., 0., 0.,  ..., 0., 0., 0.])
    torch.Size([2048])
    tensor([1., 1., 1.,  ..., 1., 1., 1.])
    torch.Size([2048])
    tensor(0)
    torch.Size([])
    after running :
    tensor([-3.3760e-03,  1.1698e-03,  3.6801e-03, -2.9386e-03, -9.6070e-04,
            -3.9772e-03,  4.3308e-04, -3.1600e-04,  4.0223e-04,  1.8968e-03,
             1.6064e-03,  3.1311e-03,  2.5905e-03, -1.9954e-03, -1.9760e-03,
             3.8538e-03, -2.7571e-03, -1.7814e-03,  1.2943e-04, -1.0755e-03,
            -2.7892e-03, -2.9490e-03,  1.4452e-03,  1.7381e-03, -2.8058e-03,
             4.1997e-04, -7.3607e-03,  7.9688e-04,  1.0959e-03, -3.6058e-03,
            -1.0386e-03, -7.6220e-04, -2.6786e-03,  5.3019e-03, -1.2099e-03,
             3.1005e-03, -2.4421e-03,  3.9982e-03, -1.3801e-03, -3.2220e-04,
             1.4922e-03,  6.3325e-04,  9.6503e-04, -1.5298e-03,  2.2660e-03,
            -2.3133e-03,  1.9339e-03, -2.4072e-03, -1.9225e-03, -9.9753e-04,
             2.3214e-03,  5.0352e-03, -1.1458e-03,  4.7263e-03,  1.1954e-03,
             3.3723e-03,  4.7266e-03, -4.6656e-03,  4.9964e-04, -2.2194e-03,
             1.7171e-03, -6.0177e-04, -2.5741e-03,  1.1872e-03, -4.0245e-03,
            -3.4781e-03,  1.4507e-03,  6.1694e-05,  1.4087e-03, -4.7972e-03,
            -2.6325e-03,  5.8721e-03, -2.2517e-03, -6.4260e-04, -1.9965e-03,
             8.3321e-04, -1.6526e-03,  1.1089e-03,  6.2366e-03, -2.7464e-03,
             4.5316e-03, -3.7131e-03,  1.9032e-03, -4.5944e-04,  1.5664e-03,
             1.0817e-03, -4.7231e-03, -1.8417e-03, -5.9235e-03,  9.6230e-04,
             2.7968e-03,  2.6654e-04,  1.0158e-03,  3.2729e-03,  1.4751e-03,
            -1.3901e-03,  1.1596e-03,  1.8867e-03,  3.4735e-04,  1.7324e-03,
             3.7804e-04,  2.5138e-03, -7.7367e-03,  3.7004e-03,  6.5667e-04,
            -3.0492e-04, -1.1047e-03,  3.0829e-03,  8.9938e-03, -4.8453e-03,
            -2.4141e-03, -2.5017e-03, -2.0548e-03, -1.3747e-03, -1.0339e-03,
            -2.4000e-03,  7.9873e-04,  7.9712e-04,  7.7021e-04, -2.6673e-04,
            -5.4646e-03,  3.6639e-03,  1.1140e-04, -1.6342e-03,  2.5980e-04,
             2.9192e-05,  1.5542e-03, -2.1954e-04])
    torch.Size([128])
    tensor([0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9004, 0.9003, 0.9004, 0.9003,
            0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9004, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004,
            0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9004, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003,
            0.9003, 0.9004, 0.9003, 0.9003, 0.9004, 0.9004, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003,
            0.9004, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003,
            0.9003, 0.9004, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004, 0.9003, 0.9003,
            0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9003, 0.9004,
            0.9003, 0.9003])
    torch.Size([128])
    tensor(1)
    torch.Size([])
    tensor([-0.0019, -0.0237, -0.0149, -0.0089,  0.0008,  0.0320, -0.0008, -0.0491,
             0.0075,  0.0201, -0.0215,  0.0047, -0.0195, -0.0045,  0.0030,  0.0399,
            -0.0123,  0.0014,  0.0482, -0.0182, -0.0409, -0.0087, -0.0104,  0.0543,
             0.0212,  0.0173,  0.0623, -0.0083,  0.0008, -0.0013, -0.0316, -0.0318,
            -0.0155, -0.0594,  0.0242, -0.0266, -0.0148,  0.0026,  0.0026,  0.0329,
            -0.0112,  0.0113,  0.0320,  0.0206, -0.0157,  0.0302,  0.0377, -0.0074,
            -0.0336, -0.0215, -0.0110, -0.0168, -0.0328,  0.0341,  0.0375,  0.0764,
            -0.0439, -0.0002, -0.0004,  0.0603, -0.0637,  0.0300, -0.0397, -0.0093,
             0.0191, -0.0357, -0.0260, -0.0022,  0.0356, -0.0065, -0.0297,  0.0398,
            -0.0045, -0.0121, -0.0308,  0.0257,  0.0023,  0.0278,  0.0019,  0.0233,
             0.0149,  0.0043,  0.0738, -0.0094,  0.0467, -0.0391, -0.0139, -0.0862,
             0.0327,  0.0174,  0.0600,  0.0419,  0.0353, -0.0563,  0.0173, -0.0065,
            -0.0263,  0.0086, -0.0065, -0.0103, -0.0190, -0.0085, -0.0237, -0.0348,
             0.0054, -0.0087, -0.0366,  0.0035, -0.0313,  0.0338, -0.0043,  0.0117,
            -0.0421, -0.0069,  0.0043, -0.0150, -0.0209, -0.0303, -0.0172,  0.0275,
            -0.0198,  0.0201,  0.0278, -0.0054,  0.0107, -0.0460,  0.0145,  0.0132,
             0.0185, -0.0072, -0.0604, -0.0555,  0.0024,  0.0016, -0.0203,  0.0131,
            -0.0095, -0.0277, -0.0319, -0.0508,  0.0157,  0.0187, -0.0374, -0.0069,
             0.0006,  0.0369, -0.0368,  0.0190, -0.0449, -0.0174, -0.0294,  0.0118,
             0.0156, -0.0023, -0.0215, -0.0277, -0.0202,  0.0006, -0.0061, -0.0270,
            -0.0335, -0.0117, -0.0078, -0.0142, -0.0146,  0.0530, -0.0320, -0.0071,
             0.0168,  0.0243,  0.0019,  0.0568,  0.0356,  0.0171,  0.0044,  0.0371,
            -0.0322,  0.0361, -0.0339, -0.0184, -0.0050,  0.0136, -0.0189, -0.0132,
             0.0293,  0.0327,  0.0003,  0.0728, -0.0283, -0.0161,  0.0203,  0.0029,
            -0.0185,  0.0667, -0.0415,  0.0123, -0.0130,  0.0591,  0.0022,  0.0059,
             0.0290,  0.0413,  0.0351, -0.0014,  0.0091,  0.0004,  0.0048, -0.0142,
             0.0122, -0.0014, -0.0210, -0.0031,  0.0061,  0.0272,  0.0164,  0.0112,
            -0.0013,  0.0124,  0.0151,  0.0094,  0.0321, -0.0046,  0.0433, -0.0329,
            -0.0341, -0.0119, -0.0256,  0.0374, -0.0197, -0.0075,  0.0088, -0.0352,
            -0.0273, -0.0386, -0.0048, -0.0119, -0.0162,  0.0164,  0.0436,  0.0001,
            -0.0352, -0.0520,  0.0033,  0.0385,  0.0317, -0.0395,  0.0175, -0.0227,
             0.0405,  0.0333, -0.0135, -0.0067, -0.0549,  0.0137,  0.0043, -0.0050])
    torch.Size([256])
    tensor([0.9342, 0.9350, 0.9314, 0.9305, 0.9322, 0.9304, 0.9327, 0.9336, 0.9343,
            0.9365, 0.9336, 0.9292, 0.9342, 0.9334, 0.9314, 0.9303, 0.9349, 0.9325,
            0.9361, 0.9327, 0.9334, 0.9324, 0.9318, 0.9336, 0.9331, 0.9323, 0.9315,
            0.9357, 0.9313, 0.9337, 0.9355, 0.9326, 0.9325, 0.9320, 0.9330, 0.9306,
            0.9346, 0.9337, 0.9346, 0.9331, 0.9318, 0.9360, 0.9384, 0.9320, 0.9348,
            0.9340, 0.9335, 0.9328, 0.9316, 0.9345, 0.9313, 0.9345, 0.9345, 0.9366,
            0.9358, 0.9338, 0.9321, 0.9294, 0.9332, 0.9355, 0.9353, 0.9328, 0.9359,
            0.9336, 0.9320, 0.9331, 0.9305, 0.9304, 0.9322, 0.9340, 0.9346, 0.9355,
            0.9325, 0.9326, 0.9326, 0.9305, 0.9337, 0.9338, 0.9353, 0.9326, 0.9334,
            0.9325, 0.9352, 0.9339, 0.9327, 0.9336, 0.9352, 0.9358, 0.9340, 0.9319,
            0.9318, 0.9372, 0.9348, 0.9328, 0.9330, 0.9333, 0.9336, 0.9311, 0.9331,
            0.9338, 0.9336, 0.9358, 0.9373, 0.9317, 0.9340, 0.9308, 0.9320, 0.9319,
            0.9393, 0.9369, 0.9316, 0.9340, 0.9353, 0.9369, 0.9347, 0.9311, 0.9348,
            0.9367, 0.9291, 0.9358, 0.9337, 0.9342, 0.9347, 0.9331, 0.9329, 0.9330,
            0.9313, 0.9306, 0.9336, 0.9327, 0.9315, 0.9323, 0.9316, 0.9318, 0.9362,
            0.9335, 0.9338, 0.9326, 0.9327, 0.9361, 0.9355, 0.9347, 0.9316, 0.9322,
            0.9329, 0.9336, 0.9319, 0.9307, 0.9350, 0.9316, 0.9369, 0.9347, 0.9345,
            0.9336, 0.9332, 0.9368, 0.9355, 0.9361, 0.9336, 0.9330, 0.9349, 0.9331,
            0.9355, 0.9334, 0.9364, 0.9366, 0.9341, 0.9356, 0.9342, 0.9316, 0.9339,
            0.9315, 0.9341, 0.9285, 0.9377, 0.9333, 0.9318, 0.9342, 0.9332, 0.9349,
            0.9346, 0.9320, 0.9312, 0.9321, 0.9317, 0.9328, 0.9346, 0.9309, 0.9330,
            0.9316, 0.9343, 0.9342, 0.9311, 0.9350, 0.9315, 0.9334, 0.9337, 0.9308,
            0.9343, 0.9338, 0.9335, 0.9343, 0.9318, 0.9355, 0.9337, 0.9360, 0.9327,
            0.9337, 0.9342, 0.9356, 0.9338, 0.9340, 0.9333, 0.9307, 0.9309, 0.9305,
            0.9341, 0.9340, 0.9311, 0.9327, 0.9316, 0.9318, 0.9358, 0.9329, 0.9334,
            0.9363, 0.9363, 0.9303, 0.9311, 0.9324, 0.9354, 0.9346, 0.9322, 0.9333,
            0.9327, 0.9328, 0.9334, 0.9341, 0.9309, 0.9355, 0.9304, 0.9329, 0.9315,
            0.9349, 0.9349, 0.9334, 0.9309, 0.9348, 0.9339, 0.9352, 0.9305, 0.9353,
            0.9361, 0.9348, 0.9344, 0.9316])
    torch.Size([256])
    tensor(1)
    torch.Size([])
    tensor([ 2.5769e-02,  3.2003e-02,  4.0426e-02,  2.5748e-02,  7.2832e-02,
             1.5658e-02,  2.5115e-02,  2.5380e-02,  4.3120e-02, -7.6767e-02,
             4.8386e-02, -1.7225e-02,  3.9784e-02, -1.3605e-02,  2.6205e-02,
            -3.3973e-02, -2.5717e-02, -8.6800e-03, -8.9120e-02, -3.1962e-02,
            -5.2733e-02, -2.7778e-02, -1.5557e-02, -4.8369e-02, -6.4511e-03,
            -2.6600e-02,  1.7034e-02, -4.5308e-02, -1.7030e-02, -3.1436e-02,
             1.1061e-03,  8.9047e-02, -1.4947e-02,  8.0814e-02, -7.3011e-03,
             3.2562e-02,  4.6302e-02, -2.9296e-02,  7.3519e-02,  4.7905e-02,
             2.4076e-03,  3.1211e-02, -5.2155e-02,  1.0838e-02,  5.7961e-02,
            -3.3471e-03, -2.8430e-03,  1.1444e-03, -3.2272e-02, -5.7009e-02,
            -9.2353e-02, -2.1453e-02,  4.7136e-02,  4.5234e-02, -1.0296e-02,
            -1.3034e-02,  1.4136e-02,  1.0600e-02, -4.7237e-02, -1.0242e-02,
            -1.4815e-02, -3.5088e-03, -6.3280e-02, -9.8644e-04, -2.3059e-02,
            -1.3445e-02,  2.9654e-02,  2.6669e-02, -1.7380e-02,  5.4696e-03,
             2.1582e-02,  6.5305e-02, -2.8333e-02, -1.4173e-02, -2.6366e-02,
             8.4090e-02,  1.0214e-02,  7.0343e-02, -3.8497e-02, -5.4475e-02,
             3.1934e-02,  6.2931e-02,  5.0918e-02, -1.5748e-02,  6.0137e-02,
             4.6816e-02,  4.8743e-02, -3.6490e-02, -1.4070e-02, -5.5744e-02,
            -8.7710e-03,  1.6054e-02, -2.4121e-02,  5.0592e-02, -1.0744e-02,
             1.4429e-02,  4.8309e-03, -2.8721e-02, -3.1048e-02, -1.1565e-02,
             8.7734e-02,  1.8962e-02, -1.6371e-02, -1.8743e-02,  2.2613e-03,
             7.1928e-03,  2.1043e-02, -1.4599e-02,  1.6153e-02,  3.7763e-02,
             4.3269e-03, -3.5493e-02, -5.2598e-02,  1.5344e-02,  2.1441e-02,
             9.2463e-02,  1.5741e-02, -3.8817e-02,  4.6949e-02, -1.0287e-02,
             4.6703e-02,  6.6172e-02, -1.3216e-02, -7.6751e-02,  3.6660e-02,
            -2.5026e-02,  7.6301e-02,  1.3926e-02,  1.5871e-02, -3.5111e-02,
            -7.1907e-03, -1.0339e-01, -4.4918e-02, -2.4152e-02,  6.3309e-02,
            -2.7762e-02, -2.2627e-02, -1.6631e-02, -1.9683e-03, -2.2786e-02,
            -3.9106e-02, -1.2523e-02, -2.3914e-02, -8.7628e-02, -5.3616e-02,
             3.7245e-02,  4.1308e-02,  5.8160e-02, -5.9610e-02, -1.4550e-02,
             5.9928e-03, -1.2012e-02,  1.2292e-02,  8.4839e-02, -5.1759e-03,
            -9.5818e-03,  3.8721e-02,  6.7283e-03,  4.6232e-02, -5.4140e-02,
             1.5234e-02,  7.6472e-02,  3.6063e-02, -3.9120e-03, -2.8301e-02,
             5.2318e-02,  6.3161e-03, -4.1881e-02, -2.7641e-02, -2.3957e-02,
             2.2977e-02, -5.3927e-02, -9.4426e-03,  2.3404e-02, -4.5836e-02,
             9.8488e-03, -5.1690e-02,  4.0070e-02, -1.3923e-02, -2.4386e-02,
            -1.1535e-02,  6.0975e-02, -1.7121e-02, -6.7577e-02,  6.4819e-02,
             3.5068e-02,  2.8911e-02,  2.9796e-02, -2.5551e-02,  9.4217e-02,
            -8.1372e-03,  4.0888e-03,  5.5938e-03, -3.6768e-02,  1.4441e-02,
            -1.8997e-02,  1.5464e-03, -5.3608e-05, -2.3572e-03, -2.8609e-02,
            -6.0448e-02,  4.6937e-02,  4.2591e-02,  1.9752e-02, -2.5235e-02,
             3.0911e-02,  2.5987e-02,  1.7226e-03,  1.0095e-02, -2.4058e-02,
             1.8213e-02,  5.4116e-02, -6.0333e-02,  2.6258e-02,  6.0458e-02,
            -4.2852e-03,  3.4615e-02,  5.5996e-03,  3.3450e-02, -1.6998e-02,
            -3.8624e-02,  4.7385e-02,  2.9592e-02,  3.5316e-02, -3.9366e-03,
             1.4218e-02,  3.8937e-02,  2.0447e-03,  1.6828e-02, -1.4085e-02,
             3.7000e-02, -1.1752e-02, -1.2822e-02,  5.1092e-03,  6.9776e-02,
            -1.7114e-02,  3.5346e-02, -8.4873e-03,  1.9357e-02, -2.8954e-02,
            -2.0002e-02, -1.7849e-02,  3.7224e-02, -2.0103e-03,  1.8310e-02,
             5.1715e-02,  8.5137e-03, -1.9735e-02, -3.8351e-02, -3.5967e-02,
            -4.5121e-02,  3.6773e-02,  1.2142e-02, -2.3320e-03,  1.4159e-02,
             5.1570e-03,  1.5933e-02, -1.6325e-02,  1.2221e-02,  3.4894e-03,
            -8.5704e-02,  4.0650e-02, -6.6170e-02, -6.2233e-02,  1.4543e-02,
             4.4968e-02, -3.8874e-02,  3.9377e-02,  3.0383e-03, -1.6053e-03,
             2.2372e-02,  1.3575e-02, -1.3049e-02, -1.4711e-02, -4.3797e-02,
             1.5224e-03,  1.9025e-02, -3.6885e-02, -6.7741e-03, -4.1376e-02,
             3.1974e-02, -4.0833e-02, -5.8944e-02, -5.9171e-02,  8.3822e-02,
            -1.9277e-02,  1.5525e-02,  3.1380e-02,  5.2410e-02,  2.4664e-02,
            -5.1298e-02, -7.0221e-02,  3.2354e-02,  1.4572e-02,  7.6821e-02,
            -6.8654e-02, -1.7554e-02,  3.6301e-02,  2.0001e-02, -2.6152e-02,
             7.6607e-02, -3.1379e-03, -6.6274e-02,  4.7406e-02,  2.7557e-02,
            -5.6120e-02,  4.6689e-02,  4.5309e-02,  2.6608e-02, -2.5557e-02,
            -3.5906e-02,  1.4348e-02, -2.2431e-03, -1.5763e-02,  4.9855e-02,
            -7.1161e-02,  4.2684e-02,  2.1841e-02,  6.4723e-03, -3.6387e-02,
             1.3752e-02, -3.2767e-02,  4.0802e-02, -3.6758e-02, -4.6568e-02,
            -4.7367e-03, -2.4984e-02,  3.2021e-02,  1.9488e-02,  6.1584e-03,
             2.8842e-02, -3.3784e-02,  2.7394e-02,  8.5315e-03,  2.7566e-02,
             5.0114e-02,  6.8048e-03,  8.5549e-04,  3.1000e-02,  1.0139e-02,
            -1.6105e-02, -2.5671e-02,  2.3197e-03, -4.2809e-02,  9.8833e-04,
             6.5868e-03, -4.5146e-02,  9.0819e-03, -4.7215e-02, -3.0381e-02,
            -4.1886e-04, -1.6289e-02,  1.2936e-02, -3.9101e-02, -5.7306e-02,
             2.9948e-04, -5.4190e-03, -1.9369e-02, -5.5113e-02, -2.4558e-02,
             1.4119e-02, -2.7469e-02,  1.5950e-02,  7.2587e-03, -1.1168e-02,
            -1.9534e-02, -5.6258e-03,  8.0654e-03,  1.0765e-02, -8.0776e-02,
             1.5469e-02,  4.3477e-02,  8.0382e-03,  6.1378e-02,  2.8184e-02,
             4.1482e-02, -6.5876e-02, -2.3104e-02,  3.5433e-02, -2.5846e-02,
             1.6766e-02, -4.4522e-02,  2.9070e-02, -4.0928e-02,  3.7450e-02,
             1.1707e-02, -1.5259e-02,  3.5983e-02, -1.5232e-03, -5.5514e-02,
             6.7920e-02, -5.3521e-02,  3.1599e-02, -1.2989e-02, -8.3181e-03,
             6.8398e-02, -3.8819e-02,  5.3185e-04, -1.8690e-02,  1.0082e-02,
            -2.5835e-03,  1.9094e-02, -5.2345e-02,  9.8490e-03,  4.7015e-02,
            -2.0472e-02, -8.0189e-03,  6.9176e-03, -1.2634e-03,  3.4354e-02,
             3.9389e-02,  1.7511e-02, -9.6130e-02, -1.2743e-02,  1.5633e-02,
             6.5546e-02,  7.6114e-03, -4.7284e-02, -1.2856e-02,  4.1969e-02,
             1.6360e-02, -7.5796e-02, -3.1768e-02,  2.0013e-02, -3.2596e-02,
            -1.3465e-02, -5.2833e-02, -3.1620e-02,  7.1866e-02,  2.0732e-02,
             9.5775e-02,  3.5919e-02, -2.1153e-03, -4.4907e-02, -4.3439e-03,
             1.3606e-02, -4.5540e-02,  9.4055e-03,  1.8481e-03,  5.4999e-02,
             1.3219e-02, -4.8859e-03, -1.5467e-02,  3.4535e-02,  4.9613e-02,
            -5.6436e-02, -1.9687e-03, -2.3989e-02,  3.3957e-02,  2.1383e-03,
            -3.8722e-02, -4.0204e-04, -3.8855e-02,  8.1356e-02, -2.0538e-02,
            -1.4779e-02, -5.3581e-02,  2.4808e-02, -1.5770e-02,  1.8319e-02,
            -1.7443e-02,  4.3508e-02, -5.5921e-02, -1.8543e-02, -6.7227e-03,
             1.7551e-02, -2.7990e-02,  1.5976e-02, -2.5273e-02, -1.3250e-02,
            -1.9063e-02, -3.9713e-02,  1.4416e-02,  3.1798e-02, -3.9206e-02,
             8.6097e-03,  9.0590e-03,  3.4666e-02, -4.3512e-02,  3.5496e-02,
             6.6108e-02,  5.4080e-02,  2.2509e-02,  3.4298e-02,  2.4821e-02,
             1.1323e-02, -1.8867e-02, -2.2725e-02, -1.8874e-02,  6.5678e-03,
            -6.2875e-02, -1.8410e-02,  7.7500e-03, -5.8016e-02, -4.4243e-02,
             5.3432e-02, -2.7515e-03,  3.1921e-02,  2.0511e-02,  1.4370e-02,
            -1.1303e-02,  6.8358e-03, -5.2930e-03, -7.3147e-03, -6.1960e-02,
             3.1448e-02,  1.9133e-03, -1.4177e-02,  1.3810e-02, -6.0344e-02,
            -1.9071e-02, -7.6946e-02])
    torch.Size([512])
    tensor([0.9661, 0.9703, 0.9690, 0.9684, 0.9720, 0.9656, 0.9736, 0.9639, 0.9667,
            0.9635, 0.9728, 0.9656, 0.9646, 0.9631, 0.9616, 0.9658, 0.9753, 0.9664,
            0.9746, 0.9702, 0.9706, 0.9662, 0.9702, 0.9619, 0.9635, 0.9661, 0.9746,
            0.9700, 0.9736, 0.9660, 0.9603, 0.9705, 0.9656, 0.9594, 0.9686, 0.9705,
            0.9678, 0.9590, 0.9656, 0.9600, 0.9688, 0.9733, 0.9623, 0.9717, 0.9732,
            0.9639, 0.9672, 0.9569, 0.9656, 0.9673, 0.9726, 0.9618, 0.9651, 0.9700,
            0.9619, 0.9621, 0.9657, 0.9720, 0.9642, 0.9640, 0.9700, 0.9668, 0.9639,
            0.9648, 0.9693, 0.9691, 0.9722, 0.9632, 0.9602, 0.9656, 0.9633, 0.9648,
            0.9645, 0.9632, 0.9650, 0.9691, 0.9656, 0.9714, 0.9617, 0.9653, 0.9651,
            0.9667, 0.9569, 0.9683, 0.9627, 0.9608, 0.9630, 0.9706, 0.9632, 0.9567,
            0.9595, 0.9516, 0.9625, 0.9641, 0.9671, 0.9655, 0.9690, 0.9577, 0.9645,
            0.9673, 0.9584, 0.9633, 0.9642, 0.9631, 0.9631, 0.9725, 0.9671, 0.9546,
            0.9689, 0.9586, 0.9703, 0.9685, 0.9600, 0.9684, 0.9655, 0.9652, 0.9759,
            0.9671, 0.9670, 0.9685, 0.9580, 0.9638, 0.9694, 0.9656, 0.9626, 0.9689,
            0.9571, 0.9661, 0.9634, 0.9612, 0.9666, 0.9591, 0.9638, 0.9588, 0.9724,
            0.9674, 0.9715, 0.9676, 0.9637, 0.9664, 0.9674, 0.9729, 0.9602, 0.9576,
            0.9643, 0.9638, 0.9635, 0.9673, 0.9614, 0.9627, 0.9672, 0.9682, 0.9700,
            0.9604, 0.9707, 0.9639, 0.9636, 0.9612, 0.9703, 0.9704, 0.9644, 0.9741,
            0.9670, 0.9703, 0.9580, 0.9684, 0.9616, 0.9649, 0.9647, 0.9646, 0.9631,
            0.9621, 0.9674, 0.9658, 0.9686, 0.9640, 0.9662, 0.9638, 0.9591, 0.9695,
            0.9737, 0.9678, 0.9676, 0.9661, 0.9598, 0.9652, 0.9670, 0.9578, 0.9731,
            0.9677, 0.9639, 0.9614, 0.9692, 0.9673, 0.9675, 0.9651, 0.9712, 0.9660,
            0.9695, 0.9691, 0.9610, 0.9610, 0.9653, 0.9659, 0.9643, 0.9696, 0.9621,
            0.9623, 0.9658, 0.9663, 0.9708, 0.9658, 0.9667, 0.9651, 0.9657, 0.9606,
            0.9638, 0.9627, 0.9642, 0.9688, 0.9589, 0.9657, 0.9658, 0.9671, 0.9650,
            0.9719, 0.9611, 0.9684, 0.9544, 0.9648, 0.9675, 0.9689, 0.9615, 0.9719,
            0.9757, 0.9738, 0.9663, 0.9638, 0.9684, 0.9674, 0.9601, 0.9683, 0.9672,
            0.9640, 0.9591, 0.9674, 0.9641, 0.9697, 0.9647, 0.9679, 0.9585, 0.9726,
            0.9648, 0.9691, 0.9642, 0.9686, 0.9666, 0.9721, 0.9686, 0.9726, 0.9679,
            0.9589, 0.9613, 0.9594, 0.9702, 0.9632, 0.9610, 0.9721, 0.9677, 0.9614,
            0.9626, 0.9689, 0.9656, 0.9695, 0.9744, 0.9796, 0.9554, 0.9693, 0.9680,
            0.9643, 0.9621, 0.9600, 0.9629, 0.9727, 0.9713, 0.9701, 0.9715, 0.9595,
            0.9661, 0.9617, 0.9763, 0.9720, 0.9638, 0.9636, 0.9693, 0.9616, 0.9689,
            0.9673, 0.9594, 0.9675, 0.9589, 0.9695, 0.9724, 0.9653, 0.9687, 0.9712,
            0.9741, 0.9621, 0.9684, 0.9639, 0.9690, 0.9702, 0.9672, 0.9627, 0.9664,
            0.9682, 0.9752, 0.9652, 0.9671, 0.9624, 0.9694, 0.9634, 0.9692, 0.9724,
            0.9647, 0.9619, 0.9625, 0.9610, 0.9710, 0.9644, 0.9590, 0.9622, 0.9715,
            0.9566, 0.9621, 0.9726, 0.9619, 0.9687, 0.9698, 0.9654, 0.9611, 0.9680,
            0.9673, 0.9687, 0.9570, 0.9650, 0.9609, 0.9674, 0.9647, 0.9755, 0.9595,
            0.9724, 0.9654, 0.9688, 0.9656, 0.9721, 0.9720, 0.9633, 0.9651, 0.9647,
            0.9623, 0.9679, 0.9631, 0.9699, 0.9641, 0.9742, 0.9761, 0.9663, 0.9742,
            0.9746, 0.9742, 0.9654, 0.9661, 0.9753, 0.9676, 0.9663, 0.9788, 0.9585,
            0.9627, 0.9580, 0.9644, 0.9590, 0.9660, 0.9650, 0.9658, 0.9623, 0.9668,
            0.9710, 0.9665, 0.9590, 0.9636, 0.9701, 0.9676, 0.9680, 0.9680, 0.9660,
            0.9597, 0.9692, 0.9680, 0.9696, 0.9714, 0.9627, 0.9640, 0.9615, 0.9642,
            0.9717, 0.9658, 0.9738, 0.9660, 0.9633, 0.9725, 0.9749, 0.9713, 0.9720,
            0.9588, 0.9676, 0.9602, 0.9709, 0.9658, 0.9614, 0.9696, 0.9670, 0.9628,
            0.9662, 0.9637, 0.9701, 0.9679, 0.9693, 0.9657, 0.9602, 0.9731, 0.9727,
            0.9639, 0.9687, 0.9695, 0.9600, 0.9542, 0.9620, 0.9682, 0.9663, 0.9678,
            0.9677, 0.9630, 0.9641, 0.9665, 0.9634, 0.9706, 0.9615, 0.9679, 0.9631,
            0.9646, 0.9709, 0.9632, 0.9667, 0.9644, 0.9668, 0.9775, 0.9641, 0.9626,
            0.9691, 0.9787, 0.9696, 0.9685, 0.9668, 0.9640, 0.9693, 0.9651, 0.9610,
            0.9627, 0.9694, 0.9613, 0.9639, 0.9687, 0.9625, 0.9576, 0.9662, 0.9642,
            0.9641, 0.9692, 0.9711, 0.9600, 0.9648, 0.9701, 0.9717, 0.9674, 0.9651,
            0.9563, 0.9682, 0.9668, 0.9681, 0.9599, 0.9641, 0.9668, 0.9670, 0.9654,
            0.9672, 0.9707, 0.9663, 0.9688, 0.9671, 0.9674, 0.9638, 0.9692, 0.9720,
            0.9636, 0.9732, 0.9721, 0.9734, 0.9628, 0.9640, 0.9585, 0.9725])
    torch.Size([512])
    tensor(1)
    torch.Size([])
    tensor([-0.0009,  0.0246, -0.0794,  ...,  0.0150,  0.0272, -0.0704])
    torch.Size([1024])
    tensor([1.0095, 1.0234, 1.0410,  ..., 1.0388, 1.0404, 1.0162])
    torch.Size([1024])
    tensor(1)
    torch.Size([])
    tensor([-0.0063,  0.0608,  0.0883,  ...,  0.0642, -0.0890, -0.0128])
    torch.Size([2048])
    tensor([1.1888, 1.1833, 1.1989,  ..., 1.2491, 1.0934, 1.1935])
    torch.Size([2048])
    tensor(1)
    torch.Size([])
    View Code

    18)

    如果想要知道缓冲区中的值的具体含义,可以通过得到其名字了解

    named_buffers(prefix='', recurse=True)

    返回一个模块缓冲区的迭代器,返回值包含缓冲区的名字和参数本身

    参数:

    • prefix (str) – 添加到所有参数名字前面的前缀.

    • recurse (bool) – 如果设置为真,则递归获取该模块及其子模块参数;如果为False,则仅得到本模块参数

    举例:

    for name, buffer in e.named_buffers():
        print(name)
        print(buffer.size())

    返回:

    models.BatchNorm2d_128.running_mean
    torch.Size([128])
    models.BatchNorm2d_128.running_var
    torch.Size([128])
    models.BatchNorm2d_128.num_batches_tracked
    torch.Size([])
    models.BatchNorm2d_256.running_mean
    torch.Size([256])
    models.BatchNorm2d_256.running_var
    torch.Size([256])
    models.BatchNorm2d_256.num_batches_tracked
    torch.Size([])
    models.BatchNorm2d_512.running_mean
    torch.Size([512])
    models.BatchNorm2d_512.running_var
    torch.Size([512])
    models.BatchNorm2d_512.num_batches_tracked
    torch.Size([])
    models.BatchNorm2d_1024.running_mean
    torch.Size([1024])
    models.BatchNorm2d_1024.running_var
    torch.Size([1024])
    models.BatchNorm2d_1024.num_batches_tracked
    torch.Size([])
    models.BatchNorm2d_2048.running_mean
    torch.Size([2048])
    models.BatchNorm2d_2048.running_var
    torch.Size([2048])
    models.BatchNorm2d_2048.num_batches_tracked
    torch.Size([])
    View Code

    19)

    register_buffer(name, tensor)

    向模块添加持久缓冲区。
    这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm的running_mean不是一个参数,而是持久状态的一部分。
    缓冲区可以使用给定的名称作为属性访问。

    参数:

    • name (string) – 缓冲区的名字,可以根据给定的名字从模块中访问该缓冲区

    • tensor (Tensor) – 用来注册的缓冲区

    举例:

    import torch as t
    from torch import nn
    from torch.autograd import Variable as V
     
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))
            self.register_buffer('running_mean', torch.zeros(128))
            #等价于self.param1 = nn.Parameter(t.rand(3, 3))
            self.submodel1 = nn.Linear(3, 4)
            
        def forward(self, input):
            x = self.param1.mm(input) #param1参数与input相乘等到结果x
            x = self.submodel1(x)
            return x
        
    net = Net()  
    x = Variable(torch.randn(3,3))
    output = net(x)
    
    for name, buffer in net.named_buffers():
        print(name)
        print(buffer.size())
        print(buffer)
    View Code

    返回:

    running_mean
    torch.Size([128])
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0.])
    View Code

    hook的函数:

    作用:当你训练一个网络,想要提取中间层的参数、或者特征图的时候可用hook实现

     20)

    register_forward_hook(hook)

    在模块中注册一个前向传播hook

    该hook将在每一次调用forward()计算出一个输出后被调用,其有着下面的签名:

    hook(module, input, output) -> None or modified output

    该hook能够修改输出。其能够修改内置输入,但是因为它实在forward()被调用后才被调用的,所以其对输入的修改不会影响forward()

    返回:

      返回一个句柄handle,能够通过调用handle.remove()来移除该添加的hook

    返回类型:

      torch.utils.hooks.RemovableHandle

    21)

    register_forward_pre_hook(hook)

    在模块上注册一个前向的pre-hook

    它与上面的方法的不同在于上面的函数是在调用forward()之后被调用,这个是在调用之前被调用,有下面的签名:

    hook(module, input) -> None or modified input

    所以可想而知,能够用该函数来对模块的输入进行一个处理和修改。然后返回一个元组或单个修改后的值。如果返回的是单个值,我们会将其封装成一个tuple

    返回:

      返回一个句柄handle,能够通过调用handle.remove()来移除该添加的hook

    返回类型:

      torch.utils.hooks.RemovableHandle

    22)

    register_backward_hook(hook)

    在模块中注册一个后向传播hook

    该hook将会在每次根据模块输入计算后向传播时被调用,有着如下的签名,即该函数的输入hook是一个有着如下参数的函数:

    hook(module, grad_input, grad_output) -> Tensor or None

    如果模块有着多个输入和输出,那么grad_input和grad_output可能是tuples。该hook不应该修改它的参数,但是它可以选择性地根据输入返回一个新的梯度,这将会用来替换在接下来进行的子序列计算中的grad_input

    返回:

      返回一个句柄handle,能够通过调用handle.remove()来移除该添加的hook

    返回类型:

      torch.utils.hooks.RemovableHandle

    警告⚠️:

      对于执行许多操作的复杂模块,当前实现将不具有所显示的行为。在某些错误情况下,grad_input和grad_output只包含输入和输出子集的梯度。对于这样的模块,您应该直接在特定的输入或输出上使用torch.Tensor.register_hook()来获得所需的梯度。

    该问题和下面的例子都可见https://oldpan.me/archives/pytorch-autograd-hook有更详细解释

    举例:

    import torch
    import torch.nn as nn
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    class MyMul(nn.Module):
        def forward(self, input):
            out = input * 2
            return out
    
    class MyMean(nn.Module):            # 自定义除法module
        def forward(self, input):
            out = input/4
            return out
    
    def tensor_hook(grad):
        print('tensor hook')
        print('grad:', grad)
        return grad
    
    class MyNet(nn.Module):
        def __init__(self):
            super(MyNet, self).__init__()
            self.f1 = nn.Linear(4, 1, bias=True)    
            self.f2 = MyMean()
            self.weight_init()
    
        def forward(self, input):
            self.input = input
            output = self.f1(input)       # 先进行运算1,后进行运算2
            output = self.f2(output)      
            return output
    
        def weight_init(self):
            self.f1.weight.data.fill_(8.0)    # 这里设置Linear的权重为8
            self.f1.bias.data.fill_(2.0)      # 这里设置Linear的bias为2
    
        def my_hook(self, module, grad_input, grad_output):
            print('doing my_hook')
            print('original grad:', grad_input)
            print('original outgrad:', grad_output)   
    
            return grad_input
    
    if __name__ == '__main__':
    
        input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)
    
        net = MyNet()
        net.to(device)
    
        net.register_backward_hook(net.my_hook)   # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
        input.register_hook(tensor_hook)
        result = net(input)
    
        print('result =', result)
    
        result.backward()
    
        print('input.grad:', input.grad)
        for param in net.parameters():
            print('{}:grad->{}'.format(param, param.grad))
    View Code

    返回:

    result = tensor([20.5000], grad_fn=<DivBackward0>)
    # 该hook只会绑定在module中的最后一个执行函数上,所以该结果是最后一个函数f2()的结果
    # 上面的式子等价于 y = wx + b = f1(x) , Z = y / 4 = f2(y) doing my_hook original grad: (tensor([
    0.2500]), None) # 这个是z对输入y求导,为1/4 original outgrad: (tensor([1.]),) # 这个是z对z求导,所以是1 # 这个是将hook挂在input x上,所以得到的grad是z对x求导
    # 因为 ∂z/∂x = ∂z/∂y * ∂y/∂x = 1/4 * w = 1/4 * 8 = 2
    # 因为输入x大小为(1,4)所以grad也为(1,4)

    tensor hook grad: tensor([
    2., 2., 2., 2.])
    # 这里不用hook,直接输出input的grad也能看到结果和上面的是一样的 input.grad: tensor([
    2., 2., 2., 2.])
    # 下面是返回f1()函数中w,b的梯度
    # ∂z/∂w = ∂z/∂y * ∂y/∂w = 1/4 * [x1, x2, x3, x4] = 1/4 * [1,2,3,4] = [0.2500, 0.5000, 0.7500, 1.0000]
    # ∂z/∂b = ∂z/∂y * ∂y/∂b = 1/4 * 1 = 0.25
    Parameter containing: tensor([[
    8., 8., 8., 8.]], requires_grad=True):grad->tensor([[0.2500, 0.5000, 0.7500, 1.0000]]) Parameter containing: tensor([2.], requires_grad=True):grad->tensor([0.2500])

     另一个例子:

    #coding=UTF-8
    import torch
    from PIL import Image
    import numpy as np
    import torchvision.models as models
    
    alexnet = models.alexnet()
    
    print('The architecture of alexnet: ')
    for i in alexnet.named_modules():
        print(i)
    # print(alexnet.features[12]) #卷积层的最后一层输出
    # print(alexnet.classifier[4]) #全连接层的倒数第二个Linear输出
    
    imgSize = [224,224]
    
    img = Image.open('Tom_Hanks_54745.png')
    res_img = img.resize((imgSize[0],imgSize[1]))
    img = np.double(res_img)
    img = np.transpose(img, (2,0,1)) # h * w *c==> c*h*w
    input_data = torch.from_numpy(img).type(torch.FloatTensor)
    input_data = torch.unsqueeze(input_data, 0)
    
    
    def forward_hook(module, input, output):
        print('-'*8 + 'forward_hook' + '-'*8)
    
        print('number of input : ', len(input))
        print('number of output : ', len(output))
        print('shape of input : ', input[0].shape)
        print('shape of output : ', output.shape)
    
    def forward_hook_0(module, input, output):
        print('-'*8 + 'forward_hook_0' + '-'*8)
    
        print('number of input : ', len(input))
        print('number of output : ', len(output))
        print('shape of input : ', input[0].shape)
        print('shape of output : ', output.shape)
    
    def forward_hook_12(module, input, output):
        print('-'*8 + 'forward_hook_12' + '-'*8)
    
        print('number of input : ', len(input))
        print('number of output : ', len(output))
        print('shape of input : ', input[0].shape)
        print('shape of output : ', output.shape)
    
    # backward是用来获得该层的梯度的
    def backward_hook(module, grad_input, grad_output):
        # 默认挂载在最后一层
        print('-' * 8 + 'backward_hook' + '-' * 8)
        print('number of grad_input : ', len(grad_input))
        print('number of grad_output : ', len(grad_output))
        # grad_input格式为元组:(bias_grad, x_grad, weight_grad)
        # 对最后一层的三个输入求导 y = x * weight + bias
        print('shape of grad_input[0] : ', grad_input[0].shape) # y对bias求导
        print('shape of grad_input[1] : ', grad_input[1].shape) # y对x求导
        print('shape of grad_input[2] : ', grad_input[2].shape) # y对weight求导
        # 输出的grad_output为元组形式:(grad_output, )
        print('shape of grad_output : ', grad_output[0].shape) #y对y求导,返回都是1
        # print(grad_output[0])
    
    
    def backward_hook_0(module, grad_input, grad_output):
        print('-' * 8 + 'backward_hook_0' + '-' * 8)
        print('number of grad_input : ', len(grad_input))
        print('number of grad_output : ', len(grad_output))
        # grad_input格式为元组:(None, weight_grad, bias_grad)
        # 因为该层下一层为ReLU,y = wx + b
        print('grad_input[0] : ', grad_input[0])
        print('shape of grad_input[1] : ', grad_input[1].shape)
        print('shape of grad_input[2] : ', grad_input[2].shape)
        # 输出的grad_output为元组形式:(grad_output, )
        print('shape of grad_output : ', grad_output[0].shape)
    
    
    def backward_hook_12(module, grad_input, grad_output):
        print('-' * 8 + 'backward_hook_12' + '-' * 8)
        print('number of grad_input : ', len(grad_input))
        print('number of grad_output : ', len(grad_output))
        #  输入的grad_input为元组形式:(grad_input, )
        print('shape of grad_input : ', grad_input[0].shape)
        # 输出的grad_output为元组形式:(grad_output, )
        print('shape of grad_output : ', grad_output[0].shape)
    
    def backward_hook_classier_4(module, grad_input, grad_output):
        # 挂载在倒数第二个Linear层,得到该层的参数的梯度
        print('-' * 8 + 'backward_hook_classier_4' + '-' * 8)
        print('number of grad_input : ', len(grad_input))
        print('number of grad_output : ', len(grad_output))
        # grad_input格式为元组:(bias_grad, x_grad, weight_grad)
        # 对最后一层的三个输入求导 y = x * weight + bias
        print('shape of grad_input[0] : ', grad_input[0].shape) # y对bias求导
        print('shape of grad_input[1] : ', grad_input[1].shape) # y对x求导
        print('shape of grad_input[2] : ', grad_input[2].shape) # y对weight求导
        # 输出的grad_output为元组形式:(grad_output, )
        print('shape of grad_output : ', grad_output[0].shape) #y对y求导,返回就不都是1了,因为这个结果是上面的梯度向下走的结果
    
    
    def pre_forward_hook(module, input):
        print('-' * 8 + 'pre_forward_hook' + '-' * 8)
        # 输入的input为元组形式:(input, )
        print('number of input : ', len(input))
        print('shape of input : ', input[0].shape)
    
    
    def pre_forward_hook_0(module, input):
        print('-' * 8 + 'pre_forward_hook_0' + '-' * 8)
        # 输入的input为元组形式:(input, )
        print('number of input : ', len(input))
        print('shape of input : ', input[0].shape)
    
    
    # 如果没有专门指定层,register_forward_pre_hook和register_backward_hook都默认是第一层
    pre_hook = alexnet.register_forward_pre_hook(pre_forward_hook)
    # 等价于:
    pre_hook_0 = alexnet.register_forward_pre_hook(pre_forward_hook_0)
    
    # 下面的方法能够让你得到某一层的输入输出以及某一层的输入输出的梯度值
    # 直接挂载在网络上,则默认forward得到的输入是网络的输入,即[1, 3, 224, 224];输出是网络的输出,这里即[1, 1000]
    # backward挂载在最后一层,得到的就是该层输入和输出的梯度
    forward_hook = alexnet.register_forward_hook(forward_hook)
    backward_hook = alexnet.register_backward_hook(backward_hook)
    
    # 挂载在卷积层的第一层网络,这样就能够得到该层的中间值(特征图)和梯度
    forward_hook_0 = alexnet.features[0].register_forward_hook(forward_hook_0)
    backward_hook_0 = alexnet.features[0].register_backward_hook(backward_hook_0)
    
    # 挂载在卷积层的第12层网络
    forward_hook_12 = alexnet.features[12].register_forward_hook(forward_hook_12)
    backward_hook_12 = alexnet.features[12].register_backward_hook(backward_hook_12)
    
    #挂载在全连接层的第四层网络
    backward_hook_classier_4= alexnet.classifier[4].register_backward_hook(backward_hook_classier_4)
    
    num_class = 1000
    output  = alexnet(input_data)
    print('-'*20)
    print('-'*5 + 'forward done' + '-'*5)
    print()
    
    output.backward(torch.ones(1,num_class))
    print('-'*20)
    print('-'*5 + 'backward done' + '-'*5)
    print()
    
    #### remove handle
    pre_hook.remove()
    pre_hook_0.remove()
    
    forward_hook.remove()
    backward_hook.remove()
    
    forward_hook_0.remove()
    backward_hook_0.remove()
    
    forward_hook_12.remove()
    backward_hook_12.remove()
    
    backward_hook_classier_4.remove()
    View Code

    返回:

    /anaconda3/envs/deeplearning/bin/python3.6 /Users/wanghui/pytorch/face_data/learning.py
    The architecture of alexnet: 
    ('', AlexNet(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        (1): ReLU(inplace)
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (4): ReLU(inplace)
        (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU(inplace)
        (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (9): ReLU(inplace)
        (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace)
        (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
      (classifier): Sequential(
        (0): Dropout(p=0.5)
        (1): Linear(in_features=9216, out_features=4096, bias=True)
        (2): ReLU(inplace)
        (3): Dropout(p=0.5)
        (4): Linear(in_features=4096, out_features=4096, bias=True)
        (5): ReLU(inplace)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    ))
    ('features', Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    ))
    ('features.0', Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)))
    ('features.1', ReLU(inplace))
    ('features.2', MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))
    ('features.3', Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)))
    ('features.4', ReLU(inplace))
    ('features.5', MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))
    ('features.6', Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
    ('features.7', ReLU(inplace))
    ('features.8', Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
    ('features.9', ReLU(inplace))
    ('features.10', Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
    ('features.11', ReLU(inplace))
    ('features.12', MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False))
    ('avgpool', AdaptiveAvgPool2d(output_size=(6, 6)))
    ('classifier', Sequential(
      (0): Dropout(p=0.5)
      (1): Linear(in_features=9216, out_features=4096, bias=True)
      (2): ReLU(inplace)
      (3): Dropout(p=0.5)
      (4): Linear(in_features=4096, out_features=4096, bias=True)
      (5): ReLU(inplace)
      (6): Linear(in_features=4096, out_features=1000, bias=True)
    ))
    ('classifier.0', Dropout(p=0.5))
    ('classifier.1', Linear(in_features=9216, out_features=4096, bias=True))
    ('classifier.2', ReLU(inplace))
    ('classifier.3', Dropout(p=0.5))
    ('classifier.4', Linear(in_features=4096, out_features=4096, bias=True))
    ('classifier.5', ReLU(inplace))
    ('classifier.6', Linear(in_features=4096, out_features=1000, bias=True))
    --------pre_forward_hook--------
    number of input :  1
    shape of input :  torch.Size([1, 3, 224, 224])
    --------pre_forward_hook_0--------
    number of input :  1
    shape of input :  torch.Size([1, 3, 224, 224])
    --------forward_hook_0--------
    number of input :  1
    number of output :  1
    shape of input :  torch.Size([1, 3, 224, 224])
    shape of output :  torch.Size([1, 64, 55, 55])
    --------forward_hook_12--------
    number of input :  1
    number of output :  1
    shape of input :  torch.Size([1, 256, 13, 13])
    shape of output :  torch.Size([1, 256, 6, 6])
    --------forward_hook--------
    number of input :  1
    number of output :  1
    shape of input :  torch.Size([1, 3, 224, 224])
    shape of output :  torch.Size([1, 1000])
    --------------------
    -----forward done-----
    
    --------backward_hook--------
    number of grad_input :  3
    number of grad_output :  1
    shape of grad_input[0] :  torch.Size([1000])
    shape of grad_input[1] :  torch.Size([1, 4096])
    shape of grad_input[2] :  torch.Size([4096, 1000])
    shape of grad_output :  torch.Size([1, 1000])
    --------backward_hook_classier_4--------
    number of grad_input :  3
    number of grad_output :  1
    shape of grad_input[0] :  torch.Size([4096])
    shape of grad_input[1] :  torch.Size([1, 4096])
    shape of grad_input[2] :  torch.Size([4096, 4096])
    shape of grad_output :  torch.Size([1, 4096])
    --------backward_hook_12--------
    number of grad_input :  1
    number of grad_output :  1
    shape of grad_input :  torch.Size([1, 256, 13, 13])
    shape of grad_output :  torch.Size([1, 256, 6, 6])
    --------backward_hook_0--------
    number of grad_input :  3
    number of grad_output :  1
    grad_input[0] :  None
    shape of grad_input[1] :  torch.Size([64, 3, 11, 11])
    shape of grad_input[2] :  torch.Size([64])
    shape of grad_output :  torch.Size([1, 64, 55, 55])
    --------------------
    -----backward done-----
    
    
    Process finished with exit code 0
    View Code

    每一个挂载的hook都要写一个单独的hook函数,不能挂载同一个函数在不同层,否则会报错:

    TypeError: 'RemovableHandle' object is not callable

    而且很多时候每一层的输入输出情况也有所不同

    上面的例子可以看出来不同的hook是有调用顺序的

    23)

    zero_grad()

    将所有模型参数的梯度都设置为0,在训练的时候,每一个新epoch开始前都会进行该操作

    24)

    state_dict(destination=None, prefix='', keep_vars=False)

    返回一个包含整个模块状态的字典

    参数和持久缓冲(如运行均值等)都包含在里面,keys与参数和缓冲的名字相关联

    举例:

    import torchvision.models as models
    
    alexnet = models.alexnet()
    print(alexnet.state_dict().keys())

    返回:

    odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

    25)

    load_state_dict(state_dict, strict=True)

    从state_dict()中复制参数和缓冲到该模块及其子模块中。如果参数strict设置为True,则state_dict()中的keys一定要和加载其的模块的keys完全一样。所以如果只是想要相同部分加载,不同部分不加载的话,设置为False即可(因为有时我们可能会对模型最后几层进行更改,前面的参数还是想要加载进来)

    参数:

    • state_dict (dict) – 包含参数和持久缓冲的字典

    • strict (bool, optional) – 是否严格要求则state_dict()中的keys一定要和加载其的模块的keys完全一样。Default: True

    26)

    train(mode=True)

    将模块设置为训练模式

    这仅对某些模块会产生影响可以通过查看特殊模块的文档得到这些模块受影响时在训练/验证模式的行为的细节,比如有Dropout和BatchNorm层的模块

    参数:

    mode (bool) – 是设置为训练模式 (True) 还是验证模式(False). Default: True.

    27)

    eval()

    这是模块为验证模式

    这仅对某些模块会产生影响可以通过查看特殊模块的文档得到这些模块受影响时在训练/验证模式的行为的细节,比如有Dropout和BatchNorm层的模块

    等价于self.train(False)

    28)??????

    requires_grad_(requires_grad=True)

    根据autograd是否应该记录此模块中参数的操作来更改requires_grad参数的值。
    此方法设置参数的requires_grad属性。
    该方法有助于通过冻结模型一部分来进行微调或单独训练模块的一部分(如GAN训练)

    参数:

    requires_grad (bool) – autograd是否应该记录模块中参数进行的操作。默认为 True.

    举例:

    报错:

    AttributeError: 'Encoder' object has no attribute 'requires_grad_'

    不知道这个到底怎么玩

  • 相关阅读:
    Android 聊天表情输入、表情翻页带效果、下拉刷新聊天记录
    android启动界面
    ubuntu 关于sublime text3的一些应用
    [LeetCode]Valid Sudoku解题记录
    在 Mac OS X 10.10 安装 pyenv 的一个小坑
    java调用百度地图API依据地理位置中文获取经纬度
    debug openStack
    error recoder,error debug for openStack kilo
    SDN,NFV
    openStack kilo 手动Manual部署随笔记录
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/11285055.html
Copyright © 2011-2022 走看看