zoukankan      html  css  js  c++  java
  • Caffe模型读取 Sanny.Liu

    caffe模型最终保存使用过的protobuf形式,将一个已经训练好的caffe模型读取出来,可以参考如下:

    1,包含的头文件:

    #include <google/protobuf/io/coded_stream.h>
    #include <google/protobuf/io/zero_copy_stream_impl.h>
    #include <google/protobuf/text_format.h>
    
    #include "caffe/proto/caffe.pb.cc"    //在caffe src/caffe、proto里面,是编译后自动生成的,其中包括(caffe.pb.cc caffe.pb.d caffe.pb.h caffe.pb.o.warnings.txt)
    

    2,读取网络Message:

    bool loadCaffeNet(const std::string& model_list, Message* proto){  //
    
    	using google::protobuf::io::FileInputStream;
    	using google::protobuf::io::ZeroCopyInputStream;
    	using google::protobuf::io::CodedInputStream;
    	
    	//Message * proto;
    	std::vector<std::string> model_names;
    	boost::split(model_names, model_list, boost::is_any_of(",") );
    	bool success = false;
    	for (int i = 0; i < model_names.size(); ++i) {
    		std::cout<< "Finetuning from " << model_names[i];
    		const char* filename = model_names[i].c_str();
    		int fd = open(filename, O_RDONLY);
    		if( fd < 0 ){
    			std::cout << "File not found: " << fd;
    			return -1;
    		}
    		
    		ZeroCopyInputStream* raw_input = new FileInputStream(fd);
    		CodedInputStream* coded_input = new CodedInputStream(raw_input);		
    		coded_input->SetTotalBytesLimit(INT_MAX, 536870912);
    
    		success = proto->ParseFromCodedStream(coded_input);
    		
    		delete coded_input;
    		delete raw_input;
    		close(fd);
    		
    		return success;
    	}
    	return success;
    }
    

    3,参考caffe/proto/caffe.pb.cc 文件,获取对应的参数

      例如读取文件后:

         std::string trained_filename = "lenet_iter_10000.caffemodel";
    
    	caffe::NetParameter net_protobuf;
    	
    	if(loadCaffeNet(trained_filename, &net_protobuf)){  
    		std::cout<<"load net param success"<<std::endl;
    	}else{
    		std::cout<<"load net param failed"<<std::endl;
    	}
    

      获取网络层数:

    int num_source_layers = net_protobuf.layer_size();
    

      

    for(int i=0; i<num_source_layers; ++i){		
    	caffe::LayerParameter layer_param = net_protobuf.layer(i);
            std::cout << layer_param.name() << std::endl;
            std::cout << layer_param.type() << std::endl;
    
    	int blobsize = layer_param.blobs_size();
    	std::cout << "blobs_size: "<<blobsize << std::endl;
    	for(int j=0; j<blobsize; j++){
    		int dataSize = layer_param.blobs(j).data_size();
    						
    		if(j==0){
    			std::cout << "  weight data_size: "<<dataSize << std::endl;					
    			int ind_weight = dataSize;
    			weight = (float*)malloc(ind_weight*sizeof(float));
    			for(int index=0; index<dataSize; index++){	
    				weight[index] = layer_param.blobs(j).data(index);						
    			}
                  std::cout<<" Convolution->:"<<std::endl; 
                  std::cout<<" layer_param.blobs weight_n "<<layer_param.blobs(0).shape().dim(0)<<std::endl; //n
                  std::cout<<" layer_param.blobs weight_c "<<layer_param.blobs(0).shape().dim(1)<<std::endl; //c
                  std::cout<<" layer_param.blobs weight_h "<<layer_param.blobs(0).shape().dim(2)<<std::endl; //h
                  std::cout<<" layer_param.blobs weight_w "<<layer_param.blobs(0).shape().dim(3)<<std::endl; //w
    		}
    		else if(j==1){
    			std::cout << "  bias data_size: "<<dataSize << std::endl;
    			int ind_bias = dataSize;
    			bias = (float*)malloc(ind_bias*sizeof(float));
    			for(int index=0; index<dataSize; index++){
    				bias[index] = layer_param.blobs(j).data(index);					
    			}
    			
    		}
    										
    	}
    }
    

      以上仅仅是部分代码,需要注意调试!

    其中caffe.pb.cc 和caffe.pb.hpp 文件是基于caffe.proto文件生成的。执行过程为:protoc caffe.proto --cpp_out=. ;将caffe.proto文件,基于目前protobuf的版本生成对应的版本的.cc 和 .hpp文件。

  • 相关阅读:
    int k=0;k=k++;结果等于0,为什么?
    CentOS在无法连接外网的服务器上安装软件(以docker为例)
    docker搭建 elasticsearch 集群
    docker容器之间进行网络通信
    Kafka集群搭建
    Zookeeper集群搭建及常用命令
    SpringBoot将Swagger2文档导出为markdown或html
    Linux(CentOS7)虚拟机修改 NAT模式固定IP
    Linux Maven私服(Nexus)搭建
    Linux配置开机自启动
  • 原文地址:https://www.cnblogs.com/hansjorn/p/4816059.html
Copyright © 2011-2022 走看看