zoukankan      html  css  js  c++  java
  • pytorch基础-搭建网络

    搭建网络的步骤大致为以下:

    1.准备数据

    2. 定义网络结构model

    3. 定义损失函数
    4. 定义优化算法 optimizer
    5. 训练
      5.1 准备好tensor形式的输入数据和标签(可选)
      5.2 前向传播计算网络输出output和计算损失函数loss
      5.3 反向传播更新参数
        以下三句话一句也不能少:
        5.3.1 optimizer.zero_grad()  将上次迭代计算的梯度值清0
        5.3.2 loss.backward()  反向传播,计算梯度值
        5.3.3 optimizer.step()  更新权值参数
      5.4 保存训练集上的loss和验证集上的loss以及准确率以及打印训练信息。(可选
    6. 图示训练过程中loss和accuracy的变化情况(可选)
    7. 在测试集上测试

    代码注释都写的很详细 

     1 import torch
     2 import torch.nn.functional as F
     3 import matplotlib.pyplot as plt
     4 
     5 # 1.准备数据 generate data
     6 x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
     7 print(x.shape)
     8 y=x*x+0.2*torch.rand(x.size())
     9 #显示数据散点图
    10 plt.scatter(x.data.numpy(),y.data.numpy())
    11 
    12 # 2.定义网络结构 build net
    13 class Net(torch.nn.Module):
    14     #n_feature:输入特征个数  n_hidden:隐藏层个数 n_output:输出层个数
    15     def __init__(self,n_feature,n_hidden,n_output):
    16         # super表示继承Net的父类,并同时初始化父类的参数
    17         super(Net,self).__init__()
    18         # nn.Linear代表线性层 代表y=w*x+b  其中w的shape为[n_hidden,n_feature] b的shape为[n_hidden]
    19         # y=w^T*x+b 这里w的维度是转置前的维度 所以是反的
    20         self.hidden =torch.nn.Linear(n_feature,n_hidden)
    21         self.predict =torch.nn.Linear(n_hidden,n_output)
    22         print(self.hidden.weight)
    23         print(self.predict.weight)
    24     #定义一个前向传播过程函数
    25     def forward(self, x):
    26         #         n_feature  n_hidden  n_output
    27         #举例(2,5,1)   2         5         1
    28         #                    -  **  -
    29         #             ** - - -  **  - -
    30         #                    -  **  - - - **
    31         #             ** - - -  **  - -
    32         #                    -  **  -
    33         #            输入层    隐藏层    输出层
    34         x=F.relu(self.hidden(x))
    35         x=self.predict(x)
    36         return x
    37 # 实例化一个网络为net
    38 net = Net(n_feature=1,n_hidden=10,n_output=1)
    39 print(net)
    40 # 3.定义损失函数 这里使用均方误差(mean square error)
    41 loss_func=torch.nn.MSELoss()
    42 # 4.定义优化器 这里使用随机梯度下降
    43 optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
    44 #定义300遍更新 每10遍显示一次
    45 plt.ion()
    46 # 5.训练
    47 for t in range(100):
    48     prediction = net(x)     # input x and predict based on x
    49     loss = loss_func(prediction, y)     # must be (1. nn output, 2. target)
    50     # 5.3反向传播三步不可少
    51     optimizer.zero_grad()   # clear gradients for next train
    52     loss.backward()         # backpropagation, compute gradients
    53     optimizer.step()        # apply gradients
    54 
    55     if t % 10 == 0:
    56         # plot and show learning process
    57         plt.cla()
    58         plt.scatter(x.data.numpy(), y.data.numpy())
    59         plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    60         plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})
    61         plt.show()
    62         plt.pause(0.1)
    63 
    64 plt.ioff()

    参考:莫烦python

  • 相关阅读:
    父子进程 signal 出现 Interrupted system call 问题
    一个测试文章
    《淘宝客户端 for Android》项目实战 html webkit android css3
    Django 中的 ForeignKey ContentType GenericForeignKey 对应的数据库结构
    coreseek 出现段错误和Unigram dictionary load Error 新情况(Gentoo)
    一个 PAM dbus 例子
    漫画统计学 T分数
    解决 paramiko 安装问题 Unable to find vcvarsall.bat
    20141202
    js
  • 原文地址:https://www.cnblogs.com/bob-jianfeng/p/11407955.html
Copyright © 2011-2022 走看看