一、计算图简介
在pytorch的官网上,可以看到一个简单的计算图示意图, 如下。
import torch
from torch.autograd import Variable x = Variable(torch.randn(1, 10)) prev_h = Variable(torch.randn(1, 20)) W_h = Variable(torch.randn(20, 20)) W_x = Variable(torch.randn(20, 10)) i2h = torch.mm(W_x, x.t()) h2h = torch.mm(W_h, prev_h.t()) next_h = i2h + h2h next_h = next_h.tanh()
这个图里有两种节点:Variable节点和Function节点,Variable记录运算数据,Function记录运算操作。其中Variable节点又可以分为叶节点和非叶节点两类。叶节点由用户直接创建产生,而非叶节点则由Variable节点之间的运算操作产生,在图的代码中,x、prev_h、W_h、W_x属于叶节点,i2h、h2h、next_h属于非叶节点。
在这个图上,节点之间的关系是很明确的:Variable非叶节点指向产生它的Function,因为产生某个Variable的Function只可能有一个,因此一个Variable只指向一个Function。Function的指向则是可以一对多的,因为一个运算函数往往可以接受大量的参数。Function指向两种节点,当Function接受一个叶节点的Variable输入时,Function需指向此Variable,当Function接受一个非叶节点Variable输入时,Function需指向此Variable所指向的那个Function。
那这个计算图是怎么建立的,具体实现又是怎么样的呢?我们通过从顶向下、从底至上的两个视角分别切入,来研究计算图的形成过程。
二、框架性视角
2.1 Variable与Function
首先从顶至下,大略框架性地了解一下pytorch自动求导模块几个类,其中最重要的便是Variable类和Function类。此处注意的是,因为在C++代码中与python代码中均有名为Variable、Function的类,为示区别,如在不同语言中的类名有重复,在C++中的类称为如Variable(C++),在python中的类称为如Variable(py),当不刻意去分辨两者的区别时则不特意加后缀括号,以此类推。
这里提到的类是Varaible(C++)和Function(C++),分别定义在torch/csrc/autograd/variable.h、torch/csrc/autograd/function.h,此处不复制粘贴代码了。Variable(C++)类为自动求导过程中的核心数据类,与gif中的Variable节点可对应;Fucntion(C++)为自动求导过程中的函数类,与gif中的Fcuntion节点可对应。
首先解释Variable(C++)类的几个成员变量,
std::unique_ptr<thpp::Tensor> data:这是具体的底层数据,Variable(C++)为了完成自动求导的任务,在持有这个数据对象的条件下进行了一些包装。
bool requires_grad,bool is_volatile:是两个求导选项,在Variable(python)的构造函数中,可以传中两个选项参数。对于叶节点,如果有x.requires_grad==False或者x.volatile==True,则不需要对x进行求导。然而,这两个选项的不同在于,对于非叶节点,当产生非叶节点的所有Variable节点的requires_grad都是False时,它的requires_grad才是False,而只要有一个产生它的Variable节点的volatile是True,那它的volatile就是True。前者通常用于固定某个不需或暂时不需迭代参数的模块,如迁移学习、GAN训练等场景中的常见情形;后者则通常用于明确地确认此部分任务不需要执行反向求导的情形,如一个深度学习模型的测试过程。这两个选项可能会产生冲突,当冲突时,以volatile为准。其它阅读资料可参考http://pytorch.org/docs/master/notes/autograd.html
std::shared_ptr<Function> grad_fn:这就是计算图中,由Vairbale节点指向Function节点的连接。看命名的方式,是grad_fn,即gradient function。说明实际上,Variable节点与Function节点之间的连接,并不完全像gif中所示,连接前向计算时调用的函数对象,而是连接对应的梯度函数。这个连接是如何被建立的,将是文章后半段重点探索的内容。
然后考虑Function(C++)类的几个成员变量,
bool is_executable:这是Funciton节点中,类似于Varibale节点里requires_grad、is_volatile标识的一个成员变量。如果给某个Function节点输入的所有Vairbale节点都有requires_grad==False,或者少有一个的volatile==True,那这个Function的is_executable就会为False。结合Variable(C++)和Function(C++)求导标识的逻辑,可以在Function::flags方法里看见,逻辑正如此前描述一样。
function_list next_functions:在Function(C++)定义的同文件内,也有function_list的定义,using function_list = std::vector<std::pair<std::shared_ptr<Function>, int>>。可见,通过next_functions可以访问到一系列的Function(C++)对象,直观地推断,它就是gif图中,Function节点与Function节点连接的关键。这个连接是如何被建立的,将是文章后半段重点探索的内容。
这里我们留了一个小疑问,在之前对gif的非严谨分析里,有叶节点输入的Function节点,其会有一个指向叶节点对象的连接,但是在Function类里没有发现有对应的成员变量,那这里是如何实现的呢,可留待具体实现分析的时候查看。
2.2 从C++到Python,以Function类为例
这一段希望解释Function(C++)类与THPFunction等类的关系,不可避免地涉及部分python-C API的内容,但讲解得较粗略,关于这部分详细的内容可以去阅读 专门讲解python的C扩展的文章。
C++中,除了拥有底层逻辑的类以外,还有一层向python包装的中间类,比如,Function(C++)类就是通过一个THPFunction类、一个PyTypeObject类实例THPFunctionType,包装成一个python里可以访问的torch._C._FunctionBase类的。这几个类(或实例)之间的关系是什么呢?
首先看THPFunction类,它定义在torch/csrc/autograd/python_fucntion.h,THPFunction类持有一个PyFunction对象,而PyFunction类在同文件内定义并继承Function类。THPFunction类的其它成员变量中,有部分是PyObject*类的,这部分通常被设计于暴露给python层,还有一部分不是PyObject*类的,它们在python层中不可见,仅在C++层的代码逻辑中进行运作。
struct THPFunction { PyObject_HEAD PyObject *needs_input_grad; // Python tuple of tensors whose variables we should save. Set // by Python with 'save_for_backward'. If NULL, no tensors were // saved. PyObject *to_save; // Python pairs of distinct tensors which share storage. Set by // Python with 'mark_shared_storage'. If NULL, no tensors share // storage. PyObject *shared_pairs; // Python tuple of tensors which are not differentiable. Set by // Python with 'mark_non_differentiable'. If NULL, no tensors were // non-differentiable. PyObject *non_differentiable; // Python tuple of tensors which had inplace updates in the forward() // pass. Set by Python with 'mark_dirty'. If NULL, no tensors were // modified inplace. PyObject *dirty_tensors; std::vector<output_info_type> *output_info; std::vector<torch::autograd::SavedVariable> *saved_variables; // For each input, true if the input is a THPVariable std::vector<bool> *is_variable_input; char has_freed_buffers; // The C++ wrapper for this Python function. // See a comment in THPFunction_asFunction for details about this field. torch::autograd::PyFunction cdata; };
再看PyTypeObject类的实例THPFunctionType,它定义在torch/csrc/autograd/python_fucntion.cpp中,从注释上可以看出来,它定义了一个python类的诸多基本操作。比如,如果python层创建一个对象的时候,要知道需要分配多大的空间,就到PyTypeObject负责tp_basicsize的那个slot里面去找,在这个个例里,它的值是sizeof(THPFunction);又如,这个类封装到python层以后,有哪些方法呢,这个可以在tp_methods的这个域找到,在这个个例里,它的值是THPFunction_properties,THPFunction_properties这个变量也定义在同样的文件夹下,它负责把C++的函数映射成python的类方法。
PyTypeObject THPFunctionType = { PyVarObject_HEAD_INIT(NULL, 0) "torch._C._FunctionBase", /* tp_name */ sizeof(THPFunction), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)THPFunction_dealloc, /* tp_dealloc */ 0, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_reserved */ 0, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ NULL, /* tp_doc */ (traverseproc)THPFunction_traverse, /* tp_traverse */ (inquiry)THPFunction_clear, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ THPFunction_methods, /* tp_methods */ 0, /* tp_members */ THPFunction_properties, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ 0, /* tp_init */ 0, /* tp_alloc */ THPFunction_new /* tp_new */ };
_FunctionBase这个类是在哪里创建的呢,实际就在上述同样的文件里,THPFunctionType定义下方,可见函数THPFunction_initModule,其中有一句调用PyModule_AddObject,如此便注册了一个新的python类_FunctionBase
bool THPFunction_initModule(PyObject *module) { if (PyType_Ready(&THPFunctionType) < 0) return false; Py_INCREF(&THPFunctionType); PyModule_AddObject(module, "_FunctionBase", (PyObject *)&THPFunctionType); return true; }
我们看到PyModue_AddObject调用的参数里,只传了THPFunctionType这个对象,却没见到THPFunction相关的信息,那_FunctionBase是怎么样会与THPFunction扯上关系的呢?答案是通过THPFunctionType的各个slot下变量的具体定义。比如,tp_new这个slot下,值为THPFunction_new,THPFunction_new在同样的文件下定义,它是一个函数
PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { PyObject* obj = type->tp_alloc(type, 0); if (!obj) return NULL; // Python zero-initializes the object memory, so there's no need to initialize // most fields THPFunction* self = (THPFunction*)obj; new (&self->cdata) PyFunction(obj); self->cdata.num_inputs = -1; self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass); return obj; }
在这个函数体的第5行,它把新分配的内存强转为了THPFunction *并赋值出去,这就实现了THPFunctionType对象和THPFunction类的联系,在其它的属性操作上也是如此。
总结来说,Function(C++)类定义了一个类的底层逻辑;THPFunction类会持有一个Function(C++)类对象,并暴露了一些可以在python层访问的数据;PyTypeObject类的对象THPFunctionType里,各个slot定义了在python层里这个类的诸多基本操作(包括构造、析构、成员变量、方法等等等等);_FunctionBase是被包装好后的python类,在python中可以通过import torch._C._FunctionBase访问到它。以此类推地,Variable(C++)类、THPVariable类、PyTypeObject类的对象THPVariableType、_VariableBase也是类似的关系。
至于python层的Variable(py)类、Function(py)类,分别被定义在torch/autograd/variable.py、torch/autograd/function.py里,可以看到它们分别跟_VariableBase、_FunctionBase有一定的继承关系
class Variable(_C._VariableBase): ............................ ............................ ............................
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
............................
............................
............................
三、实现细节
我们了解pytorch中几个关键类的关系与结构之后,通过自底向下,从具体代码追踪调用的方法,探索整个计算图是怎么样形成的。
3.1 python层代码
首先,假设有下面的python代码,并得到结果
import torch from torch.autograd import Variable x = Variable(torch.Tensor([[1, 2, 3], [4, 5, 6]]), requires_grad=True) y = x.prod(dim=1, keepdim=True) print(y.grad_fn) print(y.grad_fn.next_functions)
<torch.autograd.function.ProdBackward object at 0x000001F18FB912F8>
((<AccumulateGrad object at 0x000001F18FBC2710>, 0),)
可以看到,通过简单的创建数据、运算赋值这两行代码,计算图已经建立起来了。这个操作的背后具体是怎么建立起来的呢,先从x.prod()这一方法的定义追踪起。在torch/autograd/variable.py里可以看见
class Variable(_C._VariableBase): ……………………… def prod(self, dim=None, keepdim=None): return Prod.apply(self, dim, keepdim) ………………………
根据追踪,可以发现这里提及的Prod类定义在torch/autograd/_functions/reduce.py里,继承Function(py)类,有两个类方法forward、backward,但是没有显式地定义apply方法。不管如何,此前的赋值语句可以视为
from torch.autograd._functions import Prod y = Prod.apply(x, 1, True) # same as y = x.prod(dim=1, keepdim=True)
因为Prod没有显示地定义apply方法,所以我们需要到它的父类里找apply方法,Prod继承Function(py)类,Function(py)类的定义可以在torch/autograd/function.py里面找到
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # only for backward compatibility __call__ = _C._FunctionBase._do_forward @staticmethod def forward(*args, **kwargs): raise NotImplementedError @staticmethod def backward(*grad_outputs): raise NotImplementedError
Function(py)类有两个待子类实现的方法,但也没有定义apply方法,它的父类是以一种元类的形式动态产生的,这里不详细解释元类的具体产生知识,但我们至少看代码上可以知道,可以到FunctionMeta、_C._FunctionBase、_ContextMethodMixin、_HookMixin几个类里面去找apply方法。其中除了第二个类以外,另外三个都在python代码中进行定义(和Function(py)类都在同一个文件内),很容易就发现都没有定义apply方法。_FunctionBase类此前第一部分详细提过,由C++源代码包装而得。
在这里,先暂缓一下apply方法的追踪,分神看一看FunctionMeta这个元类的实现
class FunctionMeta(type): """Function metaclass. This metaclass sets up the following properties: _is_legacy: True if forward is not defined as a static method. _backward_cls: The Function class corresponding to the differentiated version of this function (which is generated on the fly by this metaclass). """ def __init__(cls, name, bases, attrs): for super_cls in cls.mro(): forward = super_cls.__dict__.get('forward') if forward is not None: has_static_forward = isinstance(forward, staticmethod) or isinstance(forward, classmethod) break setattr(cls, '_is_legacy', not has_static_forward) # old-style functions if not has_static_forward: return super(FunctionMeta, cls).__init__(name, bases, attrs) backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls}) setattr(cls, '_backward_cls', backward_fn) return super(FunctionMeta, cls).__init__(name, bases, attrs)
这个元类动态给生成的类设置了一个_backward_cls属性,这个属性的值是backward_fn,而backward_fn又是一个动态生成的类,这个类由type函数创建,第一个参数是类名,第二个参数是继承的父类集合,第三个参数是类属性名与具体对象的映射字典。我们在交互界面Prod来进行一些操作验证一下。
from torch.autograd._functions import Prod Prod._backward_cls Out[2]: torch.autograd.function.ProdBackward Prod._backward_cls() Out[3]: <torch.autograd.function.ProdBackward at 0x1caee4ae450> Prod._backward_cls_.forward_cls Out[4]: torch.autograd._functions.reduce.Prod Prod._backward_cls._forward_cls.apply? Docstring: <no docstring> Type: builtin_function_or_method Prod._backward_cls.apply? Signature: Prod._backward_cls.apply(self, *args) Docstring: <no docstring> File: c:anaconda3libsite-packages orchautogradfunction.py Type: function
可以看到,FunctionMeta元类动态地给Prod类生成了一个_backward_cls属性,这个属性的值是一个类,类的名字叫ProdBackward, 符合源代码中类名为name + 'Backward’的构造形式。将Prod._backward_cls实例化以后可以得到一个对应的对象。因为这个动态类建立的时候给它定义了一个_forward_cls的属性,映射回类本身,所以Prod._backward_cls._forward_cls又能访问回Prod类。
那像这类动态生成的Backward类的继承关系又是怎么样的呢?从type的第二个参数可以看到,它的父类是BackwardCFunction,定义在和Function类同样的文件夹里。
class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin): _is_legacy = False def apply(self, *args): return self._forward_cls.backward(self, *args)
可以看到,BackwardCFunction类除了FunctionMeta类以外,也继承了另外三个重要的类,然后覆写了apply方法。所以ProdBackward和Prod的整体操作都是很像的,主要区别只有两个。其一是Prod在创建的时候,会动态生成一个PordBackward类;其二是,Prod从_FunctionBase继承apply方法,而ProdBackward继承的apply则是其是父类覆写后的apply。
3.2 C++层代码 - THPFunction_apply
于是,现在回过头来继续追踪Prod的apply方法。_FunctionBase在torch/csrc/autograd/python_function.cpp里被注册,在同样文件里,变量THPFunction_methods指明了C++函数与python对象方法的映射,可以看到_FunctionBase.apply相当于就是调用了THPFunction_apply函数。看一下这个函数的具体定义
PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs) { HANDLE_TH_ERRORS THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); if (!backward_cls) return NULL; THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, NULL)); if (!ctx_obj) return NULL; THPFunction* ctx = (THPFunction*)ctx_obj.get(); // Prepare inputs and allocate context (grad fn) auto info_pair = unpack_input<false>(_inputs); auto& unpacked_input = info_pair.first; auto& input_info = info_pair.second; bool is_volatile = input_info.flags.is_volatile; ctx->cdata.set_flags(std::move(input_info.flags)); ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = new std::vector<bool>(std::move(input_info.is_variable_input)); // Prepend ctx to tensor_input, in preparation for static method call auto num_args = PyTuple_GET_SIZE(_inputs); THPObjectPtr ctx_tensor_input(PyTuple_New(num_args + 1)); PyTuple_SET_ITEM(ctx_tensor_input.get(), 0, ctx_obj.release()); for (int i = 0; i < num_args; ++i) { PyObject *arg = PyTuple_GET_ITEM(unpacked_input.tensor_input.get(), i); Py_INCREF(arg); PyTuple_SET_ITEM(ctx_tensor_input.get(), i + 1, arg); } // Call forward THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); if (!forward_fn) return NULL; THPObjectPtr tensor_outputs(PyObject_CallObject(forward_fn, ctx_tensor_input)); if (!tensor_outputs) return NULL; return process_outputs(ctx, unpacked_input, std::move(tensor_outputs), is_volatile); END_HANDLE_TH_ERRORS }
函数比较长,源代码书写的时候也明确地分成了几个部分,每写完一个部分就用一个空行来分隔,一共可以算作分了五段。每段代码都简单总结如下:
a.传入参数PyObject* cls和PyObject* _inputs。cls代表调用这个函数的类本身,即python中的Prod类。_inputs则代表这个函数在python中调用时的所有参数,以tuple的形式打包成_inputs,对于最初的示例代码而言,则相当于python中的(x, 1, True)。第一段写成类python代码时,形如
ctx = Prod._backward_cls()
b.参数通过unpack_input函数解析参数_inputs,得到unpacked_input和input_info两个对象,该函数与返回值的定义将稍后细查,先略带剧透地进行类比,unpacked_input.tensor_input的值类似于在python中的(x.data, 1, True),把原来的input_的值(x, 1, True)中的Variable换成了Tensor,其余参数保持一致。
c.重处理输入参数,将这一段翻译成类python代码如下所示,基本目的就是生成一个比_inputs长1的tuple,把这个tuple的首位赋值为ctx,剩余的按unpacked_input.tensor_input依序填入。
num_args = len(_inputs) ctx_tensor_input = (None, ) * (num_args + 1) ctx_tensor_input[0] = ctx for i in range(num_args): arg = unpacked_input.tensor_input[i] ctx_tensor_input[i + 1] = arg
d.运行forward计算,得到一个Tensor对象,这一段写成类python代码如下所示。PyObject_CallObject调用了python代码中的函数,也就是Prod类中的forward方法,稍后将回来追踪此方法的实现。
forward_fn = Prod.forward tensor_outputs = forward_fn(*ctx_tensor_input) #ctx_tensor_input = (ctx, x.data, 1, True)
e.调用process_outputs,将返回的Tensor对象包装成一个Variable对象,并返回
在五个部分中,a、c部分相对简单,b、d、e部分均调用了其它效果不甚显然的函数,计算图是在哪一部分形成的连接呢?以下进行详细的解析
3.3 C++层代码 - unpack_input与参数解析
unpack_input是THPFunction_apply的b部分解析参数_inputs的重要函数,它返回由一个UnpackedInput实例,和一个InputFlags实例组成的std::pair,这两个类的定义恰好在unpack_input的定义之前
struct UnpackedInput { PyObject *raw_input; THPObjectPtr tensor_input; variable_list input_vars; }; struct InputFlags { FunctionFlags flags; THPObjectPtr needs_input_grad; std::vector<bool> is_variable_input; };
从类及类成员变量的命名上可猜测,UnpackedInput主要保存_inputs解析后的数据,InputFlags类通过逐个解析_inputs的分量,来判断每个变量的求导标识。在这种基本先验思想指导之下,查看unpack_input的代码
template<bool enforce_variables> std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { UnpackedInput unpacked; InputFlags flags; auto num_args = PyTuple_GET_SIZE(args); unpacked.tensor_input = PyTuple_New(num_args); flags.needs_input_grad = PyTuple_New(num_args); for (int i = 0; i < num_args; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); PyObject *new_arg; bool is_variable = THPVariable_Check(arg); flags.is_variable_input.push_back(is_variable); if (!is_variable) { if (enforce_variables) { THPUtils_setError("expected a Variable argument, but got %s", THPUtils_typename(arg)); throw python_error(); } Py_INCREF(arg); new_arg = arg; Py_INCREF(Py_False); PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False); } else { THPVariable* variable = (THPVariable*)arg; new_arg = THPVariable_get_data(variable); unpacked.input_vars.push_back(variable->cdata); PyObject* needs_grad = variable->cdata->requires_grad ? Py_True : Py_False; Py_INCREF(needs_grad); PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad); } PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); } flags.flags = Function::flags(unpacked.input_vars); return std::make_pair(std::move(unpacked), std::move(flags)); }
由于代码冗长,不如单独追踪两个类的每个成员变量,如对于UnpackedInput类的tensor_input域而言,循环体实际如下。如果某变量不是一个THPVariable,则直接添加到unpacked.tensor_input中,如果某变量是一个THPVariable,则对unpacked.tensor_input添加这个变量的data域,相当于python中的Tensor。
template<bool enforce_variables> std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { .................................. for (int i = 0; i < num_args; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); PyObject *new_arg; bool is_variable = THPVariable_Check(arg) ..............................if (!is_variable) { .................................... new_arg = arg; .................................... } else { THPVariable* variable = (THPVariable*)arg; new_arg = THPVariable_get_data(variable); ................................................ } PyTuple_SET_ITEM(unpacked.tensor_input.get(), i, new_arg); } ................................................... }
对两个类各个成员变量进行分析,可以得知解析函数后具体产生的数值,简易但不太严谨的简述如下:
UnpackedInput.tensor_input:和_inputs一样长的tuple,如果_inputs[i]不是THPVariable,那就维持tensor_input[i] = inputs[i],如果是THPVariable,则tensor_input[i] = input[i].data。值得注意的是,虽然这个成员变量的名字叫做tensor_input,但并不代表tensor_input的每一个分量都是tensor,实际上,它有可能包括大量的计算参数。
UnpackedInput.input_vars:如果_inputs[i]是一个THPVariable,添加_inputs[i].cdata(Variable(C++)类)到input_vars中,否则忽略
InputFlags.flags:UnpackedInput.input_vars把输入的Variable(C++)收集完毕后,调用Function::flags来判断求导标识。基本规则按照2.1中提过的进行。
InputFlags.needs_input_grad:和_inputs一样长的tuple,如果_inputs[i]需要求导则为PyTrue,不需要求导则为PyFalse。不需要求导的可能性有两种,一是_inputs[i]本身就不是THPVariable,二是虽然_inputs[i]是THPVariable但其求导标识不需要计算图对它求导。
InputFlags.is_variable_input:和inputs一样长的vector,如果_inputs[i]是THPVariable则为True,否则为False
对于本节开始的python代码,几个域的值以类python的风格写出来大致如下:
_inputs:(x, 1, True)
UnpackedInput.tensor_input:(x.data, 1, True)
UnpackedInput.input_vars:(x,)
InputFlags.needs_input_grad:(True, False, False)
InputFlags.is_variable_input:(True, False, False)
非常值得注意的是InputsFlags.flags这个域,它调用了Function::flags,看一下Function::flags的源代码,在torch/csrc/autograd/function.cpp
auto Function::flags(const variable_list& inputs) -> FunctionFlags { int num_inputs = inputs.size(); FunctionFlags f; f.is_executable = false; f.is_volatile = false; f.next_functions.resize(num_inputs); for (int i = 0; i != num_inputs; ++i) { auto& var = inputs[i]; if (var) { f.is_executable |= var->requires_grad; f.is_volatile |= var->is_volatile; if (var->grad_fn) { f.next_functions[i] = std::make_pair<>(var->grad_fn, var->output_nr); } else { f.next_functions[i] = std::make_pair<>(var->get_grad_accumulator(), 0); } } } f.is_executable &= !f.is_volatile; return f; }
可以看到,除了一直有提到的is_volatile、requires_grad、is_executable之间的关系,很令人感兴趣的一点是FunctionFlags的next_functions域也有被操作,当某个input有grad_fn属性的时候(换句话说,不是叶节点的时候),则FunctionFlag的next_functions域的某分量会指向这个grad_fn;当某个input没有grad_fn的时候,则next_functions的这个分量会指向一个Function(C++)对象。这是一个从FunctionFlags类实例到Fucntion(C++)类实例的连接,实际上,它已经与Function(C++)类实例与Function(C++)类实例的连接非常接近了。是在哪里做出这个转换的呢?可以从调用完iunpack_input后,THPFunction_apply的代码往下查阅。
PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs) { ................................................... // Prepare inputs and allocate context (grad fn) auto info_pair = unpack_input<false>(_inputs); auto& unpacked_input = info_pair.first; auto& input_info = info_pair.second; bool is_volatile = input_info.flags.is_volatile; ctx->cdata.set_flags(std::move(input_info.flags)); ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = new std::vector<bool>(std::move(input_info.is_variable_input)); .................................................... }
代码摘抄区域的第5行,ctx->cdata调用下成员函数set_flags,参数是由_inputs经过unpack_input得到的input_info中的flags域。查看Function::set_flags这个方法,在torch/csrc/autograd/function.h中,显而易见地,经过这个方法的调用,计算图最终实现了Function(C++)与Function(C++)的连接。更详细地,ctx是当前output即将会指向的梯度函数(在下面的process_outputs中实现),被连接的,即处于ctx的next_functions域中的,则是所有计算出output的inputs对象各种指向的梯度函数,这是一个由output梯度函数指向inputs梯度函数而形成相连的链条。
struct Function { ....................................... inline void set_flags(FunctionFlags&& flags) { is_executable = flags.is_executable; next_functions = std::move(flags.next_functions); } ....................................... }
3.4 python层代码 - forward函数调用
在THPFunction_apply函数的d部分,C++代码通过PyObject_CallObject调用python中的函数,进行前向运算。对于初始的python示例代码,它相当于在python中调用了
Prod.forward(ctx, x.data, dim=1, keepdim=True)
观察Prod的forward方法
class Prod(Function): @staticmethod def forward(ctx, input, dim=None, keepdim=None): ctx.dim = dim ctx.keepdim = False if keepdim is None else keepdim ctx.input_size = input.size() if dim is None: ctx.result = input.prod() ctx.save_for_backward(input) return input.new((ctx.result,)) else: if keepdim is not None: output = input.prod(dim, keepdim=keepdim) else: output = input.prod(dim) ctx.save_for_backward(input, output) return output
....................................
整个函数的定义相对明确且简单,有两个值得提的点。第一,输入的input形参与返回值output,在python中均为Tensor类,而非Variable类;第二,将input和output控制为Tensor类的原因在于,底层的Tensor类已经设计好了一套数据运算方法,如果不调用基于Tensor的方法,而在Variable上建立新的运算规则,不利于分层维护的原则,也会造成较大的资源浪费。
3.4 C++层代码 - process_outputs
THPFunction_apply的d部分,数据通过前向运算,得到了output,但是这个output只是一个Tensor,还未被包装为Variable。从计算图的角度看,现在虽然已经有了从Function节点到Function节点的连接,但是从Variable节点到Function节点的连接却还未建立。e部分,process_outputs就是处理这种后续工作的。
PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs) { HANDLE_TH_ERRORS ..................................
return process_outputs(ctx, unpacked_input, std::move(tensor_outputs), is_volatile); END_HANDLE_TH_ERRORS }
看process_outputs函数的定义
PyObject* process_outputs(THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr&& raw_output, bool is_volatile) { bool unpack_output = _ensure_tuple(raw_output); auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); THPObjectPtr outputs(PyTuple_New(num_outputs)); if (!outputs) throw python_error(); grad_fn->cdata.num_inputs = num_outputs; // Initialize t2var map t2var_type t2var; for (auto& c_var : unpacked.input_vars) { THPVariable* py_var = (THPVariable*)c_var->pyobj; t2var.emplace(py_var->data, py_var); } std::unordered_set<PyObject *> dirty_inputs; _mark_dirty(grad_fn, t2var, dirty_inputs); _wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile); _join_version_counters(grad_fn, t2var); if (grad_fn->cdata.is_executable) { _mark_non_differentiable(grad_fn, t2var); _save_variables(grad_fn, t2var); } else { // Remove unnecessary attributes Py_XDECREF(grad_fn->to_save); grad_fn->to_save = NULL; Py_XDECREF(grad_fn->non_differentiable); grad_fn->non_differentiable = NULL; } // Unpack the output, unless .forward() returned a tuple if (unpack_output) { PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0); Py_INCREF(output);_sa return output; } return outputs.release(); }
看到这个函数一样继续往下调用了很多其它的函数,按顺序包括_mark_dirty、_wrap_outputs、_join_version_counters、_mark_non_differentiable、_save_variables。这几个函数可以分成两类,_wrap_outputs作为核心的包装函数属于一类,其它的四个函数属于另一类。转到_mark_diry等四个函数的定义一看,函数内部开始时,都会先检查某个grad_fn的成员变量。
static void _mark_dirty(THPFunction *self, t2var_type &t2var, std::unordered_set<PyObject *> &dirty_inputs) { // Increase versions of modified tensors if (!self->dirty_tensors) return; ................................ } static void _save_variables(THPFunction* self, t2var_type &t2var) { if (!self->to_save) return; ................................. } static void _join_version_counters(THPFunction *self, t2var_type &t2var) { if (!self->shared_pairs) return; ................................ } static void _mark_non_differentiable(THPFunction *self, t2var_type &t2var) { if (!self->non_differentiable) return; ................................ }
这些成员变量都是从哪里来的呢?检查THPFunction的类定义dirty_tensors、to_save、shared_pairs、non_fidderentiable均有PyObject*,可以在python中访问到。在torch/autograd/function.py的python源代码中,可以看见一个 _ContextMethodMixin类,它是Function类的元类组成部分之一,它的方法操作这这些变量的赋值。另外,从这些方法的注释里可以发现,这些方法只允许在Funciton(py)类的forward方法中被调用,各自解决一些特定的问题,如mark_dirty处理数据原地操作后如何正确建立计算图的问题。
class _ContextMethodMixin(object): def save_for_backward(self, *tensors): """Saves given tensors for a future call to :func:`~Function.backward`. **This should be called at most once, and only from inside the** :func:`forward` **method.** Later, saved tensors can be accessed through the :attr:`saved_tensors` attribute; or, if the corresponding Variable is needed (e.g. for double backwards), those can be accessed through the :attr:`saved_variables` attribute. Before returning them to the user, a check is made, to ensure they weren't used in any in-place operation that modified their content. Arguments can also be ``None``. """ self.to_save = tensors def mark_dirty(self, *args): """Marks given tensors as modified in an in-place operation. **This should be called at most once, only from inside the** :func:`forward` **method, and all arguments should be inputs.** Every tensor that's been modified in-place in a call to :func:`forward` should be given to this function, to ensure correctness of our checks. It doesn't matter whether the function is called before or after modification. """ self.dirty_tensors = args def mark_shared_storage(self, *pairs): """Marks that given pairs of distinct tensors are sharing storage. **This should be called at most once, only from inside the** :func:`forward` **method, and all arguments should be pairs of (input, output).** If some of the outputs are going to be tensors sharing storage with some of the inputs, all pairs of (input_arg, output_arg) should be given to this function, to ensure correctness checking of in-place modification. The only exception is when an output is exactly the same tensor as input (e.g. in-place ops). In such case it's easy to conclude that they're sharing data, so we don't require specifying such dependencies. This function is not needed in most functions. It's primarily used in indexing and transpose ops. """ self.shared_pairs = pairs def mark_non_differentiable(self, *args): """Marks outputs as non-differentiable. **This should be called at most once, only from inside the** :func:`forward` **method, and all arguments should be outputs.** This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in :meth:`~Function.backward`, but it's always going to be ``None``. This is used e.g. for indices returned from a max :class:`Function`. """ self.non_differentiable = args
因为C++里对应的处理函数相对比较繁琐复杂,这里就不一一详解了。仅以当前举例的Prod类为主,它的forward函数只调用了save_for_backward函数,则在几个功能函数中,只看对应的其中一个函数。
了解这些功能函数,可以注意看_wrap_outputs这个从Tensor向Variable包装的核心函数。
static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
PyObject *outputs, bool is_volatile){
// Wrap outputs in Variables auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); if (self->cdata.is_executable) { self->output_info = new std::vector<output_info_type>(); self->output_info->reserve(num_outputs); } for (int i = 0; i < num_outputs; i++) { PyObject *output = PyTuple_GET_ITEM(raw_output, i); THPVariable *output_var; auto it = t2var.find(output); if (it == t2var.end()) { // A completely new tensor - just wrap it and continue if (is_volatile) { output_var = (THPVariable*)THPVariable_NewVolatile(output); } else { output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata); } } else {
.....................................
.....................................
.....................................
//long long long code
} if (!output_var) throw python_error(); if (self->output_info) { auto& output_tensor = *output_var->cdata->data; self->output_info->emplace_back( (PyObject *)getPyTypeObject(output_tensor), output_tensor.getDevice(), output_tensor.sizes() ); } t2var[output] = output_var; output_var->cdata->output_nr = i; PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var); } }
第一个参数self,在当前情况下,实参为一路传进来的THPFunction* ctx;第二个参数t2var,是一个由输入Tensor到对应输入Variable的无序映射,在process_outputs作用域中生成,在当前情况下,t2var的值类似于python中的字典{x.data: x};第三个参数dirty_inputs,因为在_mark_dirty中直接return了,所以是一个空的集合,在本次调用中也不起作用;第四个参数raw_output,就是由python层forward方法计算得到的输出Tensor,需要进行包装的数据;第五个参数outputs用于返回包装好后的值;第六个参数is_volatile作为求导选项标识传入。
函数开始先做一些is_volatile、is_executable的检查,如果确实需要求导,则开一个输出变量大小的空间,然后进入循环。