zoukankan      html  css  js  c++  java
  • mxnet实战系列(一)入门与跑mnist数据集

    最近在摸mxnet和tensorflow。两个我都搭起来了。tensorflow跑了不少代码,总的来说用得比较顺畅,文档很丰富,api熟悉熟悉写代码没什么问题。

    今天把两个平台做了一下对比。同是跑mnist,tensorflow 要比mxnet 慢一二十倍。mxnet只需要半分钟,tensorflow跑了13分钟。

    在mxnet中如何开跑?

    cd /mxnet/example/image-classification
    python train_mnist.py

    我用的是最新的mxnet版本。运行脚本它会自动下载数据集。
    然后刷刷刷的刷屏了。
    我们来看看这个脚本如何写的,从而建立mxnet编程思路:
    import find_mxnet

    import mxnet as mx
    import argparse
    import os, sys
    import train_model

    def _download(data_dir):
        if not os.path.isdir(data_dir):
            os.system("mkdir " + data_dir)
        os.chdir(data_dir)
        if (not os.path.exists('train-images-idx3-ubyte')) or
           (not os.path.exists('train-labels-idx1-ubyte')) or
           (not os.path.exists('t10k-images-idx3-ubyte')) or
           (not os.path.exists('t10k-labels-idx1-ubyte')):
            os.system("wget http://data.dmlc.ml/mxnet/data/mnist.zip")
            os.system("unzip -u mnist.zip; rm mnist.zip")
        os.chdir("..")

    def get_loc(data, attr={'lr_mult':'0.01'}):
        """
        the localisation network in lenet-stn, it will increase acc about more than 1%,
        when num-epoch >=15
        """
        loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2))
        loc = mx.symbol.Activation(data = loc, act_type='relu')
        loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type='max')
        loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))
        loc = mx.symbol.Activation(data = loc, act_type='relu')
        loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type='avg')
        loc = mx.symbol.Flatten(data=loc)
        loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr)
        return loc

    def get_mlp():
        """
        multi-layer perceptron
        """
        data = mx.symbol.Variable('data')
        fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
        act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
        fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
        act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
        fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
        mlp  = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
        return mlp

    def get_lenet(add_stn=False):
        """
        LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
        Haffner. "Gradient-based learning applied to document recognition."
        Proceedings of the IEEE (1998)
        """
        data = mx.symbol.Variable('data')
        if(add_stn):
            data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
                                             transform_type="affine", sampler_type="bilinear")
        # first conv
        conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
        tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
        pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
                                  kernel=(2,2), stride=(2,2))
        # second conv
        conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
        tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
        pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
                                  kernel=(2,2), stride=(2,2))
        # first fullc
        flatten = mx.symbol.Flatten(data=pool2)
        fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
        tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
        # second fullc
        fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
        # loss
        lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
        return lenet

    def get_iterator(data_shape):
        def get_iterator_impl(args, kv):
            data_dir = args.data_dir
            if '://' not in args.data_dir:
                _download(args.data_dir)
            flat = False if len(data_shape) == 3 else True

            train           = mx.io.MNISTIter(
                image       = data_dir + "train-images-idx3-ubyte",
                label       = data_dir + "train-labels-idx1-ubyte",
                input_shape = data_shape,
                batch_size  = args.batch_size,
                shuffle     = True,
                flat        = flat,
                num_parts   = kv.num_workers,
                part_index  = kv.rank)

            val = mx.io.MNISTIter(
                image       = data_dir + "t10k-images-idx3-ubyte",
                label       = data_dir + "t10k-labels-idx1-ubyte",
                input_shape = data_shape,
                batch_size  = args.batch_size,
                flat        = flat,
                num_parts   = kv.num_workers,
                part_index  = kv.rank)

            return (train, val)
        return get_iterator_impl

    def parse_args():
        parser = argparse.ArgumentParser(description='train an image classifer on mnist')
        parser.add_argument('--network', type=str, default='mlp',
                            choices = ['mlp', 'lenet', 'lenet-stn'],
                            help = 'the cnn to use')
        parser.add_argument('--data-dir', type=str, default='mnist/',
                            help='the input data directory')
        parser.add_argument('--gpus', type=str,
                            help='the gpus will be used, e.g "0,1,2,3"')
        parser.add_argument('--num-examples', type=int, default=60000,
                            help='the number of training examples')
        parser.add_argument('--batch-size', type=int, default=128,
                            help='the batch size')
        parser.add_argument('--lr', type=float, default=.1,
                            help='the initial learning rate')
        parser.add_argument('--model-prefix', type=str,
                            help='the prefix of the model to load/save')
        parser.add_argument('--save-model-prefix', type=str,
                            help='the prefix of the model to save')
        parser.add_argument('--num-epochs', type=int, default=10,
                            help='the number of training epochs')
        parser.add_argument('--load-epoch', type=int,
                            help="load the model on an epoch using the model-prefix")
        parser.add_argument('--kv-store', type=str, default='local',
                            help='the kvstore type')
        parser.add_argument('--lr-factor', type=float, default=1,
                            help='times the lr with a factor for every lr-factor-epoch epoch')
        parser.add_argument('--lr-factor-epoch', type=float, default=1,
                            help='the number of epoch to factor the lr, could be .5')
        return parser.parse_args()


    if __name__ == '__main__':
        args = parse_args()


        if args.network == 'mlp':
            data_shape = (784, )
            net = get_mlp()
        elif args.network == 'lenet-stn':
            data_shape = (1, 28, 28)
            net = get_lenet(True)
        else:
            data_shape = (1, 28, 28)
            net = get_lenet()

        # train
        train_model.fit(args, net, get_iterator(data_shape))

    先看Main函数,就是读配置参数,读网络结构,包括设置数据的大小,然后就是调用已有的包train_model。然后传入这之前设置的三个参数。就开始训练了。
    编程架构也蛮清晰的。模块化也搞的好。
    接着看看参数设置问题。参数导入了很多配置文件,基本上caffe中的Proto都在这个里面设置了。包括数据集地址,批大小,学习率,损失函数,等等。然后看看读网络结构,
    读网络结构就是在一层一层的搭积木,根据之前读入的配置文件或者自己定义一些参数。搭好积木就开始训练了。
    caffe的一个缺点是不够灵活,毕竟不是自己写代码,只是写配置文件,总感觉受制于人。mxnet和tensorflow就比较方便,提供api,你可以按你的方式来调用和定义
    网络结构。总的说来,其实是后两个框架模块化做的好,提供底层的api支持你写自己的网络。caffe要自己写网络层的话还是很费劲的






  • 相关阅读:
    洛谷 P1508 Likecloud-吃、吃、吃
    Codevs 1158 尼克的任务
    2017.10.6 国庆清北 D6T2 同余方程组
    2017.10.6 国庆清北 D6T1 排序
    2017.10.3 国庆清北 D3T3 解迷游戏
    2017.10.3 国庆清北 D3T2 公交车
    2017.10.3 国庆清北 D3T1 括号序列
    2017.10.4 国庆清北 D4T1 财富
    2017.10.7 国庆清北 D7T2 第k大区间
    2017.10.7 国庆清北 D7T1 计数
  • 原文地址:https://www.cnblogs.com/whu-zeng/p/6010188.html
Copyright © 2011-2022 走看看