zoukankan      html  css  js  c++  java
  • 深度学习之构造模型,访问模型参数——2020.3.11

    今天主要学习了利用torch中的nn模块定义Module类,下面的代码包含对于模型类的构建以及参数访问,简便的可以使用‘net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())’构建模型,默认进行初始化。

    # 3.1 构造模型
    import torch 
    from torch import nn #module类是nn模块里提供的一个模型构造类
    
    # 定义MLP类
    class MLP(nn.Module):
        def __init__(self, **kwargs):
            super(MLP, self).__init__(**kwargs) #重载MLP类
            self.hidden = nn.Linear(784, 256)
            self.act = nn.ReLU()
            self.output = nn.Linear(256, 10)
        
        # 定义前向计算,反向传播函数可通过生成反向传播所需的backward函数
        def forward(self, x):
            a = self.act(self.hidden(x))
            return self.output(a)
    
    # 初始化net并传入输入数据x,做前向计算
    X = torch.rand(2, 784)
    net = MLP()
    net(X)
    

    # 4.12 module 的子类
    class MySquential(nn.Module):
        from collections import OrderedDict
        def __init__(self, *args):
            super(MySquential, self).__init__()
            if len(args) == 1 and isinstance(args[0], OrderedDict):
                for key, module in args[0].items():
                    self.add_module(key,module)
            else:
                for idx, module in enumerate(args):
                    self.add_module(str(idx), module)
        
        def forward(self, input):
            for module in self._modules.values():
                input = module(input)
            return input
    
    net = MySquential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10),)
    print(net)
    net(X)
    

    输出结果

    # ModuleLise 类
    net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
    net.append(nn.Linear(256, 10))
    print(net[0]) #使用Listd的索引访问
    print(net)
    

    输出结果

    # ModuleDict类
    net = nn.ModuleDict({'linear' : nn.Linear(784, 256), 'act' : nn.ReLU(),})
    net['output'] = nn.Linear(256, 10)
    print(net['linear']) # 访问
    print(net.output)
    print(net)
    
    # 构造复杂模型
    class FancyMLP(nn.Module):
        def __init__(self, **kwargs):
            super(FancyMLP, self).__init__(**kwargs)
            
            self.rand_weight = torch.rand((20, 20),requires_grad=False)
            self.linear = nn.Linear(20, 20)
            
        def forward(self, x):
            x = self.linear(x)
            x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
            
            x = self.linear(x)
            while x.norm().item() > 1:
                x /= 2
            if x.norm().item() < 0.0:
                x *= 10
            return x.sum()
            
    
    X = torch.rand(2, 20)
    net = FancyMLP()
    print(net)
    net(X)
    

    输出结果

    # 嵌套调用FancyMLP和Sequential类
    class NestMLP(nn.Module):
        def __init__(self, **kwargs):
            super(NestMLP, self).__init__(**kwargs)
            self.net = nn.Sequential(nn.Linear(40,30), nn.ReLU())
        
        def forward(self, x):
            return self.net(x)
    
    net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
    
    X = torch.rand(2, 40)
    print(net)
    net(X)
    

    输出结果

    # 4.2 模型参数的访问、初始化和共享
    import torch
    from torch import nn
    from torch.nn import init
    
    net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) 
    
    print(net)
    X = torch.rand(2, 4)
    Y = net(X).sum()
    

    输出结果

    # 访问模型参数
    print(type(net.named_parameters()))
    for name, param in net.named_parameters():
        print(name, param.size())
    

    输出结果

  • 相关阅读:
    POJ 3710 Christmas Game#经典图SG博弈
    POJ 2599 A funny game#树形SG(DFS实现)
    POJ 2425 A Chess Game#树形SG
    LeetCode Array Easy 122. Best Time to Buy and Sell Stock II
    LeetCode Array Easy121. Best Time to Buy and Sell Stock
    LeetCode Array Easy 119. Pascal's Triangle II
    LeetCode Array Easy 118. Pascal's Triangle
    LeetCode Array Easy 88. Merge Sorted Array
    ASP.NET MVC 学习笔记之 MVC + EF中的EO DTO ViewModel
    ASP.NET MVC 学习笔记之面向切面编程与过滤器
  • 原文地址:https://www.cnblogs.com/somedayLi/p/12461476.html
Copyright © 2011-2022 走看看