zoukankan      html  css  js  c++  java
  • pytorch seq2seq闲聊机器人加入attention机制

    attention.py

    """
    实现attention
    """
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import config
    
    
    class Attention(nn.Module):
        def __init__(self,method="general"):
            super(Attention,self).__init__()
            assert method in ["dot","general","concat"],"attention method error"
            self.method = method
            if method == "general":
                self.W = nn.Linear(config.chatbot_encoder_hidden_size*2,config.chatbot_encoder_hidden_size*2,bias=False)
    
            if method == "concat":
                self.W = nn.Linear(config.chatbot_decoder_hidden_size*4,config.chatbot_decoder_hidden_size*2,bias=False)
                self.V = nn.Linear(config.chatbot_decoder_hidden_size*2,1,bias=False)
    
    
    
        def forward(self,decoder_hidden,encoder_outputs):
            if self.method == "dot":
                return self.dot_score(decoder_hidden,encoder_outputs)
    
            elif self.method == "general":
                return self.general_socre(decoder_hidden,encoder_outputs)
    
            elif self.method == "concat":
                return self.concat_socre(decoder_hidden,encoder_outputs)
    
        def dot_score(self,decoder_hidden,encoder_outputs):
            """H_t^T * H_s
            :param decoder_hidden:[1,batch_size,128*2] --->[batch_size,128*2,1]
            :param encoder_outputs:[batch_size,encoder_max_len,128*2] --->[batch_size,encoder_max_len,128*2]
            :return:attention_weight:[batch_size,encoder_max_len]
            """
            decoder_hidden_viewed = decoder_hidden.squeeze(0).unsqueeze(-1) #[batch_size,128*2,1]
            attention_weight = torch.bmm(encoder_outputs,decoder_hidden_viewed).squeeze(-1)
            return F.softmax(attention_weight,dim=-1)
    
        def general_socre(self,decoder_hidden,encoder_outputs):
            """
            H_t^T *W* H_s
            :param decoder_hidden:[1,batch_size,128*2]-->[batch_size,decode_hidden_size] *[decoder_hidden_size,encoder_hidden_size]--->[batch_size,encoder_hidden_size]
            :param encoder_outputs:[batch_size,encoder_max_len,128*2]
            :return:[batch_size,encoder_max_len]
            """
            decoder_hidden_processed =self.W(decoder_hidden.squeeze(0)).unsqueeze(-1) #[batch_size,encoder_hidden_size*2,1]
            attention_weight = torch.bmm(encoder_outputs, decoder_hidden_processed).squeeze(-1)
            return F.softmax(attention_weight, dim=-1)
    
        def concat_socre(self,decoder_hidden,encoder_outputs):
            """
            V*tanh(W[H_t,H_s])
            :param decoder_hidden:[1,batch_size,128*2]
            :param encoder_outputs:[batch_size,encoder_max_len,128*2]
            :return:[batch_size,encoder_max_len]
            """
            #1. decoder_hidden:[batch_size,128*2] ----> [batch_size,encoder_max_len,128*2]
            # encoder_max_len 个[batch_size,128*2] -->[encoder_max_len,bathc_size,128*2] -->transpose--->[]
            encoder_max_len = encoder_outputs.size(1)
            batch_size = encoder_outputs.size(0)
            decoder_hidden_repeated = decoder_hidden.squeeze(0).repeat(encoder_max_len,1,1).transpose(0,1) #[batch_size,max_len,128*2]
            h_cated = torch.cat([decoder_hidden_repeated,encoder_outputs],dim=-1).view([batch_size*encoder_max_len,-1]) #[batch_size*max_len,128*4]
            attention_weight = self.V(F.tanh(self.W(h_cated))).view([batch_size,encoder_max_len]) #[batch_size*max_len,1]
            return F.softmax(attention_weight,dim=-1)
    

      decoder.py

    """
    实现解码器
    """
    import torch.nn as nn
    import config
    import torch
    import torch.nn.functional as F
    import numpy as np
    import random
    from chatbot.attention import Attention
    
    
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder,self).__init__()
    
            self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
                                          embedding_dim=config.chatbot_decoder_embedding_dim,
                                          padding_idx=config.target_ws.PAD)
    
            #需要的hidden_state形状:[1,batch_size,64]
            self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
                              hidden_size=config.chatbot_decoder_hidden_size,
                              num_layers=config.chatbot_decoder_number_layer,
                              bidirectional=False,
                              batch_first=True,
                              dropout=config.chatbot_decoder_dropout)
    
            #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64]
    
            self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
            self.attn = Attention(method="general")
            self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False)
    
        def forward(self, encoder_hidden,target,encoder_outputs):
            # print("target size:",target.size())
            #第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  #[1,batch_size,128*2]
            #第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device)         #[batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
    
            #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device)
    
            if random.random()>0.5:    #teacher_forcing机制
    
                for t in range(config.chatbot_target_max_len):
                    decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
                    decoder_outputs[:,t,:] = decoder_output_t
    
    
                    #获取当前时间步的预测值
                    value,index = decoder_output_t.max(dim=-1)
                    decoder_input = index.unsqueeze(-1)  #[batch_size,1]
                    # print("decoder_input:",decoder_input.size())
            else:
                for t in range(config.chatbot_target_max_len):
                    decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                    decoder_outputs[:, t, :] = decoder_output_t
                    #把真实值作为下一步的输入
                    decoder_input = target[:,t].unsqueeze(-1)
                    # print("decoder_input size:",decoder_input.size())
            return decoder_outputs,decoder_hidden
    
    
        def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
            '''
            计算一个时间步的结果
            :param decoder_input: [batch_size,1]
            :param decoder_hidden: [1,batch_size,128*2]
            :return:
            '''
    
            decoder_input_embeded = self.embedding(decoder_input)
            # print("decoder_input_embeded:",decoder_input_embeded.size())
    
            #out:[batch_size,1,128*2]
            #decoder_hidden :[1,bathc_size,128*2]
            # print(decoder_hidden.size())
            out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden)
    
            ##### 开始attention ############
            ### 1. 计算attention weight
            attn_weight = self.attn(decoder_hidden,encoder_outputs)  #[batch_size,1,encoder_max_len]
            ### 2. 计算context vector
            #encoder_ouputs :[batch_size,encoder_max_len,128*2]
            context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
            ### 3. 计算 attention的结果
            #[batch_size,128*2]  #context_vector:[batch_size,128*2] --> 128*4
            #attention_result = [batch_size,128*4] --->[batch_size,128*2]
            attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
            # attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
            #### attenion 结束
    
            # print("decoder_hidden size:",decoder_hidden.size())
            #out :【batch_size,1,hidden_size】
    
            # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
            out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
            # print("out_fc:",out_fc.size())
            return out_fc,decoder_hidden
    
        def evaluate(self,encoder_hidden,encoder_outputs):
    
            # 第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  # [1,batch_size,128*2]
            # 第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
            # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
                config.device)
    
            predict_result = []
            for t in range(config.chatbot_target_max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                decoder_outputs[:, t, :] = decoder_output_t
    
                # 获取当前时间步的预测值
                value, index = decoder_output_t.max(dim=-1)
                predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
                decoder_input = index.unsqueeze(-1)  # [batch_size,1]
                # print("decoder_input:",decoder_input.size())
                # predict_result.append(decoder_input)
            #把结果转化为ndarray,每一行是一条预测结果
            predict_result = np.array(predict_result).transpose()
            return decoder_outputs, predict_result
    

      seq2seq.py

    """
    完成seq2seq模型
    """
    import torch.nn as nn
    from chatbot.encoder import Encoder
    from chatbot.decoder import Decoder
    
    
    class Seq2Seq(nn.Module):
        def __init__(self):
            super(Seq2Seq,self).__init__()
            self.encoder = Encoder()
            self.decoder = Decoder()
    
        def forward(self, input,input_len,target):
            encoder_outputs,encoder_hidden = self.encoder(input,input_len)
            decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
            return decoder_outputs
    
        def evaluate(self,input,input_len):
            encoder_outputs, encoder_hidden = self.encoder(input, input_len)
            decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
            return decoder_outputs,predict_result
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    云计算-MapReduce
    云计算--hbase shell
    云计算--hdfs dfs 命令
    云计算--MPI
    jQuery 效果
    jQuery 效果
    JQuery效果隐藏/显示
    JQuery教程
    六级啊啊啊
    jQuery 安装
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12380384.html
Copyright © 2011-2022 走看看