zoukankan      html  css  js  c++  java
  • pytorch-卷积基本网络结构-提取网络参数-初始化网络参数

    基本的卷积神经网络

    from torch import nn
    
    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            layer1 = nn.Sequential() # 将网络模型进行添加
            layer1.add_module('conv1', nn.Conv2d(3, 32, 3, 1, padding=1)) # nn.Conv
            layer1.add_module('relu1', nn.ReLU(True))
            layer1.add_module('pool1', nn.MaxPool2d(2, 2))
            self.layer1 = layer1
    
            layer2 = nn.Sequential()
            layer2.add_module('conv2', nn.Conv2d(32, 64, 3, 1, padding=1))
            layer2.add_module('relu2', nn.ReLU(True))
            layer2.add_module('pool2', nn.MaxPool2d(2, 2))
            self.layer2 = layer2
    
            layer3 = nn.Sequential()
            layer3.add_module('conv3', nn.Conv2d(64, 128, 3, 1, padding=1))
            layer3.add_module('relu3', nn.ReLU(True))
            layer3.add_module('pool3', nn.MaxPool2d(2, 2))
            self.layer3 = layer3
    
            layer4 = nn.Sequential()
            layer4.add_module('fc1', nn.Linear(2048, 512))
            layer4.add_module('fc_relu1', nn.ReLU(True))
            layer4.add_module('fc2', nn.Linear(512, 64))
            layer4.add_module('fc_relu2', nn.ReLU(True))
            layer4.add_module('fc3', nn.Linear(64, 10))
            self.layer4 = layer4
    
        def forward(self, x):
            conv1 = self.layer1(x)
            conv2 = self.layer2(conv1)
            conv3 = self.layer3(conv2)
            fc_input = conv3.view(conv3.size(0), -1)
            fc_out = self.layer4(fc_input)
    
            return fc_out
    
    model = SimpleCNN()
    # print(model) # 打印输出网络结构

    提取前两层网络结构 

    new_model = nn.Sequential(*list(model.children())[:2])  # 提取前两层的网络结构, 构造nn.Sequential网络串接, * 表示将里面的内容一个个传进去

    提取所有的卷积层网络

    conv_model = nn.Sequential()
    # 提取所有的卷积层操作
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            name = name.replace('.', '_')
            conv_model.add_module(name, layer)
    print(conv_model)

    打印卷积层的网络名字

    for param in model.named_parameters():
        print(param)

    对权重参数进行初始化操作

    from torch.nn import init
    # 对权重参数进行初始化操作
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            init.normal(m.weight.data)
            init.xavier_normal(m.weight.data)
            init.kaiming_normal(m.weight.data)
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_()
  • 相关阅读:
    bootstrap在线引用 bootstrap百度调用
    CentOS7下安装MySQL5.7安装与配置(YUM)
    screen命令的常见用法
    Nginx主要用来干什么
    linux-Centos7安装python3并与python2共存
    爬虫小问题之以为是编码问题,却是headers中参数问题
    LabWindows/CVI基础
    STM32 命名方法
    Ubuntu14.04虚拟机下基本操作(typical安装)
    网关,路由器,交换机,猫小结
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/11725664.html
Copyright © 2011-2022 走看看