本节介绍Operator定义注册机制
简介
Op的核心方法是Run,Run方法需要两方面的资源:数据资源Scope和计算资源Place。框架内部有一个全局的DeviceContextPool,用来记录Place和 DeviceContext之间的对应的关系,即每个Place有且仅有一个DeviceContext与之对应, DeviceContext中存放了当前设备的计算资源,比如对于GPU,这些资源包括cudnn_handle、 cublas_handle、stream等,所有的计算(数据拷贝和CUDA Kernel)都必须在绑定到 DeviceContext中的stream上。
Fluid框架的设计理念是可以在多种设备及第三方库上运行,有些Op的实现可能会因为设 备或者第三方库的不同而不同,为此Fluid引入了OpKernel的方式,即一个Op可以有多个 OpKernel,这类Op继承自OperatorWithKernel,这类Op的代表是conv,conv_op的OpKerne 有:GemmConvKernel、CUDNNConvOpKernel、ConvMKLDNNOpKernel,且每个 OpKernel都有double和float两种数据类型。不需要OpKernel的代表有WhileOp等。
下边以conv2d为例来介绍注册及使用流程
一、注册
1、新建OP注册
OP注册时通过宏来进行,一般形式如下:
REGISTER_OPERATOR(op_type,
OperatorBase,
op_maker_and_checker_maker,
op_grad_opmaker,
op_infer_var_shape,
op_infer_var_type)
例如:
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, ops::ConvOpInferVarType, ops::Conv2DGradMaker<paddle::framework::OpDesc>, ops::Conv2DGradMaker<paddle::imperative::OpBase>);
对于一般的OP,前三个参数是必须的。实际使用时,不必按照特定顺序填写。注册器会根据模板特化逐个注册。
- op_type:op的名字
- OpeartorBase:该OP的对象
- op_maker_and_checker_maker是op的maker和op中attr的checker。
- op_grad_opmaker:创建当前Op对应的反向OP。如果Op有反向,必须要有op_grad_opmaker,因为backward会从正向的Op中获取反向Op的Maker。默认的op_grad_opmaker:DefaultGradOpMaker(grad_op_desc_maker.h)。它会将前向Op的输入和输出都作为反向Op的输入,将前向Op的输入的剃度作为反向Op的输出,并将前向Op的属性拷贝过来。使用DefaultGradOpMaker带来的问题是会将前向Op的所有输入输出都作为反向OP的输入,即使这个输入是非必须的,这会导致无法作内存优化,排除无用变量。
- 框架没有默认的op_infer_var_shape提供。因此在保证shape不会出问题的情况下,OP可以不对optput的shape作推断,即可以不提供op_infer_var_shape,但是如果shape出问题会导致后续OP的shape都出错。如果Op是继承自OperatorWithKernel,可以通过覆盖OperatorWithKernel中的 InferShape方法的方式不去注册op_infer_var_shape,这也是大多数带Kernel的Op的 做法。
- 建议每个OP都注册op_infer_var_type。在InferVarType中根据输入的Var的 type和dtype推断输出Var的type和dtype
对于继承自OperatorWithKernel的Op,需要分别注册OpKernle
2、REGISTER_OPERATOR详解
-
/* The variadic arguments should be class types derived from one of the following classes: OpProtoAndCheckerMaker GradOpDescMakerBase VarTypeInference InferShapeBase */ #define REGISTER_OPERATOR(op_type, op_class, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, \ "REGISTER_OPERATOR must be called in global namespace"); \ static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \ __op_registrar_##op_type##__(#op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ }
paddle/fluid/framework/op_registry.h 该宏定义用于注册新建OP。第一步会检查op_type是否已存在。第二步执行具体的注册逻辑。第三步
- STATIC_ASSERT_GLOBAL_NAMESPACE(fluid/extension/include/ext_op_meta_info.h)检查op_type是否已经存在,声明一个特定名称的结构体,然后比较全局作用域中是否存在同名类型,以此来判断名称是否存在
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg)
- OpeartorRegistrar(framework/op_registry.h)这里是实际进行注册的逻辑。主体逻辑为将注册的各项函数添加到OpInfo中,然后将info存放到单例的OpInfoMap中。下面着重梳理OperatorRegistrarRecursive的逻辑。
template <typename... ARGS> struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const char* op_type) { PADDLE_ENFORCE_EQ( OpInfoMap::Instance().Has(op_type), false, platform::errors::AlreadyExists( "Operator '%s' is registered more than once.", op_type)); static_assert(sizeof...(ARGS) != 0, "OperatorRegistrar should be invoked at least by OpClass"); OpInfo info; details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); OpInfoMap::Instance().Insert(op_type, info); } };
- OperatorRegistrarRecursive(frame/details/op_registry.h)。模板类,有两个特化,一个at_end=false,一个at_end=true。其处理的基本逻辑为递归方式。初始调用时,I=0,at_end=false,并传入ARGS。而后使用std::typle_element取出ARGS[I],然后调用对应特化的OpInfoFiller将ARGS[I]放入info中。接着I+1,如果I+1!=ARGS.size(),则会继续调用I=1,at_end=false,ARGS 的自身。这样依次遍历ARGS,直到处理完所有的注册函数后,会进入at_end=true结束递归。
template <size_t I, bool at_end, typename... ARGS> class OperatorRegistrarRecursive; template <size_t I, typename... ARGS> class OperatorRegistrarRecursive<I, false, ARGS...> { public: using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type; OperatorRegistrarRecursive(const char* op_type, OpInfo* info) { OpInfoFiller<T> fill; fill(op_type, info); constexpr auto size = sizeof...(ARGS); OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type, info); (void)(reg); } }; template <size_t I, typename... ARGS> class OperatorRegistrarRecursive<I, true, ARGS...> { public: OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {} };
-
例如对于ops::ConvOp这个OperatorBase类型,会调用对应OpInfoFiller<T,KOperator>
template <typename T> struct OpInfoFiller<T, kOperator> { void operator()(const char* op_type, OpInfo* info) const { PADDLE_ENFORCE_EQ(info->creator_, nullptr, platform::errors::AlreadyExists( "OpCreator of %s has been registered", op_type)); // 将OperatorBase构造函数放入info->creator这个函数指针中 info->creator_ = [](const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) { return new T(type, inputs, outputs, attrs); }; // 如果T为OperatorWithKernel类型,会有更多操作 if (std::is_base_of<OperatorWithKernel, T>::value) { PADDLE_ENFORCE_EQ( info->infer_shape_, nullptr, platform::errors::AlreadyExists( "Duplicate InferShapeFN of %s has been registered", op_type)); OperatorWithKernel* op = dynamic_cast<OperatorWithKernel*>(info->creator_( std::string{}, VariableNameMap{}, VariableNameMap{}, AttributeMap{})); PADDLE_ENFORCE_NOT_NULL(op, platform::errors::InvalidArgument( "%s should have kernels", op_type)); info->infer_shape_ = [op](InferShapeContext* ctx) { op->InferShape(ctx); }; } } };
- 这里OpInfoFiller进行类型推导的逻辑是:
- 首先有一个枚举类与ARGS类型对应
enum OpInfoFillType { kOperator = 0, kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, kShapeInference = 4, kInplaceOpInference = 5, kNoNeedBufferVarsInference = 6, kGradOpBaseMaker = 7, kUnknown = -1 };
- OpInfoFiller会调用OpInfoFillTypeID对T进行类型推导
template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()> struct OpInfoFiller;
- 推导的方式也是类似用模板特化的方式,ID函数会获得特化类的kType,这里调用的时候也是递归的方式
template <typename T> struct OpInfoFillTypeID { static constexpr OpInfoFillType ID() { return internal::OpInfoFillTypeGetter<T>::kType; } };
template <typename T> using OpInfoFillTypeGetter = OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber, kOpRegistryClassNumber == 0, IsMatchedBaseType<T, 0>()>;
这里kOpRegistryClassNumber是枚举列表的长度
using OpRegistryClasses = std::tuple< // NOLINT TypePair<OperatorBase, kOperator>, // NOLINT TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT TypePair<VarTypeInference, kVarTypeInference>, // NOLINT TypePair<InferShapeBase, kShapeInference>, // NOLINT TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference> // NOLINT >; static constexpr int kOpRegistryClassNumber = std::tuple_size<OpRegistryClasses>::value;
IsMatchedBaseType用于判断T与kPos的类型是否一致。
template <typename T, int kPos> static inline constexpr bool IsMatchedBaseType() { return IsMatchedBaseTypeImpl< T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue; }
这里如果kPos超出了列表长度或者设置了非法值,会直接返回false
template <typename T, int kPos> struct IsMatchedBaseTypeImpl<T, kPos, false> { static constexpr bool kValue = false; };
否则,会比较T与OpRegistryClasses的类型是否一致
// TypePair定义,T与枚举OpInfoFillerType一一对应 template <typename T, OpInfoFillType kType> struct TypePair { using Type = T; static constexpr OpInfoFillType kFillType = kType; }; // 比较传入的T与OpRegistryClasses中T的类型 template <typename T, int kPos, bool kIsBounded /* = true*/> struct IsMatchedBaseTypeImpl { using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type; static constexpr bool kValue = std::is_base_of<typename PairType::Type, T>::value; };
- OpInfoFillTypeGetterImpl是一个递归调用,如果IsMatchedBaseType返回false,会将kStart+1继续比较,直到匹配上后返回kType也就是枚举中的序号
// 没匹配上就kStart+1 template <typename T, int kStart, int kEnd> struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> { static constexpr OpInfoFillType kType = OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd, IsMatchedBaseType<T, kStart + 1>()>::kType; }; // 匹配上直接返回kType,也就是OpRegistryClasses中T对应的枚举值 template <typename T, int kStart, int kEnd> struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> { using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type; static constexpr OpInfoFillType kType = PairType::kFillType; };
- 首先有一个枚举类与ARGS类型对应
- 宏中的第三步是定义一个函数,并且调用一下第二步创建的静态OperatorRegistrar变量。这是因为在framework打包时,这些注册时声明的变量并没有被调用过,会被编译器移除。因此创建一个空函数调用一下,保证framework编译打包时能保存该变量
// 创建一个函数调用下创建的变量 int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } // Touch实际是空函数 class Registrar { public: // In our design, various kinds of classes, e.g., operators and kernels, // have their corresponding registry and registrar. The action of // registration is in the constructor of a global registrar variable, which // are not used in the code that calls package framework, and would // be removed from the generated binary file by the linker. To avoid such // removal, we add Touch to all registrar classes and make USE_OP macros to // call this method. So, as long as the callee code calls USE_OP, the global // registrar variable won't be removed by the linker. void Touch() {} };
二、创建OP
在《PaddlePaddle Inference源码分析(二)》 中介绍到,PrepareExecutor会创建OP对象并放入Executor中。我们从准备阶段开始看起。
- NaiveExecutor::Prepare,会调用CreateOps
void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, int block_id, bool with_feed_fetch_ops) { if (!scope) { scope_ = new framework::Scope; } else { scope_ = scope; } VLOG(3) << "NaiveExecutor init with scope " << scope; CreateOps(program_desc, block_id, with_feed_fetch_ops); }
- CreateOps会从ProgramDesc中保存的模型文件信息中获取到所有OP的信息OpDesc(framework/block_desc.cc),然后使用OpDesc创建OP对象
void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id, bool with_feed_fetch_ops) { for (const auto &op_desc : desc.Block(block_id).AllOps()) { if (!with_feed_fetch_ops && (op_desc->Type() == "feed" || op_desc->Type() == "fetch")) { LOG(INFO) << "--- skip [" << op_desc->Input("X")[0] << "], " << op_desc->Type() << " -> " << op_desc->Output("Out")[0]; continue; } ops_.emplace_back(OpRegistry::CreateOp(*op_desc)); } }
- 这里OpRegistry::CreateOp均为静态函数。实际逻辑为从OpInfoMap(单例,全局共享)中根据op_type取出对应的OpInfo,再调用OpInfo中的Createor(2.5节中创建的lambda函数,放入了构造函数)创建OperatorBase对象。
// 接口函数,取出OpDesc信息后进行实际调用 std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(), op_desc.GetAttrMap()); } //实际调用 std::unique_ptr<OperatorBase> OpRegistry::CreateOp( const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) { auto& info = OpInfoMap::Instance().Get(type); if (attr_check && info.Checker() != nullptr) { info.Checker()->Check(&attrs); } auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr<OperatorBase>(op); }
三、调用OP
待完成predictor调用部分后补充