zoukankan      html  css  js  c++  java
  • pytorch 中LSTM模型获取最后一层的输出结果,单向或双向

    单向LSTM

    import torch.nn as nn
    import torch
    
    seq_len = 20
    batch_size = 64
    embedding_dim = 100
    num_embeddings = 300
    hidden_size = 128
    number_layer = 3
    
    input = torch.randint(low=0,high=256,size=[batch_size,seq_len])  #[64,20]
    
    embedding = nn.Embedding(num_embeddings,embedding_dim)
    
    input_embeded = embedding(input)  #[64,20,100]
    
    #转置,变换batch_size 和seq_len
    # input_embeded = input_embeded.transpose(0,1)
    # input_embeded = input_embeded.permute(1,0,2)
    #实例化lstm
    
    lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,batch_first=True,num_layers=number_layer)
    
    output,(h_n,c_n) = lstm(input_embeded)
    print(output.size()) #[64,20,128]       [batch_size,seq_len,hidden_size]
    print(h_n.size()) #[3,64,128]           [number_layer,batch_size,hidden_size]
    print(c_n.size()) #同上
    
    
    #获取最后时间步的output
    output_last = output[:,-1,:]
    #获取最后一层的h_n
    h_n_last = h_n[-1]
    
    print(output_last.size())
    print(h_n_last.size())
    #最后的output等于最后一层的h_n
    print(output_last.eq(h_n_last))


    D:anacondapython.exe C:/Users/liuxinyu/Desktop/pytorch_test/day4/LSTM练习.py
    torch.Size([64, 20, 128])
    torch.Size([3, 64, 128])
    torch.Size([3, 64, 128])
    torch.Size([64, 128])
    torch.Size([64, 128])
    tensor([[True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    ...,
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True]])

    Process finished with exit code 0

      双向LSTM

    import torch.nn as nn
    import torch
    
    seq_len = 20
    batch_size = 64
    embedding_dim = 100
    num_embeddings = 300
    hidden_size = 128
    number_layer = 3
    
    input = torch.randint(low=0,high=256,size=[batch_size,seq_len])  #[64,20]
    
    embedding = nn.Embedding(num_embeddings,embedding_dim)
    
    input_embeded = embedding(input)  #[64,20,100]
    
    #转置,变换batch_size 和seq_len
    # input_embeded = input_embeded.transpose(0,1)
    # input_embeded = input_embeded.permute(1,0,2)
    #实例化lstm
    
    lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,batch_first=True,num_layers=number_layer,bidirectional=True)
    
    output,(h_n,c_n) = lstm(input_embeded)
    print(output.size()) #[64,20,128*2]       [batch_size,seq_len,hidden_size]
    print(h_n.size()) #[3*2,64,128]           [number_layer,batch_size,hidden_size]
    print(c_n.size()) #同上
    
    
    #获取反向的最后一个output
    output_last = output[:,0,-128:]
    #获反向最后一层的h_n
    h_n_last = h_n[-1]
    
    print(output_last.size())
    print(h_n_last.size())
    # 反向最后的output等于最后一层的h_n
    print(output_last.eq(h_n_last))
    
    #获取正向的最后一个output
    output_last = output[:,-1,:128]
    #获取正向最后一层的h_n
    h_n_last = h_n[-2]
    # 反向最后的output等于最后一层的h_n
    print(output_last.eq(h_n_last))


    D:anacondapython.exe C:/Users/liuxinyu/Desktop/pytorch_test/day4/双向LSTM练习.py
    torch.Size([64, 20, 256])
    torch.Size([6, 64, 128])
    torch.Size([6, 64, 128])
    torch.Size([64, 128])
    torch.Size([64, 128])
    tensor([[True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    ...,
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True]])
    tensor([[True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    ...,
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True],
    [True, True, True, ..., True, True, True]])

    Process finished with exit code 0

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    2019gdcpc
    STL容器
    C. Neko does Maths
    19年天梯赛总结
    初识事物处理
    Mybatis和spring整合
    build path导入的jar失效导致找不到类
    整合mybatis和spring时 Error creating bean with name 'sqlSessionFactory' defined in class path resource
    了解并使用springAOP(面向切面编程)
    aop配置问题引发的报错
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12322993.html
Copyright © 2011-2022 走看看