zoukankan      html  css  js  c++  java
  • pytorch(二) 自定义神经网络模型

    一、nn.Modules

    我们可以定义一个模型,这个模型继承自nn.Module类。如果需要定义一个比Sequential模型更加复杂的模型,就需要定义nn.Module模型。
    定义了__init__和 forward 两个方法,就实现了自定义的网络模型。
    _init_(),定义模型架构,实现每个层的定义。
    forward(),实现前向传播,返回y_pred

    import torch
    
    
    class TwoLayerNet(torch.nn.Module):
        def __init__(self, D_in, H, D_out):
            """
            In the constructor we instantiate two nn.Linear modules and assign them as
            member variables.
            """
            super(TwoLayerNet, self).__init__()
            self.linear1 = torch.nn.Linear(D_in, H)
            self.linear2 = torch.nn.Linear(H, D_out)
    
        def forward(self, x):
            """
            In the forward function we accept a Tensor of input data and we must return
            a Tensor of output data. We can use Modules defined in the constructor as
            well as arbitrary operators on Tensors.
            """
            h_relu = self.linear1(x).clamp(min=0)
            y_pred = self.linear2(h_relu)
            return y_pred
    
    
    
    N, D_in, H, D_out = 64, 1000, 100, 10
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
    
    model = TwoLayerNet(D_in, H, D_out)
    
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
    for t in range(500):
       
        y_pred = model(x)
        loss = criterion(y_pred, y)
        print(t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

    二、一个实例:FizzBuzz

    FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况下则正常数数。

    # One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
    def fizz_buzz_encode(i):
        if   i % 15 == 0: return 3
        elif i % 5  == 0: return 2
        elif i % 3  == 0: return 1
        else:             return 0
        
    def fizz_buzz_decode(i, prediction):
        return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
    

    首先定义模型的输入与输出(训练数据)

    import numpy as np
    import torch
    
    NUM_DIGITS = 10
    
    # Represent each input by an array of its binary digits.
    def binary_encode(i, num_digits):
        return np.array([i >> d & 1 for d in range(num_digits)])[::-1] # 右移一位再和1做与运算。
    # 右移动运算符:把">>"左边的运算数的各二进位全部右移若干位,>> 右边的数字指定了移动的位数 
    trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
    trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)]) #因为表示类别,用LongTensor
    

    然后用PyTorch定义模型,损失函数,优化器。

    # Define the model
    NUM_HIDDEN = 100
    model = torch.nn.Sequential(
        torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
        torch.nn.ReLU(),
        torch.nn.Linear(NUM_HIDDEN, 4)
    )
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)
    

    以下是模型的训练代码

    # Start training it
    BATCH_SIZE = 128
    for epoch in range(10000):
        for start in range(0, len(trX), BATCH_SIZE):
            end = start + BATCH_SIZE
            batchX = trX[start:end]
            batchY = trY[start:end]
    
            y_pred = model(batchX)
            loss = loss_fn(y_pred, batchY)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        # Find loss on training data
        loss = loss_fn(model(trX), trY).item()
        print('Epoch:', epoch, 'Loss:', loss)
    
  • 相关阅读:
    C++中的queue类、Qt中的QQueue类
    FeignClient传输实体类(包括GET、POST)
    Spring @Order注解的使用
    springboot整合fastdfs实现上传和下载
    Illegal group reference解决方法
    Java io下载并替换文件内容
    Java根据模板下载TXT文件
    SpringBoot--logger日志配置,使用@Slf4j注解
    SpringBoot--poi导出Excel文件
    JRebel for IntelliJ 热部署破解方法
  • 原文地址:https://www.cnblogs.com/leimu/p/13230723.html
Copyright © 2011-2022 走看看