zoukankan      html  css  js  c++  java
  • pytorchnum_flat_features(x)

    #coding=utf-8
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable

    class Net(nn.Module):
    #定义Net的初始化函数,这个函数定义了该神经网络的基本结构
    def __init__(self):
    super(Net, self).__init__() #复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
    self.conv1 = nn.Conv2d(1, 6, 5) # 定义conv1函数的是图像卷积函数:输入为图像(1个频道,即灰度图),输出为 6张特征图, 卷积核为5x5正方形
    self.conv2 = nn.Conv2d(6, 16, 5)# 定义conv2函数的是图像卷积函数:输入为6张特征图,输出为16张特征图, 卷积核为5x5正方形
    self.fc1 = nn.Linear(16*5*5, 120) # 定义fc1(fullconnect)全连接函数1为线性函数:y = Wx + b,并将16*5*5个节点连接到120个节点上。
    self.fc2 = nn.Linear(120, 84)#定义fc2(fullconnect)全连接函数2为线性函数:y = Wx + b,并将120个节点连接到84个节点上。
    self.fc3 = nn.Linear(84, 10)#定义fc3(fullconnect)全连接函数3为线性函数:y = Wx + b,并将84个节点连接到10个节点上。

    #定义该神经网络的向前传播函数,该函数必须定义,一旦定义成功,向后传播函数也会自动生成(autograd)
    def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) #输入x经过卷积conv1之后,经过激活函数ReLU(原来这个词是激活函数的意思),使用2x2的窗口进行最大池化Max pooling,然后更新到x。
    x = F.max_pool2d(F.relu(self.conv2(x)), 2) #输入x经过卷积conv2之后,经过激活函数ReLU,使用2x2的窗口进行最大池化Max pooling,然后更新到x。
    x = x.view(-1, self.num_flat_features(x)) #view函数将张量x变形成一维的向量形式,总特征数并不改变,为接下来的全连接作准备。
    x = F.relu(self.fc1(x)) #输入x经过全连接1,再经过ReLU激活函数,然后更新x
    x = F.relu(self.fc2(x)) #输入x经过全连接2,再经过ReLU激活函数,然后更新x
    x = self.fc3(x) #输入x经过全连接3,然后更新x
    return x

    #使用num_flat_features函数计算张量x的总特征量(把每个数字都看出是一个特征,即特征总量),比如x是4*2*2的张量,那么它的特征总量就是16。
    def num_flat_features(self, x):

    size = x.size()[1:] # 这里为什么要使用[1:],是因为pytorch只接受批输入,也就是说一次性输入好几张图片,那么输入数据张量的维度自然上升到了4维。【1:】让我们把注意力放在后3维上面

    num_features = 1
    for s in size:
    num_features *= s
    return num_features


    net = Net()
    net

    # 以下代码是为了看一下我们需要训练的参数的数量
    print(net)
    params = list(net.parameters())

    k=0
    for i in params:
    l =1
    print ("该层的结构:"+str(list(i.size())))
    for j in i.size():
    l *= j
    print ("参数和:"+str(l))
    k = k+l

    print ("总参数和:"+ str(k))


    def num_flat_features(x):
    size = x.size()[1:] # 这里为什么要使用[1:],是因为pytorch只接受批输入,也就是说一次性输入好几张图片,那么输入数据张量的维度自然上升到了4维。【1:】让我们把注意力放在后3维上面

    num_features = 1
    for s in size:
    num_features *= s
    return num_features
    a=torch.arange(1,17).resize(2,2,2,2)
    print(a.size())
    print(a.size()[1:])
    size=num_flat_features(a)
    print(size)

    Net(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (fc1): Linear(in_features=400, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
    )
    该层的结构:[6, 1, 5, 5]
    参数和:150
    该层的结构:[6]
    参数和:6
    该层的结构:[16, 6, 5, 5]
    参数和:2400
    该层的结构:[16]
    参数和:16
    该层的结构:[120, 400]
    参数和:48000
    该层的结构:[120]
    参数和:120
    该层的结构:[84, 120]
    参数和:10080
    该层的结构:[84]
    参数和:84
    该层的结构:[10, 84]
    参数和:840
    该层的结构:[10]
    参数和:10
    总参数和:61706
    torch.Size([2, 2, 2, 2])
    torch.Size([2, 2, 2])
    8

  • 相关阅读:
    使用C#实现DHT磁力搜索的BT种子后端管理程序+数据库设计(开源)
    便携版WinSCP在命令行下同步文件夹
    ffmpeg (ffprobe)分析文件关键帧时间点
    sqlite删除数据或者表后,回收数据库文件大小
    ubuntu 20.04下 freeswitch 配合 fail2ban 防恶意访问
    ffmpeg使用nvenc编码的结论记录
    PC版跑跑卡丁车 故事模式 亚瑟传说章节 卡美洛庆典 2阶段 心灵之眼 攻略
    There was an error loading or playing the video
    Nvidia RTX Voice 启动报错修复方法
    火狐浏览器 关闭跨域限制
  • 原文地址:https://www.cnblogs.com/tianyudizhua/p/15505363.html
Copyright © 2011-2022 走看看