zoukankan      html  css  js  c++  java
  • caffe源代码改动:抽取随意一张图片的特征


    caffe源代码改动:抽取随意一张图片的特征

    眼下caffe不是非常完好,输入的图片数据须要在prototxt指定路径。

    可是我们往往有这么一个需求:训练后得到一个模型文件。我们想拿这个模型文件来对一张图片抽取特征或者预測分类等。

    假设非得在prototxt指定路径,就非常不方便。

    因此,这种工具才是我们须要的:给一个可运行文件通过命令行来传递图片路径。然后caffe读入图片数据。进行一次正向传播。


    因此我做了这么一个工具。用来抽取随意一张图片的特征。

    这工具的用法例如以下:


    extract_one_feature.bin ./model/caffe_reference_imagenet_model ./examples/_temp/imagenet_val.prototxt fc7 ./examples/_temp/features /media/G/imageset/clothing/针织衫/针织衫_426.jpg CPU

    參数1:./model/caffe_reference_imagenet_model是训练后的模型文件

    參数2:./examples/_temp/imagenet_val.prototxt 网络配置文件

    參数3:fc7是blob的名字

    參数4:./examples/_temp/features 将该图片的特征保存在该文件

    參数5:图片路径

    參数6:GPU或者CPU模式


    (事实上我还想到更好的工具,假设该可运行文件是监听模式的,就是通过一定的方式,给该进程传递 图片路径,进程接到任务就运行。

    这样子的话。就不须要每次抽一张图片都要申请内存空间。(*^__^*) 嘻嘻……)


    以下给出初步改动方法,大家能够依据自己需求再改动。



    extract_one_feature.cpp(该文件參考过源代码中extract_features.cpp改动)

    #include <stdio.h>  // for snprintf
    #include <string>
    #include <vector>
    #include <iostream>
    #include <fstream>
    
    #include "boost/algorithm/string.hpp"
    #include "google/protobuf/text_format.h"
    #include "leveldb/db.h"
    #include "leveldb/write_batch.h"
    
    #include "caffe/blob.hpp"
    #include "caffe/common.hpp"
    #include "caffe/net.hpp"
    #include "caffe/proto/caffe.pb.h"
    #include "caffe/util/io.hpp"
    #include "caffe/vision_layers.hpp"
    
    using namespace caffe;  // NOLINT(build/namespaces)
    
    template<typename Dtype>
    int feature_extraction_pipeline(int argc, char** argv);
    
    int main(int argc, char** argv) {
      return feature_extraction_pipeline<float>(argc, argv);
    //  return feature_extraction_pipeline<double>(argc, argv);
    }
    
    template<typename Dtype>
    class writeDb
    {
    public:
    	void open(string dbName)
    	{
    		db.open(dbName.c_str());
    	}
    	void write(const Dtype &data)
    	{
    		db<<data;
    	}
    	void write(const string &str)
    	{
    		db<<str;
    	}
    	virtual ~writeDb()
    	{
    		db.close();
    	}
    private:
    	std::ofstream db;
    };
    
    template<typename Dtype>
    int feature_extraction_pipeline(int argc, char** argv) {
      ::google::InitGoogleLogging(argv[0]);
      const int num_required_args = 6;
      if (argc < num_required_args) {
        LOG(ERROR)<<
        "This program takes in a trained network and an input data layer, and then"
        " extract features of the input data produced by the net.
    "
        "Usage: extract_features  pretrained_net_param"
        "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
        "  save_feature_leveldb_name1[,name2,...]  img_path  [CPU/GPU]"
        "  [DEVICE_ID=0]
    "
        "Note: you can extract multiple features in one pass by specifying"
        " multiple feature blob names and leveldb names seperated by ','."
        " The names cannot contain white space characters and the number of blobs"
        " and leveldbs must be equal.";
        return 1;
      }
      int arg_pos = num_required_args;
    
      arg_pos = num_required_args;
      if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
        LOG(ERROR)<< "Using GPU";
        uint device_id = 0;
        if (argc > arg_pos + 1) {
          device_id = atoi(argv[arg_pos + 1]);
          CHECK_GE(device_id, 0);
        }
        LOG(ERROR) << "Using Device_id=" << device_id;
        Caffe::SetDevice(device_id);
        Caffe::set_mode(Caffe::GPU);
      } else {
        LOG(ERROR) << "Using CPU";
        Caffe::set_mode(Caffe::CPU);
      }
      Caffe::set_phase(Caffe::TEST);
    
      arg_pos = 0;  // the name of the executable
      string pretrained_binary_proto(argv[++arg_pos]);//网络模型參数文件
    
      string feature_extraction_proto(argv[++arg_pos]);
    
      shared_ptr<Net<Dtype> > feature_extraction_net(
          new Net<Dtype>(feature_extraction_proto));
    
      feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);//将网络參数load进内存
    
    
      string extract_feature_blob_names(argv[++arg_pos]);
      vector<string> blob_names;//要抽取特征的layer的名字,能够是多个
      boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
    
      string save_feature_leveldb_names(argv[++arg_pos]);
      vector<string> leveldb_names;// 这里我改写成一个levedb为一个文件,数据格式不使用真正的levedb,而是自己定义
      boost::split(leveldb_names, save_feature_leveldb_names,
                   boost::is_any_of(","));
      CHECK_EQ(blob_names.size(), leveldb_names.size()) <<
          " the number of blob names and leveldb names must be equal";
      size_t num_features = blob_names.size();
    
      for (size_t i = 0; i < num_features; i++) {
        CHECK(feature_extraction_net->has_blob(blob_names[i]))  //检測blob的名字在网络中是否存在
            << "Unknown feature blob name " << blob_names[i]
            << " in the network " << feature_extraction_proto;
      }
    
    
      vector<shared_ptr<writeDb<Dtype> > > feature_dbs;
      for (size_t i = 0; i < num_features; ++i) //打开db,准备写入数据
      {
        LOG(INFO)<< "Opening db " << leveldb_names[i];
        writeDb<Dtype>* db = new writeDb<Dtype>();
        db->open(leveldb_names[i]);
        feature_dbs.push_back(shared_ptr<writeDb<Dtype> >(db));
      }
    
    
    
      LOG(ERROR)<< "Extacting Features";
    
      const shared_ptr<Layer<Dtype> > layer = feature_extraction_net->layer_by_name("data");//获取第一层
      MyImageDataLayer<Dtype>* my_layer = (MyImageDataLayer<Dtype>*)layer.get();
      my_layer->setImgPath(argv[++arg_pos],1);//"/media/G/imageset/clothing/针织衫/针织衫_1.jpg"
      //设置图片路径
    
      vector<Blob<float>*> input_vec;
      vector<int> image_indices(num_features, 0);
      int num_mini_batches = 1;//atoi(argv[++arg_pos]);//共多少次迭代。

    每次迭代的数量在prototxt用batchsize指定 for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) //共num_mini_batches次迭代 { feature_extraction_net->Forward(input_vec);//一次正向传播 for (int i = 0; i < num_features; ++i) //多层特征 { const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net ->blob_by_name(blob_names[i]); int batch_size = feature_blob->num(); int dim_features = feature_blob->count() / batch_size; Dtype* feature_blob_data; for (int n = 0; n < batch_size; ++n) { feature_blob_data = feature_blob->mutable_cpu_data() + feature_blob->offset(n); feature_dbs[i]->write("3 "); for (int d = 0; d < dim_features; ++d) { feature_dbs[i]->write((Dtype)(d+1)); feature_dbs[i]->write(":"); feature_dbs[i]->write(feature_blob_data[d]); feature_dbs[i]->write(" "); } feature_dbs[i]->write(" "); } // for (int n = 0; n < batch_size; ++n) } // for (int i = 0; i < num_features; ++i) } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) LOG(ERROR)<< "Successfully extracted the features!"; return 0; }


    my_data_layer.cpp(參考image_data_layer改动)

    #include <fstream>  // NOLINT(readability/streams)
    #include <iostream>  // NOLINT(readability/streams)
    #include <string>
    #include <utility>
    #include <vector>
    
    #include "caffe/layer.hpp"
    #include "caffe/util/io.hpp"
    #include "caffe/util/math_functions.hpp"
    #include "caffe/util/rng.hpp"
    #include "caffe/vision_layers.hpp"
    
    namespace caffe {
    
    
    template <typename Dtype>
    MyImageDataLayer<Dtype>::~MyImageDataLayer<Dtype>() {
    }
    
    
    template <typename Dtype>
    void MyImageDataLayer<Dtype>::setImgPath(string path,int label)
    {
    	lines_.clear();
    	lines_.push_back(std::make_pair(path, label));
    }
    
    
    template <typename Dtype>
    void MyImageDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
          vector<Blob<Dtype>*>* top) {
      Layer<Dtype>::SetUp(bottom, top);
      const int new_height  = this->layer_param_.image_data_param().new_height();
      const int new_width  = this->layer_param_.image_data_param().new_width();
      CHECK((new_height == 0 && new_width == 0) ||
          (new_height > 0 && new_width > 0)) << "Current implementation requires "
          "new_height and new_width to be set at the same time.";
    
      /*
       * 由于以下须要随便拿一张图片来初始化blob。
       * 因此须要硬盘上有一张图片。
       * 1 从prototxt读取一张图片的路径,
       * 2 事实上也能够在这里将用于初始化的图片路径写死
      */
    
      /*1*/
      /*
      const string& source = this->layer_param_.image_data_param().source();
      LOG(INFO) << "Opening file " << source;
      std::ifstream infile(source.c_str());
      string filename;
      int label;
      while (infile >> filename >> label) {
        lines_.push_back(std::make_pair(filename, label));
      }
      */
    
      /*2*/
      lines_.push_back(std::make_pair("/home/linger/init.jpg",1));
    
      //上面1和2代码能够随意用一段
    
      lines_id_ = 0;
      // Read a data point, and use it to initialize the top blob. (随便)读取一张图片,来初始化blob
      Datum datum;
      CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
                             new_height, new_width, &datum));
      // image
      const int crop_size = this->layer_param_.image_data_param().crop_size();
      const int batch_size = 1;//this->layer_param_.image_data_param().batch_size();
      const string& mean_file = this->layer_param_.image_data_param().mean_file();
      if (crop_size > 0) {
        (*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
        prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
      } else {
        (*top)[0]->Reshape(batch_size, datum.channels(), datum.height(),
                           datum.width());
        prefetch_data_.Reshape(batch_size, datum.channels(), datum.height(),
            datum.width());
      }
      LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
          << (*top)[0]->channels() << "," << (*top)[0]->height() << ","
          << (*top)[0]->width();
      // label
      (*top)[1]->Reshape(batch_size, 1, 1, 1);
      prefetch_label_.Reshape(batch_size, 1, 1, 1);
      // datum size
      datum_channels_ = datum.channels();
      datum_height_ = datum.height();
      datum_width_ = datum.width();
      datum_size_ = datum.channels() * datum.height() * datum.width();
      CHECK_GT(datum_height_, crop_size);
      CHECK_GT(datum_width_, crop_size);
      // check if we want to have mean
      if (this->layer_param_.image_data_param().has_mean_file()) {
        BlobProto blob_proto;
        LOG(INFO) << "Loading mean file from" << mean_file;
        ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto);
        data_mean_.FromProto(blob_proto);
        CHECK_EQ(data_mean_.num(), 1);
        CHECK_EQ(data_mean_.channels(), datum_channels_);
        CHECK_EQ(data_mean_.height(), datum_height_);
        CHECK_EQ(data_mean_.width(), datum_width_);
      } else {
        // Simply initialize an all-empty mean.
        data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
      }
      // Now, start the prefetch thread. Before calling prefetch, we make two
      // cpu_data calls so that the prefetch thread does not accidentally make
      // simultaneous cudaMalloc calls when the main thread is running. In some
      // GPUs this seems to cause failures if we do not so.
      prefetch_data_.mutable_cpu_data();
      prefetch_label_.mutable_cpu_data();
      data_mean_.cpu_data();
    
    
    }
    
    //--------------------------------以下是读取一张图片数据-----------------------------------------------
    template <typename Dtype>
    void MyImageDataLayer<Dtype>::fetchData() {
    	  Datum datum;
    	  CHECK(prefetch_data_.count());
    	  Dtype* top_data = prefetch_data_.mutable_cpu_data();
    	  Dtype* top_label = prefetch_label_.mutable_cpu_data();
    	  ImageDataParameter image_data_param = this->layer_param_.image_data_param();
    	  const Dtype scale = image_data_param.scale();//image_data_layer相关參数
    	  const int batch_size = 1;//image_data_param.batch_size(); 这里我们仅仅须要一张图片
    
    	  const int crop_size = image_data_param.crop_size();
    	  const bool mirror = image_data_param.mirror();
    	  const int new_height = image_data_param.new_height();
    	  const int new_width = image_data_param.new_width();
    
    	  if (mirror && crop_size == 0) {
    	    LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
    	        << "set at the same time.";
    	  }
    	  // datum scales
    	  const int channels = datum_channels_;
    	  const int height = datum_height_;
    	  const int width = datum_width_;
    	  const int size = datum_size_;
    	  const int lines_size = lines_.size();
    	  const Dtype* mean = data_mean_.cpu_data();
    
    	  for (int item_id = 0; item_id < batch_size; ++item_id) {//读取一图片
    	    // get a blob
    	    CHECK_GT(lines_size, lines_id_);
    	    if (!ReadImageToDatum(lines_[lines_id_].first,
    	          lines_[lines_id_].second,
    	          new_height, new_width, &datum)) {
    	      continue;
    	    }
    	    const string& data = datum.data();
    	    if (crop_size) {
    	      CHECK(data.size()) << "Image cropping only support uint8 data";
    	      int h_off, w_off;
    	      // We only do random crop when we do training.
    	        h_off = (height - crop_size) / 2;
    	        w_off = (width - crop_size) / 2;
    
    	        // Normal copy 正常读取。把裁剪后的图片数据读给top_data
    	        for (int c = 0; c < channels; ++c) {
    	          for (int h = 0; h < crop_size; ++h) {
    	            for (int w = 0; w < crop_size; ++w) {
    	              int top_index = ((item_id * channels + c) * crop_size + h)
    	                              * crop_size + w;
    	              int data_index = (c * height + h + h_off) * width + w + w_off;
    	              Dtype datum_element =
    	                  static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
    	              top_data[top_index] = (datum_element - mean[data_index]) * scale;
    	            }
    	          }
    	        }
    
    	    } else {
    	      // Just copy the whole data 正常读取,把图片数据读给top_data
    	      if (data.size()) {
    	        for (int j = 0; j < size; ++j) {
    	          Dtype datum_element =
    	              static_cast<Dtype>(static_cast<uint8_t>(data[j]));
    	          top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
    	        }
    	      } else {
    	        for (int j = 0; j < size; ++j) {
    	          top_data[item_id * size + j] =
    	              (datum.float_data(j) - mean[j]) * scale;
    	        }
    	      }
    	    }
    	    top_label[item_id] = datum.label();//读取该图片的标签
    
    	  }
    }
    
    template <typename Dtype>
    Dtype MyImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
          vector<Blob<Dtype>*>* top) {
    
      //更新input
    	fetchData();
    
      // Copy the data
      caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
                 (*top)[0]->mutable_cpu_data());
      caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
                 (*top)[1]->mutable_cpu_data());
    
      return Dtype(0.);
    }
    
    #ifdef CPU_ONLY
    STUB_GPU_FORWARD(ImageDataLayer, Forward);
    #endif
    
    INSTANTIATE_CLASS(MyImageDataLayer);
    
    }  // namespace caffe
    


    在data_layers.hpp加入一下代码,參考ImageDataLayer写的。

    template <typename Dtype>
    class MyImageDataLayer : public Layer<Dtype>  {
     public:
      explicit MyImageDataLayer(const LayerParameter& param)
          : Layer<Dtype>(param) {}
      virtual ~MyImageDataLayer();
      virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
          vector<Blob<Dtype>*>* top);
    
      virtual inline LayerParameter_LayerType type() const {
        return LayerParameter_LayerType_MY_IMAGE_DATA;
      }
      virtual inline int ExactNumBottomBlobs() const { return 0; }
      virtual inline int ExactNumTopBlobs() const { return 2; }
      void fetchData();
      void setImgPath(string path,int label);
     protected:
      virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
          vector<Blob<Dtype>*>* top);
    
      virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
          const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
    
    
      vector<std::pair<std::string, int> > lines_;
      int lines_id_;
      int datum_channels_;
      int datum_height_;
      int datum_width_;
      int datum_size_;
      Blob<Dtype> prefetch_data_;
      Blob<Dtype> prefetch_label_;
      Blob<Dtype> data_mean_;
      Caffe::Phase phase_;
    };


    改动caffe.proto,在适当的位置加入以下信息,也是參考image_data写的。


    MY_IMAGE_DATA = 36;


    optional MyImageDataParameter my_image_data_param = 36;


    // Message that stores parameters used by MyImageDataLayer
    message MyImageDataParameter {
      // Specify the data source.
      optional string source = 1;
      // For data pre-processing, we can do simple scaling and subtracting the
      // data mean, if provided. Note that the mean subtraction is always carried
      // out before scaling.
      optional float scale = 2 [default = 1];
      optional string mean_file = 3;
      // Specify the batch size.
      optional uint32 batch_size = 4;
      // Specify if we would like to randomly crop an image.
      optional uint32 crop_size = 5 [default = 0];
      // Specify if we want to randomly mirror data.
      optional bool mirror = 6 [default = false];
      // The rand_skip variable is for the data layer to skip a few data points
      // to avoid all asynchronous sgd clients to start at the same point. The skip
      // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
      // be larger than the number of keys in the leveldb.
      optional uint32 rand_skip = 7 [default = 0];
      // Whether or not ImageLayer should shuffle the list of files at every epoch.
      optional bool shuffle = 8 [default = false];
      // It will also resize images if new_height or new_width are not zero.
      optional uint32 new_height = 9 [default = 0];
      optional uint32 new_width = 10 [default = 0];
    }


    以上每行位置不在一起,能够參考读一个image_data相应的位置。



    本文作者:linger

    本文链接:http://blog.csdn.net/lingerlanlan/article/details/39400375



  • 相关阅读:
    [dev][ipsec][esp] ipsec链路中断的感知问题
    [dev] Go语言查看doc与生成API doc
    [daily]gtk程序不跟随系统的dark主题
    [dev] Go的协程切换问题
    基因程序设计/基因编程/遗传编程
    [daily][emacs][go] 配置emacs go-mode的编辑环境以及环境变量问题
    Java Spring中@Query中使用JPQL LIKE 写法
    JavaScript 使用HTML DOM的oninput事件,实时监听value值变化
    Java中执行.exe文件
    Java关于List<String> 进行排序,重写Comparator()方法
  • 原文地址:https://www.cnblogs.com/wgwyanfs/p/7076121.html
Copyright © 2011-2022 走看看