zoukankan      html  css  js  c++  java
  • mxnet卷积神经网络训练MNIST数据集测试

    mxnet框架下超全手写字体识别—从数据预处理到网络的训练—模型及日志的保存

    import numpy as np  
    import mxnet as mx  
    import logging  
      
    logging.getLogger().setLevel(logging.DEBUG)  
      
    batch_size = 100  
    mnist = mx.test_utils.get_mnist()  
    train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)  
    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)  
      
    data = mx.sym.var('data')   
    # first conv layer  
    conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)  
    tanh1= mx.sym.Activation(data=conv1, act_type="tanh")  
    pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))  
    # second conv layer  
    conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)  
    tanh2= mx.sym.Activation(data=conv2, act_type="tanh")  
    pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))  
    # first fullc layer  
    flatten= mx.sym.Flatten(data=pool2)  
    fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)  
    tanh3= mx.sym.Activation(data=fc1, act_type="tanh")  
    # second fullc  
    fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)  
    # softmax loss  
    lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')  
      
    # create a trainable module on GPU 0  
    lenet_model = mx.mod.Module(  
                    symbol=lenet,   
                    context=mx.cpu())  
      
    # train with the same  
    lenet_model.fit(train_iter,  
                    eval_data=val_iter,  
                    optimizer='sgd',  
                    optimizer_params={'learning_rate':0.1},  
                    eval_metric='acc',  
                    batch_end_callback = mx.callback.Speedometer(batch_size, 100),  
                    num_epoch=10)  


    INFO:root:Epoch[0] Batch [100] Speed: 1504.57 samples/sec accuracy=0.113564
    INFO:root:Epoch[0] Batch [200] Speed: 1516.40 samples/sec accuracy=0.118100
    INFO:root:Epoch[0] Batch [300] Speed: 1515.71 samples/sec accuracy=0.116600
    INFO:root:Epoch[0] Batch [400] Speed: 1505.61 samples/sec accuracy=0.110200
    INFO:root:Epoch[0] Batch [500] Speed: 1406.21 samples/sec accuracy=0.107600
    INFO:root:Epoch[0] Train-accuracy=0.108081
    INFO:root:Epoch[0] Time cost=40.572
    INFO:root:Epoch[0] Validation-accuracy=0.102800
    INFO:root:Epoch[1] Batch [100] Speed: 1451.87 samples/sec accuracy=0.115050
    INFO:root:Epoch[1] Batch [200] Speed: 1476.86 samples/sec accuracy=0.179600
    INFO:root:Epoch[1] Batch [300] Speed: 1409.67 samples/sec accuracy=0.697100
    INFO:root:Epoch[1] Batch [400] Speed: 1379.52 samples/sec accuracy=0.871900
    INFO:root:Epoch[1] Batch [500] Speed: 1374.88 samples/sec accuracy=0.901000
    INFO:root:Epoch[1] Train-accuracy=0.925556
    INFO:root:Epoch[1] Time cost=42.527
    INFO:root:Epoch[1] Validation-accuracy=0.936900
    INFO:root:Epoch[2] Batch [100] Speed: 1376.59 samples/sec accuracy=0.936436
    INFO:root:Epoch[2] Batch [200] Speed: 1379.29 samples/sec accuracy=0.948100
    INFO:root:Epoch[2] Batch [300] Speed: 1375.07 samples/sec accuracy=0.953400
    INFO:root:Epoch[2] Batch [400] Speed: 1369.65 samples/sec accuracy=0.958600
    INFO:root:Epoch[2] Batch [500] Speed: 1371.79 samples/sec accuracy=0.960900
    INFO:root:Epoch[2] Train-accuracy=0.966667
    INFO:root:Epoch[2] Time cost=43.660
    INFO:root:Epoch[2] Validation-accuracy=0.972900
    INFO:root:Epoch[3] Batch [100] Speed: 1230.74 samples/sec accuracy=0.969505
    INFO:root:Epoch[3] Batch [200] Speed: 1335.27 samples/sec accuracy=0.970800
    INFO:root:Epoch[3] Batch [300] Speed: 1264.43 samples/sec accuracy=0.972600
    INFO:root:Epoch[3] Batch [400] Speed: 1242.03 samples/sec accuracy=0.974100
    INFO:root:Epoch[3] Batch [500] Speed: 1322.77 samples/sec accuracy=0.974600
    INFO:root:Epoch[3] Train-accuracy=0.976465
    INFO:root:Epoch[3] Time cost=46.860
    INFO:root:Epoch[3] Validation-accuracy=0.980700
    INFO:root:Epoch[4] Batch [100] Speed: 1342.42 samples/sec accuracy=0.978020
    INFO:root:Epoch[4] Batch [200] Speed: 1339.98 samples/sec accuracy=0.980600
    INFO:root:Epoch[4] Batch [300] Speed: 1344.36 samples/sec accuracy=0.981000
    INFO:root:Epoch[4] Batch [400] Speed: 1338.13 samples/sec accuracy=0.980000
    INFO:root:Epoch[4] Batch [500] Speed: 1343.76 samples/sec accuracy=0.979000
    INFO:root:Epoch[4] Train-accuracy=0.983535
    INFO:root:Epoch[4] Time cost=44.694
    INFO:root:Epoch[4] Validation-accuracy=0.985700
    INFO:root:Epoch[5] Batch [100] Speed: 1333.50 samples/sec accuracy=0.981584
    INFO:root:Epoch[5] Batch [200] Speed: 1342.07 samples/sec accuracy=0.985400
    INFO:root:Epoch[5] Batch [300] Speed: 1339.04 samples/sec accuracy=0.984300
    INFO:root:Epoch[5] Batch [400] Speed: 1323.42 samples/sec accuracy=0.983500

  • 相关阅读:
    RPC
    动词 or 名词 :这是一个问题 【转载】
    js 如何清除setinterval
    封装动画特效
    飞入特效
    建字段_添加数据_生成json.php
    mybatis由浅入深day02_9.3.5使用生成的代码_9.4逆向工程注意事项
    mybatis由浅入深day02_9逆向工程
    mybatis由浅入深day02_8spring和mybatis整合
    mybatis由浅入深day02_7.4mybatis整合ehcache_7.5二级缓存应用场景_7.6二级缓存局限性
  • 原文地址:https://www.cnblogs.com/adong7639/p/8953163.html
Copyright © 2011-2022 走看看