1 #include <stdio.h> // for snprintf 2 #include <string> 3 #include <vector> 4 5 #include "boost/algorithm/string.hpp" 6 #include "google/protobuf/text_format.h" 7 8 #include "caffe/blob.hpp" 9 #include "caffe/common.hpp" 10 #include "caffe/net.hpp" 11 #include "caffe/proto/caffe.pb.h" 12 #include "caffe/util/db.hpp" 13 #include "caffe/util/io.hpp" 14 #include "caffe/vision_layers.hpp" 15 16 using caffe::Blob; 17 using caffe::Caffe; 18 using caffe::Datum; 19 using caffe::Net; 20 using boost::shared_ptr; 21 using std::string; 22 namespace db = caffe::db; 23 24 template<typename Dtype> 25 int feature_extraction_pipeline(int argc, char** argv); 26 27 int main(int argc, char** argv) { 28 return feature_extraction_pipeline<float>(argc, argv); 29 // return feature_extraction_pipeline<double>(argc, argv); 30 } 31 32 template<typename Dtype> 33 int feature_extraction_pipeline(int argc, char** argv) { 34 ::google::InitGoogleLogging(argv[0]); 35 const int num_required_args = 7; 36 if (argc < num_required_args) { 37 LOG(ERROR)<< 38 "This program takes in a trained network and an input data layer, and then" 39 " extract features of the input data produced by the net. " 40 "Usage: extract_features pretrained_net_param" 41 " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" 42 " save_feature_dataset_name1[,name2,...] num_mini_batches db_type" 43 " [CPU/GPU] [DEVICE_ID=0] " 44 "Note: you can extract multiple features in one pass by specifying" 45 " multiple feature blob names and dataset names separated by ','." 46 " The names cannot contain white space characters and the number of blobs" 47 " and datasets must be equal."; 48 return 1; 49 } 50 int arg_pos = num_required_args; 51 52 arg_pos = num_required_args; 53 if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) { 54 LOG(ERROR)<< "Using GPU"; 55 uint device_id = 0; 56 if (argc > arg_pos + 1) { 57 device_id = atoi(argv[arg_pos + 1]); 58 CHECK_GE(device_id, 0); 59 } 60 LOG(ERROR) << "Using Device_id=" << device_id; 61 Caffe::SetDevice(device_id); 62 Caffe::set_mode(Caffe::GPU); 63 } else { 64 LOG(ERROR) << "Using CPU"; 65 Caffe::set_mode(Caffe::CPU); 66 } 67 68 arg_pos = 0; // the name of the executable 69 std::string pretrained_binary_proto(argv[++arg_pos]); 70 71 // Expected prototxt contains at least one data layer such as 72 // the layer data_layer_name and one feature blob such as the 73 // fc7 top blob to extract features. 74 /* 75 layers { 76 name: "data_layer_name" 77 type: DATA 78 data_param { 79 source: "/path/to/your/images/to/extract/feature/images_leveldb" 80 mean_file: "/path/to/your/image_mean.binaryproto" 81 batch_size: 128 82 crop_size: 227 83 mirror: false 84 } 85 top: "data_blob_name" 86 top: "label_blob_name" 87 } 88 layers { 89 name: "drop7" 90 type: DROPOUT 91 dropout_param { 92 dropout_ratio: 0.5 93 } 94 bottom: "fc7" 95 top: "fc7" 96 } 97 */ 98 std::string feature_extraction_proto(argv[++arg_pos]); 99 shared_ptr<Net<Dtype> > feature_extraction_net( 100 new Net<Dtype>(feature_extraction_proto, caffe::TEST)); 101 feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); 102 103 std::string extract_feature_blob_names(argv[++arg_pos]); 104 std::vector<std::string> blob_names; 105 boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); 106 107 std::string save_feature_dataset_names(argv[++arg_pos]); 108 std::vector<std::string> dataset_names; 109 boost::split(dataset_names, save_feature_dataset_names, 110 boost::is_any_of(",")); 111 CHECK_EQ(blob_names.size(), dataset_names.size()) << 112 " the number of blob names and dataset names must be equal"; 113 size_t num_features = blob_names.size(); 114 115 for (size_t i = 0; i < num_features; i++) { 116 CHECK(feature_extraction_net->has_blob(blob_names[i])) 117 << "Unknown feature blob name " << blob_names[i] 118 << " in the network " << feature_extraction_proto; 119 } 120 121 int num_mini_batches = atoi(argv[++arg_pos]); 122 123 std::vector<shared_ptr<db::DB> > feature_dbs; 124 std::vector<shared_ptr<db::Transaction> > txns; 125 const char* db_type = argv[++arg_pos]; 126 for (size_t i = 0; i < num_features; ++i) { 127 LOG(INFO)<< "Opening dataset " << dataset_names[i]; 128 shared_ptr<db::DB> db(db::GetDB(db_type)); 129 db->Open(dataset_names.at(i), db::NEW); 130 feature_dbs.push_back(db); 131 shared_ptr<db::Transaction> txn(db->NewTransaction()); 132 txns.push_back(txn); 133 } 134 135 LOG(ERROR)<< "Extacting Features"; 136 137 Datum datum; 138 const int kMaxKeyStrLength = 100; 139 char key_str[kMaxKeyStrLength]; 140 std::vector<Blob<float>*> input_vec; 141 std::vector<int> image_indices(num_features, 0); 142 for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) { 143 feature_extraction_net->Forward(input_vec); 144 for (int i = 0; i < num_features; ++i) { 145 const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net 146 ->blob_by_name(blob_names[i]); 147 int batch_size = feature_blob->num(); 148 int dim_features = feature_blob->count() / batch_size; 149 const Dtype* feature_blob_data; 150 for (int n = 0; n < batch_size; ++n) { 151 datum.set_height(feature_blob->height()); 152 datum.set_width(feature_blob->width()); 153 datum.set_channels(feature_blob->channels()); 154 datum.clear_data(); 155 datum.clear_float_data(); 156 feature_blob_data = feature_blob->cpu_data() + 157 feature_blob->offset(n); 158 for (int d = 0; d < dim_features; ++d) { 159 datum.add_float_data(feature_blob_data[d]); 160 } 161 int length = snprintf(key_str, kMaxKeyStrLength, "%010d", 162 image_indices[i]); 163 string out; 164 CHECK(datum.SerializeToString(&out)); 165 txns.at(i)->Put(std::string(key_str, length), out); 166 ++image_indices[i]; 167 if (image_indices[i] % 1000 == 0) { 168 txns.at(i)->Commit(); 169 txns.at(i).reset(feature_dbs.at(i)->NewTransaction()); 170 LOG(ERROR)<< "Extracted features of " << image_indices[i] << 171 " query images for feature blob " << blob_names[i]; 172 } 173 } // for (int n = 0; n < batch_size; ++n) 174 } // for (int i = 0; i < num_features; ++i) 175 } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) 176 // write the last batch 177 for (int i = 0; i < num_features; ++i) { 178 if (image_indices[i] % 1000 != 0) { 179 txns.at(i)->Commit(); 180 } 181 LOG(ERROR)<< "Extracted features of " << image_indices[i] << 182 " query images for feature blob " << blob_names[i]; 183 feature_dbs.at(i)->Close(); 184 } 185 186 LOG(ERROR)<< "Successfully extracted the features!"; 187 return 0; 188 }