zoukankan      html  css  js  c++  java
  • MXNet官网案例分析--Train MLP on MNIST

    本文是MXNet的官网案例: Train MLP on MNIST. MXNet所有的模块如下图所示:

    第一步: 准备数据

    从下面程序可以看出,MXNet里面的数据是一个4维NDArray.

    import mxnet as mx
    
    # mxnet.io.MXDataIter, shape=(128,1,28,28)
    train = mx.io.MNISTIter(
        image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-images-idx3-ubyte',
        label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/train-labels-idx1-ubyte',
        batch_size = 128,
        data_shape = (784, )
    ) 
    # mxnet.io.MXDataIter, shape=(128,1,28,28)    
    val = mx.io.MNISTIter(
        image = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-images-idx3-ubyte',
        label = '/home/zhaopace/MXNet/mxnet/example/adversary/data/t10k-labels-idx1-ubyte',
        batch_size = 128,
        data_shape = (784, )
    ) 

    Second: 符号式编程, 生成一个两层的MLP

    # Declare a two-layer MLP
    data = mx.symbol.Variable('data')  # data layer
    fc1  = mx.symbol.FullyConnected(data=data, num_hidden=128)  # full connected layer 1
    act1 = mx.symbol.Activation(data=fc1, act_type="relu")  # activation layer(relu activation function)
    fc2  = mx.symbol.FullyConnected(data=act1, num_hidden=64)
    act2 = mx.symbol.Activation(data=fc2, act_type="relu")
    fc3  = mx.symbol.FullyConnected(data=act2, num_hidden=10)
    mlp  = mx.symbol.SoftmaxOutput(data=fc3, name="softmax")  # Softmax layer 

    一个CNN网络最基本的几层:

    输入层: mx.symbol.Variable()

    激活层: mx.symbol.Activation()

    Batch正则化: mx.symbol.BatchNorm()

    Dropout: mx.symbol.Dropout()

    全连接层: mx.symbol.FullyConnected()

    池化层: mx.symbol.Pooling()

    卷积层: mx.symbol.Convolution()

    Softmax输出: mx.symbol.SoftmaxOutput()

    LRN: mx.symbol.LRN()

    ......

    mx.symbol.FullyConnected(*args, **kwargs)

    功能: 对input作矩阵乘法, 并且加上一个偏置. 将shape为(batch_size, input_dim)的input变成(batch_size, num_hidden)的输出;

    输入参数:

    • data:  Symbol类型, 输入数据;
    • weight:  Symbol类型, 权重矩阵;
    • biasSymbol类型, 偏置参数;
    • num_hidden: int型, 必要参数, 隐层节点的数目;
    • no_bias: 布尔型, 可选参数, defalut=False, 表示是否不要偏置参数
    • name:  字符串类型, 可选参数, 计算结果symbol的名称;

    输出参数:

    • 输出是一个Symbol: the result symbol

    Last: 训练以及测试

    # Type: mxnet.model.FeedForward 
    # Train a model on the data 
    model = mx.model.FeedForward(
        symbol = mlp,
        num_epoch = 20,
        learning_rate = .1
    )
    model.fit(X = train, eval_data = val)
    
    # Predict
    model.predict(X = train)

    class mxnet.model.FeedForward(sklearn.base.BaseEstimator)

    输入参数:

    • symbol: Symbol类型, 网络的symbol结构配置;
    • ctx:
    • num_epoch: int型, 可选参数,是一个训练参数, 训练的迭代次数;
    • epoch_size: 一次epoch使用的batches数目, 默认情况下为(num_train_examples / batch_size)
    • optimizer:q
    • initializer:
    • numpy_batch_size:
    • ......

    图2 mxnet.model函数列表

  • 相关阅读:
    完全开源Android网络框架 — 基于JAVA原生的HTTP框架
    博客园—Android客户端
    撸一个Android高性能日历控件,高仿魅族
    Android开发登陆博客园的正确方式
    基于pthread的线程池实现
    重复造轮子系列——基于FastReport设计打印模板实现桌面端WPF套打和商超POS高度自适应小票打印
    重复造轮子系列——基于Ocelot实现类似支付宝接口模式的网关
    零基础ASP.NET Core WebAPI团队协作开发
    零基础ASP.NET Core MVC插件式开发
    jquery对下拉框的操作
  • 原文地址:https://www.cnblogs.com/zhao441354231/p/6080830.html
Copyright © 2011-2022 走看看