zoukankan      html  css  js  c++  java
  • C# Task WhenAny和WhenAll 以及TaskFactory 的ContinueWhenAny和ContinueWhenAll的实现

    个人感觉Task 的WaitAny和WhenAny以及TaskFactory 的ContinueWhenAny有相似的地方,而WaitAll和WhenAll以及TaskFactory 的ContinueWhenAll也是相同,但是WaitAny和WhenAny的返回值有所不同。我们首先来看看Task WhenAny和WhenAll 的实现吧,

    public class Task : IThreadPoolWorkItem, IAsyncResult, IDisposable
    {
        //Creates a task that will complete when any of the supplied tasks have completed.
        public static Task<Task> WhenAny(IEnumerable<Task> tasks)
        {
            if (tasks == null) throw new ArgumentNullException("tasks");
            Contract.EndContractBlock();
            List<Task> taskList = new List<Task>();
            foreach (Task task in tasks)
            {
                if (task == null) throw new ArgumentException(Environment.GetResourceString("Task_MultiTaskContinuation_NullTask"), "tasks");
                taskList.Add(task);
            }
            if (taskList.Count == 0)
            {
                throw new ArgumentException(Environment.GetResourceString("Task_MultiTaskContinuation_EmptyTaskList"), "tasks");
            }
            // Previously implemented CommonCWAnyLogic() can handle the rest
            return TaskFactory.CommonCWAnyLogic(taskList);
        }
        
        //Creates a task that will complete when all of the supplied tasks have completed.
        public static Task<TResult[]> WhenAll<TResult>(params Task<TResult>[] tasks)
        {
            if (tasks == null) throw new ArgumentNullException("tasks");
            Contract.EndContractBlock();
    
            int taskCount = tasks.Length;
            if (taskCount == 0) return InternalWhenAll<TResult>(tasks); // small optimization in the case of an empty task array
    
            Task<TResult>[] tasksCopy = new Task<TResult>[taskCount];
            for (int i = 0; i < taskCount; i++)
            {
                Task<TResult> task = tasks[i];
                if (task == null) throw new ArgumentException(Environment.GetResourceString("Task_MultiTaskContinuation_NullTask"), "tasks");
                tasksCopy[i] = task;
            }
            return InternalWhenAll<TResult>(tasksCopy);
        }
            
        private static Task<TResult[]> InternalWhenAll<TResult>(Task<TResult>[] tasks)
        {
            Contract.Requires(tasks != null, "Expected a non-null tasks array");
            return (tasks.Length == 0) ? new Task<TResult[]>(false, new TResult[0], TaskCreationOptions.None, default(CancellationToken)) : new WhenAllPromise<TResult>(tasks);
        }
            
        private sealed class WhenAllPromise<T> : Task<T[]>, ITaskCompletionAction
        {
            private readonly Task<T>[] m_tasks;
            private int m_count;
    
            internal WhenAllPromise(Task<T>[] tasks) :base()
            {
                Contract.Requires(tasks != null, "Expected a non-null task array");
                Contract.Requires(tasks.Length > 0, "Expected a non-zero length task array");
                m_tasks = tasks;
                m_count = tasks.Length;
                if (AsyncCausalityTracer.LoggingOn)
                    AsyncCausalityTracer.TraceOperationCreation(CausalityTraceLevel.Required, this.Id, "Task.WhenAll", 0);
    
                if (s_asyncDebuggingEnabled)
                {
                    AddToActiveTasks(this);
                }
                foreach (var task in tasks)
                {
                    if (task.IsCompleted) this.Invoke(task); // short-circuit the completion action, if possible
                    else task.AddCompletionAction(this); // simple completion action
                }
            }
    
            public void Invoke(Task ignored)
            {
                if (AsyncCausalityTracer.LoggingOn)
                    AsyncCausalityTracer.TraceOperationRelation(CausalityTraceLevel.Important, this.Id, CausalityRelation.Join);
    
                // Decrement the count, and only continue to complete the promise if we're the last one.
                if (Interlocked.Decrement(ref m_count) == 0)
                {
                    T[] results = new T[m_tasks.Length];
                    List<ExceptionDispatchInfo> observedExceptions = null;
                    Task canceledTask = null;
    
                    for (int i = 0; i < m_tasks.Length; i++)
                    {
                        Task<T> task = m_tasks[i];
                        Contract.Assert(task != null, "Constituent task in WhenAll should never be null");
    
                        if (task.IsFaulted)
                        {
                            if (observedExceptions == null) observedExceptions = new List<ExceptionDispatchInfo>();
                            observedExceptions.AddRange(task.GetExceptionDispatchInfos());
                        }
                        else if (task.IsCanceled)
                        {
                            if (canceledTask == null) canceledTask = task; // use the first task that's canceled
                        }
                        else
                        {
                            Contract.Assert(task.Status == TaskStatus.RanToCompletion);
                            results[i] = task.GetResultCore(waitCompletionNotification: false); // avoid Result, which would triggering debug notification
                        }                
                        if (task.IsWaitNotificationEnabled) this.SetNotificationForWaitCompletion(enabled: true);
                        else m_tasks[i] = null; // avoid holding onto tasks unnecessarily
                    }
    
                    if (observedExceptions != null)
                    {
                        Contract.Assert(observedExceptions.Count > 0, "Expected at least one exception");
                        TrySetException(observedExceptions);
                    }
                    else if (canceledTask != null)
                    {
                        TrySetCanceled(canceledTask.CancellationToken, canceledTask.GetCancellationExceptionDispatchInfo());
                    }
                    else
                    {
                        if (AsyncCausalityTracer.LoggingOn)
                            AsyncCausalityTracer.TraceOperationCompletion(CausalityTraceLevel.Required, this.Id, AsyncCausalityStatus.Completed);
    
                        if (Task.s_asyncDebuggingEnabled)
                        {
                            RemoveFromActiveTasks(this.Id);
                        }
                        TrySetResult(results);
                    }
                }
                Contract.Assert(m_count >= 0, "Count should never go below 0");
            }
    
            internal override bool ShouldNotifyDebuggerOfWaitCompletion
            {
                get
                {
                    return base.ShouldNotifyDebuggerOfWaitCompletion && Task.AnyTaskRequiresNotifyDebuggerOfWaitCompletion(m_tasks);
                }
            }
        }
    }

    首先我们来看看Task的WhenAny的实现,非常简单调用TaskFactory.CommonCWAnyLogic方法,直接返回Task,而WaitAny【也调用TaskFactory.CommonCWAnyLogic】则需要等待这个Task完成。接下来我们来看看WhenAll的实现,WhenAll核心方法是InternalWhenAll,在InternalWhenAll里面返回了一个WhenAllPromise的Task,WhenAllPromise里面有一个计数器m_count,每当Task完成一个,就调用WhenAllPromise的Invoke方法,实现计数器m_count减1,当计数器m_count为0时表示所有的Task有已经完成。【在Task的WaitAll里面用的是SetOnCountdownMres,和这里的WhenAllPromise相似,都有一个计数器】

    那么我们现在来看看TaskFactory的ContinueWhenAny和ContinueWhenAll的实现

    public class TaskFactory
    {
        public Task ContinueWhenAny<TAntecedentResult>(Task<TAntecedentResult>[] tasks, Action<Task<TAntecedentResult>> continuationAction,
            CancellationToken cancellationToken, TaskContinuationOptions continuationOptions, TaskScheduler scheduler)
        {
            if (continuationAction == null) throw new ArgumentNullException("continuationAction");
            Contract.EndContractBlock();
    
            StackCrawlMark stackMark = StackCrawlMark.LookForMyCaller;
            return TaskFactory<VoidTaskResult>.ContinueWhenAnyImpl<TAntecedentResult>(tasks, null, continuationAction, continuationOptions, cancellationToken, scheduler, ref stackMark);
        }
        // Creates a continuation Task<TResult> ,that will be started upon the completion of a set of provided Tasks.
         public Task<TResult> ContinueWhenAll<TAntecedentResult>(Task<TAntecedentResult>[] tasks, Func<Task<TAntecedentResult>[], TResult> continuationFunction,
            CancellationToken cancellationToken)
        {
            if (continuationFunction == null) throw new ArgumentNullException("continuationFunction");
            Contract.EndContractBlock();
    
            StackCrawlMark stackMark = StackCrawlMark.LookForMyCaller;
            return ContinueWhenAllImpl<TAntecedentResult>(tasks, continuationFunction, null, m_defaultContinuationOptions, cancellationToken, DefaultScheduler, ref stackMark);
        }
    
        internal static Task<TResult> ContinueWhenAnyImpl<TAntecedentResult>(Task<TAntecedentResult>[] tasks,
            Func<Task<TAntecedentResult>, TResult> continuationFunction, Action<Task<TAntecedentResult>> continuationAction,
            TaskContinuationOptions continuationOptions, CancellationToken cancellationToken, TaskScheduler scheduler, ref StackCrawlMark stackMark)
        {
            
            TaskFactory.CheckMultiTaskContinuationOptions(continuationOptions);
            if (tasks == null) throw new ArgumentNullException("tasks");
            if (tasks.Length == 0) throw new ArgumentException(Environment.GetResourceString("Task_MultiTaskContinuation_EmptyTaskList"), "tasks");
            Contract.Requires((continuationFunction != null) != (continuationAction != null), "Expected exactly one of endFunction/endAction to be non-null");
            if (scheduler == null) throw new ArgumentNullException("scheduler");
            Contract.EndContractBlock();
            
            var starter = TaskFactory.CommonCWAnyLogic(tasks);
    
            // Bail early if cancellation has been requested.
            if (cancellationToken.IsCancellationRequested
                && ((continuationOptions & TaskContinuationOptions.LazyCancellation) == 0)
                )
            {
                return CreateCanceledTask(continuationOptions, cancellationToken);
            }
    
            // returned continuation task, off of starter
            if (continuationFunction != null)
            {
                return starter.ContinueWith<TResult>(
                    GenericDelegateCache<TAntecedentResult, TResult>.CWAnyFuncDelegate,
                    continuationFunction, scheduler, cancellationToken, continuationOptions, ref stackMark);
            }
            else
            {
                Contract.Assert(continuationAction != null);
                return starter.ContinueWith<TResult>(
                    // Use a cached delegate
                    GenericDelegateCache<TAntecedentResult,TResult>.CWAnyActionDelegate,
                    continuationAction, scheduler, cancellationToken, continuationOptions, ref stackMark);
            }
        }
         internal static Task<TResult> ContinueWhenAllImpl<TAntecedentResult>(Task<TAntecedentResult>[] tasks,
                Func<Task<TAntecedentResult>[], TResult> continuationFunction, Action<Task<TAntecedentResult>[]> continuationAction,
                TaskContinuationOptions continuationOptions, CancellationToken cancellationToken, TaskScheduler scheduler, ref StackCrawlMark stackMark)
        {
            
            TaskFactory.CheckMultiTaskContinuationOptions(continuationOptions);
            if (tasks == null) throw new ArgumentNullException("tasks");
          
            Contract.Requires((continuationFunction != null) != (continuationAction != null), "Expected exactly one of endFunction/endAction to be non-null");
            if (scheduler == null) throw new ArgumentNullException("scheduler");
            Contract.EndContractBlock();
    
            Task<TAntecedentResult>[] tasksCopy = TaskFactory.CheckMultiContinuationTasksAndCopy<TAntecedentResult>(tasks);
    
            if (cancellationToken.IsCancellationRequested
                && ((continuationOptions & TaskContinuationOptions.LazyCancellation) == 0)
                )
            {
                return CreateCanceledTask(continuationOptions, cancellationToken);
            }
    
            var starter = TaskFactory.CommonCWAllLogic(tasksCopy);
            if (continuationFunction != null)
            {
                return starter.ContinueWith<TResult>(
                   // use a cached delegate
                   GenericDelegateCache<TAntecedentResult, TResult>.CWAllFuncDelegate,
                   continuationFunction, scheduler, cancellationToken, continuationOptions, ref stackMark);
            }
            else
            {
                Contract.Assert(continuationAction != null);
    
                return starter.ContinueWith<TResult>(
                   // use a cached delegate
                   GenericDelegateCache<TAntecedentResult, TResult>.CWAllActionDelegate,
                   continuationAction, scheduler, cancellationToken, continuationOptions, ref stackMark);
            }
        }
       internal static Task<Task<T>[]> CommonCWAllLogic<T>(Task<T>[] tasksCopy)
        {
            Contract.Requires(tasksCopy != null);
    
            // Create a promise task to be returned to the user
            CompleteOnCountdownPromise<T> promise = new CompleteOnCountdownPromise<T>(tasksCopy);
    
            for (int i = 0; i < tasksCopy.Length; i++)
            {
                if (tasksCopy[i].IsCompleted) promise.Invoke(tasksCopy[i]); // Short-circuit the completion action, if possible
                else tasksCopy[i].AddCompletionAction(promise); // simple completion action
            }
    
            return promise;
        }
        private sealed class CompleteOnCountdownPromise<T> : Task<Task<T>[]>, ITaskCompletionAction
        {
            private readonly Task<T>[] _tasks;
            private int _count;
    
            internal CompleteOnCountdownPromise(Task<T>[] tasksCopy) : base()
            {
                Contract.Requires((tasksCopy != null) && (tasksCopy.Length > 0), "Expected non-null task array with at least one element in it");
                _tasks = tasksCopy;
                _count = tasksCopy.Length;
    
                if (AsyncCausalityTracer.LoggingOn)
                    AsyncCausalityTracer.TraceOperationCreation(CausalityTraceLevel.Required, this.Id, "TaskFactory.ContinueWhenAll<>", 0);
    
                if (Task.s_asyncDebuggingEnabled)
                {
                    AddToActiveTasks(this);
                }
            }
    
            public void Invoke(Task completingTask)
            {
                if (AsyncCausalityTracer.LoggingOn)
                    AsyncCausalityTracer.TraceOperationRelation(CausalityTraceLevel.Important, this.Id, CausalityRelation.Join);
    
                if (completingTask.IsWaitNotificationEnabled) this.SetNotificationForWaitCompletion(enabled: true);
                if (Interlocked.Decrement(ref _count) == 0)
                {
                    if (AsyncCausalityTracer.LoggingOn)
                        AsyncCausalityTracer.TraceOperationCompletion(CausalityTraceLevel.Required, this.Id, AsyncCausalityStatus.Completed);
    
                    if (Task.s_asyncDebuggingEnabled)
                    {
                        RemoveFromActiveTasks(this.Id);
                    }
    
                    TrySetResult(_tasks);
                }
                Contract.Assert(_count >= 0, "Count should never go below 0");
            }
    
            internal override bool ShouldNotifyDebuggerOfWaitCompletion
            {
                get
                {
                    return base.ShouldNotifyDebuggerOfWaitCompletion && Task.AnyTaskRequiresNotifyDebuggerOfWaitCompletion(_tasks);
                }
            }
        }
    
    }

    首先我们来看看TaskFactory的ContinueWhenAny方法,ContinueWhenAny方法主要调用的是ContinueWhenAnyImpl,在ContinueWhenAnyImpl里面主要调用的是TaskFactory.CommonCWAnyLogic(tasks)方法,这个方法方返回的是一个CompleteOnInvokePromise Task,在Task的WhenAny方法中直接返回这个CompleteOnInvokePromise task,而Task的WaitAny则需要等待这个CompleteOnInvokePromise task的完成;而TaskFactory的ContinueWhenAny则是返回这个CompleteOnInvokePromise task的ContinueWith方法

    接下来我们在看看TaskFactory的ContinueWhenAll方法,ContinueWhenAll方法主要调用的是ContinueWhenAllImpl方法,ContinueWhenAllImpl方法主要是调用TaskFactory.CommonCWAnyLogic(tasks)方法,TaskFactory.CommonCWAnyLogic(tasks)方法返回一个CompleteOnCountdownPromise<T> 的task,然后ContinueWhenAllImpl最后返回这个task的ContinueWith方法,这里的CompleteOnCountdownPromise和Task的WhenAllPromise相似,里面都有一个计数器来标记里面的task是否执行完毕。

  • 相关阅读:
    YbtOJ#573后缀表达【二分图匹配】
    CF605EIntergalaxy Trips【期望dp】
    YbtOJ#482爬上山顶【凸壳,链表】
    AT4996[AGC034F]RNG and XOR【FWT,生成函数】
    YbtOJ#903染色方案【拉格朗日插值,NTT,分治】
    YbtOJ#832鸽子饲养【凸包,Floyd】
    YbtOJ#463序列划分【二分答案,线段树,dp】
    CF618FDouble Knapsack【结论】
    P3214[HNOI2011]卡农【dp】
    YbtOJ#526折纸游戏【二分,hash】
  • 原文地址:https://www.cnblogs.com/majiang/p/7908001.html
Copyright © 2011-2022 走看看