zoukankan      html  css  js  c++  java
  • pytorch定义一个简单的神经网络

    刚学习pytorch,简单记录一下

    """
        test Funcition
    """
    
    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Net(nn.Module):
        ''' a neural network with pytorch'''
        def __init__(self):
            # 父类的构造方法
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16*5*5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            x = x.view(-1, self.num_flat_features(x))
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
        def num_flat_features(self, x):
            size = x.size()[1:]
            num_features = 1
            for s in size:
                num_features *= s
            return num_features
    
    
    net = Net()
    # 查看网络
    print(net)
    
    # 查看模型需要学习的参数
    params = list(net.parameters())
    print(len(params))
    for param in params:
        print(param.size())
    
    # 输入数据
    input = Variable(torch.randn(1,1,32,32))
    print(input)
    out = net(input)
    print(out)
    
    # 损失函数
    target = Variable(torch.arange(1, 11, dtype=torch.float32))
    print(target)
    criterion = nn.MSELoss()
    loss = criterion(out, target)
    print(loss)

    输出结果:

  • 相关阅读:
    5js
    4js
    3js
    2js
    1js
    7css
    6css
    5css
    Django高级项目实战-开发企业级问答网站,学Django就这了
    Java日志第27天 2020.8.1
  • 原文地址:https://www.cnblogs.com/demo-deng/p/10599049.html
Copyright © 2011-2022 走看看