zoukankan      html  css  js  c++  java
  • LSTM的简洁实现

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    class LSTM(nn.Module):
        def __init__(self, indim, hidim, outdim):
            super(LSTM, self).__init__()
            self.LSTM = nn.LSTM(indim, hidim, 2)# 设定层数为两层
            self.Linear = nn.Linear(hidim, outdim)
        
        def forward(self, input, h0, c0):
            input = input.type(torch.float32)
            state = state.type(torch.float32)
            if torch.cuda.is_available():
                input = input.cuda()
                state = state.cuda()
            state = (h0, c0)
            y, state = self.LSTM(input, state)
            y = self.Linear(y.reshape(-1, y.shape[-1]))
            return y, state
    

    尝龟,没啥好说的,LSTM就是多了一个C状态,因为C状态也是需要初始化的,于是我们在用模型计算时候一定要注意传入的初始状态要包括h0 c0

    其他的都是尝龟。

  • 相关阅读:
    P1082 同余方程
    P2678 跳石头
    P2827 蚯蚓
    P1351 联合权值
    P2822 组合数问题
    P3958 奶酪
    P2296 寻找道路
    P2661 信息传递
    平时问题总结
    平时总结
  • 原文地址:https://www.cnblogs.com/kalicener/p/15547466.html
Copyright © 2011-2022 走看看