zoukankan      html  css  js  c++  java
  • 【深度学习】paddlepaddle——基于卷积神经网络的手写字识别案例

      1 # 1、导包
      2 import paddle.fluid as fluid
      3 import paddle
      4 import time
      5 
      6 start = time.time()
      7 
      8 
      9 def test_program(exe, feeder, program, fetch_list, reader):
     10     """
     11     测试进程
     12     :param exe:执行器
     13     :param feeder: 数据与网络关系
     14     :param program: 测试主进程
     15     :param fetch_list: 需要执行之后返回的损失与准确率
     16     :param reader: 测试reader
     17     :return:
     18     """
     19     # 训练次数
     20     count = 0
     21     # 整个测试集的总损失
     22     sum_loss = 0
     23     # 整个训练集的准确率
     24     sum_acc = 0
     25     for test_data in reader():
     26         test_avg_loss_value, test_acc_values = exe.run(
     27             program=program,  # 测试主进程
     28             feed=feeder.feed(test_data),  # 给测试喂数据
     29             fetch_list=fetch_list  # 需要执行之后返回的值
     30         )
     31 
     32         sum_loss += test_avg_loss_value
     33         sum_acc += test_acc_values
     34         count += 1
     35     # 得到整个训练集的平均损失,与整个训练集的准确率
     36     test_avg_loss = sum_loss / count
     37     test_acc = sum_acc / count
     38 
     39     return test_avg_loss, test_acc
     40 
     41 
     42 # 2、数据处理---paddlepaddle 自带的mnist数据已经经过了数据处理
     43 
     44 # 3、定义reader
     45 # paddlepaddle给我们已经定义好了reader,只需要去调用
     46 
     47 # 4、指定训练场所
     48 place = fluid.CPUPlace()
     49 
     50 # 5、配置网络
     51 # 特征数据层
     52 image = fluid.layers.data(name="image", shape=[1, 28, 28], append_batch_size=True, dtype="float64")
     53 # 目标数据层
     54 label = fluid.layers.data(name="label", shape=[1], append_batch_size=True, dtype="int64")
     55 # 设计两个卷积、激活、池化 之后 + fc的卷积神经网络
     56 conv1 = fluid.nets.simple_img_conv_pool(
     57     input=image,  # 输入
     58     num_filters=20,  # 卷积核个数
     59     filter_size=3,  # 卷积核大小
     60     pool_size=2,  # 池化大小
     61     pool_stride=2,  # 池化步长
     62     act="relu",  # 激活函数
     63 )
     64 conv2 = fluid.nets.simple_img_conv_pool(
     65     input=conv1,  # 输入
     66     num_filters=10,  # 卷积核个数
     67     filter_size=5,  # 卷积核大小
     68     pool_size=2,  # 池化大小
     69     pool_stride=2,  # 池化步长
     70     act="relu",  # 激活函数
     71 )
     72 y_predict = fluid.layers.fc(input=conv2, size=10, act="softmax", name="output_layer")
     73 
     74 # 6、损失
     75 # 交叉熵损失
     76 loss = fluid.layers.cross_entropy(input=y_predict, label=label)
     77 # 计算平均损失
     78 avg_loss = fluid.layers.mean(loss)
     79 
     80 # 计算准确率
     81 acc = fluid.layers.accuracy(input=y_predict, label=label)
     82 
     83 # 7、指定优化---sgd随机梯度下降优化算法
     84 sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
     85 # 指定去优化损失
     86 sgd_optimizer.minimize(avg_loss)
     87 
     88 # 8、指定网络与数据层的关系
     89 feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
     90 
     91 # 9、构建执行器
     92 # 训练执行器
     93 exe_train = fluid.Executor(place=place)
     94 # 测试执行器
     95 exe_test = fluid.Executor(place=place)
     96 
     97 # 10、初始化网络参数
     98 # 初始化参数进程
     99 startup_program = fluid.default_startup_program()
    100 exe_train.run(startup_program)
    101 # 主进程
    102 # 训练主进程
    103 train_main_program = fluid.default_main_program()
    104 # 测试主进程
    105 test_main_program = train_main_program.clone(for_test=True)
    106 
    107 # 11、获取图片数据
    108 # 并不是直接拿到数据就往网络里面送
    109 # 构建一个缓冲区,--打乱顺序,--再往网络里面送
    110 # paddle.dataset.mnist.train() ----paddlepaddle的训练reader
    111 # 缓冲区大小buf_size与批次大小batch_size 并没有多大的关系
    112 # 一般设计的时候注意:buf_size 略微需要比batch_size 大一点就可以
    113 # 而且batch_size 不能过大
    114 # 训练reader 与测试reader 的batch_size数量必须一致
    115 train_reader = paddle.batch(
    116     paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=50),
    117     batch_size=10
    118 )
    119 test_reader = paddle.batch(
    120     paddle.reader.shuffle(paddle.dataset.mnist.test(), buf_size=50),
    121     batch_size=10
    122 )
    123 
    124 # 12、训练
    125 # 指定训练轮数
    126 loop_num = 2
    127 # 定义的执行次数
    128 step = 0
    129 
    130 flag = False
    131 
    132 for loop in range(loop_num):
    133     print("第%d轮训练" % loop)
    134     # train_data 每批次的数据
    135     for train_data in train_reader():
    136         # 执行器运行训练主进程
    137         train_avg_loss_value, train_acc_value = exe_train.run(
    138             program=train_main_program,  # 训练主进程
    139             feed=feeder.feed(train_data),  # 利用数据层与网络构建好的关系,将真实的数据喂给网络
    140             fetch_list=[avg_loss, acc]  # 执行之后需要返回的结果的值
    141         )
    142         # 每隔10步来打印一下损失与准确率
    143         if step % 10 == 0 and step != 0:
    144             print("第%d次训练的损失为%f,准确率为%f" % (step, train_avg_loss_value, train_acc_value))
    145 
    146         step += 1
    147 
    148         # 每隔100步 去测试集中测试一下训练效果
    149         if step % 100 == 0 and step != 0:
    150             test_avg_loss, test_acc = test_program(exe_test,
    151                                                    feeder,
    152                                                    test_main_program,
    153                                                    fetch_list=[avg_loss, acc],
    154                                                    reader=test_reader
    155                                                    )
    156             print("*" * 100)
    157             print("测试集的损失为:%f,准确率为:%f" % (test_avg_loss, test_acc))
    158             print("*" * 100)
    159             if test_avg_loss <= 0.1 and test_acc >= 0.98:
    160                 flag = True
    161 
    162                 print("最终测试集的损失为:%f,准确率为:%f" % (test_avg_loss, test_acc))
    163                 end = time.time()
    164 
    165                 print("运行总时长为:", end - start)
    166                 break
    167     if flag:
    168         break
  • 相关阅读:
    如何将CentOS的默认启动界面修改为图形界面or字符界面
    如何将CentOS的默认启动界面修改为图形界面or字符界面
    virtualbox下CentOS7安装增强功能
    蓝牙4.0
    HC-SR04超声波测距
    STM32F407 通用同步异步收发器(串口)
    STM32F4 TIM(外设定时器)
    STM32F4 系统定时器
    STM32F4 异常与中断
    LED和按键实验
  • 原文地址:https://www.cnblogs.com/Tree0108/p/12116316.html
Copyright © 2011-2022 走看看