zoukankan      html  css  js  c++  java
  • Mxnet 查看模型params的网络结构

    import mxnet as mx  
    import pdb  
    def load_checkpoint():  
        """ 
        Load model checkpoint from file. 
        :param prefix: Prefix of model name. 
        :param epoch: Epoch number of model we would like to load. 
        :return: (arg_params, aux_params) 
        arg_params : dict of str to NDArray 
            Model parameter, dict of name to NDArray of net's weights. 
        aux_params : dict of str to NDArray 
            Model parameter, dict of name to NDArray of net's auxiliary states. 
        """  
        save_dict = mx.nd.load('model-0000.params')  
        arg_params = {}  
        aux_params = {}  
        for k, v in save_dict.items():  
            tp, name = k.split(':', 1)  
            if tp == 'arg':  
                arg_params[name] = v  
            if tp == 'aux':  
                aux_params[name] = v  
        return arg_params, aux_params  
      
      
    def convert_context(params, ctx):  
        """ 
        :param params: dict of str to NDArray 
        :param ctx: the context to convert to 
        :return: dict of str of NDArray with context ctx 
        """  
        new_params = dict()  
        for k, v in params.items():  
            new_params[k] = v.as_in_context(ctx)  
        #print new_params[0]  
        return new_params  
      
      
    def load_param(convert=False, ctx=None):  
        """ 
        wrapper for load checkpoint 
        :param prefix: Prefix of model name. 
        :param epoch: Epoch number of model we would like to load. 
        :param convert: reference model should be converted to GPU NDArray first 
        :param ctx: if convert then ctx must be designated. 
        :return: (arg_params, aux_params) 
        """  
        arg_params, aux_params = load_checkpoint()  
        if convert:  
            if ctx is None:  
                ctx = mx.cpu()  
            arg_params = convert_context(arg_params, ctx)  
            aux_params = convert_context(aux_params, ctx)  
        return arg_params, aux_params  
      
      
    if __name__=='__main__':  
            result =  load_param();  
            #pdb.set_trace()  
            print 'result is'  
            #print result
            for dic in result:
                for key in dic:
                    print(key,dic[key].shape)
            # print 'one of results is:'  
            # print result[0]['fc2_weight'].asnumpy()  

    python showmxmodel.py 2>&1 | tee log.txt
    result is
    ('stage3_unit2_bn1_beta', (256L,))
    ('stage3_unit2_bn3_beta', (256L,))
    ('stage3_unit11_bn1_gamma', (256L,))
    ('stage3_unit5_bn3_gamma', (256L,))
    ('stage3_unit3_conv1_weight', (256L, 256L, 3L, 3L))
    ('stage2_unit1_bn3_gamma', (128L,))
    ('stage3_unit4_conv1_weight', (256L, 256L, 3L, 3L))
    ('stage3_unit12_bn3_beta', (256L,))
    ('stage2_unit2_bn3_beta', (128L,))
    ('conv0_weight', (64L, 3L, 3L, 3L))
    ('stage3_unit11_relu1_gamma', (256L,))
    ('stage4_unit1_conv1sc_weight', (512L, 256L, 1L, 1L))
    ('stage3_unit1_conv1sc_weight', (256L, 128L, 1L, 1L))
    ('bn1_beta', (512L,))
    ('stage1_unit2_bn2_beta', (64L,))
    ('stage3_unit2_conv2_weight', (256L, 256L, 3L, 3L))
    ('stage1_unit2_conv1_weight', (64L, 64L, 3L, 3L))
    ('stage3_unit14_bn2_beta', (256L,))
    ('stage4_unit2_bn3_beta', (512L,))
    ('stage3_unit8_bn1_gamma', (256L,))
    ('stage3_unit7_bn1_gamma', (256L,))
    ('stage2_unit3_bn1_beta', (128L,))
    ('stage2_unit4_conv1_weight', (128L, 128L, 3L, 3L))
    ('stage3_unit2_bn2_gamma', (256L,))
    ('stage1_unit1_conv1_weight', (64L, 64L, 3L, 3L))
    ('stage3_unit9_conv2_weight', (256L, 256L, 3L, 3L))
    ('stage3_unit13_conv1_weight', (256L, 256L, 3L, 3L))
    ('stage3_unit1_relu1_gamma', (256L,))
    ('stage4_unit1_bn3_beta', (512L,))
    ('stage2_unit1_bn2_beta', (128L,))
    ('stage3_unit14_conv1_weight', (256L, 256L, 3L, 3L))
    ('stage3_unit8_bn1_beta', (256L,))
    ('stage3_unit11_conv1_weight', (256L, 256L, 3L, 3L))
    ('stage1_unit1_bn3_gamma', (64L,))
    ('stage2_unit2_conv2_weight', (128L, 128L, 3L, 3L))
    ('stage4_unit2_bn1_gamma', (512L,))
    ('stage3_unit3_bn1_gamma', (256L,))
    ('stage1_unit3_bn2_gamma', (64L,))
    ('stage1_unit3_bn3_gamma', (64L,))
    ('stage4_unit2_relu1_gamma', (512L,))
    ('stage3_unit10_conv2_weight', (256L, 256L, 3L, 3L))
    ('stage3_unit12_conv1_weight', (256L, 256L, 3L, 3L))
    ('stage3_unit2_relu1_gamma', (256L,))
    ('stage3_unit10_bn2_beta', (256L,))
    ('stage2_unit3_bn3_gamma', (128L,))
    ('stage2_unit3_bn2_beta', (128L,))
    ('stage3_unit8_bn3_beta', (256L,))
    ('fc1_gamma', (512L,))
    ('stage3_unit14_bn3_gamma', (256L,))
    ('stage3_unit9_bn3_gamma', (256L,))
    ('stage2_unit3_bn3_beta', (128L,))
    ('stage3_unit1_sc_gamma', (256L,))
    ('stage3_unit7_bn1_beta', (256L,))
    ('stage1_unit2_bn3_beta', (64L,))
    ('stage3_unit14_relu1_gamma', (256L,))
    ('stage3_unit13_bn2_beta', (256L,))
    ('stage2_unit1_conv1sc_weight', (128L, 64L, 1L, 1L))
    ('bn0_beta', (64L,))
    ('stage3_unit12_bn1_gamma', (256L,))
    ('stage2_unit1_sc_gamma', (128L,))
    ('relu0_gamma', (64L,))
    ('stage2_unit2_bn2_gamma', (128L,))
    ('stage3_unit4_relu1_gamma', (256L,))

  • 相关阅读:
    A*寻路算法
    Flump使用GPU渲染Flash动画
    Flash AS3.0 垃圾回收机制
    flash builder无法启动的解决方法
    AS3.0 BitmapData类介绍
    x&(x1)表达式的意义
    Feathers: Stage3D加速的UI组件
    Knockout.js入门
    TcxStyleRepository使用示例
    TPageControl使用代码节选
  • 原文地址:https://www.cnblogs.com/adong7639/p/9173854.html
Copyright © 2011-2022 走看看