zoukankan      html  css  js  c++  java
  • 利用Module模块把构建的神经网络跑起来

    训练一个神经网络往往只需要简单的几步:

    1. 准备训练数据
    2. 初始化模型的参数
    3. 模型向往计算与向后计算
    4. 更新模型参数
    5. 设置相关的checkpoint

    如果上述的每个步骤都需要我们写Python的代码去一步步实现,未免显的繁琐,好在MXNet提供了Module模块来解决这个问题,Module把训练和推理中一些常用到的步骤代码进行了封装。对于一定已经用Symbol定义好的神经网络,我们可以很容易的使用Module提供的一些高层次接口或一些中间层次的接口来让整个训练或推理容易操作起来。

    下面我们将通过在UCI letter recognition数据集上训练一个多层感知机来说明Module模块的用法。

    第一步 加载一个数据集

    我们先下载一个数据集,然后按80:20的比例划分训练集与测试集。我们通过MXNet的IO模块提供的数据迭代器每次返回一个batch size =32的训练样本

    import logging
    logging.getLogger().setLevel(logging.INFO)
    import mxnet as mx
    import numpy as np
    
    # 数据以文本形式保存,每行一个样本,每一行数据之间用','分割,每一个字符为label
    fname = mx.test_utils.download('http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data')
    data = np.genfromtxt(fname, delimiter=',')[:,1:]
    label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])
    
    batch_size = 32
    ntrain = int(data.shape[0]*0.8)
    train_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain], batch_size, shuffle=True)
    val_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:], batch_size)
    

    第二步 定义一个network

    net = mx.sym.var('data')
    net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
    net = mx.sym.Activation(data=net, name='relu1', act_type='relu')
    net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=26)
    net = mx.sym.SoftmaxOutput(net, name='softmax')
    mx.viz.plot_network(net)
    

    第三步 创建一个Module

    我们可以通过mx.mod.Module接口创建一个Module对象,它接收下面几个参数:

    • symbol:神经网络的定义
    • context:执行运算的设备
    • data_names:网络输入数据的列表
    • label_names:网络输入标签的列表

    对于我们在第二步定义的net,只有一个输入数据即data,输入标签名为softmax_label,这个是我们在使用SoftmaxOutput操作时,自动命名的。

    mod = mx.mod.Module(symbol=net, 
                        context=mx.cpu(), 
                        data_names=['data'], 
                        label_names=['softmax_label'])
    

    Module的中间层次的接口

    中间层次的接口主要是为了给开发者足够的灵活性,也方便排查问题。我们下面会先列出来Moduel模块有哪些常见的中间层API,然后再利用这个API来训练我们刚才定义的网络。

    • bind:绑定输入数据的形状,分配内存
    • init_params:初始化网络参数
    • init_optimizer:指定优化方法,比如sgd
    • metric.create:指定评价方法
    • forward:向前计算
    • update_metric:根据上一次的forward结果,更新评价指标
    • backward:反射传播
    • update:根据优化方法和梯度更新模型的参数
    # allocate memory given the input data and label shapes
    mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
    # initialize parameters by uniform random numbers
    mod.init_params(initializer=mx.init.Uniform(scale=.1))
    # use SGD with learning rate 0.1 to train
    mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
    # use accuracy as the metric
    metric = mx.metric.create('acc')
    # train 5 epochs, i.e. going over the data iter one pass
    for epoch in range(5):
        train_iter.reset()
        metric.reset()
        for batch in train_iter:
            mod.forward(batch, is_train=True)       # compute predictions
            mod.update_metric(metric, batch.label)  # accumulate prediction accuracy
            mod.backward()                          # compute gradients
            mod.update()                            # update parameters
        print('Epoch %d, Training %s' % (epoch, metric.get()))
    

    Module 高层次的API

    训练

    Moudle模块同时提供了高层次的API来完成训练、预测和评估。不像使用中间层次API那样繁琐,我们只需要一个接口fit就可以完成上面的步骤。

    # reset train_iter to the beginning
    train_iter.reset()
    
    # create a module
    mod = mx.mod.Module(symbol=net,
                        context=mx.cpu(),
                        data_names=['data'],
                        label_names=['softmax_label'])
    
    # fit the module
    mod.fit(train_iter,
            eval_data=val_iter,
            optimizer='sgd',
            optimizer_params={'learning_rate':0.1},
            eval_metric='acc',
            num_epoch=8)
    

    预测和评估

    使用Moudle.predict可以得到数据的predict的结果。如果我们对结果不关心,我们可以使用score接口直接计算验证数据集的准确率。

    y = mod.predict(val_iter)
    score = mod.score(val_iter, ['acc'])
    print("Accuracy score is %f" % (score[0][1]))
    

    上面的代码中我们使用了acc来计算准确率,我们还可以设置其他评估方法,如:top_k_acc,F1,RMSE,MSE,MAE,ce等。

    训练模型的保存

    我们可以通过设计一个checkpoint calback来在训练过程中每个epoch结束后保存模型的参数

    # construct a callback function to save checkpoints
    model_prefix = 'mx_mlp'
    checkpoint = mx.callback.do_checkpoint(model_prefix)
    
    mod = mx.mod.Module(symbol=net)
    mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)
    

    使用load_checkpoint来加载已经保存的模型参数,随后我们可以把这些参数加载到Moudle中

    sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
    # assign the loaded parameters to the module
    mod.set_params(arg_params, aux_params)
    

    我们也可以不使用set_params,而是直接在fit接口中指定已经保存的checkpoint的参数,这些保存的参数会替代fit原本的参数初始化。

    mod = mx.mod.Module(symbol=sym)
    mod.fit(train_iter,
            num_epoch=21,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=3)
    
  • 相关阅读:
    ajax 传递参数中文乱码解决办法
    jQuery 时间戳转化成时间
    IDEA2017 导入 SVN上的 Myeclipse或Eclipse 项目
    ajax返回json数据,对其中日期的解析
    MYSQL 按照字母排序查询
    JVM介绍
    正则表达式
    could not find the main class错误
    转:MyEclipse使用总结——MyEclipse10安装SVN插件
    转:Oracle数据库sqlplus与plsqldev解决乱码
  • 原文地址:https://www.cnblogs.com/ronny/p/8571855.html
Copyright © 2011-2022 走看看