zoukankan      html  css  js  c++  java
  • JAVA多线程(九) ForkJoin框架

    Fork/Join框架是Java7提供的一个用于并行执行任务的框架,是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。
    Fork就是把一个大任务切分为若干子任务并行的执行,Join就是合并这些子任务的执行结果,最后得到这个大任务的结果。比如处理100个任务,可以分割成20个子任务,每个子任务分别处理5个,最终汇总这20个子任务的结果。

    工作窃取算法 

    工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。那么,为什么需要使用工作窃取算法呢?假如我们需要做一个比较大的任务,
    可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,
    线程和队列一一对应。比如A线程负责处理A队列里的任务。但是,有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。
    干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,
    所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,
    而窃取任务的线程永远从双端队列的尾部拿任务执行。

     

    工作窃取算法的优点:充分利用线程进行并行计算,减少了线程间的竞争。

    工作窃取算法的缺点:在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且该算法会消耗了更多的系统资源,比如创建多个线程和多个双端队列。

    ForkJoin框架的设计

    分割任务

    首先我们需要有一个fork类来把大任务分割成子任务,有可能子任务还是很大,所以还需要不停地分割,直到分割出的子任务足够小

    执行任务并合并结果

    分割的子任务分别放在双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都统一放在一个队列里,启动一个线程从队列里拿数据,然后合并这些数据

    Fork/Join使用两个类来完成以上两件事情
    1. ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制。通常情况下,我们不需要直接继承ForkJoinTask类,只需要继承它的子类,Fork/Join框架提供了以下两个子类
       - RecursiveAction:用于没有返回结果的任务
       - RecursiveTask:用于有返回结果的任务
    2. ForkJoinPool:ForkJoinTask需要通过ForkJoinPool来执行。任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务

    ForkJoin Demo演示

    代码地址: https://gitee.com/showkawa_admin/netty-annotation/blob/master/src/main/java/com/brian/mutilthread/forkjoin/service/AutomationTask.java

    假设有个复杂的批量自动化任务要分割为单个子任务去执行,跑完全部子任务后要汇总每一个任务的结果到一个集合中统一返回

    自动化任务类

    package com.brian.mutilthread.forkjoin.service;
    
    import lombok.extern.slf4j.Slf4j;
    
    import java.util.*;
    import java.util.concurrent.RecursiveTask;
    
    
    @Slf4j
    public class AutomationTask extends RecursiveTask<List<Map<String, String>>> {
    
        static List<Map<String, String>> resultList;
        private List<String> list;
        private int start;
        private int end;
    
        static {
            resultList = new ArrayList<>(8);
        }
    
        public AutomationTask(List<String> list, int start, int end) {
            this.list = list;
            this.start = start;
            this.end = end;
            if (resultList.size() >= 8) {
                resultList.clear();
            }
        }
    
        @Override
        protected List<Map<String, String>> compute() {
            if ((end - start) < 1) {
                log.info("=== {} === {}-{}", Thread.currentThread().getName(), start, list.get(start));
                Map<String, String> result = new HashMap<>();
                result.put("region", list.get(start));
                try {
              // 任务2的定义 ServiceTask serviceTask
    = uuid -> { int rad = (int) (Math.random() * 100); if (rad > 80) { throw new Exception("getTransitions exception"); } Thread.sleep(rad); return UUID.randomUUID().toString().replace("-", ""); }; // step 1 String parameter = serviceTask.getParameter(list.get(start)); // step 2 String task = serviceTask.createTask(parameter); // step 3 String transitions = serviceTask.getTransitions(task); // step 4 String s = serviceTask.transferStatus(transitions); result.put("status", s); } catch (Exception e) { result.put("error", e.toString()); } resultList.add(result); } else { int middle = (start + end) / 2; AutomationTask leftTask = new AutomationTask(list, start, middle); log.info("=== {} === fork left {}-{}", Thread.currentThread().getName(), start, middle); AutomationTask rightTask = new AutomationTask(list, middle + 1, end); log.info("=== {} === fork right {}-{}", Thread.currentThread().getName(), middle + 1, end); leftTask.fork(); rightTask.compute(); List<Map<String, String>> leftList = leftTask.join(); log.info("=== {} === leftList {}", Thread.currentThread().getName(), leftList.toArray()); } return resultList; } }

    复杂的业务类

    模拟一个复杂的业务,中途可能会出现异常

    package com.brian.mutilthread.forkjoin.service;
    
    
    import java.util.UUID;
    
    @FunctionalInterface
    public interface ServiceTask {
        // 1
        default String getParameter(String region) throws Exception {
            int rad = (int) (Math.random() * 100);
            if (rad > 80) {
                throw new Exception("getParameter exception");
            }
            Thread.sleep(rad);
            return UUID.randomUUID().toString().replace("-", "");
        }
        // 2
        String createTask(String uuid) throws Exception;
        // 3
        default String getTransitions(String jid) throws Exception {
            int rad = (int) (Math.random() * 100);
            if (rad > 80) {
                throw new Exception("getTransitions exception");
            }
            Thread.sleep(rad);
            return UUID.randomUUID().toString().replace("-", "");
        }
        // 4
        default String transferStatus(String tid) throws Exception {
            int rad = (int) (Math.random() * 100);
            if (rad > 80) {
                throw new Exception("transferStatus exception");
            }
            Thread.sleep(rad);
            return UUID.randomUUID().toString().replace("-", "");
        }
    }

    测试类

    package com.brian.mutilthread.forkjoin.controller;
    
    
    import com.brian.mutilthread.forkjoin.service.AutomationTask;
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.util.StopWatch;
    import org.springframework.web.bind.annotation.GetMapping;
    import org.springframework.web.bind.annotation.RequestParam;
    import org.springframework.web.bind.annotation.RestController;
    import reactor.core.publisher.Flux;
    
    import java.util.Arrays;
    import java.util.List;
    import java.util.Map;
    import java.util.concurrent.ExecutionException;
    import java.util.concurrent.ForkJoinPool;
    import java.util.concurrent.Future;
    import java.util.stream.Collectors;
    
    @RestController
    @Slf4j
    public class ForkJoinTestController {
    
        @Autowired
        public ForkJoinPool forkJoinPool;
    
        @GetMapping("/getForkJoinResult")
        public Flux<Map<String, String>> getForkJoinResult(@RequestParam(value = "country", defaultValue = "CN,HK,JP,KR,SG,TH,TW,ER") String country)
                throws ExecutionException, InterruptedException {
            String[] countries = country.split(",");
            if (countries.length < 8) {
                return Flux.empty();
            }
            StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            AutomationTask computeTask = new AutomationTask(Arrays.stream(countries).collect(Collectors.toList()), 0, countries.length - 1);
            Future<List<Map<String, String>>> results = forkJoinPool.submit(computeTask);
            if(computeTask.isCompletedAbnormally()){
                log.info("<><><><><><><><><><> automationTask exception: {}", computeTask.getException());
            }
    
            List<Map<String, String>> res = results.get();
            log.info(">>>>>>>>>>>>>>>>>>>>result size : {}", res.size());
            Flux<Map<String, String>> mapFlux = Flux.fromIterable(res);
            stopWatch.stop();
            log.info(">>>>>>>>>>>total handle time: {} ms", stopWatch.getTotalTimeMillis());
            return mapFlux;
        }
    
    }

     

    ForkJoinTask在执行的时候可能会抛出异常,但是我们没办法在主线程里直接捕获异常,所以ForkJoinTask提供了isCompletedAbnormally()方法来检查任务是否已经抛出异常或已经被取消,并且可以通过ForkJoinTask的getException方法获取异常。如上面的测试类中有个如下的代码片段

            if(computeTask.isCompletedAbnormally()){
                log.info("<><><><><><><><><><> automationTask exception: {}", computeTask.getException());
            }

    getException方法返回Throwable对象,如果任务被取消了则返回CancellationException。如果任务没有完成或者没有抛出异常则返回null

    ForkJoin框架的原理

    ForkJoinPool类

    //继承AbstractExecutorService 类
    public class ForkJoinPool extends AbstractExecutorService{
    
        //任务队列数组,存储了所有任务队列,包括内部队列和外部队列
        volatile WorkQueue[] workQueues;     // main registry
    
        //一个静态常量,ForkJoinPool 提供的内部公用的线程池
        static final ForkJoinPool common;
    
        //默认的线程工厂类
      public static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory;
    
    }

    ForkJoinWorkerThread类

    //继承Thread 类
    public class ForkJoinWorkerThread extends Thread {
    
      //线程工作的线程池,即此线程所属的线程池
      final ForkJoinPool pool;
    
      // 线程的内部队列
      final ForkJoinPool.WorkQueue workQueue;
    
    //.....
    }

    ForkJoinPool中线程的创建

    默认的线程工厂类,ForkJoinPool 中的线程是由默认的线程工厂类 defaultForkJoinWorkerThreadFactory 创建的

    //默认的工厂类
      public static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory;
    
    defaultForkJoinWorkerThreadFactory =
                new DefaultForkJoinWorkerThreadFactory();

    defaultForkJoinWorkerThreadFactory 创建线程的方法 newThread(),其实就是传入当前的线程池,直接创建

        /**
         * Default ForkJoinWorkerThreadFactory implementation; creates a
         * new ForkJoinWorkerThread using the system class loader as the
         * thread context class loader.
         */
        private static final class DefaultForkJoinWorkerThreadFactory
            implements ForkJoinWorkerThreadFactory {
            private static final AccessControlContext ACC = contextWithPermissions(
                new RuntimePermission("getClassLoader"),
                new RuntimePermission("setContextClassLoader"));
    
            public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
                return AccessController.doPrivileged(
                    new PrivilegedAction<>() {
                        public ForkJoinWorkerThread run() {
                            return new ForkJoinWorkerThread(
                                pool, ClassLoader.getSystemClassLoader()); }},
                    ACC);
            }
        }

    ForkJoinWorkerThread 的构造方法

     protected ForkJoinWorkerThread(ForkJoinPool pool) {
            // Use a placeholder until a useful name can be set in registerWorker
            super("aForkJoinWorkerThread");
            //线程工作的线程池,即创建这个线程的线程池
            this.pool = pool;
            //注册线程到线程池中,并返回此线程的内部任务队列
            this.workQueue = pool.registerWorker(this);
        }

    创建一个工作线程,最后一步还要注册到其所属的线程池中, registerWorker这里不展开了

    ForkJoinTask的fork()方法

    public final ForkJoinTask<V> fork() {
            Thread t;
            //判断是否是一个工作线程
            if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
                //加入到内部队列中
                ((ForkJoinWorkerThread)t).workQueue.push(this);
            else//由common线程池来执行任务
                ForkJoinPool.common.externalPush(this);
            return this;
        }


    fork()方法先判断当前线程(调用fork()来提交任务的线程)是不是一个 ForkJoinWorkerThread 的工作线程,如果是,则将任务加入到内部队列中,否则,由 ForkJoinPool 提供的内部公用的线程池common线程池 来执行这个任务。我们可以在普通线程池中直接调用 fork() 方法来提交任务到一个默认提供的线程池中。这将非常方便。假如,你要在程序中处理大任务,需要分治编程,但你仅仅只处理一次,以后就不会用到,而且任务不算太大,不需要设置特定的参数,那么你肯定不想为此创建一个线程池,这时默认的提供的线程池将会很有用。

    ForkJoinTask的join()方法

     public final V join() {
            int s;
            if ((s = doJoin() & DONE_MASK) != NORMAL)
                reportException(s);
            return getRawResult();//直接返回结果
        }
    private int doJoin() {
            int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
            return 
                //如果完成,直接返回s
                (s = status) < 0 ? s : 
                //没有完成,判断是不是池中的 ForkJoinWorkerThread 工作线程
                ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
                //如果是池中线程,执行这里
                (w = (wt = (ForkJoinWorkerThread)t).workQueue).
                tryUnpush(this) && (s = doExec()) < 0 ? s :
                wt.pool.awaitJoin(w, this, 0L) :
                //如果不是池中的线程池,则执行这里
                externalAwaitDone();
        }

    join()方法有执行一个重要的方法doJoin(), 当dojoin()方法发现任务没有完成且当前线程是池中线程时,执行了 tryUnpush()方法。tryUnpush()方法尝试去执行此任务:如果要join的任务正好在当前任务队列的顶端,那么pop出这个任务,然后调用 doExec() 让当前线程去执行这个任务

    final boolean tryUnpush(ForkJoinTask<?> t) {
                ForkJoinTask<?>[] a; int s;
                if ((a = array) != null && (s = top) != base &&
                    U.compareAndSwapObject
                    (a, (((a.length - 1) & --s) << ASHIFT) + ABASE, t, null)) {
                    U.putOrderedInt(this, QTOP, s);
                    return true;
                }
                return false;
            }



    final int doExec() { int s; boolean completed; if ((s = status) >= 0) { try { completed = exec(); } catch (Throwable rex) { return setExceptionalCompletion(rex); } if (completed) s = setCompletion(NORMAL); } return s; }

    如果任务不是处于队列的顶端,那么就会执行 awaitJoin() 方法

     final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
            int s = 0;
            if (task != null && w != null) {
                ForkJoinTask<?> prevJoin = w.currentJoin;
                U.putOrderedObject(w, QCURRENTJOIN, task);
                CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
                    (CountedCompleter<?>)task : null;
                for (;;) {
                    if ((s = task.status) < 0)//如果任务完成了,跳出死循环
                        break;
                    if (cc != null)//当前任务是CountedCompleter类型,则尝试从任务队列中获取当前任务的派生子任务来执行;
                        helpComplete(w, cc, 0);
                    else if (w.base == w.top || w.tryRemoveAndExec(task))//如果当前线程的内部队列为空,或者成功完成了任务,帮助某个线程完成任务。
                        helpStealer(w, task);
                    if ((s = task.status) < 0)//任务完成,跳出死循环
                        break;
                    long ms, ns;
                    if (deadline == 0L)
                        ms = 0L;
                    else if ((ns = deadline - System.nanoTime()) <= 0L)
                        break;
                    else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                        ms = 1L;
                    if (tryCompensate(w)) {
                        task.internalWait(ms);
                        U.getAndAddLong(this, CTL, AC_UNIT);
                    }
                }
                U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
            }
            return s;
        }
        /**
         * Tries to locate and execute tasks for a stealer of the given
         * task, or in turn one of its stealers, Traces currentSteal ->
         * currentJoin links looking for a thread working on a descendant
         * of the given task and with a non-empty queue to steal back and
         * execute tasks from. The first call to this method upon a
         * waiting join will often entail scanning/search, (which is OK
         * because the joiner has nothing better to do), but this method
         * leaves hints in workers to speed up subsequent calls.
         *
         * @param w caller
         * @param task the task to join
         */
        private void helpStealer(WorkQueue w, ForkJoinTask<?> task) {
            WorkQueue[] ws = workQueues;
            int oldSum = 0, checkSum, m;
            if (ws != null && (m = ws.length - 1) >= 0 && w != null &&
                task != null) {
                do {                                       // restart point
                    checkSum = 0;                          // for stability check
                    ForkJoinTask<?> subtask;
                    WorkQueue j = w, v;                    // v is subtask stealer
                    descent: for (subtask = task; subtask.status >= 0; ) {
                        for (int h = j.hint | 1, k = 0, i; ; k += 2) {
                            if (k > m)                     // can't find stealer
                                break descent;
                            if ((v = ws[i = (h + k) & m]) != null) {
                                if (v.currentSteal == subtask) {
                                    j.hint = i;
                                    break;
                                }
                                checkSum += v.base;
                            }
                        }
                        for (;;) {                         // help v or descend
                            ForkJoinTask<?>[] a; int b;
                            checkSum += (b = v.base);
                            ForkJoinTask<?> next = v.currentJoin;
                            if (subtask.status < 0 || j.currentJoin != subtask ||
                                v.currentSteal != subtask) // stale
                                break descent;
                            if (b - v.top >= 0 || (a = v.array) == null) {
                                if ((subtask = next) == null)
                                    break descent;
                                j = v;
                                break;
                            }
                            int i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                            ForkJoinTask<?> t = ((ForkJoinTask<?>)
                                                 U.getObjectVolatile(a, i));
                            if (v.base == b) {
                                if (t == null)             // stale
                                    break descent;
                                if (U.compareAndSwapObject(a, i, t, null)) {
                                    v.base = b + 1;
                                    ForkJoinTask<?> ps = w.currentSteal;
                                    int top = w.top;
                                    do {
                                        U.putOrderedObject(w, QCURRENTSTEAL, t);
                                        t.doExec();        // clear local tasks too
                                    } while (task.status >= 0 &&
                                             w.top != top &&
                                             (t = w.pop()) != null);
                                    U.putOrderedObject(w, QCURRENTSTEAL, ps);
                                    if (w.base != w.top)
                                        return;            // can't further help
                                }
                            }
                        }
                    }
                } while (task.status >= 0 && oldSum != (oldSum = checkSum));
            }
        }

    上面的helpStealer()方法,原则是你帮助我执行任务,我也帮你执行任务。

    1.遍历奇数下标,如果发现队列对象currentSteal放置的刚好是自己要找的任务,则说明自己的任务被该队列a的owner线程偷来执行
    2.如果队列a队列中有任务,则从队尾(base)取出执行;
    3.如果发现队列b队列为空,则根据它正在join的任务,在拓扑找到相关的队列B去偷取任务执行。在执行的过程中要注意,我们应该完整的把任务完成

    参考链接:

    1. Fork/Join框架解析 - JDK1.8

    2. Fork/Join 框架-设计与实现(翻译自论文《A Java Fork/Join Framework》原作者 Doug Lea)

  • 相关阅读:
    Java中的subList方法
    某同学工作之后的感悟
    存放80000000学生成绩的集合,怎么统计平均分性能高
    为了金秋那沉甸甸的麦穗,我绝不辜负春天
    subList?? subString???
    "爸妈没多大本事"……
    中秋节支付宝口令红包解析
    算法>动态规划(一) 小强斋
    数据结构>优先队列(堆) 小强斋
    算法>贪心算法 小强斋
  • 原文地址:https://www.cnblogs.com/hlkawa/p/15115064.html
Copyright © 2011-2022 走看看