zoukankan      html  css  js  c++  java
  • Caffe::Snapshot的运行过程

    Snapshot的存储

    概述

    Snapshot的存储格式有两种,分别是BINARYPROTO格式和hdf5格式。BINARYPROTO是一种二进制文件,并且可以通过修改shapshot_format来设置存储类型。该项的默认是BINARYPROTO不管哪种格式,运行的过程是类似的,都是从Solver<Dtype>::Snapshot()函数进入,首先调用Net网络的方法,再操作网络中的每一层,最后再操作每一层中blob,最后调用write函数写入输出。源码入口:

     1 void Solver<Dtype>::Snapshot() {
     2   CHECK(Caffe::root_solver());
     3   string model_filename;
     4   switch (param_.snapshot_format()) {
     5   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
     6     model_filename = SnapshotToBinaryProto();
     7     break;
     8   case caffe::SolverParameter_SnapshotFormat_HDF5:
     9     model_filename = SnapshotToHDF5();
    10     break;
    11   default:
    12     LOG(FATAL) << "Unsupported snapshot format.";
    13   }

     

    BINARYPROTO格式

    如果是BINARYPROTO的存储格式,就执行如下代码:

    1 string Solver<Dtype>::SnapshotToBinaryProto() {
    2   string model_filename = SnapshotFilename(".caffemodel");
    3   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
    4   NetParameter net_param;
    5   net_->ToProto(&net_param, param_.snapshot_diff());
    6   WriteProtoToBinaryFile(net_param, model_filename);
    7   return model_filename;
    8 }   

     

    首先会执行SnapshotFilename(“.caffemodel”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToProto(),具体的代码如下:

     1 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
     2   param->Clear();
     3   param->set_name(name_);
     4   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
     5     param->add_input(blob_names_[net_input_blob_indices_[i]]);
     6   }
     7   for (int i = 0; i < layers_.size(); ++i) {
     8     LayerParameter* layer_param = param->add_layer();
     9     layers_[i]->ToProto(layer_param, write_diff);
    10   }
    11 }  

    获取到网络中的每层的名字等参数后,调用layers_[i]->ToProto()每一层的ToProto方法,接下来

    1 void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
    2   param->Clear();
    3   param->CopyFrom(layer_param_);
    4   param->clear_blobs();
    5   for (int i = 0; i < blobs_.size(); ++i) {
    6     blobs_[i]->ToProto(param->add_blobs(), write_diff);
    7   }
    8 } 

    然后调用当前层下的所有blobToProto方法,即:

     1 void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
     2   proto->clear_shape();
     3   for (int i = 0; i < shape_.size(); ++i) {
     4     proto->mutable_shape()->add_dim(shape_[i]);
     5   }
     6   proto->clear_double_data();
     7   proto->clear_double_diff();
     8   const double* data_vec = cpu_data();
     9   for (int i = 0; i < count_; ++i) {
    10     proto->add_double_data(data_vec[i]);
    11   }
    12   if (write_diff) {
    13     const double* diff_vec = cpu_diff();
    14     for (int i = 0; i < count_; ++i) {
    15       proto->add_double_diff(diff_vec[i]);
    16     }
    17   }

    在每一个blob中,会调用add_double_data()函数,把data添加到snapshot文件中,同时会判断是否当前blob参与diff的计算,如果需要当前blob需要diff参数,就调用add_double_diff()添加到snapshot文件中。

    调用完所有的blobToProto()方法后,会执行WriteProtoToBinaryFile()把该文件写出即可。

    1 void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
    2   fstream output(filename, ios::out | ios::trunc | ios::binary);
    3   CHECK(proto.SerializeToOstream(&output));
    4 }

    在该方法里调用FStreamoutput方法进行输出。

    Hdf5格式

    Hdf5格式的运行过程和BINARYPROTO格式的过程类似,首先会调用SnapshotToHDF5()函数,即:

    1 string Solver<Dtype>::SnapshotToHDF5() {
    2   string model_filename = SnapshotFilename(".caffemodel.h5");
    3   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
    4   net_->ToHDF5(model_filename, param_.snapshot_diff());
    5   return model_filename;
    6 }

    首先会执行SnapshotFilename(“.caffemodel.h5”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToHDF5(),即:

     1 void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
     2   hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
     3       H5P_DEFAULT);
     4   hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
     5       H5P_DEFAULT);
     6     hid_t diff_hid = -1;
     7   if (write_diff) {
     8     diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
     9         H5P_DEFAULT);
    10    }
    11   for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
    12     const LayerParameter& layer_param = layers_[layer_id]->layer_param();
    13     string layer_name = layer_param.name();
    14     hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
    15     hid_t layer_diff_hid = -1;
    16     if (write_diff) {
    17       layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
    18           H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);  
    19  }
    20     int num_params = layers_[layer_id]->blobs().size();
    21     for (int param_id = 0; param_id < num_params; ++param_id) {
    22       ostringstream dataset_name;
    23       dataset_name << param_id;
    24       const int net_param_id = param_id_vecs_[layer_id][param_id];
    25       if (param_owners_[net_param_id] == -1) {
    26         hdf5_save_nd_dataset<Dtype>(layer_data_hid, dataset_name.str(),
    27             *params_[net_param_id]);
    28       }
    29       if (write_diff) {
    30         hdf5_save_nd_dataset<Dtype>(layer_diff_hid, dataset_name.str(),
    31             *params_[net_param_id], true);
    32       }
    33 ...............
    34 H5Fclose(file_hid);
    35 }

    该函数首先调用H5Fcreate()创建一个file文件,然后循环调用每一层,通过调用每一层的H5Gcreate2函数记录出该层的data_hid或者diff_hid(如果该层需要参与计算),然后进入每一层内部的blob,然后在当前blob内调用hdf5_save_nd_dataset()hdf5_save_nd_dataset()(如果当前blob需要参与计算diff),将data添加到hdf5格式的文件中,最后调用H5Fclose(file_hid)函数,输出该文件。

     

    Snapshot的恢复

    概述

    想在已经训练好的网络上继续训练,那么需要调用Restore()方法从snapshot的文件中恢复成网络,从而缩短了训练时间。方法的入口是Solver<Dtype>::Restore(const char* state_file)函数,即:

    1 void Solver<Dtype>::Restore(const char* state_file) {
    2   CHECK(Caffe::root_solver());
    3   string state_filename(state_file);
    4   if (state_filename.size() >= 3 &&
    5       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
    6     RestoreSolverStateFromHDF5(state_filename);
    7   } else {
    8     RestoreSolverStateFromBinaryProto(state_filename);
    9   }

    该函数会解析snapshot文件是BINARYPROTO格式还是Hdf5格式,如果是BINARYPROTO格式的话就调用RestoreSolverStateFromBinaryProto()函数,如果格式Hdf5的格式,就执行RestoreSolverStateFromHDF5()

    BINARYPROOTO格式

    如果是BINARYPROTO格式,则执行下列代码:

     

     1 void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
     2     const string& state_file) {
     3   SolverState state;
     4   ReadProtoFromBinaryFile(state_file, &state);
     5   this->iter_ = state.iter();
     6   if (state.has_learned_net()) {
     7     NetParameter net_param;
     8     ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
     9     this->net_->CopyTrainedLayersFrom(net_param);
    10   }
    11   this->current_step_ = state.current_step();
    12   CHECK_EQ(state.history_size(), history_.size())
    13       << "Incorrect length of history blobs.";
    14   for (int i = 0; i < history_.size(); ++i) {
    15     history_[i]->FromProto(state.history(i));
    16   }
    17 }

     

    该函数会大量调用googleprotobuf包内的函数,首先会通过ReadProtoFromBinaryFile()函数读取BINARYPROTO格式的文件来返回是否可以成功读取。然后判断该snapshot是否有曾经训练过的网络,如果有,则调用函数ReadNetParamsFromBinaryFileOrDie()读取出该Net网络,然后调用函数CopyTrainedLayersFrom(net_param)具体恢复该网络的每一层以及当前层内的所有blob,具体数据恢复的工作就是CopyTrainedLayersFrom()函数内部变量调用FromProto()函数来实现blob复制的。然后会通过函数current_step()来判断上次训练的位置(迭代到多少次),然后通过循环把训练过的data数据通过FromProto()完成数据的复制。

    Hdf5格式

     1 void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
     2   hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
     3   CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
     4   this->iter_ = hdf5_load_int(file_hid, "iter");
     5   if (H5LTfind_dataset(file_hid, "learned_net")) {
     6     string learned_net = hdf5_load_string(file_hid, "learned_net");
     7     this->net_->CopyTrainedLayersFrom(learned_net);
     8   }
     9   this->current_step_ = hdf5_load_int(file_hid, "current_step");
    10   hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
    11   CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
    12   int state_history_size = hdf5_get_num_links(history_hid);
    13   CHECK_EQ(state_history_size, history_.size())
    14       << "Incorrect length of history blobs.";
    15   for (int i = 0; i < history_.size(); ++i) {
    16     ostringstream oss;
    17     oss << i;
    18     hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
    19                                 kMaxBlobAxes, history_[i].get());
    20   }
    21   H5Gclose(history_hid);
    22   H5Fclose(file_hid);
    23 }

    该函数会识别hdf5格式存储的snapshot文件的file_hid编号,会判断是否存在之前训练过的网络,如果存在则执行CopyTrainedLayersFrom()函数,完成网络的每层以及每层内的blob的数据的恢复复制,然后或取上一次的训练位置(进行的迭代),并且调用函数hdf5_load_nd_dataset()具体把每次迭代的数据恢复复制,最后再调用H5Fclose()关闭。

     

     

  • 相关阅读:
    Java单例多例的线程安全问题(转)
    Class.forName( )、class.getClassLoader().getResourceAsStream、newInstance()
    new 和Class.forName()有什么区别?(转)
    PS
    Fine BI
    Ipython
    微软推 Azure 机器学习工具:Algorithm Cheat Sheet
    MySQL基本数据类型
    Httprunner3.X+jenkins持续集成
    MSF使用之信息收集
  • 原文地址:https://www.cnblogs.com/liuzhongfeng/p/7340196.html
Copyright © 2011-2022 走看看