zoukankan      html  css  js  c++  java
  • HybridBlock supports forwarding with both Symbol and NDArray

    gluon/image_classification.py代码有这么一段:

    import mxnet as mx
    from mxnet.gluon.model_zoo import vision as models
    ...
    net = models.get_model('vgg11', context, opt)     # 这里得到的是gluon的模型
    ...
    data = mx.sym.var('data')      # symbol
    out = net(data)                # 这里把symbol传到了gluon模型里?
    softmax = mx.sym.SoftmaxOutput(out, name='softmax')
    mod = mx.mod.Module(softmax, context=context)
    mx.viz.plot_network(softmax).view()    #打印出来symbol模型结构

    symbol传到了gluon模型里?深感疑惑,看了一下api发现可以的:

    首先是看到get_model方法返回的是gluon.HybridBlock类型

    而HybridBlock是同时支持Symbol和NDArray的!:

    所以这么搞是没问题的。 

    所以这个代码其实有3种训练模式: symbolic(符号式-类似tf)、hybrid(混合式-mx特性)、imperative(交互式-类似pytorch)

        if opt.mode == 'symbolic':          # 符号式
            data = mx.sym.var('data')
            if opt.dtype == 'float16':
                data = mx.sym.Cast(data=data, dtype=np.float16)
            out = net(data)
            if opt.dtype == 'float16':
                out = mx.sym.Cast(data=out, dtype=np.float32)
            softmax = mx.sym.SoftmaxOutput(out, name='softmax')
            mod = mx.mod.Module(softmax, context=context)
            train_data, val_data = get_data_iters(dataset, batch_size, opt)
            mod.fit(train_data,
                    eval_data=val_data,
                    num_epoch=opt.epochs,
                    kvstore=kv,
                    batch_end_callback = mx.callback.Speedometer(batch_size, max(1, opt.log_interval)),
                    epoch_end_callback = mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
                    optimizer = 'sgd',
                    optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'multi_precision': True},
                    initializer = mx.init.Xavier(magnitude=2))
            mod.save_parameters('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
        else:
            if opt.mode == 'hybrid':        # 混合式
                net.hybridize()
            train(opt, context)             # 交互式 
  • 相关阅读:
    Scala泛型
    Tensorflow激活函数
    20181030-4 每周例行报告
    20181023-3 每周例行报告
    20181016-10 每周例行报告
    20181009-9 每周例行报告
    第三周作业(4)——单元测试
    第三周作业(5)——代码规范
    第三周作业(2)——功能测试
    第三周作业(3)——词频统计--效能分析
  • 原文地址:https://www.cnblogs.com/king-lps/p/13068950.html
Copyright © 2011-2022 走看看