zoukankan      html  css  js  c++  java
  • [源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

    [源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

    0x00 摘要

    前文中我们介绍了反向传播引擎的动态逻辑,因为具体反向传播算法是在设备线程中完成的,所以我们单独用一章来讲解。

    img

    本系列其他文章如下:

    深度学习利器之自动微分(1)

    深度学习利器之自动微分(2)

    [源码解析]深度学习利器之自动微分(3) --- 示例解读

    [源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

    [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

    [源码解析] PyTorch如何实现前向传播(3) --- 具体实现

    [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

    [源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

    [源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

    0x01 工作线程主体

    thread_main是工作线程的主体函数,主要逻辑就是围绕着 ReadyQueue 执行一个 while 循环,工作线程阻塞在 ReadyQueue -> pop 这里,如果主线程或者其他线程插入了一个 NodeTask,则 pop 会返回取出一个 NodeTask,工作线程处理这个 NodeTask,完成后向计算的一个环节,如果有需要就继续往某一ReadyQueue插入新的 NodeTask,驱动引擎继续执行后向计算其他环节。

    thread_main 从如下途径被调用:

    1. CUDA, XLA 设备的 autograd threads 会调用。
    2. CPU 之上的反向传播主线程会调用。
    3. 前两个case 进行可重入反向传播,也会调用。

    1.1 线程主体代码

    工作线程的计算始于动态图的GraphRoot函数,反向传播就以 Node 的edge为纽带,层层从前向后计算,直到来到了leaf节点,最终完成了反向计算,具体如下:

    • local_graph_task表示我们从队列中检索的graph_task。外部graph_ 任务表示我们需要执行的可重入执行的总体 graph_任务。
    • 从自己的ReadyQueue之中取出NodeTask实例,使用 local_graph_task 为参数来执行evaluate_function(反向传播函数)。
    • outstanding_tasks 自减 1。
    • 如果本 local_graph_task 已经结束(可重入反向传播会运行多个 GraphTask),即:
      • 执行后续操作 exec_post_processing,然后使用 future_result_->markCompleted。
      • 如果这个task是来自其它worker thread,即 worker_device != base_owner,则向那个worker thread的queue发送一个dummy function task,让那个工作线程也执行起来。

    具体代码如下:

    // thread_main is used by:
    // 1). autograd threads for devices (i.e. CUDA, XLA)
    // 2). the caller/owning thread of the backward call on CPU (sync mode)
    // 3). Renetrant backward that invoked by either 1) or 2)
    // The exit conditions are different for the above three cases.
    // For 1), we are spinning on running the thread_main on device autograd
    //         threads throughout the Engine lifetime, thread_main will get
    //         terminated during Engine destruction by pushing shutdown tasks
    // For 2), the owning thread of the backward call drives the thread_main
    //         synchronously until the graph_task of that owning thread is
    //         completed and exit the thread_main to continue executing the
    //         result of caller's code.
    // For 3), the reentrant backward that invokes
    //         thread_main, either from 1) or 2), will not spin and will exit as
    //         long as graph_task is completed and notify the owning thread as
    //         needed.
    auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
      // When graph_task is nullptr, this is a long running thread that processes
      // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
      // backwards, user thread), this function is expected to exit once that
      // graph_task complete.
    
      // local_ready_queue should already been initialized when we get into thread_main
      while (graph_task == nullptr || !graph_task->future_result_->completed()) {
        // local_graph_task represents the graph_task we retrieve from the queue.
        // The outer graph_task represents the overall graph_task we need to execute
        // for reentrant execution.
        std::shared_ptr<GraphTask> local_graph_task;
        {
          // Scope this block of execution since NodeTask is not needed after this
          // block and can be deallocated (release any references to grad tensors
          // as part of inputs_).
          NodeTask task = local_ready_queue->pop(); // 阻塞等待
          // This will only work if the worker is running a non backward task
          // TODO Needs to be fixed this to work in all cases
          if (task.isShutdownTask_) {
            break;
          }
    
          if (!(local_graph_task = task.base_.lock())) {
            // GraphTask for function is no longer valid, skipping further
            // execution.
            continue;
          }
    
          if (task.fn_ && !local_graph_task->has_error_.load()) {
           // 利用grad_mode_来配置AutoGradMode,整个反向计算期间的代码都靠GradMode::is_enabled()来判断当前是否是要计算grad  
            AutoGradMode grad_mode(local_graph_task->grad_mode_);
            try {
              // The guard sets the thread_local current_graph_task on construction
              // and restores it on exit. The current_graph_task variable helps
              // queue_callback() to find the target GraphTask to append final
              // callbacks.
              GraphTaskGuard guard(local_graph_task);
              NodeGuard ndguard(task.fn_);
              // 执行后向计算
              evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
            } catch (std::exception& e) {
              thread_on_exception(local_graph_task, task.fn_, e);
            }
          }
        }
    
        // Decrement the outstanding tasks.
        --local_graph_task->outstanding_tasks_;
    
        // Check if we've completed execution.
        if (local_graph_task->completed()) { // 已经结束了,进行后续处理
          local_graph_task->mark_as_completed_and_run_post_processing();
    
          auto base_owner = local_graph_task->owner_; // 后续是需要在 GraphTask 的 owner_ 处理
          // The current worker thread finish the graph_task, but the owning thread
          // of the graph_task might be sleeping on pop() if it does not have work.
          // So we need to send a dummy function task to the owning thread just to
          // ensure that it's not sleeping, so that we can exit the thread_main.
          // If it has work, it might see that graph_task->outstanding_tasks_ == 0
          // before it gets to the task, but it's a no-op anyway.
          //
          // NB: This is not necessary if the current thread is the owning thread.
          if (worker_device != base_owner) {
            // Synchronize outstanding_tasks_ with queue mutex
            std::atomic_thread_fence(std::memory_order_release);
            // 获取后续工作的queue
            ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
                ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
          }
        }
      }
    }
    

    1.2 使用 Ready Queue

    上述代码之中,最后使用 ready_queue_by_index 获取到后续工作对应的queue。

    ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
        ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
    

    如何获取Ready Queue?具体策略是:

    • 如果下一个 需要执行的设备是 CPU,则选用cpu_ready_queue。
    • 否则从device_ready_queues_选取一个GPU对应的 ReadyQueue。

    代码如下:

    auto Engine::ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue, int device_index) -> std::shared_ptr<ReadyQueue> {
      if (device_index == CPU_DEVICE) {
        // return the cpu ready queue passed in
        TORCH_INTERNAL_ASSERT(cpu_ready_queue);
        return cpu_ready_queue;
      } else {
        // Static cast is ok here as the number of device should never overflow an int.
        TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast<int>(device_ready_queues_.size()));
        // See Note [Allocating GPUs to autograd threads]
        // NB: This function would become obsolete if we truly allocated a CPU thread
        // per device, rather than colocate.
        return device_ready_queues_.at(device_index);
      }
    }
    

    逻辑如下:

    +---------------------------------------------------------------------+
    |  Main Thread                                                        |
    |                                                                     |
    |            push(NodeTask)+--------------+                           |
    |                                         |                           |
    +---------------------------------------------------------------------+
                                              |
                                              |
                                              v
                                       +------+-----+
                                       |            |
                                       | ReadyQueue |
                                       |            |
                                       +------+-----+
                                              |
                                              |
                                              |
    +---------------------------------------------------------------------+
    | Worker Thread 1                         |                           |
    |                                         |                           |
    |  thread_main{                           |                           |
    |                                         v                           |
    |     NodeTask task = local_ready_queue->pop()                        |
    |                                                                     |
    |     evaluate_function(task.fn_.get(),task.inputs_)                  |
    |  }                                                                  |
    +---------------------------------------------------------------------+
    
    

    0x02 反向计算总体逻辑

    evaluate_function 方法完成了反向计算的逻辑,总体逻辑如下:

    • 准备工作:如果exec_info需要处理,则处理 captured_vars_。
    • 反向计算:调用 call_function(graph_task, func, inputs),这是反向传播中计算相关的核心逻辑:
      • 调用pre hooks。
      • 调用fn进行计算。
      • 调用post hooks。
    • 扫尾工作:
      • 如果不需要keep graph,则fn.release_variables();
      • 依据 call_function的输出 outputs,进行计算 num_outputs = outputs.size(),得到 num_outputs的元素数量(该数量等同于当前fn的next_edge()返回的list中的元素数量)。
    • 准备下一步工作,具体就是查找后续需要计算的NodeTask,num_outputs 就是在这里被用到。这部分比较复杂。

    总体代码如下:

    void Engine::evaluate_function(
        std::shared_ptr<GraphTask>& graph_task,
        Node* func, // 导数计算方法
        InputBuffer& inputs, // 当前Node的输入梯度
        const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
        
      // 进行准备工作  
      // If exec_info_ is not empty, we have to instrument the execution
      auto& exec_info_ = graph_task->exec_info_;
      if (!exec_info_.empty()) {
        auto& fn_info = exec_info_.at(func); // 取出当前的进行处理
        if (auto* capture_vec = fn_info.captures_.get()) {
          // Lock mutex for writing to graph_task->captured_vars_.
          std::lock_guard<std::mutex> lock(graph_task->mutex_);
          for (const auto& capture : *capture_vec) {
            // captured_grad 就是临时存储下,每次node计算都会更新,最终输出给调用者,相当于引用
            // 1. captured_grad 引用了captured_vars_[capture.output_idx_],
            auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
            // 2. 给 captured_vars_[capture.output_idx_] 赋值 inputs[capture.input_idx_]
            captured_grad = inputs[capture.input_idx_];
            // 遍历hooks,链式调用hook进行计算,captured_grad 不停的作为输入和输出在流水线中流淌
            // 就是针对 captured_vars_[capture.output_idx_]不停的计算,最终结果还是在 captured_vars_[capture.output_idx_] 之中。
            for (auto& hook : capture.hooks_) {
              captured_grad = (*hook)(captured_grad);
            }
          }
        }
        if (!fn_info.needed_) {
          // Skip execution if we don't need to execute the function.
          return;
        }
      }
    
      // Set the ThreadLocalState before calling the function.
      // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
      // always saves ThreadLocalState without grad_mode.
      at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_);
    
      // Switches to a function's CUDA stream (if applicable) before calling it
      const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
      c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
    
      // 进行反向计算
      auto outputs = call_function(graph_task, func, inputs);
    
      // 如果不需要保持计算图,则本节点释放变量
      auto& fn = *func;
      if (!graph_task->keep_graph_) {
        fn.release_variables();
      }
    
      // 得到 num_outputs的元素数量(该数量等同于当前fn的next_edge()返回的list中的元素数量),后续遍历本节点输出时候会用到
      int num_outputs = outputs.size();
      if (num_outputs == 0) { // Note: doesn't acquire the mutex
        // Records leaf stream (if applicable)
        // See note "Streaming backwards"
        if (opt_parent_stream) {
          std::lock_guard<std::mutex> lock(graph_task->mutex_);
          graph_task->leaf_streams.emplace(*opt_parent_stream);
        }
        return;
      }
    
      if (AnomalyMode::is_enabled()) {
        AutoGradMode grad_mode(false);
        for (int i = 0; i < num_outputs; ++i) {
          auto& output = outputs[i];
          at::OptionalDeviceGuard guard(device_of(output));
          if (output.defined() && isnan(output).any().item<uint8_t>()) {
            std::stringstream ss;
          }
        }
      }
    
      // 准备下一步工作
      // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (int i = 0; i < num_outputs; ++i) {
        auto& output = outputs[i];
        const auto& next = fn.next_edge(i); // next_edge是该node在前向传播图中的输入,在反向传播时候就是本节点的输出,所以next就是下一个可能运算的节点
    
        if (!next.is_valid()) continue;
    
        // Check if the next function is ready to be computed
        bool is_ready = false;
        auto& dependencies = graph_task->dependencies_;
        auto it = dependencies.find(next.function.get()); // 找到下一个节点的依赖
    
        if (it == dependencies.end()) {
          auto name = next.function->name();
          throw std::runtime_error(std::string("dependency not found for ") + name);
        } else if (--it->second == 0) {
          dependencies.erase(it);
          is_ready = true; // 下一个节点没有入度了,那么说明计算该节点梯度依赖的其他节点梯度都已经计算完成
        }
    
        // 要去 not_ready里面看看,是否已经存储了
        auto& not_ready = graph_task->not_ready_;
        auto not_ready_it = not_ready.find(next.function.get());
        if (not_ready_it == not_ready.end()) {
          // 下一个节点的梯度还没有进行计算
          // Skip functions that aren't supposed to be executed
          // 跳过不需要计算的节点
          if (!exec_info_.empty()) {
            auto it = exec_info_.find(next.function.get());
            if (it == exec_info_.end() || !it->second.should_execute()) {
              continue;
            }
          }
          // No buffers have been allocated for the function
          InputBuffer input_buffer(next.function->num_inputs()); // 下一个节点前置梯度的buffer,就是下一个节点的输入梯度
    
          // Accumulates into buffer
          // 下一个节点的输入梯度就是当前节点的输出,所以要拷贝过去
          const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
          input_buffer.add(next.input_nr,
                           std::move(output),
                           opt_parent_stream,
                           opt_next_stream);
    
          if (is_ready) {
            auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
            // 既然依赖全部完成,就插入到ReadyQueue 之中
            queue->push(
                NodeTask(graph_task, next.function, std::move(input_buffer)));
          } else {
            // 下一个节点的输入依赖还没有完成,就放到not_ready之中。
            not_ready.emplace(next.function.get(), std::move(input_buffer));
          }
        } else {
          // 如果下一个节点已经开始计算,但是没有完成(就是依赖梯度还有),此时应该在not_ready之中
          // The function already has a buffer
          auto &input_buffer = not_ready_it->second;
    
          // Accumulates into buffer
          const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
          input_buffer.add(next.input_nr,
                           std::move(output),
                           opt_parent_stream,
                           opt_next_stream);
            
          // Graph中每一个node(fn)的输出是下一个node(fn)的输入,下面4句代码来将前一个fn的输出转化为下一个fn的输入  
          if (is_ready) {
            // 如果此时已经没有输入依赖,就放入新的NodeTask,就是下一个需要计算梯度的NodeTask
            auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
            queue->push(
                NodeTask(graph_task, next.function, std::move(input_buffer)));
            //已经完成下一个节点前置梯度计算,从not_ready中移除相应的buffer
            not_ready.erase(not_ready_it);
          }
        }
      }
    }
    

    因为这部分代码十分复杂,我们逐一进行分析。

    0x03 准备工作

    首先我们看看准备工作,具体如下:

    • 取出当前 Node 的 ExecInfo。
    • 取出其 captures_,遍历其中每一个 Capture。
    • 遍历Capture 的 hooks,链式调用hook进行计算。
      • captured_grad 不停的作为输入和输出在流水线中流淌,针对 captured_vars_[capture.output_idx_]陆续计算。
      • 最终结果保存在 captured_vars_[capture.output_idx_] 之中。

    代码中有一个细节,就是captured_grad 只是临时存储,每次node计算都会更新,最终输出给调用者,相当于引用

    void Engine::evaluate_function(
        std::shared_ptr<GraphTask>& graph_task,
        Node* func, // 导数计算方法
        InputBuffer& inputs, // 当前Node的输入梯度
        const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
        
      // 进行准备工作  
      // If exec_info_ is not empty, we have to instrument the execution
      auto& exec_info_ = graph_task->exec_info_;
      if (!exec_info_.empty()) {
        auto& fn_info = exec_info_.at(func); // 取出当前的进行处理
        if (auto* capture_vec = fn_info.captures_.get()) {
          // Lock mutex for writing to graph_task->captured_vars_.
          std::lock_guard<std::mutex> lock(graph_task->mutex_);
          for (const auto& capture : *capture_vec) {
            // captured_grad 就是临时存储下,每次node计算都会更新,最终输出给调用者,相当于引用
            // 1. captured_grad 引用了captured_vars_[capture.output_idx_],
            auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
            // 2. 给 captured_vars_[capture.output_idx_] 赋值 inputs[capture.input_idx_]
            captured_grad = inputs[capture.input_idx_];
            // 遍历hooks,链式调用hook进行计算,captured_grad 不停的作为输入和输出在流水线中流淌
            // 就是针对 captured_vars_[capture.output_idx_]不停的计算,最终结果还是在 captured_vars_[capture.output_idx_] 之中。
            for (auto& hook : capture.hooks_) {
              captured_grad = (*hook)(captured_grad);
            }
          }
        }
        if (!fn_info.needed_) {
          // Skip execution if we don't need to execute the function.
          return;
        }
      }
    

    0x04 核心逻辑

    call_function是反向传播中计算相关的核心逻辑。

    • 调用注册在本 node上的pre_hooks;
    • 调用node本身,比如MeanBackward0、MulBackward0等。
      • 输入是InputBuffer::variables(std::move(inputBuffer)),一组Variable的实例。当动态图刚开始进行反向计算时,引擎首先执行的是图的根节点——graph_root,它的输入是task.inputs——InputBuffer(0)。
      • 调用的是fn的apply(),apply是多态实现,针对不同的operation会dispatch到operation对应的apply实现上。
      • 输出也是一组Variable的实例 outputs = fn(std::move(inputs_copy)),outputs 要作为下一个fn的输入。
    • 调用注册在node上的post hooks。
    • 返回当前节点对应的导数,这是一个variable_list。

    具体代码如下:

    static variable_list call_function(
        std::shared_ptr<GraphTask>& graph_task,
        Node* func,
        InputBuffer& inputBuffer) {
      CheckpointValidGuard cpvguard(graph_task);
      auto& fn = *func;
      auto inputs =
          call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
    
      if (!graph_task->keep_graph_) {
        fn.will_release_variables();
      }
    
      const auto has_post_hooks = !fn.post_hooks().empty();
      variable_list outputs;
    
      if (has_post_hooks) {
        // In functions/accumulate_grad.cpp, there is some logic to check the
        // conditions under which the incoming gradient can be stolen directly
        // (which elides a deep copy) instead of cloned. One of these conditions
        // is that the incoming gradient's refcount must be 1 (nothing else is
        // referencing the same data).  Stashing inputs_copy here bumps the
        // refcount, so if post hooks are employed, it's actually still ok for
        // accumulate_grad.cpp to steal the gradient if the refcount is 2.
        //
        // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
        // accumulate_grad.cpp accounts for this, but also creates a silent
        // dependency between engine.cpp (ie, this particular engine
        // implementation) and accumulate_grad.cpp.
        //
        // If you change the logic here, make sure it's compatible with
        // accumulate_grad.cpp.
        auto inputs_copy = inputs;
        outputs = fn(std::move(inputs_copy));
      } else {
        outputs = fn(std::move(inputs));
      }
    
      validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
        std::ostringstream ss;
        return ss.str();
      });
    
      if(has_post_hooks){
        return call_post_hooks(fn, std::move(outputs), inputs);
      }
      return outputs;
    }
    

    0x05 准备下一步工作

    这部分是反向传播的复杂之处。

    现在调用 call_function,得到了后向传播的输出,记录到了 outputs 之中。

    auto outputs = call_function(graph_task, func, inputs);
    

    所以,后半部分就是从 outputs 之中寻找后续可以计算的Node

    总体思路就是:遍历后向传播的输出节点(就是该节点在前向计算图中的入边连接的节点),逐一衡量输出节点。遍历循环中分为两段代码,对于每一个输出节点做如下操作:

    • 第一段是依据依赖排查这个节点,得到这个节点是否就绪。核心就是看看这个输出节点在GraphTask的dependencies的计数是否降为0
      • 如果是0,就说明这个节点就绪了,说明这个node不会被未来的计算所依赖了。
      • 如果非0,就说明这个节点有多个输入,即,被多个node连接,而且有的输入还没有计算完成梯度。
    • 第二段是依据是否就绪来处理这个节点,比如放入哪一个queue

    5.1 依据依赖排查节点

    第一段代码功能是依据依赖关系来 排查节点,得到这个节点是否就绪,具体如下:

    • 假定某一个节点是 output,我们得到对应的边,遍历输出边。

      • 每次把一个输出边记录为 next,func 是 NodeTask 之中的函数。

      • 利用 dependencies_ 的信息,next 是否可以计算。dependencies_ 里面记录的是图中所有节点的依赖。

      • 从 dependencies_ 之中找到 next 对应的依赖数目,把依赖数目减一(通常因为有多个 input)。

        • 如果--it->second == 0,说明该前置节点计算梯度所依赖的其他节点梯度都已经完成计算。则
          • 把该前置节点对应的信息GraphTask中移除,即从GraphTask的dependencies中移除(后续也会从GraphTask的 not_ready 成员变量之中移除)。
          • 将is_ready 置为true,后续会依据这个 is_ready 的数值进行操作。
      • 从 not_ready_ 之中得到 next 对应的输入buffer(后续代码就是对此进行操作);

        • std::unordered_map<Node*, InputBuffer> not_ready_;
          

      代码如下:

      for (int i = 0; i < num_outputs; ++i) { // 遍历输出节点,逐一衡量
        auto& output = outputs[i];
        const auto& next = fn.next_edge(i); // 获得一个输出节点
          
        if (!next.is_valid()) continue;
    
        // Check if the next function is ready to be computed
        bool is_ready = false;
        auto& dependencies = graph_task->dependencies_; // 拿到GraphTask的依赖关系
        auto it = dependencies.find(next.function.get()); // 找到输出节点的依赖项
    
        if (it == dependencies.end()) {
          auto name = next.function->name(); // 没找到
          throw std::runtime_error(std::string("dependency not found for ") + name);
        } else if (--it->second == 0) {
          dependencies.erase(it);  // 找到了,并且已经计算完毕
          is_ready = true;
        }
    
        auto& not_ready = graph_task->not_ready_; 
        auto not_ready_it = not_ready.find(next.function.get()); // 找到输入buffer     
    

    现在已经找到了某一个输出节点,也知道其是否计算完毕(依据有没有依赖项),也拿到了其存在"未就绪队列"的输入buffer(如果存在的话)。

    5.2 处理这个节点

    第二段是依据是否就绪来处理这个节点,比如放入哪一个queue,是就绪队列?还是未就绪队列?核心是:

    • 如果就绪,就放到该节点对应的 ReadyQueue 去处理。
    • 如果没有就绪,就新建立一个NodeTask放到 GraphTask的 not_ready 等待后续处理。需要注意的是,这个新的NodeTask 是在 worker thread 之中创建的。
    • 如何找到 ReadyQueue?需要看这个 Node 节点的 input_buffer.device() ,即,这个新 NodeTask 应该发送到 input_buffer.device() 那个 device 对应的 ReadyQueue。

    我们具体看看如何依据 is_ready 的数值来对 not_ready 进行操作。

    • 如果在 未就绪队列 not_ready 之中 没有找到 next_edge 对应的元素,则:
      • 如果 exec_info_ 不为空,则在 exec_info_ 之中查找 next_edge 对应的元素,如果有元素且注明了不需要执行,就跳到for循环的下一个。
      • 用 next_edge 的流,inut_nr 等信息构建一个 input_buffer。
      • 如果 is_ready 是 True,就用 本 GraphTask,next.function,input_buffer构建一个NodeTask,放入 ReadyQueue(利用 input_buffer.device() 来得到对应的 queue)。这就要唤醒下一个 worker 线程
      • 如果 is_ready 是 False,这通常表明这个node有多个输入(被更多的node连接,使用num_inputs()可以获得数量),也说明此次处理的是这个node的第一个输入,后续还需要使用这个 next_edge,所以这个 next_edge 需要被放到 not_ready 之中。则把 next.function,input_buffer 放入到 not_ready 之中,这个input_buffer 就是 next_edge 后续执行时候需要的各种输入。
    • 如果在 未就绪队列 not_ready 之中找到了 next_edge 对应的元素,则:
      • 拿出来该元素对应的 input_buffer,把信息累积到 input_buffer 之中。此次累积的是该节点的其他输入。 input_buffer.add(next.input_nr, std::move(output), opt_parent_stream, opt_next_stream) 完成了累积操作,next.input_nr 就表明当前的node是反向传播中要流向的node(next)的第几个输入。
      • 如果is_ready 是 True,就用 本 GraphTask,next.function,input_buffer构建一个NodeTask,放入 ReadyQueue。这就要唤醒下一个 worker 线程
      • 从 not_ready 之中移除此元素,就是从 GraphTask 的依赖关系之中去除。

    代码如下:

        if (not_ready_it == not_ready.end()) {
          // Skip functions that aren't supposed to be executed
          if (!exec_info_.empty()) {
            auto it = exec_info_.find(next.function.get());
            if (it == exec_info_.end() || !it->second.should_execute()) {
              continue;
            }
          }
          // No buffers have been allocated for the function
          InputBuffer input_buffer(next.function->num_inputs());
    
          // Accumulates into buffer
          const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
          input_buffer.add(next.input_nr,
                           std::move(output),
                           opt_parent_stream,
                           opt_next_stream);
    
          if (is_ready) {
            // 找出了下一个Node的queue
            auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
            queue->push( //
                NodeTask(graph_task, next.function, std::move(input_buffer)));
          } else {
            not_ready.emplace(next.function.get(), std::move(input_buffer));
          }
        } else {
          // The function already has a buffer
          auto &input_buffer = not_ready_it->second;
    
          // Accumulates into buffer
          const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
          input_buffer.add(next.input_nr,
                           std::move(output),
                           opt_parent_stream,
                           opt_next_stream);
          if (is_ready) {
            // 找出了下一个Node的queue
            auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
            queue->push(
                NodeTask(graph_task, next.function, std::move(input_buffer)));
            not_ready.erase(not_ready_it);
          }
        }
    

    具体逻辑图如下:

    1. func 指向了目前正在进行反向计算的 Node。
    2. func 调用自己的 apply 方法进行计算,得出了 outputs,假设有3个输出,遍历,我们选择第三个为 output。
    3. func 的边是 next_edges_ 成员变量,遍历,我们选择第三个边为next。
    4. 用 next 和 GraphTask 的 dependencies_ 来判断 next 是不是就绪。
    5. 如果就绪,把 output 构建一个 input_buffer,然后生成一个 NodeTask,插入到对应的 ReadyQuieue。
    6. 如果没就绪,把 output 构建一个 input_buffer,和 next 一起放入 GraphTask 的 not_ready_,后续会使用。
           1  +---------------+
    func +--> | Node          |              +---> ...
              |               |              |
              |               |              |
              |  apply() +------> outputs +------> ...  2
              |               |              |
              |               |              |
              |               |              |                 +--------------+
              |               |              +---> output +--> | input_buffer +--+
              |               |                                +--------------+  |
              |               |                                                  |
              |               |                                                  |
              |               |                                                  | 5
              |               |                                                  |
              |               |                                                  |
              |               |   +----> ...                                     |
              |               |   |                                              +---------+
              |               |   |                                              |         |
              |  next_edges_+---> +----> ...  3                                  |         |
              |               |   |                                              |         |
              |               |   |                                              |         |
              |               |   |                                         5    v         |
              |               |   +----> next +------>+              YES                   |     +------------+
              +---------------+                       |             +---> push(NodeTask) +-----> | ReadyQueue |
                                                      |      4      |                      |     +------------+
                                                      |             |                      |
              +---------------+                       +--> Ready? +-+                      |
              | GraphTask     |                       |             |       6              |
              |               |                       |             | NO                   | 6
              |               |                       |             +----> next.function   |
              | dependencies_+--> map<Node*, int> +-->+                          +         |
              |               |                                                  |         |
              |               |                                                  |         |
              |               |                              6                   v         v
              | not_ready_ +--------------------------------------------->  map<Node*, InputBuffer>
              |               |
              +---------------+
    
    

    手机如下:

    0x06 扫尾操作

    在 thread_main 之中,如果本task已经结束,即做后续操作,具体代码如下。

    auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
      
        // 忽略前面代码
      
        // Check if we've completed execution.
    	  if (local_graph_task->completed()) { // 判断是否结束
          // 如果结束了,就进行后续操作
          local_graph_task->mark_as_completed_and_run_post_processing();
    
          auto base_owner = local_graph_task->owner_;
          // The current worker thread finish the graph_task, but the owning thread
          // of the graph_task might be sleeping on pop() if it does not have work.
          // So we need to send a dummy function task to the owning thread just to
          // ensure that it's not sleeping, so that we can exit the thread_main.
          // If it has work, it might see that graph_task->outstanding_tasks_ == 0
          // before it gets to the task, but it's a no-op anyway.
          //
          // NB: This is not necessary if the current thread is the owning thread.
          if (worker_device != base_owner) {
            // Synchronize outstanding_tasks_ with queue mutex
            std::atomic_thread_fence(std::memory_order_release);
            ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
                ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
          }
        }
    

    我们接下来分析这些扫尾工作。注意,这里是 thread_main 之中的扫尾工作

    6.1 判断结束

    以下代码用来判断本 GraphTask是否结束,其实就是 ReadyQueue 之中是否还有待运行的 NodeTask。

    outstanding_tasks_ 是待处理 NodeTask的数量,用来判断该GrapTask是否还需要执行,其数值总是先加再减,如果数目为0,则说明任务结束了。

    • 当 GraphTask 被创建出来时候,此数值为0。
    • 如果有一个NodeTask被送入到 ReadyQueue,则outstanding_tasks_ 增加 1。
    • 如果在工作线程作执行一次 evaluate_function(task)后,outstanding_tasks的值减 1。
    • 如果这个数量不为0,则此GraphTask依然需要运行。
    bool GraphTask::completed() {
      // outstanding_tasks在evaluate_function中可能会被改变
      return outstanding_tasks_.load() == 0 ||
          (exit_on_error_ && has_error_.load());
    }
    

    6.2 后续&通知

    mark_as_completed_and_run_post_processing 就是进行后续处理。

    执行后续操作 exec_post_processing,然后使用 future_result_->markCompleted 通知主线程。

    void GraphTask::mark_as_completed_and_run_post_processing() {
      // Allow only one thread one attempt to process this logic.
      if (future_completed_.exchange(true)) {
        // Future is already marked complete, or being marked as such.
        // In case the marking complete is only in progress, we add a
        // wait() to guarantee the future is marked complete on exit.
        future_result_->wait();
        return;
      }
    
      try {
        // Run post processing, before marking the future as complete.
        // Drop lock prior to completing, to avoid holding across callbacks.
        std::unique_lock<std::mutex> lock(mutex_);
    
        exec_post_processing(); // 进行后续操作
        std::vector<Variable> vars = std::move(captured_vars_);
    
        // Need to unlock before we call markCompleted to avoid holding locks
        // when the callbacks are called.
        lock.unlock();
        future_result_->markCompleted(std::move(vars));  // 通知主线程
      } catch (std::exception& e) {
        future_result_->setErrorIfNeeded(std::current_exception());
      }
    }
    

    6.2.1 后续操作

    后续操作,如果之前有注册了 callback,则进行调用。也会进行流同步。

    void GraphTask::exec_post_processing() {
      if (!not_ready_.empty()) {
        throw std::runtime_error("could not compute gradients for some functions");
      }
    
      // set the thread_local current_graph_task_ as more callbacks can be installed
      // by existing final callbacks.
      GraphTaskGuard guard(shared_from_this());
      // Lock mutex during each iteration for accessing final_callbacks.size()
      // Unlocking is necessary, because the callback can register
      // more callbacks (or they can be registered from other threads
      // while it's waiting.
      std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
      // WARNING: Don't use a range-for loop here because more callbacks may be
      // added in between callback calls, so iterators may become invalidated.
      for (size_t i = 0; i < final_callbacks_.size(); ++i) {
        cb_lock.unlock();
        final_callbacks_[i]();
        cb_lock.lock();
      }
    
      // Syncs leaf streams with default streams (if necessary)
      // See note "Streaming backwards"
      for (const auto& leaf_stream : leaf_streams) {
        const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
        const auto default_stream = guard.getDefaultStream(leaf_stream.device());
        if (leaf_stream != default_stream) {
          auto event = c10::Event{c10::DeviceType::CUDA};
          event.record(leaf_stream);
          default_stream.wait(event);
        }
      }
    }
    

    6.2.2 通知主线程

    之前在 execute 之中会用 fut->wait() 来等待任务完成。下面我们省略了部分代码。

    auto Engine::execute(const edge_list& roots,
                         const variable_list& inputs,
                         bool keep_graph,
                         bool create_graph,
                         bool accumulate_grad,
                         const edge_list& outputs) -> variable_list {
    
      
      // Queue the root
      if (skip_dummy_node) {
        execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
      } else {
        execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
      }
      auto& fut = graph_task->future_result_;
      fut->wait();
      return fut->value().toTensorVector();
    }
    

    在 mark_as_completed_and_run_post_processing 会用如下代码来通知主线程。

    future_result_->markCompleted(std::move(vars));  // 通知主线程
    

    6.3 通知其他线程

    如果这个task是来自其它work thread,即 worker_device != base_owner,则向那个worker thread的queue发送一个dummy function task,让那个工作线程也执行起来。

    local_graph_task 表示我们从队列中检索的 graph_task。外部graph_ 任务表示我们需要执行的可重入执行的总体graph_任务。

    在 thread_main 之中,有一个 work around。就是:当前工作线程完成 graph_task,但此时,拥有graph_task的线程可能正在pop()上等待休眠。因此,我们需要向所属线程发送一个仿造的函数任务,以唤醒它,这样我们可以退出thread_main。

    这种情况发生在可重入反向传播的情形。

    // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
    //    backward call from that device.
    graph_task->owner_ = worker_device;
    

    具体代码如下:

        // Check if we've completed execution.
        if (local_graph_task->completed()) {
          local_graph_task->mark_as_completed_and_run_post_processing();
          auto base_owner = local_graph_task->owner_; // 当前设备
            
          if (worker_device != base_owner) {
              
            // 不是同一个设备
              
            // Synchronize outstanding_tasks_ with queue mutex
            std::atomic_thread_fence(std::memory_order_release);
            ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
                ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0))); // dummy task
          }
        }
    

    其他线程当收到了 dummy task 之后,不会处理,因为 function 是 nullptr,然后就调用 local_ready_queue->pop() 继续从自己的queue 中读取下一个 task

    具体如下:

    1. 主线程等待。
    2. 如果工作线程发现GraphTask 已经结束,就通知主线程。
    3. 如果需要唤醒其他线程,就向该线程对应的 queue 插入 NodeTask。
    4. 对应线程取出 NodeTask 进行执行。
                                             +------------------------------------------------+
                                             | Worker Thread 1                                |
                                             |                                                |
                                             |  thread_main{                                  |
                                             |                                                |
                                             |     mark_as_completed_and_run_post_processing  |
                           2 markCompleted() |     {                                          |
                                     +-------------------+                                    |
                                     |       |     }                                          |
                                     |       |                                                |
    +---------------+                |       |     push(NodeTask) +-----+                     |
    | Main Thread   |                |       |                          |                     |
    |               |                |       |   }                      |                     |
    |               |                |       |                          |                     |
    |               |                |       +------------------------------------------------+
    |               |                |                                  |
    |               |                |                                3 |
    |               |                v                                  v
    |               |                                           +-------+-------+
    |               |   1      +----------------+               |               |
    |               | wait()   |                |               |  ReadyQueue   |
    |           +------------> | future_result_ |               |               |
    |               |          |                |               +-------+-------+
    |               |          +----------------+                       |
    |               |                                                   |
    |               |                                                 4 | pop(NodeTask)
    |               |                                                   |
    |               |                                                   v
    |               |                                          +--------+---------------------+
    |               |                                          | Worker Thread 2              |
    |               |                                          |                              |
    |               |                                          |                              |
    +---------------+                                          |                              |
                                                               |                              |
                                                               |                              |
                                                               +------------------------------+
    
    

    至此,后向传播已经分析完毕,从下一篇开始,我们正式进入 PyTorch 分布式训练。

    0xFF 参考

    https://www.zhihu.com/column/gemfield

    【PyTorch】聊聊 backward 背后的代码

    pytorch笔记(计算图+autograd)-Node(1)

    详解Pytorch中的网络构造

    PyTorch的优化器

    PyTorch的分布式

    PyTorch的Tensor(下)

    PyTorch的Tensor(中)

    PyTorch的Tensor(上)

    PyTorch的动态图(下)

    PyTorch的动态图(上)

    PyTorch Internals 5:Autograd的实现

    A GENTLE INTRODUCTION TO TORCH.AUTOGRAD

    PyTorch学习笔记(12)——PyTorch中的Autograd机制介绍

    PyTorch 的 Autograd

  • 相关阅读:
    读取美团购
    获取enum的Description
    获取手机号码所在地
    手动添加XA/XD的端口和磁盘映射
    无法使用SQL Server Management Studio的找到Network Server
    [XenDesktop5.5]+HyperV上的Win7+VDA无法启用Aero效果
    傻瓜式设置WANem配置 (点对点网络设置)
    [XD5.5]如何关闭XD的Audio UDP通道
    使用TCP方式登陆OCS
    在Linux上建立文件夹指向在Win共享的文件夹
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/15491941.html
Copyright © 2011-2022 走看看