zoukankan      html  css  js  c++  java
  • TrajPreModel

    轨迹预测模型

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    #######################################
    class TrajPreModel(nn.Module):
        """self-attention model"""
        def __init__(self, loc_size=528, loc_emb_size=128, hidden_size=32, head_num=1, dropout_p=0):
            super(TrajPreModel, self).__init__()
            self.loc_size = loc_size
            self.loc_emb_size = loc_emb_size
            self.hidden_size = hidden_size
            self.heads = head_num
            self.dropout_p = dropout_p
            # embeding
            self.emb_loc = nn.Embedding(self.loc_size, self.loc_emb_size)
            self.weight = self.emb_loc.weight
                  
            #-------------model---------------
            self.attention = MultiSelfAttention(self.heads, self.loc_emb_size, dropout=self.dropout_p)
            self.fc = nn.Linear(self.loc_emb_size, self.loc_size)
            self.is_weight_sharing = False#is_weight_sharing
            self.init_weights()
            self.dropout = nn.Dropout(p=dropout_p)
    
        def init_weights(self):
            ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
            hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
            b = (param.data for name, param in self.named_parameters() if 'bias' in name)
            for t in ih:
                nn.init.xavier_uniform(t)
            for t in hh:
                nn.init.orthogonal(t)
            for t in b:
                nn.init.constant_(t, 0)
    
        def forward(self, x):
            
            seq = x[1] # [batch_size, seq_len]
            loc_emb = self.emb_loc(seq) 
            output = self.dropout(loc_emb)
            #Self-attention
            
            output = self.attention(output,output, output)
            output = self.dropout(output)
    
            if not self.is_weight_sharing:
                y = self.fc(output)
            else:
                y = F.linear(output, self.weight)
            
            score = F.log_softmax(y, dim=-1) 
            return score.view(-1, self.loc_size) # [batch_size, seq_len, loc_size]
    
    
  • 相关阅读:
    platform_device和platform_driver
    理解和认识udev
    platform_device和platform_driver
    bzImage的概要生成过程
    shell 字符表
    分析mtk6516如何加入自己的驱动
    理解和使用Linux的硬件抽象层HAL
    bzImage的概要生成过程
    理解和认识udev
    shell 字符表
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12919950.html
Copyright © 2011-2022 走看看