zoukankan      html  css  js  c++  java
  • pytorch seq2seq模型示例

    以下代码可以让你更加熟悉seq2seq模型机制

    """
        test
    """
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.autograd import Variable
    
    # 创建字典
    seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
    char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
    num_dict = {n:i for i,n in enumerate(char_arr)}
    
    # 网络参数
    n_step = 5
    n_hidden = 128
    n_class = len(num_dict)
    batch_size = len(seq_data)
    
    # 准备数据
    def make_batch(seq_data):
        input_batch, output_batch, target_batch =[], [], []
    
        for seq in seq_data:
            for i in range(2):
                seq[i] = seq[i] + 'P' * (n_step-len(seq[i]))
            input = [num_dict[n] for n in seq[0]]
            ouput = [num_dict[n] for n in ('S'+ seq[1])]
            target = [num_dict[n] for n in (seq[1]) + 'E']
    
            input_batch.append(np.eye(n_class)[input])
            output_batch.append(np.eye(n_class)[ouput])
            target_batch.append(target)
    
        return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))
    
    input_batch, output_batch, target_batch = make_batch(seq_data)
    
    
    # 创建网络
    class Seq2Seq(nn.Module):
        """
        要点:
        1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果
        2.RNN网络结构要熟知
        3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入
        """
        def __init__(self):
            super().__init__()
            # 此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度
            self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
            self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
            self.fc = nn.Linear(n_hidden, n_class)
    
        def forward(self, enc_input, enc_hidden, dec_input):
            # RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下
            enc_input = enc_input.transpose(0,1)
            dec_input = dec_input.transpose(0,1)
            _, enc_states = self.enc(enc_input, enc_hidden)
            outputs, _ = self.dec(dec_input, enc_states)
            pred = self.fc(outputs)
    
            return pred
    
    
    # training
    model = Seq2Seq()
    loss_fun = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(5000):
        hidden = Variable(torch.zeros(1, batch_size, n_hidden))
    
        optimizer.zero_grad()
        pred = model(input_batch, hidden, output_batch)
        pred = pred.transpose(0, 1)
        loss = 0
        for i in range(len(seq_data)):
            temp = pred[i]
            tar = target_batch[i]
            loss +=  loss_fun(pred[i], target_batch[i])
        if (epoch + 1) % 1000 == 0:
            print('Epoch: %d   Cost: %f' % (epoch + 1, loss))
        loss.backward()
        optimizer.step()
    
    
    # 测试
    def translate(word):
        input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])
        # hidden 形状 (1, 1, n_class)
        hidden = Variable(torch.zeros(1, 1, n_hidden))
        # output 形状(6,1, n_class)
        output = model(input_batch, hidden, output_batch)
        predict = output.data.max(2, keepdim=True)[1]
        decoded = [char_arr[i] for i in predict]
        end = decoded.index('E')
        translated = ''.join(decoded[:end])
    
        return translated.replace('P', '')
    
    print('girl ->', translate('girl'))

    参考:https://blog.csdn.net/weixin_43632501/article/details/98525673

  • 相关阅读:
    Libcurl
    Inno Setup教程
    APICloud平台的融云2.0集成
    关于mysql建立索引 复合索引 索引类型
    linux恢复误删除文件-extundelete
    OpenStack QA
    Android之应用程序怎样调用支付宝接口
    NYOJ 22 素数求和问题
    Mycat(5):聊天消息表数据库按月分表实践,平滑扩展
    opencv对图像进行边缘及角点检測
  • 原文地址:https://www.cnblogs.com/demo-deng/p/11811090.html
Copyright © 2011-2022 走看看