zoukankan      html  css  js  c++  java
  • TrajPreModel

    轨迹预测模型

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    #######################################
    class TrajPreModel(nn.Module):
        """self-attention model"""
        def __init__(self, loc_size=528, loc_emb_size=128, hidden_size=32, head_num=1, dropout_p=0):
            super(TrajPreModel, self).__init__()
            self.loc_size = loc_size
            self.loc_emb_size = loc_emb_size
            self.hidden_size = hidden_size
            self.heads = head_num
            self.dropout_p = dropout_p
            # embeding
            self.emb_loc = nn.Embedding(self.loc_size, self.loc_emb_size)
            self.weight = self.emb_loc.weight
                  
            #-------------model---------------
            self.attention = MultiSelfAttention(self.heads, self.loc_emb_size, dropout=self.dropout_p)
            self.fc = nn.Linear(self.loc_emb_size, self.loc_size)
            self.is_weight_sharing = False#is_weight_sharing
            self.init_weights()
            self.dropout = nn.Dropout(p=dropout_p)
    
        def init_weights(self):
            ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
            hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
            b = (param.data for name, param in self.named_parameters() if 'bias' in name)
            for t in ih:
                nn.init.xavier_uniform(t)
            for t in hh:
                nn.init.orthogonal(t)
            for t in b:
                nn.init.constant_(t, 0)
    
        def forward(self, x):
            
            seq = x[1] # [batch_size, seq_len]
            loc_emb = self.emb_loc(seq) 
            output = self.dropout(loc_emb)
            #Self-attention
            
            output = self.attention(output,output, output)
            output = self.dropout(output)
    
            if not self.is_weight_sharing:
                y = self.fc(output)
            else:
                y = F.linear(output, self.weight)
            
            score = F.log_softmax(y, dim=-1) 
            return score.view(-1, self.loc_size) # [batch_size, seq_len, loc_size]
    
    
  • 相关阅读:
    hexo部署失败如何解决
    github设置添加SSH
    鼠标相对于屏幕的位置、鼠标相对于窗口的位置和获取鼠标相对于文档的位置
    git push origin master 错误解决办法
    js设计模式:工厂模式、构造函数模式、原型模式、混合模式
    d3.js实现自定义多y轴折线图
    计算机网络之HTTP(上)基础知识点
    Node.js学习笔记(一)基础介绍
    计算机组成
    Ajax及跨域
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12919950.html
Copyright © 2011-2022 走看看