zoukankan      html  css  js  c++  java
  • [源码解析] PyTorch 分布式 Autograd (3) 上下文相关

    [源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关

    0x00 摘要

    我们已经知道 dist.autograd 如何发送和接受消息,本文再来看看如何其他支撑部分,就是如何把发送接受两个动作协调起来,如何确定每个发送/接受节点,如何确定每一个消息交互Session。

    通过本文大家可以了解:AutogradMetadata 用来在不同节点间传递 autograd 元信息,DistAutogradContext 代表一个分布式autograd 相关信息,DistAutogradContainer 负责在一个worker之上存储 DistAutogradContext。

    PyTorch分布式其他文章如下:

    深度学习利器之自动微分(1)

    深度学习利器之自动微分(2)

    [源码解析]深度学习利器之自动微分(3) --- 示例解读

    [源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

    [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

    [源码解析] PyTorch如何实现前向传播(3) --- 具体实现

    [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

    [源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

    [源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

    [源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

    [源码解析] PyTorch 分布式(1)------历史和概述

    [源码解析] PyTorch 分布式(2) ----- DataParallel(上)

    [源码解析] PyTorch 分布式(3) ----- DataParallel(下)

    [源码解析] PyTorch 分布式(4)------分布式应用基础概念

    [源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

    [源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

    [源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组

    [源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇

    [源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

    [源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

    [源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

    [源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

    [源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

    [源码解析] PyTorch 分布式 Autograd (1) ---- 设计

    [源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

    为了更好的说明,本文代码会依据具体情况来进行相应精简。

    0x01 设计脉络

    1.1 前文回顾

    在前文之中当发送消息时候,我们在 sendMessageWithAutograd 通过 getMessageWithAutograd 来获得了 FORWARD_AUTOGRAD_REQ 类型的消息。

    c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
        RpcAgent& agent,
        const WorkerInfo& dst,
        torch::distributed::rpc::Message&& wrappedRpcMsg,
        bool forceGradRecording,
        const float rpcTimeoutSeconds,
        bool forceDisableProfiling) {
        
      auto msg = getMessageWithAutograd( // 这里会与上下文交互,构建了 FORWARD_AUTOGRAD_REQ
          dst.id_,
          std::move(wrappedRpcMsg),
          MessageType::FORWARD_AUTOGRAD_REQ,
          forceGradRecording,
          agent.getDeviceMap(dst));
    
      c10::intrusive_ptr<JitFuture> fut;
      if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
        auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
        auto msgWithProfiling = getMessageWithProfiling(
            std::move(msg),
            rpc::MessageType::RUN_WITH_PROFILING_REQ, //构建消息
            std::move(profilerConfig));
        // 发送消息
        fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
      } else {
        // 发送消息
        fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
      }
    
      return fut;
    }
    

    而 getMessageWithAutograd 会与上下文交互,其代码位于 torch/csrc/distributed/autograd/utils.cpp。

    Message getMessageWithAutograd(
        const rpc::worker_id_t dstId,
        torch::distributed::rpc::Message&& wrappedRpcMsg,
        MessageType msgType,
        bool forceGradRecording,
        const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
      
      // 获取到 DistAutogradContainer
      auto& autogradContainer = DistAutogradContainer::getInstance();
    
      // If there is no valid context and no tensor requires grads, send original
      // rpc message. otherwise, attach grad info and grad functions and send
      // rpcWithAutograd message.
      auto tensorsRequireGrad =
          torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
      if (!autogradContainer.hasValidContext() ||
          (!forceGradRecording && !tensorsRequireGrad)) {
        return std::move(wrappedRpcMsg);
      }
    
      // Retrieve the appropriate context to modify.
      auto autogradContext = autogradContainer.currentContext(); // 获取到上下文,每个worker都有自己的上下文
    
      // Wrap the original rpc with autograd information.
      // newAutogradMessageId 会生成一个messageID
      AutogradMetadata autogradMetadata( // 构建了 AutogradMetadata
          autogradContext->contextId(), autogradContainer.newAutogradMessageId());
      auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
          RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
          msgType,
          autogradMetadata,
          std::move(wrappedRpcMsg),
          deviceMap);
    
      if (tensorsRequireGrad) {
        // Record autograd information for 'send'.
        addSendRpcBackward( // 这里把本地上下文,autograd 的元信息等一起打包
            autogradContext, autogradMetadata, rpcWithAutograd->tensors());
      }
      // Record the workerID
      autogradContext->addKnownWorkerId(dstId);
    
      return std::move(*rpcWithAutograd).toMessage(); // 最终构建了一个message
    }
    

    因此,就引出了AutogradMetadata,DistAutogradContainer 和 DistAutogradContext 等一系列基础类,我们接下来就仔细分析一下。

    1.2 总体思路

    我们概括一下总体思路。

    先看看问题:假如一套系统包括 a,b,c 三个节点,每个节点运行一个 worker,那么当运行一个传播操作,我们涉及到在这三个节点之间互相传播。因此我们需要一个机制,来在这三个节点之中唯一标示这个传播过程,在这个传播过程之中,也要在每一个节点之上把每一个send/recv都标示出来,这样才能让节点可以支持多个操作并行

    再看看解决方案:

    • 使用上下文来唯一标示一个传播过程。DistAutogradContext 存储在一个worker之上的每一个分布式autograd的相关信息,其在分布式 autograd 之中封装前向和后向传播,累积梯度,这避免了多个worker在彼此的梯度上互相影响。每个自动微分过程被赋予一个唯一的 autograd_context_id,在容器中,这个微分过程的上下文(DistAutogradContext) 依据这个autograd_context_id 来唯一确认。
    • 使用autogradMessageId 来表示一对 send/recv autograd 函数。每send-recv对被分配一个全局唯一的autograd_message_id 以唯一地标识该send-recv对。这对于在向后传播期间查找远程节点上的相应函数很有用。
    • 最后,每个worker需要有一个地方来保持上下文和messageid,所以有了DistAutogradContainer这个类。每个worker拥有唯一一个单例DistAutogradContainer,其负责:
      • 对于每一个自动微分过程存储其分布式上下文。
      • 一旦这个自动微分过程结束,就清除其数据。

    这样,在前向传播期间,Pytorch 在上下文中存储每个 autograd 传播的sendrecv函数。这确保我们在 autograd 图中保存对适当节点的引用以使其保持活动状态。除此之外,这也使得在向后传播期间很容易查找到对应的sendrecv函数。

    0x02 AutogradMetadata

    2.1 定义

    AutogradMetadata 这个类是用来在不同节点之间传递 autograd 的元信息,就是把上下文等信息封装了一下。即,发送方通知接收方自己的上下文信息,接收方会依据收到的这些上下文信息作相应处理。

    我们提前剧透,接收方会使用 autogradContextId 和 autogradMessageId 分别作为 上下文 和 消息 的唯一标示。从注释之中可以知道。

    • autogradContextId 是全局唯一整数,用来表示一个唯一的分布式 autograd 传播过程(包括前向传播和后向传播)。一个传播过程会包括在反向传播链条上的多对send/recv autograd 函数。
    • autogradMessageId 是全局唯一整数,用来表示一对 send/recv autograd 函数。每send-recv对被分配一个全局唯一的autograd_message_id 以唯一地标识该send-recv对。这对于在向后传播期间查找远程节点上的相应函数很有用。
    // This structure represents autograd metadata that we need to pass across
    // different nodes when we call an RPC which needs autograd computation.
    struct TORCH_API AutogradMetadata {
      AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);
    
      // autogradContextId_ is a globally unique integer that identifies a
      // particular distributed autograd pass.
      int64_t autogradContextId;
      // autogradMessageId_ is a globally unique integer that identifies a pair
      // of send/recv autograd functions.
      int64_t autogradMessageId;
    };
    

    那么问题来了,autogradContextId 和 autogradMessageId 分别怎么做到全局(包括多个节点)唯一呢?

    2.2 autogradMessageId

    我们先概括一下:autogradMessageId 是由 rank 间接生成的,然后在内部进行递增,所以可以保证全局唯一。

    我们从后往前推导。

    • 先看 newAutogradMessageId 是如何生成消息 id,原来是在 DistAutogradContainer 之中的成员变量 next_autograd_message_id_ 递增得到。
    int64_t DistAutogradContainer::newAutogradMessageId() {
      // Check for overflow into workerId_ section.
      TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
      return next_autograd_message_id_++;
    }
    
    • 然后看如何初始化 next_autograd_message_id_?从 DistAutogradContainer 的 init 函数中可以知道,原来是依据 worker_id 来生成 next_autograd_message_id_。work_id 是 init 函数所得到的参数。
    DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
      std::lock_guard<std::mutex> guard(dist_container_init_lock_);
    
      auto& container = getInstanceInternal();
      container.worker_id_ = worker_id;
      container.next_context_id_ = static_cast<int64_t>(worker_id)
          << kAutoIncrementBits;
      container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
          << kAutoIncrementBits;
      container.max_id_ =
          (kAutoIncrementMask |
           (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
      container.initialized_ = true;
      return container;
    }
    
    • 我们再推导,看看如何设置 worker id,找到了如下,看来需要看看 python 世界的 _init 方法。
    module.def(
        "_init",
        [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
        py::call_guard<py::gil_scoped_release>());
    

    来到 python 世界,可以看到,使用了 rank 来作为参数,而 rank 是每个 worker 唯一的,这样就保证了 worker ID 唯一,从而 消息 id 唯一。

        def init_rpc(
            name,
            backend=None,
            rank=-1,
            world_size=None,
            rpc_backend_options=None,
        ):
    			dist_autograd._init(rank) # rank是全局唯一
    

    我们把这些逻辑关系总结下来:

    worker_id = rank;
    
    container.worker_id_ = worker_id;
    
    container.next_autograd_message_id_ = static_cast<int64_t>(worker_id) << kAutoIncrementBits
    

    然后 next_autograd_message_id_ 内部递增。

    int64_t DistAutogradContainer::newAutogradMessageId() {
      // Check for overflow into workerId_ section.
      TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
      return next_autograd_message_id_++;
    }
    

    所以,AutogradMessageId 是全局唯一的。我们用图例来看看:

    +----------------------------------------------------------------------------------------+
    | worker                                                                                 |
    |                       +-------------------------------------+                          |
    |                       | DistAutogradContainer               |                          |
    |                       |                                     |                          |
    |                       |                                     |                          |
    |              init()   |                                     |                          |
    |      rank +--------------+----> worker_id_                  |                          |
    |                1      |  |                                  |   newAutogradMessageId() |
    |                       |  +----> next_autograd_message_id_+------------------+          |
    |                       |                                     |          2    |          |
    |                       +-------------------------------------+               |          |
    |                                                                             |          |
    |                                                                             |          |
    |                                                                             |          |
    |                                                                             |          |
    |                     +---------------------------------------------------------------+  |
    |                     | getMessageWithAutograd                                |       |  |
    |                     |                                                       |       |  |
    |                     |                                                       v       |  |
    |                     |                                                               |  |
    |                     |   AutogradMetadata autogradMetadata(contextId(), MessageId()) |  |
    |                     |                           4                           3       |  |
    |                     |                                                               |  |
    |                     +---------------------------------------------------------------+  |
    |                                                                                        |
    +----------------------------------------------------------------------------------------+
    

    为了看看 autogradContextId 为什么可以保证唯一,我们需要先分析 DistAutogradContainer 和 DistAutogradContext。

    0x03 DistAutogradContainer

    每个worker拥有唯一一个单例DistAutogradContainer,其负责:

    • 对于每一个自动微分过程存储其分布式上下文。
    • 一旦这个自动微分过程结束,就清除其数据。

    每个自动微分过程被赋予一个唯一的 autograd_context_id。在每个容器中,这个微分过程的上下文(DistAutogradContext) 依据这个autograd_context_id 来唯一确认。autograd_context_id 是一个 64 bit 的全局唯一id,前 16 bis 是 worker_id,后 48 位是在每个worker内部自动递增id。所以可见,一个Container 之中,是有多个Context的。

    此容器还负责维护全局唯一的消息id,用来关联发送/接收自动微分函数对。格式类似于autograd_context_id,是一个64位整数,前16位是工作者id,后48位是worker内部自动递增的。

    因为消息 id 和 上下文 id 的前16 位是 worker_id,也就是 rank id,再加上后48位内部自增,所以可以保证 消息 id 和 上下文 id 全局唯一

    3.1 定义

    DistAutogradContainer 定义如下,其中:

    • worker_id_ : 本 worker 的 ID,其实就是本 worker 的 rank。
    • next_context_id_ :自增的上下文ID,用来给每个自动微分过程赋予一个唯一的autograd_context_id。在一个传播链条上,其实只有第一个节点的 DistAutogradContainer 用到了 next_context_id_ 来生成 Context,后续节点的 DistAutogradContainer 都是依据第一个 DistAutogradContainer 的 context id 信息来在本地生成对应 context id 的 Context。
    • next_autograd_message_id_ :维护全局唯一的消息id,用来关联 发送/接收 自动微分函数对。此变量是在本节点发送时候会使用到。
    // Singleton class per worker which is responsible for storing the distributed
    // autograd context for each autograd pass and also cleans up data for an
    // autograd pass once its done.
    //
    // Each autograd pass is assigned a unique autograd_context_id and all data for
    // that pass (DistAutogradContext) is stored in this container indexed by the
    // autograd_context_id. The autograd_context_id itself is a 64 bit globally
    // unique id. The first 16 bits is the worker_id and the next 48 bits is an
    // auto-incrementing id for each worker.
    //
    // This container is also responsible for maintaining a globally unique message
    // id, which is used to associate send/recv autograd function pairs. The format
    // is similar to the autograd_context_id where we have a 64 bit integer with
    // first 16 bits being the worker id and next 48 bits are auto-incrementing.
    class TORCH_API DistAutogradContainer {
    
     private:
      // Number of shards for the map storing autograd contexts. We'd like this
      // to be a power of 2 and we don't expect a value much higher than the
      // number of cores would provide much benefit.
      static constexpr uint32_t kNumDefaultShards = 128;
    
      // Use cache line size for alignment.
      static constexpr int kCacheLineSize = 64;
    
      // Structure holding one shard of the sharded autograd context map with its
      // associated lock. Align to cache line size to avoid contention between
      // adjacent entries.
      struct alignas(kCacheLineSize) ContextsShard {
        // Lock for this shard.
        mutable std::mutex lock;
    
        // Map storing autograd contexts for this shard.
        std::unordered_map<int64_t, ContextPtr> contexts; // 这里存储了上下文指针
      };
    
      // Auto incrementing context id used to identify unique autograd passes.
      // Initialized with the first 16 bits being the worker_id.
      std::atomic<int64_t> next_context_id_; // 新增上下文id
    
      // Unique id to identify a worker in the distributed setting.
      int16_t worker_id_;
    
      // Whether or not the container has been initialized appropriately.
      bool initialized_;
    
      // Sharded autograd context map.
      std::vector<ContextsShard> autograd_contexts_; // 存储上下文列表
    
      // Number of shards for the sharded autograd_contexts_ map.
      uint32_t num_shards_;
    
      // Autograd message id to identify unique send/recv autograd function pairs.
      std::atomic<int64_t> next_autograd_message_id_;
    
      // Maximum allowed value for autograd_context_id or autograd_message_id.
      int64_t max_id_;
    };
    

    3.2 构建

    Init 方法构建了 DistAutogradContainer,主要就是利用 worker_id 对本地成员变量进行相关赋值。

    DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
      std::lock_guard<std::mutex> guard(dist_container_init_lock_);
    
      TORCH_CHECK(
          worker_id >= 0 && worker_id <= kMaxWorkerId,
          "worker_id needs to be in the range [0, 65535]")
    
      auto& container = getInstanceInternal();
      TORCH_CHECK(
          !container.initialized_ || (worker_id == container.worker_id_),
          "Container is already initialized with worker_id: ",
          container.worker_id_,
          ", cannot initialize with different worker_id: ",
          worker_id);
    
      if (container.initialized_) {
        return container;
      }
    
      container.worker_id_ = worker_id;
      container.next_context_id_ = static_cast<int64_t>(worker_id)
          << kAutoIncrementBits;
      container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
          << kAutoIncrementBits;
      container.max_id_ =
          (kAutoIncrementMask |
           (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
      container.initialized_ = true;
      return container;
    }
    

    0x04 DistAutogradContext

    DistAutogradContext 存储在一个worker之上的每一个分布式autograd的相关信息,其在分布式 autograd 之中封装前向和后向传播,累积梯度,这避免了多个worker在彼此的梯度上互相影响。

    由前面可知道,contextId_ 是全局唯一。

    4.1 定义

    这里仅仅给出 DistAutogradContext 成员变量,忽略其成员函数。其中成员变量最主要的有三个:

    • contextId_ 是上下文 id。
    • sendAutogradFunctions_ 是一个 map 类型变量,会收集所有发送请求对应的反向传播算子 SendRpcBackward。
    • recvAutogradFunctions_ 是一个 map 类型变量,会收集所有接受送请求对应的反向传播算子 RecvRpcBackward。

    关于 SendRpcBackward 和 RecvRpcBackward,我们后续会结合引擎进行分析。

    // DistAutogradContext which stores information for a single distributed
    // autograd pass on a worker.
    class TORCH_API DistAutogradContext {
     private:
      friend class BackwardPassCleanupGuard;
      friend class DistEngine;
      friend class RecvRpcBackward;
      friend class DistAccumulateGradCaptureHook;
    
      const int64_t contextId_;
    
      // Set containing known worker IDs, used in cleaning up autograd context.
      // Whenever a sendRpcBackward is attached to the autograd graph for this
      // context, the destination is added here.
      std::unordered_set<rpc::worker_id_t> knownWorkerIds_;
    
      // Map from autograd_message_id to appropriate 'send' autograd function.
      std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
          sendAutogradFunctions_;
    
      // Map from autograd_message_id to appropriate 'recv' autograd function.
      std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
          recvAutogradFunctions_;
    
      // Gradients accumulated in this context so far. The key is the variable on
      // which the gradient needs to be accumulated and the value is the gradient
      // that needs to be accumulated on that variable..
      c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;
    
      // See comments for recordGradEvent(c10::Device device);
      std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
      const c10::impl::VirtualGuardImpl impl_;
    
      // The autograd GraphTask for the backward pass on this node for this context.
      std::shared_ptr<torch::autograd::GraphTask> graphTask_;
    
      // List of futures for RPCs initiated by this node to propagate gradients to
      // other nodes. The distributed autograd engine on this node can return
      // successfully only if all these futures are done and are successful.
      std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;
    
      // Lock to protect concurrent modification of the context.
      mutable std::mutex lock_;
    };
    
    

    4.2 消息

    上下文主要包括几种消息类型,比如:

    // Messages with autograd info
    FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
    FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,
    
    // Messages to propagate gradients on the backward pass.
    BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
    BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,
    

    4.3 构建

    我们首先看看如何构建上下文。

    4.3.1 getOrCreateContext

    getOrCreateContext 函数是用来得到上下文,如果已经有,就直接获取,如果没有,就新构建一个。这是一个被动调用,recv 端会用到这个

    ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
      auto& shard = getShard(context_id);
      std::lock_guard<std::mutex> guard(shard.lock);
      auto it = shard.contexts.find(context_id); // 根据这个context id来查找
      if (it != shard.contexts.end()) {
        return it->second; // 找到就返回
      }
    
      auto& context = // 如果没有,就构建一个 context
          shard.contexts
              .emplace(
                  std::piecewise_construct,
                  std::forward_as_tuple(context_id),
                  std::forward_as_tuple(
                      std::make_shared<DistAutogradContext>(context_id)))
              .first->second;
      return context;
    }
    

    4.3.2 newContext

    这里是主动调用,send 端会调用这个方法

    4.3.2.1 Python

    当分布式调用时候,python世界会生成一个context。

                with dist_autograd.context() as context_id:
                    output = model(indices, offsets)
                    loss = criterion(output, target)
    
                    # Run distributed backward pass
                    dist_autograd.backward(context_id, [loss])
    
                    # Run distributed optimizer. Gradients propagated all the way to the parameter servers
                    opt.step(context_id)
    

    当生成时,__enter__ 会调用 _new_context() 在C++生成一个context。

    class context(object):
        '''
        Context object to wrap forward and backward passes when using
        distributed autograd. The ``context_id`` generated in the ``with``
        statement  is required to uniquely identify a distributed backward pass
        on all workers. Each worker stores metadata associated with this
        ``context_id``, which is required to correctly execute a distributed
        autograd pass.
    
        Example::
            >>> import torch.distributed.autograd as dist_autograd
            >>> with dist_autograd.context() as context_id:
            >>>   t1 = torch.rand((3, 3), requires_grad=True)
            >>>   t2 = torch.rand((3, 3), requires_grad=True)
            >>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
            >>>   dist_autograd.backward(context_id, [loss])
        '''
        def __enter__(self):
            self.autograd_context = _new_context() # 这里生成一个上下文
            return self.autograd_context._context_id()
    
        def __exit__(self, type, value, traceback):
            _release_context(self.autograd_context._context_id())
    

    具体通过如下映射,我们可以看到 C++ 世界之中对应的方法,调用到了 DistAutogradContainer::getInstance().newContext()。

      module.def(
          "_new_context",
          []() -> const ContextPtr {
            return DistAutogradContainer::getInstance().newContext();
          },
          py::return_value_policy::reference);
    
    4.3.2.2 C++

    我们来到了C++世界。每一个线程都有一个autograd_context_id。

    constexpr int64_t kInvalidContextId = -1;
    
    // Each thread has a single autograd_context_id valid at any point in time.
    static thread_local int64_t current_context_id_ = kInvalidContextId;
    

    newContext 就是生成了一个DistAutogradContext,其中通过 Container 的成员变量 next_context_id_ 的递增来指定下一个上下文的id。

    const ContextPtr DistAutogradContainer::newContext() {
    
      auto context_id = next_context_id_++; // 递增
      current_context_id_ = context_id;  // 在这里设置了本地线程的 current_context_id_
    
      // Check for overflow into workerId_ section.
      TORCH_INTERNAL_ASSERT(context_id < max_id_);
    
      auto& shard = getShard(context_id);
      std::lock_guard<std::mutex> guard(shard.lock);
      auto& context =
          shard.contexts
              .emplace(
                  std::piecewise_construct,
                  std::forward_as_tuple(context_id),
                  std::forward_as_tuple(
                      std::make_shared<DistAutogradContext>(context_id)))
              .first->second;
    
      return context;
    }
    

    4.4 如何共享上下文

    具体使用中,在with语句中生成的context_id可以用作在所有 worker 之上唯一标识一个分布式后向传播(包括前向传播和后向传播)。每个worker存储与此 context_id关联的元数据,这是正确执行分布式自动加载过程所必需的。

    因为需要在多个 worker 之中都存储这个 context_id关联的元数据,所以就需要一个 封装/发送/接受的机制来在 worker 之间传递这个元数据,封装机制就是我们前面提到的 AutogradMetadata。我们接下来看看如何发送/接受上下文元信息

    4.4.1 发送方

    当发送消息时候,getMessageWithAutograd 会使用 autogradContainer.currentContext() 获取当前上下文,进行发送。

    Message getMessageWithAutograd(
        const rpc::worker_id_t dstId,
        torch::distributed::rpc::Message&& wrappedRpcMsg,
        MessageType msgType,
        bool forceGradRecording,
        const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
      auto& autogradContainer = DistAutogradContainer::getInstance();
    
      // If there is no valid context and no tensor requires grads, send original
      // rpc message. otherwise, attach grad info and grad functions and send
      // rpcWithAutograd message.
      auto tensorsRequireGrad =
          torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
      if (!autogradContainer.hasValidContext() ||
          (!forceGradRecording && !tensorsRequireGrad)) {
        return std::move(wrappedRpcMsg);
      }
    
      // Retrieve the appropriate context to modify.
      auto autogradContext = autogradContainer.currentContext(); // 获取当前上下文
    
      // Wrap the original rpc with autograd information.
      AutogradMetadata autogradMetadata( // 使用上下文id和消息id来构建元数据
          autogradContext->contextId(), autogradContainer.newAutogradMessageId());
      auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
          RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
          msgType,
          autogradMetadata,
          std::move(wrappedRpcMsg),
          deviceMap);
    
      if (tensorsRequireGrad) {
        // Record autograd information for 'send'.
        addSendRpcBackward(
            autogradContext, autogradMetadata, rpcWithAutograd->tensors());
      }
      // Record the workerID
      autogradContext->addKnownWorkerId(dstId);
    
      return std::move(*rpcWithAutograd).toMessage();
    }
    

    我们之前的图现在可以拓展,加入了上下文ID。

    +----------------------------------------------------------------------------------------+
    | worker                                                                                 |
    |                  +------------------------------------------+                          |
    |                  |DistAutogradContainer                     |                          |
    |          init()  |                                          |                          |
    |  rank +-------------+----> worker_id_                       |                          |
    |                  |  |                                       |                          |
    |                  |  +----> next_context_id_+-------------+  |                          |
    |                  |  |                                    |  |                          |
    |                  |  +----> next_autograd_message_id_ +----------------------+          |
    |                  |                                       |  |               |          |
    |                  |                                       |  |               |          |
    |                  +------------------------------------------+               |          |
    |                                                          |                  |          |
    |                                                          |                  |          |
    |                                                          |                  |          |
    |                  +------------------------------------------------------------------+  |
    |                  |getMessageWithAutograd                 |                  |       |  |
    |                  |                                       |                  |       |  |
    |                  |                                       v                  v       |  |
    |                  |                                                                  |  |
    |                  |    AutogradMetadata autogradMetadata(contextId(), MessageId())   |  |
    |                  |                                                                  |  |
    |                  |                                                                  |  |
    |                  +------------------------------------------------------------------+  |
    |                                                                                        |
    +----------------------------------------------------------------------------------------+
    

    addSendRpcBackward 就被传入当前上下文之中,后续反向传播时候,会取出这个 addSendRpcBackward。

    void addSendRpcBackward(
        const ContextPtr& autogradContext,
        const AutogradMetadata& autogradMetadata,
        std::vector<torch::Tensor>& tensors) {
      // Attach autograd information only for tensors requiring grad.
      std::vector<torch::Tensor> tensors_with_grad;
      std::copy_if(
          tensors.begin(),
          tensors.end(),
          std::back_inserter(tensors_with_grad),
          [](const torch::Tensor& t) { return t.requires_grad(); });
    
      // Attach the appropriate autograd edges.
      auto grad_fn = std::make_shared<SendRpcBackward>();
      grad_fn->set_next_edges(
          torch::autograd::collect_next_edges(tensors_with_grad));
    
      // Add the appropriate input metadata for the grad_fn.
      for (const auto& tensor : tensors_with_grad) {
        grad_fn->add_input_metadata(tensor);
      }
    
      // Record the send autograd function in our current context.
      autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
    }
    

    4.4.2 接受方

    在 addRecvRpcBackward 之中,会依据传递过来的 autogradMetadata.autogradContextId 来构建一个上下文。

    ContextPtr addRecvRpcBackward(
        const AutogradMetadata& autogradMetadata,
        std::vector<torch::Tensor>& tensors,
        rpc::worker_id_t fromWorkerId,
        const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
      // Initialize autograd context if necessary.
      auto& autogradContainer = DistAutogradContainer::getInstance();
      // 生成或者得到一个上下文,把发送方的 autogradContextId 传入,即利用 autogradContextId 作为key后续可以查找到这个上下文
      auto autogradContext = 
          autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
    
      if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
        // Attach the tensors as inputs to the autograd function.
        auto grad_fn = std::make_shared<RecvRpcBackward>(
            autogradMetadata, autogradContext, fromWorkerId, deviceMap);
        for (auto& tensor : tensors) {
          if (tensor.requires_grad()) {
            torch::autograd::set_history(tensor, grad_fn);
          }
        }
    
        // Now update the autograd context with the necessary information.
        autogradContext->addRecvFunction(
            grad_fn, autogradMetadata.autogradMessageId);
      }
    
      return autogradContext;
    }
    

    这样,发送方和接收方就共享了一个上下文,而且这个上下文的id是全局唯一的。

    具体逻辑如下,上方是发送端,下方是接收端。

    • 发送端
      • 利用本地 context_id 构建了 AutogradMetadata,AutogradMetadata含有 ctx_id, msg_id。
      • 利用 AutogradMetadata 构建了 Message。
      • 利用 agent.send 发送了 Message。
    • 接收端:
      • 收到了 Message。
      • 从 Message 之中解析出 AutogradMetadata。
      • 从 AutogradMetadata 提取出 context_id。
      • 利用 context_id 构建了本地的 DistAutogradContext。
    • 发送方和接收方就共享了一个上下文(这个上下文的id是全局唯一的)。
    +----------------------------------------------------------------------------------+
    | sendMessageWithAutograd                                                          |
    |                                                                                  |
    |  +----------------------------------------------------------------------------+  |
    |  | addSendRpcBackward                                                         |  |
    |  |                                                                            |  |
    |  |                                                                            |  |
    |  |               autogradMetadata = AutogradMetadata(context_id, message_id)  |  |
    |  |                          +                                                 |  |
    |  |                          |                                                 |  |
    |  +----------------------------------------------------------------------------+  |
    |                             |                                                    |
    |                             v                                                    |
    |        agent.send(message(autogradMetadata)                                      |
    |                             +                                                    |
    |                             |                                                    |
    +----------------------------------------------------------------------------------+
                                  |
                                  |
                                  |
                                  |                                             Sender
    +-----------------------------------------------------------------------------------+
                                  |                                             Receiver
                                  | message
                                  v
                                  |
    +----------------------------------------------------------------------------------+
    | processForwardAutogradReq   |                                                    |
    |                             |                                                    |
    |                             | message.autogradMetadata                           |
    |                             v                                                    |
    |  +----------------------------------------------------------------------------+  |
    |  | addSendRpcBackward       |                                                 |  |
    |  |                          |                                                 |  |
    |  |                          +--------------------+                            |  |
    |  |                                               |                            |  |
    |  |                                               v                            |  |
    |  |   autogradContext = getOrCreateContext(autogradMetadata.autogradContextId) |  |
    |  |                                                                            |  |
    |  |                                                                            |  |
    |  +----------------------------------------------------------------------------+  |
    |                                                                                  |
    +----------------------------------------------------------------------------------+
    

    0x05 前向传播交互过程

    前面的分享过程还是简略,我们接下来把完整的发送/接受过程详细分析一下。

    5.1 发送

    这里对应设计中的如下文字:

    在前向传播期间,我们在上下文中存储每个 autograd 传播的sendrecv函数。这确保我们在 autograd 图中保存对适当节点的引用以使其保持活动状态。除此之外,这也使得在后向传播期间很容易查找到对应的sendrecv函数。

    5.1.1 发送逻辑

    代码逻辑如下:

    • 生成一个 grad_fn,其类型是 SendRpcBackward。
    • 调用 collect_next_edges 和 set_next_edges 为 SendRpcBackward 添加后续边,这些函数我们在前面系列中有分析。
    • 调用 add_input_metadata 添加输入元数据。
    • 调用 addSendFunction 往上下文添加 grad_fn。
    void addSendRpcBackward(
        const ContextPtr& autogradContext,
        const AutogradMetadata& autogradMetadata,
        std::vector<torch::Tensor>& tensors) {
      // Attach autograd information only for tensors requiring grad.
      std::vector<torch::Tensor> tensors_with_grad;
      std::copy_if(
          tensors.begin(),
          tensors.end(),
          std::back_inserter(tensors_with_grad),
          [](const torch::Tensor& t) { return t.requires_grad(); });
    
      // Attach the appropriate autograd edges.
      auto grad_fn = std::make_shared<SendRpcBackward>();
      grad_fn->set_next_edges( // 这里会设置其输出边
          torch::autograd::collect_next_edges(tensors_with_grad));
    
      // Add the appropriate input metadata for the grad_fn.
      for (const auto& tensor : tensors_with_grad) {
        grad_fn->add_input_metadata(tensor);
      }
    
      // Record the send autograd function in our current context.
      autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
    }
    

    5.1.2 设置上下文

    我们再回忆一下DistAutogradContext 定义,这里仅仅给出其部分成员变量。

    • contextId_ 是上下文 id。
    • sendAutogradFunctions_ 是一个 map 类型变量,会收集所有发送请求对应的反向传播算子 SendRpcBackward。
    • recvAutogradFunctions_ 是一个 map 类型变量,会收集所有接受送请求对应的反向传播算子 RecvRpcBackward。
    // DistAutogradContext which stores information for a single distributed
    // autograd pass on a worker.
    class TORCH_API DistAutogradContext {
    
      const int64_t contextId_;
    
      // Map from autograd_message_id to appropriate 'send' autograd function.
      std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
          sendAutogradFunctions_;
    
      // Map from autograd_message_id to appropriate 'recv' autograd function.
      std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
          recvAutogradFunctions_;
    };
    

    addSendFunction 就是往 sendAutogradFunctions_ 之中添加SendRpcBackward,后续可以按照 message id 来得到这个 SendRpcBackward。

    void DistAutogradContext::addSendFunction(
        const std::shared_ptr<SendRpcBackward>& func,
        int64_t autograd_message_id) {
    
      std::lock_guard<std::mutex> guard(lock_);
      TORCH_INTERNAL_ASSERT(
          sendAutogradFunctions_.find(autograd_message_id) ==
          sendAutogradFunctions_.end());
      sendAutogradFunctions_.emplace(autograd_message_id, func);
    }
    

    前面是从上下文构建的角度看,本次从上下文内容来看。

    此时发送端逻辑如下:

    +--------------------------------------------------------------+    +-------------------+
    | worker                                                       |    |SendRpcBackward    |
    | +---------------------------------------------------------+  |    |                   |
    | | DistAutogradContext                                     |  |    |   input_metadata_ |
    | |                                                 +-------------> |                   |
    | |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
    | |                                                 +       |  |    |                   |
    | |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
    | |                                                         |  |
    | |                                                         |  |
    | |  recvAutogradFunctions_                                 |  |
    | |                                                         |  |
    | +---------------------------------------------------------+  |
    |                                                              |
    +--------------------------------------------------------------+
    
                                                                                      sender
    +---------------------------------------------------------------------------------------+
    
    

    5.2 接受

    我们略过 agent 的发送内部处理,转而看看 FORWARD_AUTOGRAD_REQ 的业务流程。

    5.2.1 接收消息 ---> 接收方

    生成 TensorPipeAgent 时候,把 RequestCallbackImpl 配置为回调函数。这是 agent 的统一响应函数。

    前面关于代理接收逻辑时候,我们也提到了,会进入以下函数,其中可以看到有对 processForwardAutogradReq 的处理逻辑。

    void RequestCallbackNoPython::processRpc(
        RpcCommandBase& rpc,
        const MessageType& messageType,
        const int64_t messageId,
        const c10::intrusive_ptr<JitFuture>& responseFuture,
        std::shared_ptr<LazyStreamContext> ctx) const {
    
        case MessageType::FORWARD_AUTOGRAD_REQ: {
          // 会来到这里
          processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
          return;
        }
        case MessageType::BACKWARD_AUTOGRAD_REQ: {
          processBackwardAutogradReq(rpc, messageId, responseFuture);
          return;
        };  
      
    }  
    

    5.2.2 处理消息

    processForwardAutogradReq 负责具体处理消息,其处理逻辑如下:

    • 虽然是收到了前向传播请求,但因为此处是接收端,后续需要进行反向传播,所以对deviceMap进行转置。
    • 使用 addRecvRpcBackward 将 rpc 消息 加入上下文。
    • 可能会有nested命令的可能,所以需要再调用一次processRpc。
    • 设置最原始的消息为处理完毕,进行相关操作。
    void RequestCallbackNoPython::processForwardAutogradReq(
        RpcCommandBase& rpc,
        const int64_t messageId,
        const c10::intrusive_ptr<JitFuture>& responseFuture,
        std::shared_ptr<LazyStreamContext> ctx) const {
      
      auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
    
      // Need to reverse the device map for the backward pass of distributed
      // autograd.
      std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
      // 对deviceMap进行转置
      for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
        reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
      }
    
      // Attach 'recv' autograd function.
      auto autogradContext = addRecvRpcBackward( // 调用了 addRecvRpcBackward 加入上下文
          rpcWithAutograd.autogradMetadata(),
          rpcWithAutograd.tensors(),
          rpcWithAutograd.fromWorkerId(),
          reverseDeviceMap);
      // For this recv thread on server side, before processRpc(),
      // set current_context_id_ to be context_id passed from client.
      // In this way, if there is nested rpc call in python rpc call, original
      // context_id from client can be passed in the chain calls.
      DistAutogradContextGuard ctxGuard(autogradContext->contextId());
    
      // Process the original RPC.
      auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
      // Make an overall future for the wrapped response.
      auto wrappedRpcResponseFuture =
          c10::make_intrusive<JitFuture>(at::AnyClassType::get());
      // Kick off processing for the nested RPC command.
      // wrappedRpcResponseFuture will be a Future<T> to the result.
      processRpc( // 可能会有nested命令的可能,所以需要再处理一次
          rpcWithAutograd.wrappedRpc(),
          wrappedMessageType,
          messageId,
          wrappedRpcResponseFuture,
          std::move(ctx));
    
      auto fromWorkerId = rpcWithAutograd.fromWorkerId();
      // The original future needs to be marked as completed when the wrapped
      // one completes, with the autograd context information wrapped.
      wrappedRpcResponseFuture->addCallback(
          [responseFuture,
           messageId,
           fromWorkerId,
           ctxId =
               autogradContext->contextId()](JitFuture& wrappedRpcResponseFuture) {
            // As this callback can be invoked by a different thread, we have to
            // make sure that the thread_local states in the previous thread is
            // correctly propagated.
            // NB: The execution of TorchScript functions can also run on a
            // different thread, which is addressed by
            // https://github.com/pytorch/pytorch/pull/36395
            // NB: when adding async UDF support, we should also propagate
            // thread_local states there.
            // TODO: Land on a general solution for RPC ThreadLocalState. See
            // https://github.com/pytorch/pytorch/issues/38510
            DistAutogradContextGuard cbCtxGuard(ctxId);
    
            if (wrappedRpcResponseFuture.hasError()) {
              // Propagate error to responseFuture if we had one.
              responseFuture->setError(wrappedRpcResponseFuture.exception_ptr());
            } else {
              auto msg = getMessageWithAutograd(
                  fromWorkerId,
                  std::move(
                      *wrappedRpcResponseFuture.value().toCustomClass<Message>()),
                  MessageType::FORWARD_AUTOGRAD_RESP);
              msg.setId(messageId);
              responseFuture->markCompleted(
                  IValue(c10::make_intrusive<Message>(std::move(msg))));
            }
          });
    }
    

    5.2.3 上下文交互

    torch/csrc/distributed/autograd/utils.cpp 之中,addRecvRpcBackward 函数会对上下文进行处理。

    这里对应设计中的:

    在前向传播期间,我们在上下文中存储每个 autograd 传播的sendrecv函数。这确保我们在 autograd 图中保存对适当节点的引用以使其保持活动状态。除此之外,这也使得在向后传播期间很容易查找到对应的sendrecv函数。

    其具体逻辑是:

    • 根据 rpc信息中的 autogradContextId 拿到本地的上下文。
    • 生成一个 RecvRpcBackward。
    • 用 rpc 信息中的张量来对 RecvRpcBackward 进行配置,包括torch::autograd::set_history(tensor, grad_fn)。
    • 调用 addRecvFunction 把 RecvRpcBackward 加入到上下文。
    ContextPtr addRecvRpcBackward(
        const AutogradMetadata& autogradMetadata,
        std::vector<torch::Tensor>& tensors,
        rpc::worker_id_t fromWorkerId,
        const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
      // Initialize autograd context if necessary.
      auto& autogradContainer = DistAutogradContainer::getInstance();
      auto autogradContext =
          autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
    
      if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
        // Attach the tensors as inputs to the autograd function.
        auto grad_fn = std::make_shared<RecvRpcBackward>(
            autogradMetadata, autogradContext, fromWorkerId, deviceMap);
        for (auto& tensor : tensors) {
          if (tensor.requires_grad()) {
            torch::autograd::set_history(tensor, grad_fn);
          }
        }
    
        // Now update the autograd context with the necessary information.
        autogradContext->addRecvFunction(
            grad_fn, autogradMetadata.autogradMessageId);
      }
    
      return autogradContext;
    }
    

    addRecvFunction 的添加操作如下,就是看看 recvAutogradFunctions_之中是否已经存在这个 message id 对应的算子,如果没有就添加 。

    void DistAutogradContext::addRecvFunction(
        std::shared_ptr<RecvRpcBackward>& func,
        int64_t autograd_message_id) {
      TORCH_INTERNAL_ASSERT(func != nullptr);
      std::lock_guard<std::mutex> guard(lock_);
      TORCH_INTERNAL_ASSERT(
          recvAutogradFunctions_.find(autograd_message_id) ==
          recvAutogradFunctions_.end());
      recvAutogradFunctions_.emplace(autograd_message_id, func);
    }
    

    至此,逻辑拓展如下,在发送端和接收端都有一个 DistAutogradContext,其 id 都是 context_id_1。

    在 每个 DistAutogradContext 之内,均以 msg_id_1 作为key,一个是 SendRpcBackward,一个建立了 RecvRpcBackward。

    这就对应了设计之中提到的:

    每个自动微分过程被赋予一个唯一的 autograd_context_id,在容器中,这个微分过程的上下文(DistAutogradContext) 依据这个autograd_context_id 来唯一确认。autograd_context_id 是一个 64 bit 的全局唯一id,前 16 bis 是 worker_id,后 48 位是在每个worker内部自动递增id。所以可见,一个Container 之中,是有多个Context的。

    此容器还负责维护全局唯一的消息id,用来关联发送/接收自动微分函数对。格式类似于autograd_context_id,是一个64位整数,前16位是工作者id,后48位是worker内部自动递增的。

    +----------------------------------------------------------------+
    | worker                                                         |    +-------------------+
    |                                                                |    |SendRpcBackward    |
    |   +---------------------------------------------------------+  |    |                   |
    |   | DistAutogradContext                                     |  |    |   input_metadata_ |
    |   |                                                 +-------------> |                   |
    |   |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
    |   |                                                 +       |  |    |                   |
    |   |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
    |   |                                                         |  |
    |   |  recvAutogradFunctions_                                 |  |
    |   |                                                         |  |
    |   +---------------------------------------------------------+  |
    |                                                                |
    |                             +                                  |
    |                             |                                  |
    +----------------------------------------------------------------+
                                  |
                                  |
                                  |                                                     Sender
    +-----------------------------------------------------------------------------------------+
                                  |                                                     Receiver
                                  |
                                  v
    +-----------------------------+----------------------------------+
    | worker                                                         |
    |                                                                |    +-------------------+
    |   +---------------------------------------------------------+  |    |RecvRpcBackward    |
    |   | DistAutogradContext                                     |  |    |                   |
    |   |                                                         |  |    |                   |
    |   |   contextId_ = context_id_1                 +-----------------> |   input_metadata_ |
    |   |                                             |           |  |    |                   |
    |   |   sendAutogradFunctions_                    |           |  |    |   next_edges_     |
    |   |                                             +           |  |    |                   |
    |   |   recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1]|  |    +-------------------+
    |   |                                                         |  |
    |   +---------------------------------------------------------+  |
    |                                                                |
    +----------------------------------------------------------------+
    

    我们加入 Container,再拓展一下目前逻辑如下:

    • 每个worker 包括一个DistAutogradContainer。
    • 每个 DistAutogradContainer 包括若干个 DistAutogradContext,依据 context id 提取 DistAutogradContext。
    • 每个 DistAutogradContext 包括 sendAutogradFunctions_ 和 recvAutogradFunctions_,利用 msg id 来获取 SendRpcBackward 或者 RecvRpcBackward。

    这样这个反向传播链条就构建了出来。

    +------------------------------------------------------------------------------------------------------------------------------------+
    | worker                                                                                                                             |
    |                                                                                                                                    |
    | +---------------------------------------+     +---------------------------------------------------------+    +-------------------+ |
    | | DistAutogradContainer                 |     | DistAutogradContext                                     |    |SendRpcBackward    | |
    | |                                       |     |                                                 +----------> |                   | |
    | |   worker_id_                          |     |  contextId_ = ctx_id_1                          |       |    |   input_metadata_ | |
    | |                                       |     |                                                 +       |    |                   | |
    | |   next_autograd_message_id_     +---------> |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |    |   next_edges_     | |
    | |                                 |     |     |                                                         |    |                   | |
    | |   next_context_id_              |     |     |  recvAutogradFunctions_                                 |    +-------------------+ |
    | |                                 +     |     |                                                         |                          |
    | |   autograd_contexts_[ctx_id_1 : ctx]  |     +---------------------------------------------------------+                          |
    | |                                       |                                                                                          |
    | +----------------------------+----------+                                                                                          |
    |                              |                                                                                                     |
    +------------------------------------------------------------------------------------------------------------------------------------+
                                   |
                                   |
    +-------------------------------------------------------------------------------------------------------------------------------------+
                                   |
                                   v
    +------------------------------+-----------------------------------------------------------------------------------------------------+
    | worker                                                                                                                             |
    |                                                                                                                                    |
    | +---------------------------------------+     +---------------------------------------------------------+    +-------------------+ |
    | | DistAutogradContainer                 |     | DistAutogradContext                                     |    |RecvRpcBackward    | |
    | |                                       |     |                                                 +----------> |                   | |
    | |   worker_id_                          |     |  contextId_ = ctx_id_1                          |       |    |   input_metadata_ | |
    | |                                       |     |                                                 |       |    |                   | |
    | |   next_autograd_message_id_     +---------> |  sendAutogradFunctions_                         |       |    |   next_edges_     | |
    | |                                 |     |     |                                                 +       |    |                   | |
    | |   next_context_id_              |     |     |  recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1] |    +-------------------+ |
    | |                                 +     |     |                                                         |                          |
    | |   autograd_contexts_[ctx_id_1 : ctx]  |     +---------------------------------------------------------+                          |
    | |                                       |                                                                                          |
    | +---------------------------------------+                                                                                          |
    |                                                                                                                                    |
    +------------------------------------------------------------------------------------------------------------------------------------+
    
    

    手机如下:

    至此,我们初步分析了上下文相关的类,下文我们把目前已经分析的内容结合起来,系统看看业务逻辑。

    0xFF 参考

  • 相关阅读:
    windows7系统下升级到IE11时无法使用F12开发人员工具的解决办法
    微信公众号在线编辑器
    solr安装使用笔记
    在windows资源管理器添加进入当前目录dos窗口的快捷菜单
    spring mvc返回jsonp内容
    oracle最大连接数相关
    redis可视化管理工具Redis Desktop Manager
    Struts2远程代码执行漏洞预警
    postman请求数据库方法(Omysql)
    Selenium+java
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/15630679.html
Copyright © 2011-2022 走看看