zoukankan      html  css  js  c++  java
  • pytorch循环神经网络实现回归预测 代码

    pytorch循环神经网络实现回归预测

    学习视频:莫烦python

    # RNN for classification
    import torch
    import numpy as np
    import torch.nn as nn
    import torch.utils.data as Data
    import matplotlib.pyplot as plt
    import torchvision
    
    
    #hyper parameters
    TIME_STEP=10     #run time step
    INPUT_SIZE=1
    LR=0.02   #learning rate
    
    # t=np.linspace(0,np.pi*2,100,dtype=float)  #from zero to pi*2, and one hundred point there
    # x=np.sin(t)
    # y=np.cos(t)
    # plt.plot(t,x,'r-',label='input (sin)')
    # plt.plot(t,y,'b-',label='target (cos)')
    # plt.legend(loc='best')
    # plt.show()
    
    class RNN_Net(nn.Module):
        def __init__(self):
            super(RNN_Net,self).__init__()
            self.rnn=nn.RNN(
                input_size=INPUT_SIZE,
                hidden_size=32,
                num_layers=1,
                batch_first=True,
            )
            self.out=nn.Linear(32,1)
    
        def forward(self,x,h_state):
            r_out,h_state=self.rnn(x,h_state)
            outs=[]
            for time_step in range(r_out.size(1)):
                outs.append(self.out(r_out[:,time_step,:]))
            return torch.stack(outs,dim=1),h_state  # the type of return data is torch, and the return data also include h_state
    
    rnn=RNN_Net()
    # print(rnn)
    
    optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
    loss_func=nn.MSELoss()
    
    plt.ion()
    h_state=None
    for step in range (60):
        start,end=step*np.pi,(step+1)*np.pi
        #using sin predicts cos
        steps=np.linspace(start,end,TIME_STEP,dtype=np.float32)
        x_np=np.sin(steps)
        y_np=np.cos(steps)
    
        x=torch.from_numpy(x_np[np.newaxis,:,np.newaxis]) # np.newaxis means increase a dim
        y=torch.from_numpy(y_np[np.newaxis,:,np.newaxis])
        predition,h_state=rnn(x,h_state) #the first h_state is None
        h_state=h_state.data  #?????
        loss=loss_func(predition,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        plt.plot(steps,y_np,'r-')
        plt.plot(steps,predition.detach().numpy().flatten(),'b-')  #flatten() 展平维度
        plt.draw()
        plt.pause(0.05)
    plt.ioff()
    plt.show()
  • 相关阅读:
    同一域环境下SQLServer DB Failover故障转移配置详解
    WebAPI项目中使用SwaggerUI
    Failed to initialize the Common Language Runtime
    WCF Throttling 限流的三道闸口
    Entity Framework 乐观并发控制
    MVC3不能正确识别JSON中的Enum枚举值
    编写高质量代码改善C#程序的157个建议读书笔记【11-20】
    如果下次做模板,我就使用Nvelocity
    对于react中的this.setState的理解
    对于react中rredux的理解
  • 原文地址:https://www.cnblogs.com/yz-lucky77/p/13927822.html
Copyright © 2011-2022 走看看