zoukankan      html  css  js  c++  java
  • Sequential 类的设备迁移

    之前因为RNN模块 没有export 方法,直接用了 cpickle 强行保存。现在要载入保存的数据,用于inference。需要解决训练时的context和 载入时 device不一致的问题。
    找了下,发现ParameterDict里面有个 reset_ctx可以用:

    import mxnet as mx                                                                                                                                           
    import numpy as np
    nn = mx.gluon.nn
    net = nn.Sequential()
    net.add(
        ¦   nn.Dense(10))
    ctx = mx.cpu()
    _x = np.random.randint(0,256,(5,199))
    x = mx.nd.array(_x)
    net.initialize()
    
    
    y= net(x)
    print y
    print('cross to gpu device...')
    ctx = mx.gpu()
    x = x.as_in_context( ctx )
    try:
        y = net(x)
    except:
        print 'forward failed, try reset_ctx for ParameterDict...'
        net.collect_params().reset_ctx( ctx )
    y= net(x)
    print y
    print 'test ok'
    
  • 相关阅读:
    vant 移动helloworld
    ts
    study vant
    uniapp 上传图片
    electron
    1
    测试vue模板
    [Java] Spring_1700_Spring_DataSource
    [Java] Spring_1600_AOP_XML
    [Java] Spring_1500_AOP_Annotation
  • 原文地址:https://www.cnblogs.com/chenyliang/p/9493448.html
Copyright © 2011-2022 走看看