zoukankan      html  css  js  c++  java
  • 从零开始学习MXnet(一)

      最近工作要开始用到MXnet,然而MXnet的文档写的实在是.....所以在这记录点东西,方便自己,也方便大家。

      我觉得搞清楚一个框架怎么使用,第一步就是用它来训练自己的数据,这是个很关键的一步。 

    一、MXnet数据预处理

      整个数据预处理的代码都集成在了toosl/im2rec.py中了,这个首先要造出一个list文件,lst文件有三列,分别是index label 图片路径。如下图所示:

           

      我这个label是瞎填的,所以都是0。另外最新的MXnet上面的im2rec是有问题的,它生成的list所有的index都是0,不过据说这个index没什么用.....但我还是改了一下。把yield生成器换成直接append即可。

      执行的命令如下:

        sudo python im2rec.py --list=True /home/erya/dhc/result/try /home/erya/dhc/result/ --recursive=True --shuffle=true --train-ratio=0.8 

      每个参数的意义在代码内部都可以查到,简单说一下这里用到的:--list=True说明这次的目的是make list,后面紧跟的是生成的list的名字的前缀,我这里是加了路径,然后是图片所在文件夹的路径,recursive是是否迭代的进入文件夹读取图片,--train-ratio则表示train和val在数据集中的比例。

    执行上面的命令后,会得到三个文件:

     

        然后再执行下面的命令生成最后的rec文件:

      sudo python im2rec.py /home/erya/dhc/result/try_val.lst  /home/erya/dhc/result --quality=100 

      以及,sudo python im2rec.py /home/erya/dhc/result/try_train.lst  /home/erya/dhc/result --quality=100 

     来生成相应的lst文件的rec文件,参数意义太简单就不说了..看着就明白,result是我存放图片的目录。

     

      这样最终就完成了数据的预处理,简单的说,就是先生成lst文件,这个其实完全可以自己做,而且后期我做segmentation的时候,label就是图片了..

     

    二、非常简单的小demo

    先上代码:

      

     1 import mxnet as mx
     2 import logging
     3 import numpy as np
     4 
     5 logger = logging.getLogger()
     6 logger.setLevel(logging.DEBUG)#暂时不需要管的log
     7 def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"):
     8     conv = mx.symbol.Convolution(data=data, workspace=256,
     9                                  num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
    10     return conv   #我把这个删除到只有一个卷积的操作
    11 def DownsampleFactory(data, ch_3x3):
    12     # conv 3x3
    13     conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1))
    14     # pool
    15     pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max')
    16     # concat
    17     concat = mx.symbol.Concat(*[conv, pool])
    18     return concat
    19 def SimpleFactory(data, ch_1x1, ch_3x3):
    20     # 1x1
    21     conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1)
    22     # 3x3
    23     conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3)
    24     #concat
    25     concat = mx.symbol.Concat(*[conv1x1, conv3x3])
    26     return concat
    27 if __name__ == "__main__":
    28     batch_size = 1
    29     train_dataiter = mx.io.ImageRecordIter(
    30         shuffle=True,
    31         path_imgrec="/home/erya/dhc/result/try_train.rec",
    32         rand_crop=True,
    33         rand_mirror=True,
    34         data_shape=(3,28,28),
    35         batch_size=batch_size,
    36         preprocess_threads=1)#这里是使用我们之前的创造的数据,简单的说就是要自己写一个iter,然后把相应的参数填进去。
    37     test_dataiter = mx.io.ImageRecordIter(
    38         path_imgrec="/home/erya/dhc/result/try_val.rec",
    39         rand_crop=False,
    40         rand_mirror=False,
    41         data_shape=(3,28,28),
    42         batch_size=batch_size,
    43         round_batch=False,
    44         preprocess_threads=1)#同理
    45     data = mx.symbol.Variable(name="data")
    46     conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
    47     in3a = SimpleFactory(conv1, 32, 32)
    48     fc = mx.symbol.FullyConnected(data=in3a, num_hidden=10)
    49     softmax = mx.symbol.SoftmaxOutput(name='softmax',data=fc)#上面就是定义了一个巨巨巨简单的结构
    50     # For demo purpose, this model only train 1 epoch
    51     # We will use the first GPU to do training
    52     num_epoch = 1
    53     model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,
    54                              learning_rate=0.05, momentum=0.9, wd=0.00001) #将整个model训练的架构定下来了,类似于caffe里面solver所做的事情。
    55 
    56 # we can add learning rate scheduler to the model
    57 # model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,
    58 #                              learning_rate=0.05, momentum=0.9, wd=0.00001,
    59 #                              lr_scheduler=mx.misc.FactorScheduler(2))
    60 model.fit(X=train_dataiter,
    61           eval_data=test_dataiter,
    62           eval_metric="accuracy",
    63           batch_end_callback=mx.callback.Speedometer(batch_size))#开跑数据。

     

      

  • 相关阅读:
    PHP+MySQL
    Appstore排名前十的程序员应用软件
    架构师的平凡之路
    程序员,如何三十而立?
    不懂技术也可以轻松开发一款APP
    php语法学习:轻松看懂PHP语言
    你真的了解软件测试行业吗?
    十个程序员必备的网站推荐
    从更高点看软件开发的侧重点
    php如何实现文件下载
  • 原文地址:https://www.cnblogs.com/daihengchen/p/5924768.html
Copyright © 2011-2022 走看看