infer_shape for symbol
形状推断是mxnet的一特色,即使撇开这样做的原因是mxnet强制要求的,其提供的功能也是很helpful的。
infer_shape通常是被封装起来供其内部使用,但也可以把symbol.infer_shape单独提出来,作为函数:
import mxnet as mx
d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
loss=mx.sym.MakeLoss(data=conv1)
in_shape,out_shape,uax_shape=loss.infer_shape(data=(1,1,30,30),kw=(1,1,3,3)) # 直接写参数名, 此处 kw 可省略
in_shape,out_shape,uax_shape
# ([(1L, 1L, 30L, 30L), (1L, 1L, 3L, 3L)], [(1L, 1L, 28L, 28L)], [])
for module
另外,上面用的是 symbol,有时需要从打包好的module里面提取symbol(mxnet的doc实在是...AWS掺和进来草根本性也不见提升啊):
import mxnet as mx
d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
loss=mx.sym.MakeLoss(data=conv1)
mod = mx.mod.Moudle(symbol=loss)
get_conv1 = mod.symbol.get_internals()['conv1_output']
get_conv1
#<Symbol conv1>