zoukankan      html  css  js  c++  java
  • 通过pytorch,零基础了解深度学习

    这是我的代码和注释,你可以通过直接复制代码到你的pycharm中跑起来。

    你不需要另外去准备数据集,当本地没有数据集运行代码就会自动下载

    这是一个很小的项目,你不需要准备GPU

    mnist_train.py

      1 import torch
      2 
      3 # nn包用来完成神经网络的搭建
      4 from torch import nn
      5 
      6 # functional包含常用的函数
      7 from torch.nn import functional as F
      8 
      9 # optim优化数据包,用来更新权重
     10 from torch import optim
     11 
     12 # 视觉相关的工具包
     13 import torchvision
     14 
     15 # 导入画图工具包
     16 from matplotlib import pyplot as plt
     17 
     18 # 从utils 包里导入所需工具
     19 from utils import plot_image, plot_curve, one_hot
     20 
     21 # step1. load dataset 加载数据集
     22 
     23 # 这里设定一次处理多少张图片
     24 batch_size = 512
     25 
     26 # 加载训练集
     27 train_loader = torch.utils.data.DataLoader(
     28     # 加载MNIST数据集(1.图片路径,2.指定下载的图片为text还是train,3.download若1本地没有则去网上下载,
     29     # 4.transform格式转换,网上图片一般为numpy格式,转为totensor格式)
     30     torchvision.datasets.MNIST('mnist_data', train=True, download=True,
     31                                transform=torchvision.transforms.Compose([
     32                                    torchvision.transforms.ToTensor(),
     33                                    torchvision.transforms.Normalize(
     34                                        (0.1307,), (0.3081,))
     35                                    # 这个参数是正则化,防止过拟合,防止参数过多或过大,避免模型过复杂。有L1正则化和L2正则化,这里是让参数维持在0的附近均匀的分配
     36                                ])),  # 0.3081是均差
     37     batch_size=batch_size, shuffle=True)  # 加载数据并随机打散数据
     38 
     39 # 加载测试集
     40 test_loader = torch.utils.data.DataLoader(
     41     torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
     42                                transform=torchvision.transforms.Compose([
     43                                    torchvision.transforms.ToTensor(),
     44                                    torchvision.transforms.Normalize(
     45                                        (0.1307,), (0.3081,))
     46                                ])),
     47     batch_size=batch_size, shuffle=False)
     48 
     49 
     50 # # 查看图片
     51 x, y = next(iter(train_loader))
     52 print(x.shape, y.shape, x.min(), x.max())
     53 plot_image(x, y, 'image sample')
     54 
     55 
     56 # 设置网络层
     57 
     58 class Net(nn.Module):
     59 
     60     def __init__(self):
     61         super(Net, self).__init__()
     62 
     63         # xw + b
     64         # 第一层,第一个参数为图像大小,第二个参数根据经验值设置输出层大小
     65         self.fc1 = nn.Linear(28 * 28, 256)
     66         # 第二层,第一个个参数为上一层的输出大小,第二个大小根据经验设置输出层大小
     67         self.fc2 = nn.Linear(256, 64)
     68         # 最后一层,第一个值为上一层输出大小,第二个参数为输出的种类数
     69         self.fc3 = nn.Linear(64, 10)
     70 
     71     # 计算函数
     72     def forward(self, x):
     73         # x:[b,1,28,28]  #relu将线性函数调整变种为非线性函数
     74         # h1=relu(xw1 +b1)
     75         x = F.relu(self.fc1(x))
     76         # h2=relu(h1w2+b20
     77         x = F.relu(self.fc2(x))
     78         # 第三层为输出层,一般输出概率值
     79         x = self.fc3(x)
     80 
     81         return x
     82 
     83 
     84 # 对创建的神经网络进行初始化
     85 net = Net()
     86 
     87 # 设置对计算后的梯度进行梯度更新方法,这里采用SGD随机梯度下降,lr是学习率
     88 optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
     89 
     90 train_loss = []
     91 
     92 for epoch in range(3):
     93     # 对整个数据集迭代三次
     94     for batch_idx, (x, y) in enumerate(train_loader):
     95         # 对整个数据集迭代一次
     96 
     97         # x :
     98         # print(x.shape,y.shape)
     99 
    100         # 输入
    101         x = x.view(x.size(0), 28 * 28)
    102 
    103         # 输出
    104         out = net(x)  # 我们的目的是将输出更加接近于y
    105 
    106         # 将真实的y转为独热编码
    107         y_onehot = one_hot(y)
    108 
    109         # 通过mse_loss计算误差值,也就是均方差
    110         loss = F.mse_loss(out, y_onehot)
    111 
    112         # 清零梯度
    113         optimizer.zero_grad()
    114         # 计算梯度
    115         loss.backward()
    116         # 更新梯度
    117         optimizer.step()
    118 
    119         # 最后我们会得到较为合适的[w1,b1,w2,b2,w3,b3]
    120 
    121         # 将loss数据收集,以便用matplotlib将其变化图示化
    122         train_loss.append(loss.item())
    123 
    124 
    125         # 查看loss下降的变化
    126         if batch_idx % 10 == 0:
    127 
    128             print(epoch, batch_idx, loss.item())
    129 
    130 
    131  plot_curve(train_loss)
    132 
    133 # 我们最终想要看到的并不是loss而是准确率
    134 # 准确度的测试
    135 # 在test测试集取数据然后进行测试
    136 total_correct = 0
    137 for x,y in test_loader:
    138     x  = x.view(x.size(0), 28*28)
    139     out = net(x)
    140     # out: [b, 10] => pred: [b]
    141     pred = out.argmax(dim=1)
    142     correct = pred.eq(y).sum().float().item()
    143     total_correct += correct
    144 145 total_num = len(test_loader.dataset)
    146 acc = total_correct / total_num
    147 print('test acc:', acc)
    148 149 x, y = next(iter(test_loader))
    150 out = net(x.view(x.size(0), 28*28))
    151 pred = out.argmax(dim=1)
    152 plot_image(x, pred, 'test')

    utils.py文件中包含的是画图函数,和独热编码的函数,可以直接调用,比如上面的代码就调用了它。将它一并放入你的pycharm中

    import torch
    from matplotlib import pyplot as plt
    
    
    # 画一条曲线
    def plot_curve(data):
        fig = plt.figure()
        plt.plot(range(len(data)), data, color='blue')
        plt.legend(['value'], loc='upper right')
        plt.xlabel('step')
        plt.ylabel('value')
        plt.show()
    
    
    
    # 可视化查看识别结果
    def plot_image(img, label, name):
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
            plt.title("{}: {}".format(name, label[i].item()))
            plt.xticks([])
            plt.yticks([])
        plt.show()
    
    
    def one_hot(label, depth=10):
        out = torch.zeros(label.size(0), depth)
        idx = torch.LongTensor(label).view(-1, 1)
        out.scatter_(dim=1, index=idx, value=1)  # 生成独热编码
        return out
  • 相关阅读:
    atoi (String to Integer) leetcode
    按层逆遍历一棵树,使用满二叉树存储
    unix网络编程-配置unp.h头文件
    ListView系列(七)——Adapter内的onItemClick监听器四个arg参数
    Windows系统下安装VirtualBox,系统找不到指定路径的做法
    Android Fragment完全解析,关于碎片你所需知道的一切
    【Android开源框架列表】
    fragment报错
    2013 年开源中国 10 大热门 Java 开源项目
    【移动开发】Android中三种超实用的滑屏方式汇总(ViewPager、ViewFlipper、ViewFlow)
  • 原文地址:https://www.cnblogs.com/98ZHANG/p/12700969.html
Copyright © 2011-2022 走看看