zoukankan      html  css  js  c++  java
  • C++20协程解糖

    在开始之前,我们先修复上一篇文章中的一个bug,SharedState::add_finish_callback中post_all_callbacks应当提前判断settled,否则会在未设置结果的情况下添加callback,callback也会被立即post


    template<class T>
    class SharedState : public SharedStateBase {
        // ...
        // private
        void add_finish_callback(std::function<void(T&)> callback) {
            finish_callbacks.push_back(std::move(callback));
            if (settled) {
                post_all_callbacks();
            }
        }
    };


    概述

    今天我们要实现的东西包括

    1. 给schedular加上timer支持
    2. 给Future和Promise补充必要设施以支持C++20协程

    如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

    开始动手

    首先是Schedular的timer支持,我们这里使用一个简单的优先队列用来管理所有timer,并在poll函数中处理完当帧pending_state后,视情况sleep到最近的timer到期,并处理所有到期的timer


    class Schedular {
        // ...
        // public
        using timer_callback = std::function<void()>;
        using timer_item = std::tuple<bool, float, chrono::time_point<chrono::steady_clock, chrono::duration<float>>, timer_callback>;
        using timer = std::chrono::steady_clock;
        struct timer_item_cmp {
            bool operator()(const timer_item& a, const timer_item& b) const {
                return std::get<2>(a) > std::get<2>(b);
            }
        };
    
        // ...
        // public
        void poll() {
            size_t sz = pending_states.size();
            for (size_t i = 0; i != sz; i++) {
                auto state = std::move(pending_states[i]);
                state->invoke_all_callback();
            }
            pending_states.erase(pending_states.begin(), pending_states.begin() + sz);
            if (timer_queue.empty()) {
                return;
            }
            if (pending_states.empty()) { //如果pending_states为空,则可以sleep较长的时间,等待第一个将要完成的timer
                std::this_thread::sleep_until(std::get<2>(timer_queue.front()));
                auto now = timer::now();
                do {
                    deal_one_timer();
                } while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now);
            } else { //否则只能处理当帧到期的timer,不能sleep,要及时返回给caller,让caller及时下一次poll处理剩下的pending_states
                auto now = timer::now();
                while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now) {
                    deal_one_timer();
                }
            }
        }
    
        void add_timer(bool repeat, float delay, timer_callback callback) {
            auto cur_time = chrono::time_point_cast<chrono::duration<float>>(timer::now());
            auto timeout = cur_time + chrono::duration<float>(delay);
            timer_queue.emplace_back(repeat, delay, timeout, callback);
            std::push_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
        }
    
        // ...
        // private
        void deal_one_timer() {
            std::pop_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
            auto item = std::move(timer_queue.back());
            timer_queue.pop_back();
            std::get<3>(item)();
            if (std::get<0>(item)) {
                add_timer(true, std::get<1>(item), std::move(std::get<3>(item)));
            }
        }
    
        std::deque<timer_item> timer_queue;
    };


    这样之后,基于当前调度器的delay函数就可以写出来了


    class Schedular {
        // ...
        // public
        Future<float> delay(float second) {
            auto promise = Promise<float>(*this);
            add_timer(false, second, [=]() mutable {
                promise.set_result(second);
            });
            return promise.get_future();
        }
        // ...
    };


    因为之前我们设计的Future和Promise并不支持void,于是这里简单用Future<float>代替,返回的是等待的秒数。

    需要注意的是,这个delay函数虽然返回Future,但并不是协程,协程的判断标准是当且仅当函数中使用了co_await/co_yield/co_return,和返回类型无关。

    这个函数同样展示了将回调式API封装为Future的做法,就是把Promise.set_result作为回调传入给API,并返回Promise.get_future,使用者在Future这边等待就好了。

    有了这些东西之后,我们可以先把本次的测试代码写出来了

    如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢


    Future<float> func(Schedular& schedular) {
        std::cout << "start sleep
    ";
        auto r = co_await schedular.delay(1.2);
        co_return r;
    }
    
    Future<int> func2(Schedular& schedular) {
        auto r = co_await func(schedular);
        std::cout << "slept for " << r << "s
    ";
        co_return 42;
    }


    这里需要注意的是C++协程在编译器实现中,会自动构造一个Promise对象,而我们的Promise并不支持默认构造,必须传入一个Schedular参数。好在C++会替我们自动将协程参数作为作为构造函数参数来构造Promise,因此要在协程参数中指定Schedular,相当于指定Schedular构造了Promise。为一个协程显式指定调度器,是一个很合理的设计,python也是类似的设计。C#将协程调度器隐藏进了Task,因为它有一个全局的默认调度器。如果我们的实现中提供一个全局构造的Schedular,让Promise自动去找他调度,那这里的协程也可以没有参数。

    如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

    为了让Future支持协程,代码中还需要补充一系列的内容,列举在下面


    template<class T>
    class Future {
        // ...
        // public
        // 协程接口
        using promise_type = Promise<T>;
    
        bool await_ready() { return _state->settled; }
    
        void await_suspend(exp::coroutine_handle<> handle) {
            add_finish_callback([=](T&) mutable { handle.resume(); });
        }
    
        T await_resume() { return _state->value; }
        // 协程接口
        // ...
    };

    • promise_type用来指定本future对应的promise,结果的输入端
    • await_ready检查future是否已经完成
    • await_suspend用来通知future,协程为了等待它完成,已经暂停,future需要在自己完成的时候,主动恢复协程
    • await_resume用来通知future,协程已经恢复执行,需要从future中取出结果,用作co_await表达式的结果。这里我们直接返回拷贝,实现中较为合理的是把future持有的对象移动出去,但这样的话被await的future就不能再单独获取结果了。

    为了让Promise支持协程,需要补充的内容在下面


    // ...
    // 在最开头
    // 如果你的编译器已经不需要std::experimental了,那就去掉这行,后面使用std而不是exp
    namespace exp = std::experimental;
    
    template<class T>
    class Promise {
        // ...
        // public
        // 协程接口
        Future<T> get_return_object();
    
        exp::suspend_never initial_suspend() { return {}; }
    
        exp::suspend_never final_suspend() noexcept { return {}; }
    
        void return_value(T v) { set_result(v); }
    
        void unhandled_exception() { std::terminate(); }
        // 协程接口
        // ...
    };
    
    // ...
    // 在Future定义后面
    template<class T>
    Future<T> Promise<T>::get_return_object() {
        return get_future();
    }

    • initial_suspend用来表明协程是否在调用时暂停,异步任务一般返回suepend_never,调用时立即启动
    • final_suspend用来表明协程是否在co_return后暂停(延迟销毁),我们是使用shared_state的异步任务,因此可以不暂停协程,直接自动销毁协程,让shared_state留在空中靠引用计数清零销毁
    • return_value用于co_return将结果传入
    • unhandled_exception用于协程中出现了未处理异常的情况,这里面可以通过std::current_exception来获取当前异常,我们的简化版不可能出现异常,出了就直接terminate
    • get_return_object就是get_future。大家不要忘记一个协程是先构造promise,后从promise获取future的

    有了这些东西之后,编译就不应该再出现错误了,我的编译选项是 clang++-9 test.cpp -stdlib=libc++ -std=c++2a 

    运行?还差最后一点

    为了方便,我们效法python,给Schedular补一个run_until_compete的方法

    如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢


    class Schedular {
        // ...
        // public
        template<class F, class... Args>
        auto run_until_complete(F&& fn, Args&&... args) -> typename std::invoke_result_t<F&&, Args&&...>::result_type {
            auto future = std::forward<F>(fn)(std::forward<Args>(args)...);
            while (!future.await_ready()) {
                poll();
            }
            return future.await_resume();
        }
    };


    然后main

    int main() {
        Schedular schedular;
    
        auto r = schedular.run_until_complete(func2, schedular);
    
        std::cout << "run complete with " << r << "
    ";
    }


    运行结果就有了

    start sleep
    slept for 1.2s
    run complete with 42


    怎么样,是不是很简单呢,赶紧自己写一个吧!

    如果你看到这行文字,说明这篇文章被无耻的盗用了(或者你正在选中文字),请前往 cnblogs.com/pointer-smq 支持原作者,谢谢

    附录 - 全部代码



    #include <vector>
    #include <deque>
    #include <memory>
    #include <iostream>
    #include <functional>
    #include <chrono>
    #include <thread>
    #include <algorithm>
    #include <experimental/coroutine>
    
    namespace exp = std::experimental;
    
    template<class T>
    class Future;
    
    template<class T>
    class Promise;
    
    class Schedular;
    
    class SharedStateBase : public std::enable_shared_from_this<SharedStateBase> {
        friend class Schedular;
    public:
        virtual ~SharedStateBase() = default;
    private:
        virtual void invoke_all_callback() = 0;
    };
    
    template<class T>
    class SharedState : public SharedStateBase {
        friend class Future<T>;
        friend class Promise<T>;
    public:
        SharedState(Schedular& schedular)
            : schedular(&schedular)
        {}
        SharedState(const SharedState&) = delete;
        SharedState(SharedState&&) = delete;
        SharedState& operator=(const SharedState&) = delete;
        SharedState& operator=(SharedState&&) = delete;
    
    private:
        template<class U>
        void set(U&& v) {
            if (settled) {
                return;
            }
            settled = true;
            value = std::forward<U>(v);
            post_all_callbacks();
        }
    
        T& get() { return value; }
    
        void add_finish_callback(std::function<void(T&)> callback) {
            finish_callbacks.push_back(std::move(callback));
            if (settled) {
                post_all_callbacks();
            }
        }
    
        void post_all_callbacks();
    
        virtual void invoke_all_callback() override {
            callback_posted = false;
            size_t sz = finish_callbacks.size();
            for (size_t i = 0; i != sz; i++) {
                auto v = std::move(finish_callbacks[i]);
                v(value);
            }
            finish_callbacks.erase(finish_callbacks.begin(), finish_callbacks.begin()+sz);
        }
    
        bool settled = false;
        bool callback_posted = false;
        Schedular* schedular = nullptr;
        T value;
        std::vector<std::function<void(T&)>> finish_callbacks;
    };
    
    template<class T>
    class Promise {
    public:
        Promise(Schedular& schedular)
            : _schedular(&schedular)
            , _state(std::make_shared<SharedState<T>>(*_schedular))
        {}
    
        Future<T> get_future();
    
        // 协程接口
        Future<T> get_return_object();
        exp::suspend_never initial_suspend() { return {}; }
        exp::suspend_never final_suspend() noexcept { return {}; }
        void return_value(T v) { set_result(v); }
        void unhandled_exception() { std::terminate(); }
        // 协程接口
    
        template<class U>
        void set_result(U&& value) {
            if (_state->settled) {
                throw std::invalid_argument("already set result");
            }
            _state->set(std::forward<U>(value));
        }
    private:
        Schedular* _schedular;
        std::shared_ptr<SharedState<T>> _state;
    };
    
    template<class T>
    class Future {
    public:
        using result_type = T;
        using promise_type = Promise<T>;
        friend class Promise<T>;
    private:
        Future(std::shared_ptr<SharedState<T>> state)
            : _state(std::move(state))
        {
        }
    public:
        // 协程接口
        bool await_ready() { return _state->settled; }
        void await_suspend(exp::coroutine_handle<> handle) {
            add_finish_callback([=](T&) mutable { handle.resume(); });
        }
        T await_resume() { return _state->value; }
        // 协程接口
    
        void add_finish_callback(std::function<void(T&)> callback) {
            _state->add_finish_callback(std::move(callback));
        }
    private:
        std::shared_ptr<SharedState<T>> _state;
    };
    
    template<class T>
    Future<T> Promise<T>::get_future() {
        return Future<T>(_state);
    }
    
    template<class T>
    Future<T> Promise<T>::get_return_object() {
        return get_future();
    }
    
    namespace chrono = std::chrono;
    
    class Schedular {
        template<class T>
        friend class SharedState;
    public:
        using timer_callback = std::function<void()>;
        using timer_item = std::tuple<bool, float, chrono::time_point<chrono::steady_clock, chrono::duration<float>>, timer_callback>;
        using timer = std::chrono::steady_clock;
        struct timer_item_cmp {
            bool operator()(const timer_item& a, const timer_item& b) const {
                return std::get<2>(a) > std::get<2>(b);
            }
        };
    
        Schedular() = default;
        Schedular(Schedular&&) = delete;
        Schedular(const Schedular&) = delete;
        Schedular& operator=(Schedular&&) = delete;
        Schedular& operator=(const Schedular&) = delete;
    
        void poll() {
            size_t sz = pending_states.size();
            for (size_t i = 0; i != sz; i++) {
                auto state = std::move(pending_states[i]);
                state->invoke_all_callback();
            }
            pending_states.erase(pending_states.begin(), pending_states.begin() + sz);
            if (timer_queue.empty()) {
                return;
            }
            if (pending_states.empty()) { //如果pending_states为空,则可以sleep较长的时间,等待第一个将要完成的timer
                std::this_thread::sleep_until(std::get<2>(timer_queue.front()));
                auto now = timer::now();
                do {
                    deal_one_timer();
                } while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now);
            } else { //否则只能处理当帧到期的timer,不能sleep,要及时返回给caller,让caller及时下一次poll处理剩下的pending_states
                auto now = timer::now();
                while (!timer_queue.empty() && std::get<2>(timer_queue.front()) <= now) {
                    deal_one_timer();
                }
            }
        }
    
        template<class F, class... Args>
        auto run_until_complete(F&& fn, Args&&... args) -> typename std::invoke_result_t<F&&, Args&&...>::result_type {
            auto future = std::forward<F>(fn)(std::forward<Args>(args)...);
            while (!future.await_ready()) {
                poll();
            }
            return future.await_resume();
        }
    
        void add_timer(bool repeat, float delay, timer_callback callback) {
            auto cur_time = chrono::time_point_cast<chrono::duration<float>>(timer::now());
            auto timeout = cur_time + chrono::duration<float>(delay);
            timer_queue.emplace_back(repeat, delay, timeout, callback);
            std::push_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
        }
    
        Future<float> delay(float second) {
            auto promise = Promise<float>(*this);
            add_timer(false, second, [=]() mutable {
                promise.set_result(second);
            });
            return promise.get_future();
        }
    private:
        void deal_one_timer() {
            std::pop_heap(timer_queue.begin(), timer_queue.end(), timer_item_cmp{});
            auto item = std::move(timer_queue.back());
            timer_queue.pop_back();
            std::get<3>(item)();
            if (std::get<0>(item)) {
                add_timer(true, std::get<1>(item), std::move(std::get<3>(item)));
            }
        }
    
        void post_call_state(std::shared_ptr<SharedStateBase> state) {
            pending_states.push_back(std::move(state));
        }
        std::vector<std::shared_ptr<SharedStateBase>> pending_states;
        std::deque<timer_item> timer_queue;
    };
    
    template<class T>
    void SharedState<T>::post_all_callbacks() {
        if (callback_posted) {
            return;
        }
        callback_posted = true;
        schedular->post_call_state(shared_from_this());
    }
    
    Future<float> func(Schedular& schedular) {
        std::cout << "start sleep
    ";
        auto r = co_await schedular.delay(1.2);
        co_return r;
    }
    
    Future<int> func2(Schedular& schedular) {
        auto r = co_await func(schedular);
        std::cout << "slept for " << r << "s
    ";
        co_return 42;
    }
    
    int main() {
        Schedular schedular;
    
        auto r = schedular.run_until_complete(func2, schedular);
    
        std::cout << "run complete with " << r << "
    ";
    }


  • 相关阅读:
    程序员佛祖保佑无bug、发发发 -注释代码
    27 友盟项目--azkaban资源调度
    26 友盟项目--数据可视化
    25 友盟项目--sqoop从hive导出数据到mysql
    24 友盟项目--优化-flume限速拦截、flume自定义源防丢失--改造exec源守护线程监控目录(防丢失)redis维护key(去重)
    23 友盟项目--sparkstreaming对接kafka、集成redis--从redis中查询月留存率
    22 友盟项目--sparkstreaming对接kafka、集成redis--从redis中存储用户使用app的最小时间戳min , 最大时间戳max
    21 友盟项目--统计连续活跃用户、近期流失用户、留存用户--创建表并插入选择出的数据
    20 友盟项目--统计月活率、沉默用户、周回流用户--创建表并插入选择出的数据
    19 友盟项目--统计新增用户---日新增、周新增、月新增--创建表并插入选择出的数据
  • 原文地址:https://www.cnblogs.com/pointer-smq/p/12940360.html
Copyright © 2011-2022 走看看