zoukankan      html  css  js  c++  java
  • Caffe源码-Solver类

    Solver类简介

    Net类中实现了网络的前向/反向计算和参数更新,而Solver类中则是对此进行进一步封装,包含可用于逐次训练网络的Step()函数,和用于求解网络的优化解的Solve()函数,同时还实现了一些存储、读取网络模型快照的接口函数。

    solver.cpp源码

    template<typename Dtype>
    void Solver<Dtype>::SetActionFunction(ActionCallback func) {
      action_request_function_ = func;    //设置回调函数,该函数会返回求解器的动作类型
    }
    
    template<typename Dtype>
    SolverAction::Enum Solver<Dtype>::GetRequestedAction() {  //返回求解器的动作类型
      if (action_request_function_) {
        // If the external request function has been set, call it.
        return action_request_function_();    //运行回调函数,该函数会返回求解器的动作类型
      }
      return SolverAction::NONE;
    }
    
    template <typename Dtype>
    Solver<Dtype>::Solver(const SolverParameter& param)   //构造函数,使用param消息初始化求解器
        : net_(), callbacks_(), requested_early_exit_(false) {
      Init(param);    //使用param消息初始化当前求解器
    }
    
    template <typename Dtype>
    Solver<Dtype>::Solver(const string& param_file)
        : net_(), callbacks_(), requested_early_exit_(false) {  //构造函数,从文本类型的proto文件中读取求解器参数
      SolverParameter param;
      ReadSolverParamsFromTextFileOrDie(param_file, &param);    //从param_file中读取消息数据到param中
      Init(param);    //初始化求解器
    }
    
    template <typename Dtype>
    void Solver<Dtype>::Init(const SolverParameter& param) {    //Solver类初始化
      LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
        << std::endl << param.DebugString();    //主线程中打印信息
      param_ = param;
      //loss的滑动平均窗的长度,每次计算最近average_loss_次的平均loss
      CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
      CheckSnapshotWritePermissions();          //检查是否能够打开快照文件
      if (param_.random_seed() >= 0) {          //SolverParameter消息中设置了随机种子
        Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());    //设置
      }
      // Scaffolding code
      InitTrainNet();   //初始化训练网络
      InitTestNets();   //初始化所有测试网络    //训练网络只有一个,但是测试网络可以有多个
      if (Caffe::root_solver()) {
        LOG(INFO) << "Solver scaffolding done.";    //只在主线程中打印
      }
      iter_ = 0;      //初始化参数
      current_step_ = 0;
    }
    
    // Load weights from the caffemodel(s) specified in "weights" solver parameter
    // into the train and test nets.
    template <typename Dtype>
    void LoadNetWeights(shared_ptr<Net<Dtype> > net, const std::string& model_list) {   //加载权重文件
      std::vector<std::string> model_names;
      boost::split(model_names, model_list, boost::is_any_of(",")); //拆分文件名,权重文件名在model_list中以","中分隔开
      for (int i = 0; i < model_names.size(); ++i) {
        boost::trim(model_names[i]);    //删除首位空格
        LOG(INFO) << "Finetuning from " << model_names[i];  //打印权重文件名
        net->CopyTrainedLayersFrom(model_names[i]);   //从文件中拷贝blob数据到网络的同名参数中
      }
    }
    
    template <typename Dtype>
    void Solver<Dtype>::InitTrainNet() {    //初始化训练网络,配置网络参数,加载预训练模型
      //训练网络的proto文件名可通过SolverParameter消息中的train_net_param, train_net, net_param, net四个中的任意一个指定
      const int num_train_nets = param_.has_net() + param_.has_net_param() +
          param_.has_train_net() + param_.has_train_net_param();    //这四个参数总共设置的训练网络个数
      const string field_names = "net, net_param, train_net, train_net_param";
      CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
          << "using one of these fields: " << field_names;          //检查是否大于等于1
      CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
          << "one of these fields specifying a train_net: " << field_names; //检查是否小于等于1 //四个中只能有一个设置了true
      NetParameter net_param;
      if (param_.has_train_net_param()) {   //训练网络的名称在train_net_param中设置了
        LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in train_net_param.";  //主线程中打印
        net_param.CopyFrom(param_.train_net_param());   //从NetParameter消息中拷贝网络参数至net_param
      } else if (param_.has_train_net()) {  //训练网络的名称在train_net中设置了
        LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from train_net file: " << param_.train_net();
        ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); //从proto文件中读取网络参数
      }
      if (param_.has_net_param()) {         //训练网络的名称在net_param中设置了
        LOG_IF(INFO, Caffe::root_solver()) << "Creating training net specified in net_param.";
        net_param.CopyFrom(param_.net_param()); //从NetParameter类型的消息中拷贝网络参数
      }
      if (param_.has_net()) {               //训练网络的名称在net中设置了
        LOG_IF(INFO, Caffe::root_solver()) << "Creating training net from net file: " << param_.net();
        ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);   //从proto文件中读取网络参数
      }
      // Set the correct NetState.  We start with the solver defaults (lowest
      // precedence); then, merge in any NetState specified by the net_param itself;
      // finally, merge in any NetState specified by the train_state (highest
      // precedence).
      //Message::MergeFrom()的机制,单字段的值会被覆盖,嵌套消息的值会被融合在一起,重复字段的值会被拼接在一起
      //Message::CopyFrom()的机制,清空当前的消息,然后将指定消息MergeFrom()到当前消息中
      //net_param中的状态值先是设置为默认值,然后使用从上面四个设置中读取到的网络参数net_param中的网络状态覆盖其中相同的,
      //再用当前求解器中设置的SolverParameter消息中的train_state覆盖其中相同的.
      //在网络中设置的网络状态优先级低,会被求解器中设置的网络状态覆盖
      NetState net_state;
      net_state.set_phase(TRAIN);   //设置网络的状态,训练模式
      net_state.MergeFrom(net_param.state());     //先使用上面的从文件或者消息中读取的网络参数中的网络状态
      net_state.MergeFrom(param_.train_state());  //再使用当前求解器中设置的训练网络状态
      net_param.mutable_state()->CopyFrom(net_state);   //将最终的到的网络状态存入网络参数中
      net_.reset(new Net<Dtype>(net_param));            //使用该网络参数初始化网络,存入net_中
      for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) { //weights参数的个数
        LoadNetWeights(net_, param_.weights(w_idx));    //加载每个参数中的一个或者多个预训练模型到net_中
      }
    }
    
    template <typename Dtype>
    void Solver<Dtype>::InitTestNets() {    //初始化测试网络
      const bool has_net_param = param_.has_net_param();
      const bool has_net_file = param_.has_net();
      const int num_generic_nets = has_net_param + has_net_file;    //是否设置了模型参数,是否设置了模型文件名
      CHECK_LE(num_generic_nets, 1)
          << "Both net_param and net_file may not be specified.";   //检查是否小于等于1,这两个不能同时指定
      const int num_test_net_params = param_.test_net_param_size(); //设置的测试网络的参数的个数
      const int num_test_net_files = param_.test_net_size();        //设置的测试网络的个数
      const int num_test_nets = num_test_net_params + num_test_net_files;   //总个数
      if (num_generic_nets) {
          //test_iter_表示每个测试网络迭代的次数,test_iter_参数设置的个数必须与测试网络的个数相等
          //如果设置了模型参数或者模型文件名,那么这里面也可能设置了test net,所以test_iter_的个数必须大于等于num_test_nets
          CHECK_GE(param_.test_iter_size(), num_test_nets)
              << "test_iter must be specified for each test network.";
      } else {
          //没有设置net_parma或者net的话,test net全部在test_net_parma和test_net中指定,个数需相等
          CHECK_EQ(param_.test_iter_size(), num_test_nets)
              << "test_iter must be specified for each test network.";
      }
      // If we have a generic net (specified by net or net_param, rather than
      // test_net or test_net_param), we may have an unlimited number of actual
      // test networks -- the actual number is given by the number of remaining
      // test_iters after any test nets specified by test_net_param and/or test_net
      // are evaluated.
      const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;  //相减得到在net_parma或者net中定义的test net的个数
      const int num_test_net_instances = num_test_nets + num_generic_net_instances;   //总的test net的个数,即为param_.test_iter_size()
      if (param_.test_state_size()) {   //设置了test_state_,则个数必须与测试网络的个数相等
        CHECK_EQ(param_.test_state_size(), num_test_net_instances)
            << "test_state must be unspecified or specified once per test net.";      //检查个数是否相等
      }
      if (num_test_net_instances) {
        CHECK_GT(param_.test_interval(), 0);  //检查设置的测试的迭代间隔是否大于0
      }
      int test_net_id = 0;
      vector<string> sources(num_test_net_instances);
      vector<NetParameter> net_params(num_test_net_instances);
      //caffe.proto文件中注明了test net运行的优先级,(1) test_net_param, (2) test_net, (3) net_param/net.
      for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
          sources[test_net_id] = "test_net_param";    //保存定义该测试网络的来源
          net_params[test_net_id].CopyFrom(param_.test_net_param(i)); //从NetParameter类型的消息中拷贝网络参数
      }
      for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
          sources[test_net_id] = "test_net file: " + param_.test_net(i);  //保存来源,加上文件名
          ReadNetParamsFromTextFileOrDie(param_.test_net(i),
              &net_params[test_net_id]);    //从proto文件中读取网络参数,存入net_param中
      }
      const int remaining_test_nets = param_.test_iter_size() - test_net_id;  //net_param/net中定义的网络的个数
      if (has_net_param) {  //定义了net_param,则剩余的测试网络都定义在此处
        for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
          sources[test_net_id] = "net_param";
          net_params[test_net_id].CopyFrom(param_.net_param()); //拷贝网络参数
        }
      }
      if (has_net_file) {    //同样,从net文件中定义的测试网络文件名中读取网络参数
        for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
          sources[test_net_id] = "net file: " + param_.net();
          ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
        }
      }
      test_nets_.resize(num_test_net_instances);    //调整大小
      for (int i = 0; i < num_test_net_instances; ++i) {
        // Set the correct NetState.  We start with the solver defaults (lowest
        // precedence); then, merge in any NetState specified by the net_param
        // itself; finally, merge in any NetState specified by the test_state
        // (highest precedence).
        //与InitTrainNet()中的操作类似,先使用默认值,然后使用网络参数中的网络状态覆盖默认值,再使用
        //求解器中设置的测试网络状态覆盖之前的值,得到最终的测试网络状态
        NetState net_state;
        net_state.set_phase(TEST);    //设置模式为test
        net_state.MergeFrom(net_params[i].state()); //先使用网络参数中设置的网络状态覆盖
        if (param_.test_state_size()) {
          net_state.MergeFrom(param_.test_state(i));  //然后使用求解器中设置的测试网络状态覆盖
        }
        net_params[i].mutable_state()->CopyFrom(net_state); //将最终的测试网络状态存入net_params[i]中
        LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i];  //打印之前保存的来源信息
        test_nets_[i].reset(new Net<Dtype>(net_params[i])); //使用net_params[i]创建网络,存入test_nets_中
        test_nets_[i]->set_debug_info(param_.debug_info()); //将求解器的是否打印信息的设置存入网络中
        for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
          LoadNetWeights(test_nets_[i], param_.weights(w_idx)); //加载预训练模型文件,每个测试网络都会尝试加载所有的预训练模型文件
        }
      }
    }
    
    //求解器单步迭代iters次
    template <typename Dtype>
    void Solver<Dtype>::Step(int iters) {
      const int start_iter = iter_;         //当前已迭代的次数
      const int stop_iter = iter_ + iters;  //终止迭代时的次数
      int average_loss = this->param_.average_loss();   //loss的滑动平均窗的长度
      losses_.clear();          //清空历史loss值
      smoothed_loss_ = 0;       //清空
      iteration_timer_.Start(); //打开计时器
    
      while (iter_ < stop_iter) {
        // zero-init the params
        net_->ClearParamDiffs();    //清空网络中所有可学习参数的梯度数据
        if (param_.test_interval() && iter_ % param_.test_interval() == 0   //两次测试之间的迭代间隔不为0,且当前轮到测试
            && (iter_ > 0 || param_.test_initialization())) {   //初始时可以进入测试模式
          //test_initialization()仅仅用于表示初始(iter_==0)时是否运行一次测试网络
          //该值为true时,(iter_ % test_interval == 0)总是成立,每次开始迭代时都会先进入测试模式.该值为false时只在iter_ > 0时进入测试
          if (Caffe::root_solver()) {   //测试网络只在主线程中运行
            TestAll();    //运行所有测试网络,并打印输出信息
          }
          if (requested_early_exit_) {  //测试过程中出现提前退出动作,退出循环
            // Break out of the while loop because stop was requested while testing.
            break;
          }
        }
    
        for (int i = 0; i < callbacks_.size(); ++i) {   //solver的回调函数,在多gpu训练时用于同步求解器
          callbacks_[i]->on_start();
        }
        const bool display = param_.display() && iter_ % param_.display() == 0; //设置了打印间隔并且当前迭代轮到打印了
        net_->set_debug_info(display && param_.debug_info());   //设置是否打印调试信息
        // accumulate the loss and gradient
        Dtype loss = 0;
        for (int i = 0; i < param_.iter_size(); ++i) {  //单次迭代会执行iter_size次前向反向过程
          loss += net_->ForwardBackward();  //执行一次前向计算和反向传播,并累加iter_size次计算得到的loss
        }
        loss /= param_.iter_size();   //每次迭代的平均loss
        // average the loss across iterations for smoothed reporting
        UpdateSmoothedLoss(loss, start_iter, average_loss); //将loss保存在losses_中,并计算新的均值smoothed_loss_
        if (display) {                //需要打印此次迭代的信息
          float lapse = iteration_timer_.Seconds();   //关闭计时器,返回已运行的时间,单位s
          float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);   //iterations_last_为上次开启计时器时的迭代次数,得到每秒可迭代的次数
          LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
              << " (" << per_s << " iter/s, " << lapse << "s/"
              << param_.display() << " iters), loss = " << smoothed_loss_;  //打印迭代次数,迭代速度,运行时间等信息
          iteration_timer_.Start();     //重新打开计时器
          iterations_last_ = iter_;     //保存当前的迭代次数
          const vector<Blob<Dtype>*>& result = net_->output_blobs();    //训练网络的所有输出blob
          int score_index = 0;
          for (int j = 0; j < result.size(); ++j) {
            const Dtype* result_vec = result[j]->cpu_data();    //第j个输出blob的data_数据
            const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; //该输出blob的名称
            const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]];  //该输出blob的loss权重
            for (int k = 0; k < result[j]->count(); ++k) {
              ostringstream loss_msg_stream;
              if (loss_weight) {    //权重不为0时,保存权重和加权后的输出值
                loss_msg_stream << " (* " << loss_weight
                                << " = " << loss_weight * result_vec[k] << " loss)";
              }
              LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
                  << score_index++ << ": " << output_name << " = "
                  << result_vec[k] << loss_msg_stream.str();    //打印信息
            }
          }
        }
        //求解器的回调函数,在梯度计算完毕之后调用.同样也是用于多gpu计算时梯度数据的同步
        for (int i = 0; i < callbacks_.size(); ++i) {
          callbacks_[i]->on_gradients_ready();
        }
        ApplyUpdate();    //根据学习率,冲量,权重衰减值等参数计算参数更新时使用的梯度,并更新网络中的参数,在SGDSolver类中实现
    
        SolverAction::Enum request = GetRequestedAction();    //获取当前求解器的动作
    
        // Save a snapshot if needed.
        if ((param_.snapshot()
             && iter_ % param_.snapshot() == 0
             && Caffe::root_solver()) ||
             (request == SolverAction::SNAPSHOT)) {   //当前迭代次数轮到存储快照,或者当前的解器动作为存快照
          Snapshot();   //生成快照文件
        }
        if (SolverAction::STOP == request) {    //当前动作为退出,则提前退出
          requested_early_exit_ = true;
          // Break out of training loop.
          break;
        }
      }
    }
    
    template <typename Dtype>
    void Solver<Dtype>::Solve(const char* resume_file) {  //从resume_file文件中恢复网络和求解器状态,并训练网络
      CHECK(Caffe::root_solver());    //在主线程中进行该操作
      LOG(INFO) << "Solving " << net_->name();
      LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();  //打印网络名称和学习率更新策略
    
      // Initialize to false every time we start solving.
      requested_early_exit_ = false;    //每次求解时初始化下状态
    
      if (resume_file) {        //文件名不为空
        LOG(INFO) << "Restoring previous solver status from " << resume_file;
        Restore(resume_file);   //从文件中还原网络参数和求解器的状态
      }
    
      // For a network that is trained by the solver, no bottom or top vecs
      // should be given, and we will just provide dummy vecs.
      int start_iter = iter_;   //当前已迭代的次数
      Step(param_.max_iter() - iter_);    //max_iter_为最大迭代次数,计算当前需要迭代的次数
      // If we haven't already, save a snapshot after optimization, unless
      // overridden by setting snapshot_after_train := false
      if (param_.snapshot_after_train()
          && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
        //如果设置了训练结束后保存快照,并且当前迭代次数在并未轮到保存快照
        //满足 param_.snapshot() && iter_ % param_.snapshot() == 0 的话会在Step()函数中保存当前iter_的快照,此处自然无需再保存
        Snapshot();
      }
      if (requested_early_exit_) {  //同样判断下求解器的动作
        LOG(INFO) << "Optimization stopped early.";
        return;
      }
      // After the optimization is done, run an additional train and test pass to
      // display the train and test loss/outputs if appropriate (based on the
      // display and test_interval settings, respectively).  Unlike in the rest of
      // training, for the train net we only run a forward pass as we've already
      // updated the parameters "max_iter" times -- this final pass is only done to
      // display the loss, which is computed in the forward pass.
      //如果需要显示,会额外进行一次前向计算.这与Step()中的最后一次计算不同,Step()中的最后一次计算包括前向和反向计算,
      //还包括参数的更新,此时参数更新之后网络的loss并不知道,所以此处会使用更新后的参数再计算一次前向过程,得到对应的loss
      if (param_.display() && iter_ % param_.display() == 0) {  //设置了打印求解器的信息并且当前迭代轮到打印了
        int average_loss = this->param_.average_loss();   //loss的滑动平均窗的长度
        Dtype loss;
        net_->Forward(&loss);   //一次前向计算
    
        UpdateSmoothedLoss(loss, start_iter, average_loss); //更新losses_,并计算平均loss
        LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;  //打印信息
      }
      if (param_.test_interval() && iter_ % param_.test_interval() == 0) {    //设置了测试网络的运行间隔,并且当前轮到测试网络
        TestAll();    //运行所有测试网络
      }
      LOG(INFO) << "Optimization Done.";    //求解器优化完成
    }
    
    template <typename Dtype>
    void Solver<Dtype>::TestAll() {   //运行全部测试网络
      for (int test_net_id = 0;
           test_net_id < test_nets_.size() && !requested_early_exit_; //没有要求提前退出
           ++test_net_id) {
        Test(test_net_id);    //执行第test_net_id个测试网络
      }
    }
    
    template <typename Dtype>
    void Solver<Dtype>::Test(const int test_net_id) {     //执行第test_net_id个测试网络
      CHECK(Caffe::root_solver());    //测试网络只在主线程中运行
      LOG(INFO) << "Iteration " << iter_
                << ", Testing net (#" << test_net_id << ")";  //打印迭代信息,测试网络的id
      //共享网络,将训练网络net_中的参数blob的数据指针赋给当前的测试网络,只修改测试网络的指针指向位置,不会拷贝数据
      CHECK_NOTNULL(test_nets_[test_net_id].get())->ShareTrainedLayersWith(net_.get());
      vector<Dtype> test_score;
      vector<int> test_score_output_id;
      const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];  //当前的测试网络
      Dtype loss = 0;
      //test_iter(test_net_id)为第test_net_id个测试网络在测试时需要迭代的次数
      for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
        SolverAction::Enum request = GetRequestedAction();    //获取当前的求解器动作
        // Check to see if stoppage of testing/training has been requested.
        while (request != SolverAction::NONE) {       //非NONE类型的话,则会执行相应的动作
            if (SolverAction::SNAPSHOT == request) {  //拍摄快照,并继续训练
              Snapshot();   //生成快照文件,并继续当前操作
            } else if (SolverAction::STOP == request) { //提前退出
              requested_early_exit_ = true;
            }
            request = GetRequestedAction();
        }
        if (requested_early_exit_) {    //退出,不进行后续的操作
          // break out of test loop.
          break;
        }
    
        Dtype iter_loss;
        //执行test_net的一次前向计算过程,loss存入iter_loss中,result为网络的输出blob(net_output_blobs_)
        const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
        if (param_.test_compute_loss()) { //是否计算测试网络的平均loss
          loss += iter_loss;              //累加每次计算出的loss
        }
        if (i == 0) {           //初次计算时,先确定好test_score和test_score_output_id的大小
          for (int j = 0; j < result.size(); ++j) {
            const Dtype* result_vec = result[j]->cpu_data();  //网络输出的第j个blob的data_
            for (int k = 0; k < result[j]->count(); ++k) {
              test_score.push_back(result_vec[k]);    //将输出blob的data中的数据全部存入test_score中
              test_score_output_id.push_back(j);      //将数据在输出blob中的来源存入test_score_output_id中
            }
          }
        } else {
          int idx = 0;
          for (int j = 0; j < result.size(); ++j) {   //每个输出blob
            const Dtype* result_vec = result[j]->cpu_data();    //输出blob的data_数据
            for (int k = 0; k < result[j]->count(); ++k) {
              test_score[idx++] += result_vec[k];     //累加测试网络每次迭代时得到的输出blob数据
            }
          }
        }
      }
      if (requested_early_exit_) {        //提前退出?
        LOG(INFO)     << "Test interrupted.";
        return;
      }
      if (param_.test_compute_loss()) {   //是否计算测试网络的平均loss
        loss /= param_.test_iter(test_net_id);    //计算该测试网络test_iter(test_net_id)次迭代的loss均值
        LOG(INFO) << "Test loss: " << loss;
      }
      for (int i = 0; i < test_score.size(); ++i) {
        //数据test_score[i]来源于blob类型的net_output_blobs_[test_score_output_id[i]]中,output_blob_index为该blob在blobs_的索引
        const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]];
        const string& output_name = test_net->blob_names()[output_blob_index];        //该blob的名称
        const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];   //该blob的loss权重
        ostringstream loss_msg_stream;
        const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); //除以迭代次数,得到输出blob的均值
        if (loss_weight) {    //权重非0时,权重和加权值
          loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)";
        }
        LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "
                  << mean_score << loss_msg_stream.str(); //打印测试网络的每个输出blob中的每个数据的均值
      }
    }
    
    template <typename Dtype>
    void Solver<Dtype>::Snapshot() {    //生成两个快照文件,分别保存网络参数(NetParameter类型)和求解器的状态(SolverState类型)
      CHECK(Caffe::root_solver());      //同样,存快照只在主线程中操作
      string model_filename;
      switch (param_.snapshot_format()) {   //设置的快照文件格式
      case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:   //二进制proto类型
        model_filename = SnapshotToBinaryProto(); //将训练网络的网络参数存为".caffemodel"后缀的文件,返回其文件名
        break;
      case caffe::SolverParameter_SnapshotFormat_HDF5:          //hdf5类型
        model_filename = SnapshotToHDF5();        //将训练网络的网络参数写入文件中,返回其文件名
        break;
      default:
        LOG(FATAL) << "Unsupported snapshot format.";
      }
    
      SnapshotSolverState(model_filename);      //将求解器的状态(SolverState类型)保存为文件
    }
    
    template <typename Dtype>
    void Solver<Dtype>::CheckSnapshotWritePermissions() { //检查是否能够创建快照文件(只检查是否能够以写方式创建文件,不会存数据进去)
      if (Caffe::root_solver() && param_.snapshot()) {    //只在主线程中操作
        CHECK(param_.has_snapshot_prefix())
            << "In solver params, snapshot is specified but snapshot_prefix is not";  //检查是否设置了快照文件名的前缀
        string probe_filename = SnapshotFilename(".tempfile");    //生成快照的文件名,".tempfile"为后缀
        std::ofstream probe_ofs(probe_filename.c_str());    //创建临时文件文件
        if (probe_ofs.good()) {   //判断是否发生错误
          probe_ofs.close();      //关闭
          std::remove(probe_filename.c_str());    //删除文件
        } else {
          LOG(FATAL) << "Cannot write to snapshot prefix '"
              << param_.snapshot_prefix() << "'.  Make sure "
              << "that the directory exists and is writable.";    //创建失败,报错
        }
      }
    }
    
    //生成快照的文件名,前缀字符串 + "_iter_" + 迭代次数转字符串 + 扩展名extension
    template <typename Dtype>
    string Solver<Dtype>::SnapshotFilename(const string& extension) {
      return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
        + extension;
    }
    
    template <typename Dtype>
    string Solver<Dtype>::SnapshotToBinaryProto() {   //将训练网络的网络参数保存为二进制proto文件,并返回文件名
      string model_filename = SnapshotFilename(".caffemodel");  //生成文件名,扩展名为".caffemodel"
      LOG(INFO) << "Snapshotting to binary proto file " << model_filename;    //打印信息
      NetParameter net_param;
      //将训练网络net_中的所有layer的参数写入到net_param中,snapshot_diff()表示是否需要保存梯度信息到快照中
      net_->ToProto(&net_param, param_.snapshot_diff());
      WriteProtoToBinaryFile(net_param, model_filename);    //将NetParameter类型的消息写入到文件中
      return model_filename;    //返回快照文件名
    }
    
    template <typename Dtype>
    string Solver<Dtype>::SnapshotToHDF5() {    //将训练网络的参数存为hdf5文件中,返回文件名
      string model_filename = SnapshotFilename(".caffemodel.h5");   //快照的文件名
      LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;  //打印
      net_->ToHDF5(model_filename, param_.snapshot_diff());         //将net_的各layer的参数写入hdf5文件中
      return model_filename;    //返回文件名
    }
    
    //还原网络参数和训练状态,从state_file文件中读取求解器的状态,如果里面还设置了网络参数的模型文件,则还会加载网络参数
    template <typename Dtype>
    void Solver<Dtype>::Restore(const char* state_file) {
      string state_filename(state_file);
      if (state_filename.size() >= 3 &&
          state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { //根据文件名判断hdf5还是proto类型,稍微粗糙了点
        RestoreSolverStateFromHDF5(state_filename);   //从hdf5文件中读取
      } else {
        RestoreSolverStateFromBinaryProto(state_filename);  //从二进制proto文件中读取
      }
    }
    
    //start_iter为初始迭代的次数
    //losses_中存放loss值,初始时(iter_ < start_iter + average_loss)存放的loss的个数逐渐增加,个数达到average_loss时不再增加.
    //之后新的loss值都是从前往后依次覆盖之前的保存的值,不断循环.
    template <typename Dtype>
    void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss) {
      if (losses_.size() < average_loss) {    //个数还不到滑动平均窗的大小,会逐渐增加losses_的大小
        losses_.push_back(loss);              //将loss存入
        int size = losses_.size();
        //smoothed_loss_为当前loss存入之前losses_的均值,存入后更新下均值
        smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
      } else {    
        int idx = (iter_ - start_iter) % average_loss;  //将iter_对应的loss存入losses_中的对应位置
        smoothed_loss_ += (loss - losses_[idx]) / average_loss; //先计算平均loss,再将值存入
        losses_[idx] = loss;
      }
    }
    

    solver.hpp源码

    /**
      * @brief Enumeration of actions that a client of the Solver may request by
      * implementing the Solver's action request function, which a
      * client may optionally provide in order to request early termination
      * or saving a snapshot without exiting. In the executable caffe, this
      * mechanism is used to allow the snapshot to be saved when stopping
      * execution with a SIGINT (Ctrl-C).
      */
      namespace SolverAction {
        enum Enum {
          NONE = 0,  // Take no special action.
          STOP = 1,  // Stop training. snapshot_after_train controls whether a
                     // snapshot is created.    //停止,提前退出
          SNAPSHOT = 2  // Take a snapshot, and keep training.  //将当前的训练网络的参数存为快照文件,并继续后续操作
        };
      }
    
    /**
     * @brief Type of a function that returns a Solver Action enumeration.
     */
    typedef boost::function<SolverAction::Enum()> ActionCallback;
    
    /**
     * @brief An interface for classes that perform optimization on Net%s.
     *
     * Requires implementation of ApplyUpdate to compute a parameter update
     * given the current state of the Net parameters.
     */
    template <typename Dtype>
    class Solver {
     public:
      explicit Solver(const SolverParameter& param);
      explicit Solver(const string& param_file);
      void Init(const SolverParameter& param);
      void InitTrainNet();
      void InitTestNets();
    
      // Client of the Solver optionally may call this in order to set the function
      // that the solver uses to see what action it should take (e.g. snapshot or
      // exit training early).
      void SetActionFunction(ActionCallback func);    //设置求解器动作的回调函数
      SolverAction::Enum GetRequestedAction();
      // The main entry of the solver function. In default, iter will be zero. Pass
      // in a non-zero iter number to resume training for a pre-trained net.
      virtual void Solve(const char* resume_file = NULL);
      inline void Solve(const string& resume_file) { Solve(resume_file.c_str()); }
      void Step(int iters);
      // The Restore method simply dispatches to one of the
      // RestoreSolverStateFrom___ protected methods. You should implement these
      // methods to restore the state from the appropriate snapshot type.
      void Restore(const char* resume_file);
      // The Solver::Snapshot function implements the basic snapshotting utility
      // that stores the learned net. You should implement the SnapshotSolverState()
      // function that produces a SolverState protocol buffer that needs to be
      // written to disk together with the learned net.
      void Snapshot();
      virtual ~Solver() {}
      inline const SolverParameter& param() const { return param_; }
      inline shared_ptr<Net<Dtype> > net() { return net_; }
      inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
        return test_nets_;
      }
      int iter() const { return iter_; }
    
      // Invoked at specific points during an iteration
      //迭代过程中调用的回调类,里面实现了两个函数,用于多gpu训练中的同步
      class Callback {
       protected:
        virtual void on_start() = 0;
        virtual void on_gradients_ready() = 0;
    
        template <typename T>
        friend class Solver;
      };
      const vector<Callback*>& callbacks() const { return callbacks_; }
      void add_callback(Callback* value) {
        callbacks_.push_back(value);    //加入
      }
    
      void CheckSnapshotWritePermissions();
      /**
       * @brief Returns the solver type.
       */
      virtual inline const char* type() const { return ""; }
    
      // Make and apply the update value for the current iteration.
      virtual void ApplyUpdate() = 0;
    
     protected:
      string SnapshotFilename(const string& extension);
      string SnapshotToBinaryProto();
      string SnapshotToHDF5();
      // The test routine
      void TestAll();
      void Test(const int test_net_id = 0);
      virtual void SnapshotSolverState(const string& model_filename) = 0;
      virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
      virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
      void DisplayOutputBlobs(const int net_id);
      void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
    
      SolverParameter param_;
      int iter_;                      //当前的迭代次数
      int current_step_;              //当前迭代的阶段,在学习率更新策略为step和multistep中使用
      shared_ptr<Net<Dtype> > net_;   //训练网络
      vector<shared_ptr<Net<Dtype> > > test_nets_;    //所有的测试网络
      vector<Callback*> callbacks_;   //回调函数
      vector<Dtype> losses_;          //保存最近average_loss_次迭代的loss值
      Dtype smoothed_loss_;           //losses_的均值
    
      // A function that can be set by a client of the Solver to provide indication
      // that it wants a snapshot saved and/or to exit early.
      ActionCallback action_request_function_;  //返回值为求解器动作的回调函数
    
      // True iff a request to stop early was received.
      bool requested_early_exit_;   //是否需要提前退出
    
      // Timing information, handy to tune e.g. nbr of GPUs
      Timer iteration_timer_;       //计时器
      float iterations_last_;       //上一次开启计时器的iter_的值
    
      DISABLE_COPY_AND_ASSIGN(Solver);
    };
    

    小结

    1. 求解器的动作回调函数在caffe.cpp文件中设置,为SignalHandler::CheckForSignals()的函数指针。当Unix系统中出现SIGINT或SIGHUP信号时,GotSIGINT()GotSIGHUP()函数会返回相应标志,并清空信号。而SignalHandler::CheckForSignals()函数则会根据标志返回对应的求解器动作类型(NONE/STOP/SNAPSHOT),具体可参考signal_handler.cpp文件。
    2. Step()函数中每次迭代计算前向/反向过程时,都使用了ClearParamDiffs()函数清空梯度。这是因为caffe中每次反向传播时的梯度数据都是累加在原数据上的,所以每次迭代时都需要手动清空,这与PyTorch中需要手动将梯度清零一致。

    Caffe的源码笔者是第一次阅读,一边阅读一边记录,对代码的理解和分析可能会存在错误或遗漏,希望各位读者批评指正,谢谢支持!

  • 相关阅读:
    大道至简第一张读后感
    字符串加密
    写一个类,在任何时候都可以向它查询创建了多少个类
    类与对象动手动脑
    2016年读书清单
    2016-09-01
    Spring笔记(五)--注解方式实现AOP
    Spring笔记(三)--代理模式
    Spring笔记(四)--公共属性的配置
    表达式之谜
  • 原文地址:https://www.cnblogs.com/Relu110/p/12079937.html
Copyright © 2011-2022 走看看