zoukankan      html  css  js  c++  java
  • pytorch seq2seq模型示例

    以下代码可以让你更加熟悉seq2seq模型机制

    """
        test
    """
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.autograd import Variable
    
    # 创建字典
    seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
    char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
    num_dict = {n:i for i,n in enumerate(char_arr)}
    
    # 网络参数
    n_step = 5
    n_hidden = 128
    n_class = len(num_dict)
    batch_size = len(seq_data)
    
    # 准备数据
    def make_batch(seq_data):
        input_batch, output_batch, target_batch =[], [], []
    
        for seq in seq_data:
            for i in range(2):
                seq[i] = seq[i] + 'P' * (n_step-len(seq[i]))
            input = [num_dict[n] for n in seq[0]]
            ouput = [num_dict[n] for n in ('S'+ seq[1])]
            target = [num_dict[n] for n in (seq[1]) + 'E']
    
            input_batch.append(np.eye(n_class)[input])
            output_batch.append(np.eye(n_class)[ouput])
            target_batch.append(target)
    
        return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))
    
    input_batch, output_batch, target_batch = make_batch(seq_data)
    
    
    # 创建网络
    class Seq2Seq(nn.Module):
        """
        要点:
        1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果
        2.RNN网络结构要熟知
        3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入
        """
        def __init__(self):
            super().__init__()
            # 此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度
            self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
            self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
            self.fc = nn.Linear(n_hidden, n_class)
    
        def forward(self, enc_input, enc_hidden, dec_input):
            # RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下
            enc_input = enc_input.transpose(0,1)
            dec_input = dec_input.transpose(0,1)
            _, enc_states = self.enc(enc_input, enc_hidden)
            outputs, _ = self.dec(dec_input, enc_states)
            pred = self.fc(outputs)
    
            return pred
    
    
    # training
    model = Seq2Seq()
    loss_fun = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(5000):
        hidden = Variable(torch.zeros(1, batch_size, n_hidden))
    
        optimizer.zero_grad()
        pred = model(input_batch, hidden, output_batch)
        pred = pred.transpose(0, 1)
        loss = 0
        for i in range(len(seq_data)):
            temp = pred[i]
            tar = target_batch[i]
            loss +=  loss_fun(pred[i], target_batch[i])
        if (epoch + 1) % 1000 == 0:
            print('Epoch: %d   Cost: %f' % (epoch + 1, loss))
        loss.backward()
        optimizer.step()
    
    
    # 测试
    def translate(word):
        input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])
        # hidden 形状 (1, 1, n_class)
        hidden = Variable(torch.zeros(1, 1, n_hidden))
        # output 形状(6,1, n_class)
        output = model(input_batch, hidden, output_batch)
        predict = output.data.max(2, keepdim=True)[1]
        decoded = [char_arr[i] for i in predict]
        end = decoded.index('E')
        translated = ''.join(decoded[:end])
    
        return translated.replace('P', '')
    
    print('girl ->', translate('girl'))

    参考:https://blog.csdn.net/weixin_43632501/article/details/98525673

  • 相关阅读:
    gbk学习笔记
    在freebsd下编译nodejs,出现无法找到execinfo.h头文件的错误
    php 截取GBK文档某个位置开始的n个字符
    linux下,phpstorm配置oracle jdk
    gb2312学习笔记
    freebsd下vim默认的vi操作方式太难用,可通过启用vim自带配置文件解决
    freebsd通过ssh远程登陆慢,用户认证时间长解决办法
    php输出全部gb2312编码内的汉字
    visibility:hidden 与 display:none 的区别
    java 实现文件/文件夹复制、删除、移动(二)
  • 原文地址:https://www.cnblogs.com/demo-deng/p/11811090.html
Copyright © 2011-2022 走看看