zoukankan      html  css  js  c++  java
  • PaddlePaddle inference 源码分析(三)

    本节介绍Operator定义注册机制

    简介

    所有的Op都继承自OperatorBase(operator.h),且所有的Op都是无状态的,每个Op包含的 成员变量只有四个:string& type、const VariableNameMap& inputs、const VariableNameMap& outputs、const AttributeMap& attribute。在创建时传入。

    Op的核心方法是Run,Run方法需要两方面的资源:数据资源Scope和计算资源Place。框架内部有一个全局的DeviceContextPool,用来记录Place和 DeviceContext之间的对应的关系,即每个Place有且仅有一个DeviceContext与之对应, DeviceContext中存放了当前设备的计算资源,比如对于GPU,这些资源包括cudnn_handle、 cublas_handle、stream等,​所有的计算(数据拷贝和CUDA Kernel)都必须在绑定到 DeviceContext中的stream上。

    Fluid框架的设计理念是可以在多种设备及第三方库上运行,有些Op的实现可能会因为设 备或者第三方库的不同而不同,为此Fluid引入了OpKernel的方式,即一个Op可以有多个 OpKernel,这类Op继承自OperatorWithKernel,这类Op的代表是conv,conv_op的OpKerne 有:GemmConvKernel、CUDNNConvOpKernel、ConvMKLDNNOpKernel,且每个 OpKernel都有double和float两种数据类型。不需要OpKernel的代表有WhileOp等。

    下边以conv2d为例来介绍注册及使用流程

    一、注册

    1、新建OP注册

    OP注册时通过宏来进行,一般形式如下:

    REGISTER_OPERATOR(op_type,
                      OperatorBase,
                      op_maker_and_checker_maker,
                      op_grad_opmaker, 
                      op_infer_var_shape,
                      op_infer_var_type​)            

    例如:

    REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
                      ops::ConvOpInferVarType,
                      ops::Conv2DGradMaker<paddle::framework::OpDesc>,
                     ops::Conv2DGradMaker<paddle::imperative::OpBase>);

    对于一般的OP,前三个参数是必须的。实际使用时,不必按照特定顺序填写。注册器会根据模板特化逐个注册。

    • op_type:op的名字
    • OpeartorBase:该OP的对象
    • op_maker_and_checker_maker是op的maker和op中attr的checker。
    • op_grad_opmaker:创建当前Op对应的反向OP。如果Op有反向,必须要有op_grad_opmaker,因为backward会从正向的Op中获取反向Op的Maker。默认的op_grad_opmaker:DefaultGradOpMaker(grad_op_desc_maker.h)。它会将前向Op的输入和输出都作为反向Op的输入,将前向Op的输入的剃度作为反向Op的输出,并将前向Op的属性拷贝过来。使用DefaultGradOpMaker带来的问题是会将前向Op的所有输入输出都作为反向OP的输入,即使这个输入是非必须的,这会导致无法作内存优化,排除无用变量。
    • 框架没有默认的op_infer_var_shape提供。因此在保证shape不会出问题的情况下,OP可以不对optput的shape作推断,即可以不提供op_infer_var_shape,但是如果shape出问题会导致后续OP的shape都出错。如果Op是继承自OperatorWithKernel,可以通过覆盖OperatorWithKernel中的 InferShape方法的方式不去注册op_infer_var_shape,这也是大多数带Kernel的Op的 做法。
    • 建议每个OP都注册op_infer_var_type。在InferVarType中根据输入的Var的 type和dtype推断输出Var的type和dtype

    对于继承自OperatorWithKernel的Op,需要分别注册OpKernle

    REGISTER_OP_CPU_KERNEL
    REGISTER_OP_CUDA_KERNEL

    2、REGISTER_OPERATOR详解

    1. /*
        The variadic arguments should be class types derived from one of the
        following classes:
          OpProtoAndCheckerMaker
          GradOpDescMakerBase
          VarTypeInference
          InferShapeBase
      */
      #define REGISTER_OPERATOR(op_type, op_class, ...)                        \
        STATIC_ASSERT_GLOBAL_NAMESPACE(                                        \
            __reg_op__##op_type,                                               \
            "REGISTER_OPERATOR must be called in global namespace");           \
        static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
            __op_registrar_##op_type##__(#op_type);                            \
        int TouchOpRegistrar_##op_type() {                                     \
          __op_registrar_##op_type##__.Touch();                                \
          return 0;                                                            \
        }

      paddle/fluid/framework/op_registry.h 该宏定义用于注册新建OP。第一步会检查op_type是否已存在。第二步执行具体的注册逻辑。第三步

    2. STATIC_ASSERT_GLOBAL_NAMESPACE(fluid/extension/include/ext_op_meta_info.h)检查op_type是否已经存在,声明一个特定名称的结构体,然后比较全局作用域中是否存在同名类型,以此来判断名称是否存在
      #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)                        \
        struct __test_global_namespace_##uniq_name##__ {};                          \
        static_assert(std::is_same<::__test_global_namespace_##uniq_name##__,       \
                                   __test_global_namespace_##uniq_name##__>::value, \
                      msg)
    3. OpeartorRegistrar(framework/op_registry.h)这里是实际进行注册的逻辑。主体逻辑为将注册的各项函数添加到OpInfo中,然后将info存放到单例的OpInfoMap中。下面着重梳理OperatorRegistrarRecursive的逻辑。
      template <typename... ARGS>
      struct OperatorRegistrar : public Registrar {
        explicit OperatorRegistrar(const char* op_type) {
          PADDLE_ENFORCE_EQ(
              OpInfoMap::Instance().Has(op_type), false,
              platform::errors::AlreadyExists(
                  "Operator '%s' is registered more than once.", op_type));
          static_assert(sizeof...(ARGS) != 0,
                        "OperatorRegistrar should be invoked at least by OpClass");
          OpInfo info;
          details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
          OpInfoMap::Instance().Insert(op_type, info);
        }
      };
    4. OperatorRegistrarRecursive(frame/details/op_registry.h)。模板类,有两个特化,一个at_end=false,一个at_end=true。其处理的基本逻辑为递归方式。初始调用时,I=0,at_end=false,并传入ARGS。而后使用std::typle_element取出ARGS[I],然后调用对应特化的OpInfoFiller将ARGS[I]放入info中。接着I+1,如果I+1!=ARGS.size(),则会继续调用I=1,at_end=false,ARGS 的自身。这样依次遍历ARGS,直到处理完所有的注册函数后,会进入at_end=true结束递归。
      template <size_t I, bool at_end, typename... ARGS>
      class OperatorRegistrarRecursive;
      
      template <size_t I, typename... ARGS>
      class OperatorRegistrarRecursive<I, false, ARGS...> {
       public:
        using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
        OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
          OpInfoFiller<T> fill;
          fill(op_type, info);
          constexpr auto size = sizeof...(ARGS);
          OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
                                                                        info);
          (void)(reg);
        }
      };
      
      template <size_t I, typename... ARGS>
      class OperatorRegistrarRecursive<I, true, ARGS...> {
       public:
        OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
      };
    5. 例如对于ops::ConvOp这个OperatorBase类型,会调用对应OpInfoFiller<T,KOperator>

      template <typename T>
      struct OpInfoFiller<T, kOperator> {
        void operator()(const char* op_type, OpInfo* info) const {
          PADDLE_ENFORCE_EQ(info->creator_, nullptr,
                            platform::errors::AlreadyExists(
                                "OpCreator of %s has been registered", op_type));
          // 将OperatorBase构造函数放入info->creator这个函数指针中
          info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
                              const VariableNameMap& outputs,
                              const AttributeMap& attrs) {
            return new T(type, inputs, outputs, attrs);
          };
      
          // 如果T为OperatorWithKernel类型,会有更多操作
          if (std::is_base_of<OperatorWithKernel, T>::value) {
            PADDLE_ENFORCE_EQ(
                info->infer_shape_, nullptr,
                platform::errors::AlreadyExists(
                    "Duplicate InferShapeFN of %s has been registered", op_type));
      
            OperatorWithKernel* op = dynamic_cast<OperatorWithKernel*>(info->creator_(
                std::string{}, VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
            PADDLE_ENFORCE_NOT_NULL(op, platform::errors::InvalidArgument(
                                            "%s should have kernels", op_type));
            info->infer_shape_ = [op](InferShapeContext* ctx) {
              op->InferShape(ctx);
            };
          }
        }
      };
    6. 这里OpInfoFiller进行类型推导的逻辑是:
      1. 首先有一个枚举类与ARGS类型对应
        enum OpInfoFillType {
          kOperator = 0,
          kOpProtoAndCheckerMaker = 1,
          kGradOpDescMaker = 2,
          kVarTypeInference = 3,
          kShapeInference = 4,
          kInplaceOpInference = 5,
          kNoNeedBufferVarsInference = 6,
          kGradOpBaseMaker = 7,
          kUnknown = -1
        };
      2. OpInfoFiller会调用OpInfoFillTypeID对T进行类型推导
        template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
        struct OpInfoFiller;
      3. 推导的方式也是类似用模板特化的方式,ID函数会获得特化类的kType,这里调用的时候也是递归的方式
        template <typename T>
        struct OpInfoFillTypeID {
          static constexpr OpInfoFillType ID() {
            return internal::OpInfoFillTypeGetter<T>::kType;
          }
        };
        template <typename T>
        using OpInfoFillTypeGetter =
            OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber,
                                     kOpRegistryClassNumber == 0,
                                     IsMatchedBaseType<T, 0>()>;

        这里kOpRegistryClassNumber是枚举列表的长度

        using OpRegistryClasses = std::tuple<                                // NOLINT
            TypePair<OperatorBase, kOperator>,                               // NOLINT
            TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>,       // NOLINT
            TypePair<GradOpDescMakerBase, kGradOpDescMaker>,                 // NOLINT
            TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>,     // NOLINT
            TypePair<VarTypeInference, kVarTypeInference>,                   // NOLINT
            TypePair<InferShapeBase, kShapeInference>,                       // NOLINT
            TypePair<InplaceOpInference, kInplaceOpInference>,               // NOLINT
            TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference>  // NOLINT
            >;
        
        static constexpr int kOpRegistryClassNumber =
            std::tuple_size<OpRegistryClasses>::value;

        IsMatchedBaseType用于判断T与kPos的类型是否一致。

        template <typename T, int kPos>
        static inline constexpr bool IsMatchedBaseType() {
          return IsMatchedBaseTypeImpl<
              T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue;
        }

        这里如果kPos超出了列表长度或者设置了非法值,会直接返回false

        template <typename T, int kPos>
        struct IsMatchedBaseTypeImpl<T, kPos, false> {
          static constexpr bool kValue = false;
        };

        否则,会比较T与OpRegistryClasses的类型是否一致

        // TypePair定义,T与枚举OpInfoFillerType一一对应
        template <typename T, OpInfoFillType kType>
        struct TypePair {
          using Type = T;
          static constexpr OpInfoFillType kFillType = kType;
        };
        // 比较传入的T与OpRegistryClasses中T的类型
        template <typename T, int kPos, bool kIsBounded /* = true*/>
        struct IsMatchedBaseTypeImpl {
          using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type;
          static constexpr bool kValue =
              std::is_base_of<typename PairType::Type, T>::value;
        };
      4. OpInfoFillTypeGetterImpl是一个递归调用,如果IsMatchedBaseType返回false,会将kStart+1继续比较,直到匹配上后返回kType也就是枚举中的序号
        // 没匹配上就kStart+1
        template <typename T, int kStart, int kEnd>
        struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> {
          static constexpr OpInfoFillType kType =
              OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd,
                                       IsMatchedBaseType<T, kStart + 1>()>::kType;
        };
        // 匹配上直接返回kType,也就是OpRegistryClasses中T对应的枚举值
        template <typename T, int kStart, int kEnd>
        struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> {
          using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type;
          static constexpr OpInfoFillType kType = PairType::kFillType;
        };
    7. 宏中的第三步是定义一个函数,并且调用一下第二步创建的静态OperatorRegistrar变量。这是因为在framework打包时,这些注册时声明的变量并没有被调用过,会被编译器移除。因此创建一个空函数调用一下,保证framework编译打包时能保存该变量
      // 创建一个函数调用下创建的变量
      int TouchOpRegistrar_##op_type() {                                     \
          __op_registrar_##op_type##__.Touch();                                \
          return 0;                                                            \
        }
      
      // Touch实际是空函数
      class Registrar {
       public:
        // In our design, various kinds of classes, e.g., operators and kernels,
        // have their corresponding registry and registrar. The action of
        // registration is in the constructor of a global registrar variable, which
        // are not used in the code that calls package framework, and would
        // be removed from the generated binary file by the linker. To avoid such
        // removal, we add Touch to all registrar classes and make USE_OP macros to
        // call this method. So, as long as the callee code calls USE_OP, the global
        // registrar variable won't be removed by the linker.
        void Touch() {}
      };

    二、创建OP

    《PaddlePaddle Inference源码分析(二)》 中介绍到,PrepareExecutor会创建OP对象并放入Executor中。我们从准备阶段开始看起。

    1. NaiveExecutor::Prepare,会调用CreateOps
      void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
                                  int block_id, bool with_feed_fetch_ops) {
        if (!scope) {
          scope_ = new framework::Scope;
        } else {
          scope_ = scope;
        }
      
        VLOG(3) << "NaiveExecutor init with scope " << scope;
        CreateOps(program_desc, block_id, with_feed_fetch_ops);
      }
    2. CreateOps会从ProgramDesc中保存的模型文件信息中获取到所有OP的信息OpDesc(framework/block_desc.cc),然后使用OpDesc创建OP对象
      void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id,
                                    bool with_feed_fetch_ops) {
        for (const auto &op_desc : desc.Block(block_id).AllOps()) {
          if (!with_feed_fetch_ops &&
              (op_desc->Type() == "feed" || op_desc->Type() == "fetch")) {
            LOG(INFO) << "---  skip [" << op_desc->Input("X")[0] << "], "
                      << op_desc->Type() << " -> " << op_desc->Output("Out")[0];
            continue;
          }
          ops_.emplace_back(OpRegistry::CreateOp(*op_desc));
        }
      }
    3. 这里OpRegistry::CreateOp均为静态函数。实际逻辑为从OpInfoMap(单例,全局共享)中根据op_type取出对应的OpInfo,再调用OpInfo中的Createor(2.5节中创建的lambda函数,放入了构造函数)创建OperatorBase对象。
      // 接口函数,取出OpDesc信息后进行实际调用
      std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
        return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(),
                        op_desc.GetAttrMap());
      }
      
      //实际调用
      std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
          const std::string& type, const VariableNameMap& inputs,
          const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) {
        auto& info = OpInfoMap::Instance().Get(type);
        if (attr_check && info.Checker() != nullptr) {
          info.Checker()->Check(&attrs);
        }
        auto op = info.Creator()(type, inputs, outputs, attrs);
        return std::unique_ptr<OperatorBase>(op);
      }

    三、调用OP

    待完成predictor调用部分后补充

    联系方式:emhhbmdfbGlhbmcxOTkxQDEyNi5jb20=
  • 相关阅读:
    CentOS6.7 mysql5.6.33修改数据文件位置
    win8 win10 安装msi 提示2502、2503的错误代码
    2016年国内开源镜像站点汇总
    eclipse创建本地maven
    maven添加sqlserver的jdbc驱动包
    CentOS6.5下RPM方式安装mysql5.6.33
    linux中~和/的区别
    Linux命令学习(22):ss命令
    Linux命令学习(21):netstat命令
    Linux命令学习(20):traceroute命令
  • 原文地址:https://www.cnblogs.com/zl1991/p/15712940.html
Copyright © 2011-2022 走看看