zoukankan      html  css  js  c++  java
  • 解析模型参数

    解析模型参数

    Caffe Model

    caffemodel是使用protobuf进行保存的。

    1. 有prototxt

    依赖于caffe。

    解析的时候如果有prototxt,caffemodel两个文件,可以使用caffe提供的接口load网络,然后再解析网络。

    weigth是net.params[param_name][0].data,bias需要判断net.params[param_name]的长度,如果大于1则有bias,是net.params[param_name][1].data,否则没有bias参数。

    #!/usr/bin/env python
    import caffe
    import numpy as np
    
    # 使输出的参数完全显示
    # 若没有这一句,因为参数太多,中间会以省略号“……”的形式代替
    np.set_printoptions(threshold='nan')
    
    # deploy文件
    MODEL_FILE = 'caffe_deploy.prototxt'
    # 预先训练好的caffe模型
    PRETRAIN_FILE = 'caffe_iter_10000.caffemodel'
    
    # 保存参数的文件
    params_txt = 'params.txt'
    pf = open(params_txt, 'w')
    
    # 让caffe以测试模式读取网络参数
    net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)
    
    # 遍历每一层
    for param_name in net.params.keys():
        # 权重参数
        weight = net.params[param_name][0].data
        # 偏置参数
        if len(net.params[param_name]) == 2:
        	bias = net.params[param_name][1].data
    
        # 该层在prototxt文件中对应“top”的名称
        pf.write(param_name)
        pf.write('
    ')
    
        # 写权重参数
        pf.write('
    ' + param_name + '_weight:
    
    ')
        # 权重参数是多维数组,为了方便输出,转为单列数组
        weight.shape = (-1, 1)
    
        for w in weight:
            pf.write('%ff, ' % w)
    
        # 写偏置参数
        if len(net.params[param_name]) == 2:
            pf.write('
    
    ' + param_name + '_bias:
    
    ')
            # 偏置参数是多维数组,为了方便输出,转为单列数组
            bias.shape = (-1, 1)
            for b in bias:
                pf.write('%ff, ' % b)
    
        pf.write('
    
    ')
    
    pf.close
    

    参考:
    http://blog.csdn.net/u011762313/article/details/49851795
    https://www.cnblogs.com/denny402/p/5686257.html

    2. 没有prototxt

    依赖于caffe.proto.caffe_pb2。

    如果不存在prototxt文件,那么就需要直接调用caffe_pb2解析caffemodel文件,成员需要参照caffe.proto文件,在caffe工程的src/caffe/proto下面,使用protobuf语法,利用protoc编译生成caffe.pb.cc, caffe.pb.h两个文件供C++调用。

    weight参数位置:
    NetParameter -> LayerParameter -> BlobProto -> data
    model -> layers -> blobs[0] -> data

    bias参数位置(如果有):
    NetParameter -> LayerParameter -> BlobProto -> data
    model -> layers -> blobs[1] -> data

    如下是caffe.proto文件定义的参数:
    NetParameter主要包含name, input, input_shape, layer。
    LayerParameter主要包含name, type, bottom, top, blobs, param(学习率), transform_param,convolution_param等。

    message NetParameter {
      optional string name = 1; // consider giving the network a name
      // DEPRECATED. See InputParameter. The input blobs to the network.
      repeated string input = 3;
      // DEPRECATED. See InputParameter. The shape of the input blobs.
      repeated BlobShape input_shape = 8;
    
      // 4D input dimensions -- deprecated.  Use "input_shape" instead.
      // If specified, for each input blob there should be four
      // values specifying the num, channels, height and width of the input blob.
      // Thus, there should be a total of (4 * #input) numbers.
      repeated int32 input_dim = 4;
    
      // Whether the network will force every layer to carry out backward operation.
      // If set False, then whether to carry out backward is determined
      // automatically according to the net structure and learning rates.
      optional bool force_backward = 5 [default = false];
      // The current "state" of the network, including the phase, level, and stage.
      // Some layers may be included/excluded depending on this state and the states
      // specified in the layers' include and exclude fields.
      optional NetState state = 6;
    
      // Print debugging information about results while running Net::Forward,
      // Net::Backward, and Net::Update.
      optional bool debug_info = 7 [default = false];
    
      // The layers that make up the net.  Each of their configurations, including
      // connectivity and behavior, is specified as a LayerParameter.
      repeated LayerParameter layer = 100;  // ID 100 so layers are printed last.
    
      // DEPRECATED: use 'layer' instead.
      repeated V1LayerParameter layers = 2;
    }
    
    message LayerParameter {
      optional string name = 1; // the layer name
      optional string type = 2; // the layer type
      repeated string bottom = 3; // the name of each bottom blob
      repeated string top = 4; // the name of each top blob
    
      // The train / test phase for computation.
      optional Phase phase = 10;
    
      // The amount of weight to assign each top blob in the objective.
      // Each layer assigns a default value, usually of either 0 or 1,
      // to each top blob.
      repeated float loss_weight = 5;
    
      // Specifies training parameters (multipliers on global learning constants,
      // and the name and other settings used for weight sharing).
      repeated ParamSpec param = 6;
    
      // The blobs containing the numeric parameters of the layer.
      repeated BlobProto blobs = 7;
    
      // Specifies whether to backpropagate to each bottom. If unspecified,
      // Caffe will automatically infer whether each input needs backpropagation
      // to compute parameter gradients. If set to true for some inputs,
      // backpropagation to those inputs is forced; if set false for some inputs,
      // backpropagation to those inputs is skipped.
      //
      // The size must be either 0 or equal to the number of bottoms.
      repeated bool propagate_down = 11;
    
      // Rules controlling whether and when a layer is included in the network,
      // based on the current NetState.  You may specify a non-zero number of rules
      // to include OR exclude, but not both.  If no include or exclude rules are
      // specified, the layer is always included.  If the current NetState meets
      // ANY (i.e., one or more) of the specified rules, the layer is
      // included/excluded.
      repeated NetStateRule include = 8;
      repeated NetStateRule exclude = 9;
    
      // Parameters for data pre-processing.
      optional TransformationParameter transform_param = 100;
    
      // Parameters shared by loss layers.
      optional LossParameter loss_param = 101;
    
      // Layer type-specific parameters.
      //
      // Note: certain layers may have more than one computational engine
      // for their implementation. These layers include an Engine type and
      // engine parameter for selecting the implementation.
      // The default for the engine is set by the ENGINE switch at compile-time.
      optional AccuracyParameter accuracy_param = 102;
      optional ArgMaxParameter argmax_param = 103;
      optional BatchNormParameter batch_norm_param = 139;
      optional BiasParameter bias_param = 141;
      optional ConcatParameter concat_param = 104;
      optional ContrastiveLossParameter contrastive_loss_param = 105;
      optional ConvolutionParameter convolution_param = 106;
      optional CropParameter crop_param = 144;
      optional DataParameter data_param = 107;
      optional DropoutParameter dropout_param = 108;
      optional DummyDataParameter dummy_data_param = 109;
      optional EltwiseParameter eltwise_param = 110;
      optional ELUParameter elu_param = 140;
      optional EmbedParameter embed_param = 137;
      optional ExpParameter exp_param = 111;
      optional FlattenParameter flatten_param = 135;
      optional HDF5DataParameter hdf5_data_param = 112;
      optional HDF5OutputParameter hdf5_output_param = 113;
      optional HingeLossParameter hinge_loss_param = 114;
      optional ImageDataParameter image_data_param = 115;
      optional InfogainLossParameter infogain_loss_param = 116;
      optional InnerProductParameter inner_product_param = 117;
      optional InputParameter input_param = 143;
      optional LogParameter log_param = 134;
      optional LRNParameter lrn_param = 118;
      optional MemoryDataParameter memory_data_param = 119;
      optional MVNParameter mvn_param = 120;
      optional ParameterParameter parameter_param = 145;
      optional PoolingParameter pooling_param = 121;
      optional PowerParameter power_param = 122;
      optional PReLUParameter prelu_param = 131;
      optional PythonParameter python_param = 130;
      optional RecurrentParameter recurrent_param = 146;
      optional ReductionParameter reduction_param = 136;
      optional ReLUParameter relu_param = 123;
      optional ReshapeParameter reshape_param = 133;
      optional ScaleParameter scale_param = 142;
      optional SigmoidParameter sigmoid_param = 124;
      optional SoftmaxParameter softmax_param = 125;
      optional SPPParameter spp_param = 132;
      optional SliceParameter slice_param = 126;
      optional TanHParameter tanh_param = 127;
      optional ThresholdParameter threshold_param = 128;
      optional TileParameter tile_param = 138;
      optional WindowDataParameter window_data_param = 129;
    }
    
    

    BlobProto里面则是存储了相关的训练参数,重要的两个成员是shape, data:
    packed = true表示采用连续存储的方式,在前面先写一个字节长度,再在下面逐行记录每个数据;如果不用packed模式,需要在每个数据前都声明是data这个字段的数据,消耗空间。默认repeated字段就是使用packed模式。

    message BlobProto {
      optional BlobShape shape = 7;
      repeated float data = 5 [packed = true];
      repeated float diff = 6 [packed = true];
      repeated double double_data = 8 [packed = true];
      repeated double double_diff = 9 [packed = true];
    }
    

    以下代码可以用来读取caffemodel各字段:

    import caffe.proto.caffe_pb2 as caffe_pb2
    import pdb
    
    caffemodel_filename = '/home/gr/deepwork/HyperLPR/lpr.caffemodel'
        
    model = caffe_pb2.NetParameter()
    
    f=open(caffemodel_filename, 'rb')
    model.ParseFromString(f.read())
    f.close()
    
    layers = model.layer
    print 'name: ' + model.name
    pdb.set_trace()
    layer_id=-1
    
    for layer in layers:
        print layer.name + ':'
        if len(layer.blobs) > 0:
            print '	weight filter ' + str(layer.blobs[0].shape.dim) + ':' + str(layer.blobs[0].data[0])
        	if len(layer.blobs) > 1:
            	print '	bias filter ' + str(layer.blobs[1].shape.dim) + ':' + str(layer.blobs[1].data[0])
        else:
            print '	equal 0'
    

    另外也可以解析生成对应的prototxt文件。

    参考:
    https://www.cnblogs.com/zjutzz/p/6185452.html
    http://blog.csdn.net/jiongnima/article/details/72904526
    http://blog.csdn.net/seven_first/article/details/47418887#message-layerparameter
    https://www.cnblogs.com/autyinjing/p/6495103.html

    3. 使用c++解析

    依赖于protobuf。

    同样需要对照caffe.proto进行解析:

    weight参数位置:
    NetParameter -> LayerParameter -> BlobProto -> data
    msg -> layer -> blobs[0] -> data

    bias参数位置(如果有):
    NetParameter -> LayerParameter -> BlobProto -> data
    msg -> layer -> blobs[1] -> data

    protoc生成的C++代码,每个字段都提供了同名访问接口,但如果是repeated字段,则在名字前面加上mutable_,如下:
    通过string name = student->name();取得学生姓名。
    通过RepeatedPtrField<string>* classes = student->mutable_classes();取得所有课程名, string first_class = classes->Get(0);取得第一门课程,而提供的classes()函数就是调用Get()

    message Student {
        # 名字
        option string name = 1;
        # 课程
        repeated string classes = 2;
    }
    

    对于一个repeated float data字段,会对该字段生成如下成员函数:

     inline int data_size() const;
     inline void clear_data();
     static const int kDataFieldNumber = 5;
     inline float data(int index) const;
     inline void set_data(int index, float value);
     inline void add_data(float value);
     inline const ::google::protobuf::RepeatedField< float >& data() const;
     inline ::google::protobuf::RepeatedField< float >* mutable_data();
    

    提取参数代码:

    #include <stdio.h>
    #include <string.h>
    #include <fstream>
    #include <iostream>
    #include "caffe.pb.h"
    
    using namespace std;
    using namespace caffe;
    
    
    int main(int argc, char* argv[])
    {
    
        caffe::NetParameter msg;
    
        fstream input("/home/gr/deepwork/caffe-tensorflow/examples/mnist/lenet_iter_10000.caffemodel", ios::in | ios::binary);
        if (!msg.ParseFromIstream(&input))
        {
            cerr << "Failed to parse address book." << endl;
            return -1;
        }
    
        ::google::protobuf::RepeatedPtrField< LayerParameter >* layers = msg.mutable_layer();
        ::google::protobuf::RepeatedPtrField< LayerParameter >::iterator it = layers->begin();
        for (; it != layers->end(); ++it)
        {
            cout << it->name() << endl;
            cout << it->type() << endl;
            ::google::protobuf::RepeatedPtrField< BlobProto >* blobs = it->mutable_blobs();
            for (int i = 0; i < blobs->size(); ++i) {
                BlobProto blob = blobs->Get(i);
                ::google::protobuf::RepeatedField< float >* datas = blob.mutable_data();
                for (int j = 0; j < datas->size(); ++j) {
                    cout << datas->Get(j) << " ";
                }
                cout << endl;
            }
        }
    
    return 0;
    }
    
    

    http://blog.csdn.net/zr459927180/article/details/50904938
    http://blog.csdn.net/dachao_xu/article/details/50899534

    二、TensorFlow

    conv2d:

    if normalizer_fn is None and a biases_initializer is provided then a biases variable would be created and added the activations.

    如果没有提供normalizer_fn,并且提供了normalizer_fn就会创建初始化biases,否则其它情况都没有偏置。

    TensorFlow如果使用高级接口slim,而不是自己定义Variable,可以通过trainable_variables()获取需要训练的变量:

    params=tf.trainable_variables()
    
    for idx, v in enumerate(params):
        print("  param {:15}: {:15}   {}".format(idx, str(v.get_shape()), v.name))
    

    结果:

    trainable_variables

    trainable_variables

    可以利用tf.train.Saverload训练好的权重,之后取得每个参数:

    saver = tf.train.Saver()
    params = tf.trainable_variables()
    fp = open('mnist_model.txt', 'w')
    
    with tf.Session() as sess:
        saver.restore(sess, './tmp/mnist_model.ckpt')
        for param in params:
        	v = sess.run(param)
        	fp.write(param.name)
            fp.write(v)
            fp.write('
    ')
    
    fp.close()
    
  • 相关阅读:
    CentOS + java
    在 Centos7 用Jexus服务器 运行.Net Core 只需几部
    dotnet core 开发中遇到的问题
    Scratch3.0设计的插件系统(上篇)
    ASP.NET的编译原理
    搭建git服务器
    ubuntu安装Pillow
    MIT线性代数课程总结与理解-第三部分
    关于在ubuntu系统下显卡为goforce1060安装tensorflow(gpu)
    关于Clion中添加makefile相关参数
  • 原文地址:https://www.cnblogs.com/gr-nick/p/9141444.html
Copyright © 2011-2022 走看看