zoukankan      html  css  js  c++  java
  • pytorch model()[] 模型对象类型

    model = Model() model(input) 直接调用Model类中的forward(input)函数,因其实现了__call__

    举个例子

     1 import math, random
     2 import numpy as np
     3 
     4 import torch
     5 import torch.nn as nn
     6 import torch.optim as optim
     7 import torch.autograd as autograd 
     8 import torch.nn.functional as F
     9 USE_CUDA = torch.cuda.is_available()
    10 Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)
    11 
    12 class Encoder(nn.Module):
    13     def __init__(self, din=32, hidden_dim=128):
    14         super(Encoder, self).__init__()
    15         self.fc = nn.Linear(din, hidden_dim)
    16 
    17     def forward(self, x):
    18         embedding = F.relu(self.fc(x))
    19         return embedding
    20 
    21 class AttModel(nn.Module):
    22     def __init__(self, n_node, din, hidden_dim, dout):
    23         super(AttModel, self).__init__()
    24         self.fcv = nn.Linear(din, hidden_dim)
    25         self.fck = nn.Linear(din, hidden_dim)
    26         self.fcq = nn.Linear(din, hidden_dim)
    27         self.fcout = nn.Linear(hidden_dim, dout)
    28 
    29     def forward(self, x, mask):
    30         v = F.relu(self.fcv(x))
    31         q = F.relu(self.fcq(x))
    32         k = F.relu(self.fck(x)).permute(0,2,1)
    33         att = F.softmax(torch.mul(torch.bmm(q,k), mask) - 9e15*(1 - mask),dim=2)
    34 
    35         out = torch.bmm(att,v)
    36         #out = torch.add(out,v)
    37         out = F.relu(self.fcout(out))
    38         return out
    39 
    40 class Q_Net(nn.Module):
    41     def __init__(self, hidden_dim, dout):
    42         super(Q_Net, self).__init__()
    43         self.fc = nn.Linear(hidden_dim, dout)
    44 
    45     def forward(self, x):
    46         q = self.fc(x)
    47         return q
    View Code
     1 class DGN(nn.Module):
     2     def __init__(self,n_agent,num_inputs,hidden_dim,num_actions):
     3         super(DGN, self).__init__()
     4         
     5         self.encoder = Encoder(num_inputs,hidden_dim)
     6         self.att_1 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
     7         self.att_2 = AttModel(n_agent,hidden_dim,hidden_dim,hidden_dim)
     8         self.q_net = Q_Net(hidden_dim,num_actions)
     9         
    10     def forward(self, x, mask):
    11         h1 = self.encoder(x)
    12         h2 = self.att_1(h1, mask)
    13         h3 = self.att_2(h2, mask)
    14         q = self.q_net(h3)
    15         return q 

    在监视窗口查看

     model是Tensor类型

    故model(input)[0]是取第一个batch

  • 相关阅读:
    周末小练习
    第十二届全国大学生信息安全竞赛总结与反思
    sql注入学习心得与sqlmap使用心得
    2019“嘉韦思”杯RSA256题目wp
    斐波那契数列求解的三种方法
    二叉树的下一个节点
    替换空格
    二维数组中的查找
    不修改数组找出重复数字
    数组中重复数字
  • 原文地址:https://www.cnblogs.com/yuelien/p/15631896.html
Copyright © 2011-2022 走看看