zoukankan      html  css  js  c++  java
  • pytorch定义一个简单的神经网络

    刚学习pytorch,简单记录一下

    """
        test Funcition
    """
    
    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Net(nn.Module):
        ''' a neural network with pytorch'''
        def __init__(self):
            # 父类的构造方法
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16*5*5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, self.num_flat_features(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
        def num_flat_features(self, x):
            size = x.size()[1:]
            num_features = 1
            for s in size:
                num_features *= s
            return num_features
    
    
    net = Net()
    # 查看网络
    print(net)
    
    # 查看模型需要学习的参数
    params = list(net.parameters())
    print(len(params))
    for param in params:
        print(param.size())
    
    # 输入数据
    input = Variable(torch.randn(1,1,32,32))
    print(input)
    out = net(input)
    print(out)
    
    # 损失函数
    target = Variable(torch.arange(1, 11, dtype=torch.float32))
    print(target)
    criterion = nn.MSELoss()
    loss = criterion(out, target)
    print(loss)

    输出结果:

  • 相关阅读:
    CF1175B Catch Overflow!
    震惊!一蒟蒻竟然写出fhqTreap
    树上差分
    洛谷 P3128 最大流Max Flow
    线段树的标记永久化/二维线段树模板
    矩阵加速~desire drive
    置换相关
    树形图们
    严格单调递增与非严格之间的转换
    记录延续性的一类dp
  • 原文地址:https://www.cnblogs.com/demo-deng/p/10599049.html
Copyright © 2011-2022 走看看