zoukankan      html  css  js  c++  java
  • TrajPreModel

    轨迹预测模型

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    #######################################
    class TrajPreModel(nn.Module):
        """self-attention model"""
        def __init__(self, loc_size=528, loc_emb_size=128, hidden_size=32, head_num=1, dropout_p=0):
            super(TrajPreModel, self).__init__()
            self.loc_size = loc_size
            self.loc_emb_size = loc_emb_size
            self.hidden_size = hidden_size
            self.heads = head_num
            self.dropout_p = dropout_p
            # embeding
            self.emb_loc = nn.Embedding(self.loc_size, self.loc_emb_size)
            self.weight = self.emb_loc.weight
                  
            #-------------model---------------
            self.attention = MultiSelfAttention(self.heads, self.loc_emb_size, dropout=self.dropout_p)
            self.fc = nn.Linear(self.loc_emb_size, self.loc_size)
            self.is_weight_sharing = False#is_weight_sharing
            self.init_weights()
            self.dropout = nn.Dropout(p=dropout_p)
    
        def init_weights(self):
            ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
            hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
            b = (param.data for name, param in self.named_parameters() if 'bias' in name)
            for t in ih:
                nn.init.xavier_uniform(t)
            for t in hh:
                nn.init.orthogonal(t)
            for t in b:
                nn.init.constant_(t, 0)
    
        def forward(self, x):
            
            seq = x[1] # [batch_size, seq_len]
            loc_emb = self.emb_loc(seq) 
            output = self.dropout(loc_emb)
            #Self-attention
            
            output = self.attention(output,output, output)
            output = self.dropout(output)
    
            if not self.is_weight_sharing:
                y = self.fc(output)
            else:
                y = F.linear(output, self.weight)
            
            score = F.log_softmax(y, dim=-1) 
            return score.view(-1, self.loc_size) # [batch_size, seq_len, loc_size]
    
    
  • 相关阅读:
    不要对父母说的十句话
    25句世界上最經典的話
    奥巴马(Obama)获胜演讲全文[中英对照]+高清视频下载
    【心理测试】男人眼中女人哪里最性感?
    一张图看出你是用左脑还是右脑
    我是有原则的~~(笑话转贴)
    各行各业常用大用大谎言
    你知道我今天为啥要来上班吗?
    女性抽烟腰会变粗?
    心理学家揭秘人临死时的感受
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12919950.html
Copyright © 2011-2022 走看看