zoukankan      html  css  js  c++  java
  • PaddlePaddle inference 源码分析(一)

    本文针对代码版本为Paddle/2.2,主要针对预测流程的梳理。

    1、paddle inference的使用较为简单,其基本代码如下:

    // 创建predictor
    std::shared_ptr<Predictor> InitPredictor() {
      Config config;
      if (FLAGS_model_dir != "") {
        config.SetModel(FLAGS_model_dir);
      }
      config.SetModel(FLAGS_model_file, FLAGS_params_file);
      if (FLAGS_use_gpu) {
        config.EnableUseGpu(100, 0);
      } else {
        config.EnableMKLDNN();
      }
    
      // Open the memory optim.
      config.EnableMemoryOptim();
      return CreatePredictor(config);
    }
    
    // 执行预测
    void run(Predictor *predictor, const std::vector<float> &input,
             const std::vector<int> &input_shape, std::vector<float> *out_data) {
      int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
                                      std::multiplies<int>());
    
      auto input_names = predictor->GetInputNames();
      auto output_names = predictor->GetOutputNames();
      auto input_t = predictor->GetInputHandle(input_names[0]);
      input_t->Reshape(input_shape);
      input_t->CopyFromCpu(input.data());
    
      for (size_t i = 0; i < FLAGS_warmup; ++i)
        CHECK(predictor->Run());
    
      auto st = time();
      for (size_t i = 0; i < FLAGS_repeats; ++i) {
        CHECK(predictor->Run());
        auto output_t = predictor->GetOutputHandle(output_names[0]);
        std::vector<int> output_shape = output_t->shape();
        int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
                                      std::multiplies<int>());
        out_data->resize(out_num);
        output_t->CopyToCpu(out_data->data());
      }
      LOG(INFO) << "run avg time is " << time_diff(st, time()) / FLAGS_repeats
                << " ms";
    }

    2、代码库地址:https://github.com/PaddlePaddle/Paddle

    目录结构如下:

    --cmake #cmake编译脚本以及编译链接的第三方库等
    --doc
    --paddle #c++代码
        -fluid
            -distributed #分布式相关代码,主要为训练使用,包括模型内all_reduce进行跨卡通信、跨机通信等
            -extension #
            -framework #基础组件代码
            -imperative #分布式通信相关代码,包括nccl、all_reduce、bkcl等
            -inference #预测相关代码以及api定义
            -memory
            -operators #算子
            -platform #平台相关代码
            -pybind #pybind接口定义
            -string
        -scripts
        -testing
        -utils
    --patches
    --python #python部分代码
    --r
    --tools
    --CMakeLists.txt #编译脚本,包括大部分编译参数、三方库依赖等逻辑

    3、编译产出

      产出目录如下:

      

    build
        -python #whl安装包
        -paddle_install_dir #产出的所有头文件及库
        -paddle_inference_install_dir #预测c++依赖库
        -paddle_inference_c_install_dir #预测c依赖库
    联系方式:emhhbmdfbGlhbmcxOTkxQDEyNi5jb20=
  • 相关阅读:
    简单工厂笔记
    P3369 【模板】普通平衡树 Treap树堆学习笔记
    tp5阿里云短信验证码
    centos 安装php
    tp6.0.2开启多应用模式
    linux navicat最新版过期
    git commit之后 取消commit
    服务器重置之后ssh root@报错
    git pull push 每次都需要输入账号和密码
    跨域问题 php
  • 原文地址:https://www.cnblogs.com/zl1991/p/15688005.html
Copyright © 2011-2022 走看看