zoukankan      html  css  js  c++  java
  • #mxnet# infer_shape ,附 module 中 symbol 提取

    infer_shape for symbol

    形状推断是mxnet的一特色,即使撇开这样做的原因是mxnet强制要求的,其提供的功能也是很helpful的。
    infer_shape通常是被封装起来供其内部使用,但也可以把symbol.infer_shape单独提出来,作为函数:

    import mxnet as mx
    d=mx.sym.Variable('data')
    conv1_w=mx.sym.Variable('kw')
    conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
    loss=mx.sym.MakeLoss(data=conv1)
    
    in_shape,out_shape,uax_shape=loss.infer_shape(data=(1,1,30,30),kw=(1,1,3,3)) # 直接写参数名, 此处 kw 可省略
    in_shape,out_shape,uax_shape
    # ([(1L, 1L, 30L, 30L), (1L, 1L, 3L, 3L)], [(1L, 1L, 28L, 28L)], [])
    

    for module

    另外,上面用的是 symbol,有时需要从打包好的module里面提取symbol(mxnet的doc实在是...AWS掺和进来草根本性也不见提升啊):

    import mxnet as mx
    d=mx.sym.Variable('data')
    conv1_w=mx.sym.Variable('kw')
    conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
    loss=mx.sym.MakeLoss(data=conv1)
    
    mod = mx.mod.Moudle(symbol=loss)
    get_conv1 = mod.symbol.get_internals()['conv1_output']
    get_conv1
    #<Symbol conv1>
    
  • 相关阅读:
    Orcad Pspice仿真
    AD导入Allegro brd文件(导入后找不到PCB的解决方法)
    VJTAG转VME DTB
    win10 非Unicode应用程序显示设置
    MFC多文档视图编程总结
    VC MFC开发示例下载
    FPGA仿真及时序约束分析
    VMWARE Thin APP
    VPX技术基础概论
    SecureCRT脚本(VBS)运行
  • 原文地址:https://www.cnblogs.com/chenyliang/p/6782071.html
Copyright © 2011-2022 走看看