zoukankan      html  css  js  c++  java
  • pytorch seq2seq闲聊机器人加入attention机制

    attention.py

    """
    实现attention
    """
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import config
    
    
    class Attention(nn.Module):
        def __init__(self,method="general"):
            super(Attention,self).__init__()
            assert method in ["dot","general","concat"],"attention method error"
            self.method = method
            if method == "general":
                self.W = nn.Linear(config.chatbot_encoder_hidden_size*2,config.chatbot_encoder_hidden_size*2,bias=False)
    
            if method == "concat":
                self.W = nn.Linear(config.chatbot_decoder_hidden_size*4,config.chatbot_decoder_hidden_size*2,bias=False)
                self.V = nn.Linear(config.chatbot_decoder_hidden_size*2,1,bias=False)
    
    
    
        def forward(self,decoder_hidden,encoder_outputs):
            if self.method == "dot":
                return self.dot_score(decoder_hidden,encoder_outputs)
    
            elif self.method == "general":
                return self.general_socre(decoder_hidden,encoder_outputs)
    
            elif self.method == "concat":
                return self.concat_socre(decoder_hidden,encoder_outputs)
    
        def dot_score(self,decoder_hidden,encoder_outputs):
            """H_t^T * H_s
            :param decoder_hidden:[1,batch_size,128*2] --->[batch_size,128*2,1]
            :param encoder_outputs:[batch_size,encoder_max_len,128*2] --->[batch_size,encoder_max_len,128*2]
            :return:attention_weight:[batch_size,encoder_max_len]
            """
            decoder_hidden_viewed = decoder_hidden.squeeze(0).unsqueeze(-1) #[batch_size,128*2,1]
            attention_weight = torch.bmm(encoder_outputs,decoder_hidden_viewed).squeeze(-1)
            return F.softmax(attention_weight,dim=-1)
    
        def general_socre(self,decoder_hidden,encoder_outputs):
            """
            H_t^T *W* H_s
            :param decoder_hidden:[1,batch_size,128*2]-->[batch_size,decode_hidden_size] *[decoder_hidden_size,encoder_hidden_size]--->[batch_size,encoder_hidden_size]
            :param encoder_outputs:[batch_size,encoder_max_len,128*2]
            :return:[batch_size,encoder_max_len]
            """
            decoder_hidden_processed =self.W(decoder_hidden.squeeze(0)).unsqueeze(-1) #[batch_size,encoder_hidden_size*2,1]
            attention_weight = torch.bmm(encoder_outputs, decoder_hidden_processed).squeeze(-1)
            return F.softmax(attention_weight, dim=-1)
    
        def concat_socre(self,decoder_hidden,encoder_outputs):
            """
            V*tanh(W[H_t,H_s])
            :param decoder_hidden:[1,batch_size,128*2]
            :param encoder_outputs:[batch_size,encoder_max_len,128*2]
            :return:[batch_size,encoder_max_len]
            """
            #1. decoder_hidden:[batch_size,128*2] ----> [batch_size,encoder_max_len,128*2]
            # encoder_max_len 个[batch_size,128*2] -->[encoder_max_len,bathc_size,128*2] -->transpose--->[]
            encoder_max_len = encoder_outputs.size(1)
            batch_size = encoder_outputs.size(0)
            decoder_hidden_repeated = decoder_hidden.squeeze(0).repeat(encoder_max_len,1,1).transpose(0,1) #[batch_size,max_len,128*2]
            h_cated = torch.cat([decoder_hidden_repeated,encoder_outputs],dim=-1).view([batch_size*encoder_max_len,-1]) #[batch_size*max_len,128*4]
            attention_weight = self.V(F.tanh(self.W(h_cated))).view([batch_size,encoder_max_len]) #[batch_size*max_len,1]
            return F.softmax(attention_weight,dim=-1)
    

      decoder.py

    """
    实现解码器
    """
    import torch.nn as nn
    import config
    import torch
    import torch.nn.functional as F
    import numpy as np
    import random
    from chatbot.attention import Attention
    
    
    class Decoder(nn.Module):
        def __init__(self):
            super(Decoder,self).__init__()
    
            self.embedding = nn.Embedding(num_embeddings=len(config.target_ws),
                                          embedding_dim=config.chatbot_decoder_embedding_dim,
                                          padding_idx=config.target_ws.PAD)
    
            #需要的hidden_state形状:[1,batch_size,64]
            self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim,
                              hidden_size=config.chatbot_decoder_hidden_size,
                              num_layers=config.chatbot_decoder_number_layer,
                              bidirectional=False,
                              batch_first=True,
                              dropout=config.chatbot_decoder_dropout)
    
            #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64]
    
            self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws))
            self.attn = Attention(method="general")
            self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False)
    
        def forward(self, encoder_hidden,target,encoder_outputs):
            # print("target size:",target.size())
            #第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  #[1,batch_size,128*2]
            #第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device)         #[batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
    
            #使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device)
    
            if random.random()>0.5:    #teacher_forcing机制
    
                for t in range(config.chatbot_target_max_len):
                    decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs)
                    decoder_outputs[:,t,:] = decoder_output_t
    
    
                    #获取当前时间步的预测值
                    value,index = decoder_output_t.max(dim=-1)
                    decoder_input = index.unsqueeze(-1)  #[batch_size,1]
                    # print("decoder_input:",decoder_input.size())
            else:
                for t in range(config.chatbot_target_max_len):
                    decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                    decoder_outputs[:, t, :] = decoder_output_t
                    #把真实值作为下一步的输入
                    decoder_input = target[:,t].unsqueeze(-1)
                    # print("decoder_input size:",decoder_input.size())
            return decoder_outputs,decoder_hidden
    
    
        def forward_step(self,decoder_input,decoder_hidden,encoder_outputs):
            '''
            计算一个时间步的结果
            :param decoder_input: [batch_size,1]
            :param decoder_hidden: [1,batch_size,128*2]
            :return:
            '''
    
            decoder_input_embeded = self.embedding(decoder_input)
            # print("decoder_input_embeded:",decoder_input_embeded.size())
    
            #out:[batch_size,1,128*2]
            #decoder_hidden :[1,bathc_size,128*2]
            # print(decoder_hidden.size())
            out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden)
    
            ##### 开始attention ############
            ### 1. 计算attention weight
            attn_weight = self.attn(decoder_hidden,encoder_outputs)  #[batch_size,1,encoder_max_len]
            ### 2. 计算context vector
            #encoder_ouputs :[batch_size,encoder_max_len,128*2]
            context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2]
            ### 3. 计算 attention的结果
            #[batch_size,128*2]  #context_vector:[batch_size,128*2] --> 128*4
            #attention_result = [batch_size,128*4] --->[batch_size,128*2]
            attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1)))
            # attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1))
            #### attenion 结束
    
            # print("decoder_hidden size:",decoder_hidden.size())
            #out :【batch_size,1,hidden_size】
    
            # out_squeezed = out.squeeze(dim=1) #去掉为1的维度
            out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size]
            # print("out_fc:",out_fc.size())
            return out_fc,decoder_hidden
    
        def evaluate(self,encoder_hidden,encoder_outputs):
    
            # 第一个时间步的输入的hidden_state
            decoder_hidden = encoder_hidden  # [1,batch_size,128*2]
            # 第一个时间步的输入的input
            batch_size = encoder_hidden.size(1)
            decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
    
            # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
            decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to(
                config.device)
    
            predict_result = []
            for t in range(config.chatbot_target_max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs)
                decoder_outputs[:, t, :] = decoder_output_t
    
                # 获取当前时间步的预测值
                value, index = decoder_output_t.max(dim=-1)
                predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...]
                decoder_input = index.unsqueeze(-1)  # [batch_size,1]
                # print("decoder_input:",decoder_input.size())
                # predict_result.append(decoder_input)
            #把结果转化为ndarray,每一行是一条预测结果
            predict_result = np.array(predict_result).transpose()
            return decoder_outputs, predict_result
    

      seq2seq.py

    """
    完成seq2seq模型
    """
    import torch.nn as nn
    from chatbot.encoder import Encoder
    from chatbot.decoder import Decoder
    
    
    class Seq2Seq(nn.Module):
        def __init__(self):
            super(Seq2Seq,self).__init__()
            self.encoder = Encoder()
            self.decoder = Decoder()
    
        def forward(self, input,input_len,target):
            encoder_outputs,encoder_hidden = self.encoder(input,input_len)
            decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs)
            return decoder_outputs
    
        def evaluate(self,input,input_len):
            encoder_outputs, encoder_hidden = self.encoder(input, input_len)
            decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs)
            return decoder_outputs,predict_result
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    2.12 使用@DataProvider
    2.11 webdriver中使用 FileUtils ()
    Xcode8 添加PCH文件
    The app icon set "AppIcon" has an unassigned child告警
    Launch Image
    iOS App图标和启动画面尺寸
    iPhone屏幕尺寸、分辨率及适配
    Xcode下载失败 使用已购项目页面再试一次
    could not find developer disk image
    NSDate与 NSString 、long long类型的相互转化
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12380384.html
Copyright © 2011-2022 走看看