zoukankan      html  css  js  c++  java
  • Caffe:深入分析(怎么训练)

    main() 

      首先入口函数caffe.cpp

     1 int main(int argc, char** argv) {
     2   ......
     3   if (argc == 2) {
     4 #ifdef WITH_PYTHON_LAYER
     5     try {
     6 #endif
     7       return GetBrewFunction(caffe::string(argv[1]))(); //根据输入参数确定是train还是test,采用string到函数指针的映射实现,非常巧妙
     8 #ifdef WITH_PYTHON_LAYER
     9     } catch (bp::error_already_set) {
    10       PyErr_Print();
    11       return 1;
    12     }
    13 #endif
    14   } else {
    15     gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
    16   }
    17 }

      在main函数中GetBrewFunction函数调用了通过工厂模式生成的由string到函数指针的map

    1 typedef int (*BrewFunction)();
    2 typedef std::map<caffe::string, BrewFunction> BrewMap;
    3 BrewMap g_brew_map;

      在train、test、device_query、time函数后面都可以看到对这些函数的register,相当于这些函数指针已经在map中存在了

    1 RegisterBrewFunction(train);
    2 RegisterBrewFunction(test);
    3 RegisterBrewFunction(device_query);
    4 RegisterBrewFunction(time);

    train()

      接着是train过程

     1 // Train / Finetune a model.
     2 int train() {
     3   ......
     4   caffe::SolverParameter solver_param;
     5   caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param
     6   ......
     7   shared_ptr<caffe::Solver<float> >
     8       solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式
     9 
    10   if (FLAGS_snapshot.size()) {//迭代snapshot次后保存模型一次
    11     LOG(INFO) << "Resuming from " << FLAGS_snapshot;
    12     solver->Restore(FLAGS_snapshot.c_str());
    13   } else if (FLAGS_weights.size()) {//若采用finetuning,则拷贝weight到指定模型
    14     CopyLayers(solver.get(), FLAGS_weights);
    15   }
    16 
    17   if (gpus.size() > 1) {
    18     caffe::P2PSync<float> sync(solver, NULL, solver->param());
    19     sync.Run(gpus);
    20   } else {
    21     LOG(INFO) << "Starting Optimization";
    22     solver->Solve();//开始训练网络
    23   }
    24   LOG(INFO) << "Optimization Done.";
    25   return 0;
    26 }

    Solver()

      看CreateSolver函数是如何构建solver和net的,CreateSolver定义在solver_factory.hpp中,首先需要知道的是solver是一个基类,继承自它的类有SGD等,下面的实现就可以根据param的type构造一个指向特定solver的指针,比如SGD。

    1 static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    2     const string& type = param.type();
    3     CreatorRegistry& registry = Registry();
    4     CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
    5         << " (known types: " << SolverTypeListString() << ")";
    6     return registry[type](param);
    7   }

      关键之处在于上面代码最后一行语句,它的作用是根据配置文件创建对应的Solver对象(默认为SGDSolver子类对象)。此处工厂模式和一个关键的宏REGISTER_SOLVER_CLASS(SGD)发挥了重要作用。

    1 #define REGISTER_SOLVER_CLASS(type)                                              
    2   template <typename Dtype>                                                      
    3   Solver<Dtype>* Creator_##type##Solver(                                         
    4       const SolverParameter& param)                                              
    5   {                                                                              
    6     return new type##Solver<Dtype>(param);                                       
    7   }                                                                              
    8   REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)    
    9 }   

      这样一个SGDSolver对象就调用其构造函数被构造出来了。

    1 explicit SGDSolver(const SolverParameter& param)
    2       : Solver<Dtype>(param) { PreSolve(); }

      同时,Solver这个基类也被构造出来了,在solver.hpp里

    1 explicit Solver(const SolverParameter& param,
    2       const Solver* root_solver = NULL);

      Solver构造函数又会调用Init进行训练网络和测试网络的初始化,Init函数没有被声明为虚函数,不能被覆写,也就是说所有的solver都调用这个函数进行初始化。

     1 template <typename Dtype>
     2 void Solver<Dtype>::Init(const SolverParameter& param) {
     3   ......
     4   // Scaffolding code
     5   InitTrainNet();//初始化训练网络
     6   if (Caffe::root_solver()) {
     7     InitTestNets();//初始化测试网络
     8     LOG(INFO) << "Solver scaffolding done.";
     9   }
    10   iter_ = 0;//迭代次数设为0
    11   current_step_ = 0;
    12 }

    InitTrainNet()

      接下来看训练网络初始化函数InitTrainNet,具体的内容见Net的网络层的构建(源码分析)

      caffe是如何来solve的:在成员函数Solve()内部,

     1 template <typename Dtype>
     2 void Solver<Dtype>::Solve(const char* resume_file) {
     3   ......
     4   // For a network that is trained by the solver, no bottom or top vecs
     5   // should be given, and we will just provide dummy vecs.
     6   int start_iter = iter_;
     7   //开始迭代
     8   Step(param_.max_iter() - iter_);
     9   ......
    10 }

    Step()

      下面我们看一下Solver::Step()函数内部实现情况,具体的一次迭代过程。见Caffe参数交换源码分析

      这就是整个网络的训练过程。 

  • 相关阅读:
    Visual Studio的框选代码区块功能
    序列化和反序列化
    使用C#采集Shibor数据到Excel
    LiveCharts文档-4基本绘图-3其他
    LiveCharts文档-4基本绘图-2基本柱形图
    LiveCharts文档-4基本绘图-1基本线条图
    LiveCharts文档-3开始-8自定义工具提示
    LiveCharts文档-3开始-7标签
    LiveCharts文档-3开始-6轴Axes
    LeetCode题解汇总(包括剑指Offer和程序员面试金典,持续更新)
  • 原文地址:https://www.cnblogs.com/liuzhongfeng/p/8126194.html
Copyright © 2011-2022 走看看