zoukankan      html  css  js  c++  java
  • #mxnet# infer_shape ,附 module 中 symbol 提取

    infer_shape for symbol

    形状推断是mxnet的一特色,即使撇开这样做的原因是mxnet强制要求的,其提供的功能也是很helpful的。
    infer_shape通常是被封装起来供其内部使用,但也可以把symbol.infer_shape单独提出来,作为函数:

    import mxnet as mx
    d=mx.sym.Variable('data')
    conv1_w=mx.sym.Variable('kw')
    conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
    loss=mx.sym.MakeLoss(data=conv1)
    
    in_shape,out_shape,uax_shape=loss.infer_shape(data=(1,1,30,30),kw=(1,1,3,3)) # 直接写参数名, 此处 kw 可省略
    in_shape,out_shape,uax_shape
    # ([(1L, 1L, 30L, 30L), (1L, 1L, 3L, 3L)], [(1L, 1L, 28L, 28L)], [])
    

    for module

    另外,上面用的是 symbol,有时需要从打包好的module里面提取symbol(mxnet的doc实在是...AWS掺和进来草根本性也不见提升啊):

    import mxnet as mx
    d=mx.sym.Variable('data')
    conv1_w=mx.sym.Variable('kw')
    conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
    loss=mx.sym.MakeLoss(data=conv1)
    
    mod = mx.mod.Moudle(symbol=loss)
    get_conv1 = mod.symbol.get_internals()['conv1_output']
    get_conv1
    #<Symbol conv1>
    
  • 相关阅读:
    学习笔记之正向代理和反向代理的区别
    PHP程序员的进阶之路
    go语言笔记——切片函数常见操作,增删改查和搜索、排序
    golang的垃圾回收(GC)机制
    堆栈的详细讲解
    springAop必导jar包
    sring框架的jdbc应用
    下载jar包方法
    mysql数据乱码
    Eclipse打包java工程
  • 原文地址:https://www.cnblogs.com/chenyliang/p/6782071.html
Copyright © 2011-2022 走看看