zoukankan      html  css  js  c++  java
  • Rust源码分析:channel's upgrade

    https://zhuanlan.zhihu.com/p/50101525

    std::sync::mpsc::channel

    本文分析Rust标准库中的channel,channel(通道)作为线程间通信的一种方式被广泛使用。

    Rust提供了多生产者单消费者的channel。我们重点关注多个生产者的情况。

    它的实现方式非常有趣。我把它分为通道升级跟并发队列两部分。

    本文描述通道升级

    对于一个channel()调用,我们得到的(sender, receiver)是oneshot的,这一点从源码可以得到暗示:

    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
        let a = Arc::new(oneshot::Packet::new());
        (Sender::new(Flavor::Oneshot(a.clone())), Receiver::new(Flavor::Oneshot(a)))
    }

    这里至少有四个结构:

    • oneshot::Packet:Packet,真正存放数据的地方。此处是单个数据(其他类型可能使用队列)
    • Flavor::Oneshot。
    • Sender/Receiver。

    我们分别看下他们的数据结构源码,首先是oneshot::Packet,它位于mpsc/oneshot.rs:

    pub struct Packet<T> {
        // Internal state of the chan/port pair (stores the blocked thread as well)
        state: AtomicUsize,
        // One-shot data slot location
        data: UnsafeCell<Option<T>>,
        // when used for the second time, a oneshot channel must be upgraded, and
        // this contains the slot for the upgrade
        upgrade: UnsafeCell<MyUpgrade<T>>,
    }

    可以看出data是为一个数据准备的。upgrade字段用于通道升级。

    另外还有其他类型的Packet,查看同一文件夹发现有shared::Packet/stream::Packet/sync::Packet,他们分别位于shared.rs/stream.rs/sync.rs中。我们重点关注shared::Packet:

    pub struct Packet<T> {
        queue: mpsc::Queue<T>,
        cnt: AtomicIsize, // How many items are on this channel
        steals: UnsafeCell<isize>, // How many times has a port received without blocking?
        to_wake: AtomicUsize, // SignalToken for wake up
    
        // The number of channels which are currently using this packet.
        channels: AtomicUsize,
    
        // See the discussion in Port::drop and the channel send methods for what
        // these are used for
        port_dropped: AtomicBool,
        sender_drain: AtomicIsize,
    
        // this lock protects various portions of this implementation during
        // select()
        select_lock: Mutex<()>,
    }

    清楚地看到queue字段,它用于存放数据。我们先不关注数据字段。

    对于这四个类型的Packet,标准库提供了enun Flavor<T>来做区分:

    enum Flavor<T> {
        Oneshot(Arc<oneshot::Packet<T>>),
        Stream(Arc<stream::Packet<T>>),
        Shared(Arc<shared::Packet<T>>),
        Sync(Arc<sync::Packet<T>>),
    }

    而我们的Sender/Receiver对象则非常简单地通过存储Flavor<T>来关联到Packet:

    pub struct Sender<T> {
        inner: UnsafeCell<Flavor<T>>,
    }
    pub struct Receiver<T> {
        inner: UnsafeCell<Flavor<T>>,
    }

    我们再看一下fn channel:

    pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
        let a = Arc::new(oneshot::Packet::new());
        (Sender::new(Flavor::Oneshot(a.clone())), Receiver::new(Flavor::Oneshot(a)))
    }

    就可以了解到Sender/Receiver里面都存了Flavor,根据Flavor的类型区分Packet的类型,同时Packet作为共享数据被安全地共享。

    这就是我们调用channel得到的结果。因为我们重点关注多生产者的情况,所以我们再看一下Clone for Sender的实现:

    impl<T> Clone for Sender<T> {
        fn clone(&self) -> Sender<T> {
            let packet = match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => {
                    let a = Arc::new(shared::Packet::new());
                    {
                        let guard = a.postinit_lock();
                        let rx = Receiver::new(Flavor::Shared(a.clone()));
                        let sleeper = match p.upgrade(rx) {
                            oneshot::UpSuccess |
                            oneshot::UpDisconnected => None,
                            oneshot::UpWoke(task) => Some(task),
                        };
                        a.inherit_blocker(sleeper, guard);
                    }
                    a
                }
                Flavor::Stream(ref p) => {
                    let a = Arc::new(shared::Packet::new());
                    {
                        let guard = a.postinit_lock();
                        let rx = Receiver::new(Flavor::Shared(a.clone()));
                        let sleeper = match p.upgrade(rx) {
                            stream::UpSuccess |
                            stream::UpDisconnected => None,
                            stream::UpWoke(task) => Some(task),
                        };
                        a.inherit_blocker(sleeper, guard);
                    }
                    a
                }
                Flavor::Shared(ref p) => {
                    p.clone_chan();
                    return Sender::new(Flavor::Shared(p.clone()));
                }
                Flavor::Sync(..) => unreachable!(),
            };
    
            unsafe {
                let tmp = Sender::new(Flavor::Shared(packet.clone()));
                mem::swap(self.inner_mut(), tmp.inner_mut());
            }
            Sender::new(Flavor::Shared(packet))
        }
    }

    代码比较多,但我们关注Flavor::Oneshot的情况,先看下self.inner()的实现,它是通过 trait UnsafeFlavor来提供的接口:

    trait UnsafeFlavor<T> {
        fn inner_unsafe(&self) -> &UnsafeCell<Flavor<T>>;
        unsafe fn inner_mut(&self) -> &mut Flavor<T> {
            &mut *self.inner_unsafe().get()
        }
        unsafe fn inner(&self) -> &Flavor<T> {
            &*self.inner_unsafe().get()
        }
    }
    impl<T> UnsafeFlavor<T> for Sender<T> {
        fn inner_unsafe(&self) -> &UnsafeCell<Flavor<T>> {
            &self.inner
        }
    }

    考虑到Sender存了inner: UnsafeCell<Flavor<T>>,所以这里是通过unsafe的指针操作得到内部Flavor<T>的引用,然后匹配到Flavor::Oneshot的情况:

    impl<T> Clone for Sender<T> {
        fn clone(&self) -> Sender<T> {
            let packet = match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => {
                    let a = Arc::new(shared::Packet::new());
                    {
                        let guard = a.postinit_lock();
                        let rx = Receiver::new(Flavor::Shared(a.clone()));
                        let sleeper = match p.upgrade(rx) {
                            oneshot::UpSuccess |
                            oneshot::UpDisconnected => None,
                            oneshot::UpWoke(task) => Some(task),
                        };
                        a.inherit_blocker(sleeper, guard);
                    }
                    a
                }
                ............
            };
    
            unsafe {
                let tmp = Sender::new(Flavor::Shared(packet.clone()));
                mem::swap(self.inner_mut(), tmp.inner_mut());
            }
            Sender::new(Flavor::Shared(packet))
        }
    }

    接下来通过Arc::new(shared::Packet::new()),创建了一个全新的shared::Packet,a。

    然后调用a.postinit_lock(),我们看下它的代码:

        pub fn postinit_lock(&self) -> MutexGuard<()> {
            self.select_lock.lock().unwrap()
        }

    结合Shared::Packet的new函数:

        pub fn new() -> Packet<T> {
            Packet {
                queue: mpsc::Queue::new(),
                cnt: AtomicIsize::new(0),
                steals: UnsafeCell::new(0),
                to_wake: AtomicUsize::new(0),
                channels: AtomicUsize::new(2),
                port_dropped: AtomicBool::new(false),
                sender_drain: AtomicIsize::new(0),
                select_lock: Mutex::new(()),
            }
        }

    发现它只是个lock操作,guard作为返回的对象将来用于解锁。

    我们接着看原来的代码,这一行是重点:

    let rx = Receiver::new(Flavor::Shared(a.clone()));

    我们根据新建的a,创建了一个Receiver rx,这里创建的rx是挺奇怪的事情。但是我们只能接着看代码:

                         let sleeper = match p.upgrade(rx) {
                            oneshot::UpSuccess |
                            oneshot::UpDisconnected => None,
                            oneshot::UpWoke(task) => Some(task),
                        };

    这里的p就是原来的oneshot::Packet,传入新建的rx,我们调用它的upgrade方法:

        pub fn upgrade(&self, up: Receiver<T>) -> UpgradeResult {
            unsafe {
                let prev = match *self.upgrade.get() {
                    NothingSent => NothingSent,
                    SendUsed => SendUsed,
                    _ => panic!("upgrading again"),
                };
                ptr::write(self.upgrade.get(), GoUp(up));
    
                match self.state.swap(DISCONNECTED, Ordering::SeqCst) {
                    // If the channel is empty or has data on it, then we're good to go.
                    // Senders will check the data before the upgrade (in case we
                    // plastered over the DATA state).
                    DATA | EMPTY => UpSuccess,
    
                    // If the other end is already disconnected, then we failed the
                    // upgrade. Be sure to trash the port we were given.
                    DISCONNECTED => { ptr::replace(self.upgrade.get(), prev); UpDisconnected }
    
                    // If someone's waiting, we gotta wake them up
                    ptr => UpWoke(SignalToken::cast_from_usize(ptr))
                }
            }
        }

    根据初始化的upgrade字段的值,我们发现只能是NothingSent:

        pub fn new() -> Packet<T> {
            Packet {
                data: UnsafeCell::new(None),
                upgrade: UnsafeCell::new(NothingSent),
                state: AtomicUsize::new(EMPTY),
            }
        }

    然后我们把GoUp(up)写入了upgrade字段,那么现在我们新建的rx:Receiver也就到了upgrade字段里面,这里我们可以看下GoUp字段相关的代码:

    enum MyUpgrade<T> {
        NothingSent,
        SendUsed,
        GoUp(Receiver<T>),
    }

    接着将通过self.state.swap操作将状态改变为DISCONNECTED,因为这个oneshot::Packet将要被淘汰,而我们只是把它的状态从EMPTY变为DISCONNECTED,可以看下相关的代码:

    // Various states you can find a port in.
    const EMPTY: usize = 0;          // initial state: no data, no blocked receiver
    const DATA: usize = 1;           // data ready for receiver to take
    const DISCONNECTED: usize = 2;   // channel is disconnected OR upgraded
    

    最后upgrade返回作为结果UpgradeResult 的UpSuccess标记。我们接着看原来clone的代码:

    impl<T> Clone for Sender<T> {
        fn clone(&self) -> Sender<T> {
            let packet = match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => {
                    let a = Arc::new(shared::Packet::new());
                    {
                        let guard = a.postinit_lock();
                        let rx = Receiver::new(Flavor::Shared(a.clone()));
                        let sleeper = match p.upgrade(rx) {
                            oneshot::UpSuccess |
                            oneshot::UpDisconnected => None,
                            oneshot::UpWoke(task) => Some(task),
                        };
                        a.inherit_blocker(sleeper, guard);
                    }
                    a
                }
                ............
            };
            ..................
        }
    }

    这里的p.upgrade(rx)的结果就是UpSuccess,那么sleeper 就是None。

    我们接着看a.inherit_blocker(sleeper, guard)的实现:

        pub fn inherit_blocker(&self,
                               token: Option<SignalToken>,
                               guard: MutexGuard<()>) {
            token.map(|token| {
                assert_eq!(self.cnt.load(Ordering::SeqCst), 0);
                assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
                self.to_wake.store(unsafe { token.cast_to_usize() }, Ordering::SeqCst);
                self.cnt.store(-1, Ordering::SeqCst);
    
                unsafe { *self.steals.get() = -1; }
            });
    
            drop(guard);
        }

    被传入的token也就是sleeper为None,None.map(||{})只是返回None,所以这里的操作只是通过guard释放了锁。到此,我们返回a,就是packet:Arc<shared::Packet<T>>。我们再接着看clone的代码:

    impl<T> Clone for Sender<T> {
        fn clone(&self) -> Sender<T> {
            let packet = match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => {
                    let a = Arc::new(shared::Packet::new());
                    {
                        let guard = a.postinit_lock();
                        let rx = Receiver::new(Flavor::Shared(a.clone()));
                        let sleeper = match p.upgrade(rx) {
                            oneshot::UpSuccess |
                            oneshot::UpDisconnected => None,
                            oneshot::UpWoke(task) => Some(task),
                        };
                        a.inherit_blocker(sleeper, guard);
                    }
                    a
                }
                ............
            };
    
            unsafe {
                let tmp = Sender::new(Flavor::Shared(packet.clone()));
                mem::swap(self.inner_mut(), tmp.inner_mut());
            }
            Sender::new(Flavor::Shared(packet))
        }
    }

    注意,我们通过Sender::new(Flavor::Shared(packet))返回了一个新的Sender对象,它基于shared::Packet。同时,我们构造了一个临时的Sender对象tmp,然后通过mem::swap这种unsafe的内存操作,将当前的对象内部的inner替换掉,注意它是UnsafeCell<Flavor<T>>。

    Flavor::Oneshot(Arc<oneshot::Packet<T>>)
    => Flavor::Shared(Arc<shared::Packet<T>>)

    而这个tmp对象,我们看下它的drop方法,由于swap操作,走Flavor::OneShot路径:

    impl<T> Drop for Sender<T> {
        fn drop(&mut self) {
            match *unsafe { self.inner() } {
                Flavor::Oneshot(ref p) => p.drop_chan(),
                Flavor::Stream(ref p) => p.drop_chan(),
                Flavor::Shared(ref p) => p.drop_chan(),
                Flavor::Sync(..) => unreachable!(),
            }
        }
    }
        pub fn drop_chan(&self) {
            match self.state.swap(DISCONNECTED, Ordering::SeqCst) {
                DATA | DISCONNECTED | EMPTY => {}
    
                // If someone's waiting, we gotta wake them up
                ptr => unsafe {
                    SignalToken::cast_from_usize(ptr).signal();
                }
            }
        }

    self.state字段已经是DISCONNECTED的值了,所以tmp被析构时不会有更多的操作。

    以上是针对Flavor::Oneshot的clone实现,我们再看下如果接着调用clone的实现:

        fn clone(&self) -> Sender<T> {
            let packet = match *unsafe { self.inner() } {
                ............
                Flavor::Shared(ref p) => {
                    p.clone_chan();
                    return Sender::new(Flavor::Shared(p.clone()));
                }
                Flavor::Sync(..) => unreachable!(),
            };
            ............
        }

    注意到它只会走Flavor::Shared的路径,只返回一个新的Sender<Flavor::Shared<..>>而已

    我们看下clone_chan的实现:

        pub fn clone_chan(&self) {
            let old_count = self.channels.fetch_add(1, Ordering::SeqCst);
    
            // See comments on Arc::clone() on why we do this (for `mem::forget`).
            if old_count > MAX_REFCOUNT {
                unsafe {
                    abort();
                }
            }
        }

    只是增加了一个关联管道的计数。

    综合以上,我们现在有两个Sender:

    • 一个是一开始的Sender,也就是代码中的self,它内部的inner已经指向Flavor::Shared。
    • 另一个是clone出来的Sender,它一样是指向Flavor::Shared,并且与第一个共享一个shared::Packet。

    同时我们还有两个Receiver:

    • 一个是一开始的Receiver,它内部的inner现在还是指向一开始的Flavor::Oneshot,里面包裹了初始的oneshot::Packet。
    • 另一个是Sender.clone()调用中创建的Receiver,它指向了Flavor::Shared。同时它被存放在了初始的oneshot::Packet里面。

    也就是说通过第一个Receiver可得到oneshot::Packet,通过它可以得到Flavor::Shared,那么我们就可以成功实现Receiver的升级操作。

    但是此刻当Sender的所有clone操作都完成时,Receiver是还没升级的。为了查看Receiver何时升级,我们来看Receiver的recv函数:

        pub fn recv(&self) -> Result<T, RecvError> {
            loop {
                let new_port = match *unsafe { self.inner() } {
                    Flavor::Oneshot(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(oneshot::Disconnected) => return Err(RecvError),
                            Err(oneshot::Upgraded(rx)) => rx,
                            Err(oneshot::Empty) => unreachable!(),
                        }
                    }
                    Flavor::Stream(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(stream::Disconnected) => return Err(RecvError),
                            Err(stream::Upgraded(rx)) => rx,
                            Err(stream::Empty) => unreachable!(),
                        }
                    }
                    Flavor::Shared(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(shared::Disconnected) => return Err(RecvError),
                            Err(shared::Empty) => unreachable!(),
                        }
                    }
                    Flavor::Sync(ref p) => return p.recv(None).map_err(|_| RecvError),
                };
                unsafe {
                    mem::swap(self.inner_mut(), new_port.inner_mut());
                }
            }
        }

    我们只关注Flavor::Oneshot的情况,得到内部的oneshot::Packet为p,调用p.recv(None):

        pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure<T>> {
            // Attempt to not block the thread (it's a little expensive). If it looks
            // like we're not empty, then immediately go through to `try_recv`.
            if self.state.load(Ordering::SeqCst) == EMPTY {
                let (wait_token, signal_token) = blocking::tokens();
                let ptr = unsafe { signal_token.cast_to_usize() };
    
                // race with senders to enter the blocking state
                if self.state.compare_and_swap(EMPTY, ptr, Ordering::SeqCst) == EMPTY {
                    if let Some(deadline) = deadline {
                        let timed_out = !wait_token.wait_max_until(deadline);
                        // Try to reset the state
                        if timed_out {
                            self.abort_selection().map_err(Upgraded)?;
                        }
                    } else {
                        wait_token.wait();
                        debug_assert!(self.state.load(Ordering::SeqCst) != EMPTY);
                    }
                } else {
                    // drop the signal token, since we never blocked
                    drop(unsafe { SignalToken::cast_from_usize(ptr) });
                }
            }
    
            self.try_recv()
        }

    此刻,由于之前Sender.clone()操作,这里的self.state已经是DISCONNECTED了,所以我们接着看self.try_recv():

        pub fn try_recv(&self) -> Result<T, Failure<T>> {
            unsafe {
                match self.state.load(Ordering::SeqCst) {
                    EMPTY => Err(Empty),
                    DATA => {
                        self.state.compare_and_swap(DATA, EMPTY, Ordering::SeqCst);
                        match (&mut *self.data.get()).take() {
                            Some(data) => Ok(data),
                            None => unreachable!(),
                        }
                    }
                    DISCONNECTED => {
                        match (&mut *self.data.get()).take() {
                            Some(data) => Ok(data),
                            None => {
                                match ptr::replace(self.upgrade.get(), SendUsed) {
                                    SendUsed | NothingSent => Err(Disconnected),
                                    GoUp(upgrade) => Err(Upgraded(upgrade))
                                }
                            }
                        }
                    }
                    // We are the sole receiver; there cannot be a blocking
                    // receiver already.
                    _ => unreachable!()
                }
            }
        }

    显然,这里走的是DISCONNECTED 路径,self.data初始值为None,所以这里的take()操作走None路径,关键是下面的代码:

                            None => {
                                match ptr::replace(self.upgrade.get(), SendUsed) {
                                    SendUsed | NothingSent => Err(Disconnected),
                                    GoUp(upgrade) => Err(Upgraded(upgrade))
                                }
                            }

    我们把self.upgrade里面存放的数据替换为SendUsed,同时取得原来的数据。

    注意,这里取得的数据GoUp(upgrade),upgrade就是之前我们不知道为何创建的Receiver<T>,同时通过Err(Upgraded(upgrade))返回出去,这里的Upgraded是:

    pub enum Failure<T> {
        Empty,
        Disconnected,
        Upgraded(Receiver<T>),
    }

    这个值一直返回到Receiver.recv()操作里面,

        pub fn recv(&self) -> Result<T, RecvError> {
            loop {
                let new_port = match *unsafe { self.inner() } {
                    Flavor::Oneshot(ref p) => {
                        match p.recv(None) {
                            Ok(t) => return Ok(t),
                            Err(oneshot::Disconnected) => return Err(RecvError),
                            Err(oneshot::Upgraded(rx)) => rx,
                            Err(oneshot::Empty) => unreachable!(),
                        }
                    }
                    ............
                };
                unsafe {
                    mem::swap(self.inner_mut(), new_port.inner_mut());
                }
            }
        }

    根据Err(oneshot::Upgraded(rx))匹配得到rx,也就是创建的那个Receiver。接着rx作为new_port,最后通过一样的mem::swap操作把Receiver内部的Flavor<T>替换为Flavor::Shared模式的对象。

    于是,我们看到Receiver已经成功升级为关联到Flavor::Shared<shared::Packet<T>>的通道。

    至此,Sender/Receiver从仅存放一个元素的通道升级为无限制容量的MPSC通道。

  • 相关阅读:
    二叉树之求叶子结点个数
    求二叉树的深度
    二叉树的基本操作
    二叉树之求结点个数
    数组面试
    数组之求子数组的最大乘积
    字符串之子串
    最近遇到的几个纯C编程的陷阱
    Ubuntu 16.04 64位安装YouCompleteMe
    Linux和Windows的遍历目录下所有文件的方法对比
  • 原文地址:https://www.cnblogs.com/dhcn/p/12957380.html
Copyright © 2011-2022 走看看