zoukankan      html  css  js  c++  java
  • pytorch搭建简单网络

    pytorch搭建一个简单神经网络

     1 import torch
     2 import torch.nn as nn
     3 
     4 # 定义数据
     5 # x:输入数据
     6 # y:标签
     7 x = torch.Tensor([[0.2, 0.4], [0.2, 0.3], [0.3, 0.4]])
     8 y = torch.Tensor([[0.6], [0.5], [0.7]])
     9 
    10 
    11 class MyNet(nn.Module):
    12     def __init__(self):
    13         # 调用基类构造函数
    14         super(MyNet, self).__init__()
    15         # 容器,使用时顺序调用各个层
    16         self.fc = nn.Sequential(
    17             # 定义三层
    18             # 输入层
    19             nn.Linear(2, 4),
    20             # 激活函数
    21             nn.Sigmoid(),
    22             # 隐藏层
    23             nn.Linear(4, 4),
    24             nn.Sigmoid(),
    25             # 输出层
    26             nn.Linear(4, 1),
    27         )
    28         # 优化器
    29         # params:优化对象
    30         # lr:学习率
    31         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
    32         # 损失函数,均方差
    33         self.mls = torch.nn.MSELoss()
    34 
    35     def forward(self, inputs):
    36         # 前向传播
    37         return self.fc(inputs)
    38 
    39     def train(self, x, y):
    40         # 训练
    41         # 得到输出结果
    42         out = self.forward(x)
    43         # 计算误差
    44         loss = self.mls(out, y)
    45         # print('loss', loss)
    46         # 梯度置零
    47         self.opt.zero_grad()
    48         # 误差反向传播
    49         loss.backward()
    50         # 更新权重
    51         self.opt.step()
    52 
    53     def test(self, x):
    54         # 测试,就是前向传播的过程
    55         return self.forward(x)
    56 
    57 
    58 net = MyNet()
    59 for i in range(10000):
    60     net.train(x, y)
    61 x = torch.Tensor([[0.4, 0.1]])
    62 out = net.test(x)
    63 print(out)  # 输出结果 tensor([[0.5205]], grad_fn=<AddmmBackward>)

    训练集较少,可能结果不是很好,主要是结构,毕竟刚开始接触这个pytorch

  • 相关阅读:
    项目结束后一点心得
    提交disabled按钮的几种方法
    发现VS2005一个BUG
    单一文件上传防止粘帖及格式限制
    MessageBox.Show常用的2个方法
    一点感受一点体会
    EXCEL导入GridView,然后再汇入数据库.
    2根ECC内存
    (转载)gridview添加删除确认对话框
    反射调用Method
  • 原文地址:https://www.cnblogs.com/MC-Curry/p/10107380.html
Copyright © 2011-2022 走看看