1、创建一个Fork/Join池
ForkJoinPool 逻辑
If (problem size < default size){ tasks=divide(task); execute(tasks); } else { resolve problem using another algorithm; }
实现一个任务来修改产品列表的价格。任务最初是负责更新一个队列中的所有元素。你将会使用10作为参考大小,如果一个任务必须更新超过10个元素,这些元素将被划分成两个部分,并创建两个任务来更新每个部分中的产品的价格。
public class MyForkJoinPool1 { public static void main(String[] args) { ProductListGenerator generator = new ProductListGenerator(); List<Product> products = generator.generate(10000); Task task = new Task(products, 0, products.size(), 0.20); ForkJoinPool pool = new ForkJoinPool(); pool.execute(task); do { System.out.printf("Main: Thread Count: %d ", pool.getActiveThreadCount()); System.out.printf("Main: Thread Steal: %d ", pool.getStealCount()); System.out.printf("Main: Parallelism: %d ", pool.getParallelism()); try { TimeUnit.MILLISECONDS.sleep(5); } catch (InterruptedException e) { e.printStackTrace(); } } while (!task.isDone()); pool.shutdown(); if (task.isCompletedNormally()) { System.out.printf("Main: The process has completed normally. "); } for (int i = 0; i < products.size(); i++) { Product product = products.get(i); if (product.getPrice() != 12) { System.out.printf("Product %s: %f ", product.getName(), product.getPrice()); } } System.out.println("Main: End of the program. "); } } class Product { private String name; private double price; public String getName() { return name; } public void setName(String name) { this.name = name; } public double getPrice() { return price; } public void setPrice(double price) { this.price = price; } } class ProductListGenerator { public List<Product> generate(int size) { List<Product> ret = new ArrayList<Product>(); for (int i = 0; i < size; i++) { Product product = new Product(); product.setName("Product" + i); product.setPrice(10); ret.add(product); } return ret; } } class Task extends RecursiveAction { private static final long serialVersionUID = 1L; private List<Product> products; private int first; private int last; private double increment; public Task(List<Product> products, int first, int last, double increment) { this.products = products; this.first = first; this.last = last; this.increment = increment; } @Override protected void compute() { if (last - first < 10) { updatePrices(); } else { int middle = (last + first) / 2; System.out.printf("Task: Pending tasks:%s ", getQueuedTaskCount()); Task t1 = new Task(products, first, middle + 1, increment); Task t2 = new Task(products, middle + 1, last, increment); invokeAll(t1, t2); } } private void updatePrices() { for (int i = first; i < last; i++) { Product product = products.get(i); product.setPrice(product.getPrice() * (1 + increment)); } } }
2.加入任务的结果
Fork/Join框架提供了执行返回一个结果的任务的能力。这些任务的类型是实现了RecursiveTask类。这个类继承了ForkJoinTask类和实现了执行者框架提供的Future接口。
If (problem size < size){ tasks=Divide(task); execute(tasks); groupResults() return result; } else { resolve problem; return result; }
如果这个任务必须解决一个超过预定义大小的问题,你应该将这个任务分解成更多的子任务,并且用Fork/Join框架来执行这些子任务。当这些子任务完成执行,发起的任务将获得所有子任务产生的结果 ,对这些结果进行分组,并返回最终的结果。最终,当在池中执行的发起的任务完成它的执行,你将获取整个问题地最终结果。
class MyForkJoinPool2 { public static void main(String[] args) { DocumentMock mock = new DocumentMock(); String[][] document = mock.generateDocument(100, 1000, "the"); DocumentTask task = new DocumentTask(document, 0, 100, "the"); ForkJoinPool pool = new ForkJoinPool(); pool.execute(task); do { System.out.printf("****************************************** "); System.out.printf("Main: Parallelism: %d ", pool.getParallelism()); System.out.printf("Main: Active Threads: %d ", pool.getActiveThreadCount()); System.out.printf("Main: Task Count: %d ", pool.getQueuedTaskCount()); System.out.printf("Main: Steal Count: %d ", pool.getStealCount()); System.out.printf("****************************************** "); try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } } while (!task.isDone()); pool.shutdown(); try { System.out.printf("Main: The word appears %d in the document", task.get()); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } try { System.out.printf("Main: The word appears %d in the document", task.get()); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } } } class DocumentMock { private String words[] = { "the", "hello", "goodbye", "packt", "java", "thread", "pool", "random", "class", "main" }; public String[][] generateDocument(int numLines, int numWords, String word) { int counter = 0; String document[][] = new String[numLines][numWords]; Random random = new Random(); for (int i = 0; i < numLines; i++) { for (int j = 0; j < numWords; j++) { int index = random.nextInt(words.length); document[i][j] = words[index]; if (document[i][j].equals(word)) { counter++; } } } System.out.println("DocumentMock: The word appears " + counter + " times in the document"); return document; } } class DocumentTask extends RecursiveTask<Integer> { /** * */ private static final long serialVersionUID = -7632107634821261866L; private String document[][]; private int start, end; private String word; public DocumentTask(String document[][], int start, int end, String word) { this.document = document; this.start = start; this.end = end; this.word = word; } @Override protected Integer compute() { int result = 0; if (end - start < 10) { result = processLines(document, start, end, word); } else { int mid = (start + end) / 2; DocumentTask task1 = new DocumentTask(document, start, mid, word); DocumentTask task2 = new DocumentTask(document, mid, end, word); invokeAll(task1, task2); try { result = groupResults(task1.get(), task2.get()); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } } return result; } private Integer processLines(String[][] document, int start, int end, String word) { List<LineTask> tasks = new ArrayList<LineTask>(); for (int i = start; i < end; i++) { LineTask task = new LineTask(document[i], 0, document[i].length, word); tasks.add(task); } invokeAll(tasks); int result = 0; for (int i = 0; i < tasks.size(); i++) { LineTask task = tasks.get(i); try { result = result + task.get(); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } } return new Integer(result); } private Integer groupResults(Integer number1, Integer number2) { Integer result; result = number1 + number2; return result; } } class LineTask extends RecursiveTask<Integer> { private static final long serialVersionUID = 1L; private String line[]; private int start, end; private String word; public LineTask(String line[], int start, int end, String word) { this.line = line; this.start = start; this.end = end; this.word = word; } @Override protected Integer compute() { Integer result = null; if (end - start < 100) { result = count(line, start, end, word); } else { int mid = (start + end) / 2; LineTask task1 = new LineTask(line, start, mid, word); LineTask task2 = new LineTask(line, mid, end, word); invokeAll(task1, task2); try { result = groupResults(task1.get(), task2.get()); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } } return result; } private Integer count(String[] line, int start, int end, String word) { int counter; counter = 0; for (int i = start; i < end; i++) { if (line[i].equals(word)) { counter++; } } try { Thread.sleep(10); } catch (InterruptedException e) { e.printStackTrace(); } return counter; } private Integer groupResults(Integer number1, Integer number2) { Integer result; result = number1 + number2; return result; } }
3. 异步方式
当你在ForkJoinPool中执行ForkJoinTask时,你可以使用同步或异步方式来实现。当你使用同步方式时,提交任务给池的方法直到提交的任务完成它的执行,才会返回结果。当你使用异步方式时,提交任务给执行者的方法将立即返回,所以这个任务可以继续执行。
你应该意识到这两个方法有很大的区别,当你使用同步方法,调用这些方法(比如:invokeAll()方法)的任务将被阻塞,直到提交给池的任务完成它的执行。这允许ForkJoinPool类使用work-stealing算法,分配一个新的任务给正在执行睡眠任务的工作线程。反之,当你使用异步方法(比如:fork()方法),这个任务将继续它的执行,所以ForkJoinPool类不能使用work-stealing算法来提高应用程序的性能。在这种情况下,只有当你调用join()或get()方法来等待任务的完成时,ForkJoinPool才能使用work-stealing算法。
public class MyForkJoinPool3 { public static void main(String[] args) { ForkJoinPool pool = new ForkJoinPool(); FolderProcessor system = new FolderProcessor("C:\Windows", "log"); FolderProcessor apps = new FolderProcessor("C:\Program Files", "log"); FolderProcessor documents = new FolderProcessor( "C:\Documents And Settings", "log"); pool.execute(system); pool.execute(apps); pool.execute(documents); do { System.out.printf("****************************************** "); System.out.printf("Main: Parallelism: %d ", pool.getParallelism()); System.out.printf("Main: Active Threads: %d ", pool.getActiveThreadCount()); System.out.printf("Main: Task Count: %d ", pool.getQueuedTaskCount()); System.out.printf("Main: Steal Count: %d ", pool.getStealCount()); System.out.printf("****************************************** "); try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } } while ((!system.isDone()) || (!apps.isDone()) || (!documents.isDone())); pool.shutdown(); List<String> results; results = system.join(); System.out.printf("System: %d files found. ", results.size()); results = apps.join(); System.out.printf("Apps: %d files found. ", results.size()); results = documents.join(); System.out.printf("Documents: %d files found. ", results.size()); } } class FolderProcessor extends RecursiveTask<List<String>> { private String path; private String extension; public FolderProcessor(String path, String extension) { this.path = path; this.extension = extension; } @Override protected List<String> compute() { List<String> list = new ArrayList<>(); List<FolderProcessor> tasks = new ArrayList<>(); File file = new File(path); File content[] = file.listFiles(); if (content != null) { for (int i = 0; i < content.length; i++) { if (content[i].isDirectory()) { FolderProcessor task = new FolderProcessor( content[i].getAbsolutePath(), extension); task.fork(); tasks.add(task); } else { if (checkFile(content[i].getName())) { list.add(content[i].getAbsolutePath()); } } } if (tasks.size() > 50) { System.out.printf("%s: %d tasks ran. ", file.getAbsolutePath(), tasks.size()); } addResultsFromTasks(list, tasks); } return list; } private void addResultsFromTasks(List<String> list, List<FolderProcessor> tasks) { for (FolderProcessor item : tasks) { list.addAll(item.join()); } } private boolean checkFile(String name) { return name.endsWith(extension); } }
4. 在任务中抛出异常
在ForkJoinTask类的compute()方法中,你不能抛出任何已检查异常,因为在这个方法的实现中,它没有包含任何抛出(异常)声明。你必须包含必要的代码来处理异常。但是,你可以抛出(或者它可以被任何方法或使用内部方法的对象抛出)一个未检查异常。ForkJoinTask和ForkJoinPool类的行为与你可能的期望不同。程序不会结束执行,并且你将不会在控制台看到任何关于异常的信息。它只是被吞没,好像它没抛出(异常)。你可以使用ForkJoinTask类的一些方法,得知一个任务是否抛出异常及其异常种类。
public class MyForkJoinPool4 { public static void main(String[] args) { int array[] = new int[100]; MyTask task = new MyTask(array, 0, 100); ForkJoinPool pool = new ForkJoinPool(); pool.execute(task); pool.shutdown(); try { pool.awaitTermination(1, TimeUnit.DAYS); } catch (InterruptedException e) { e.printStackTrace(); } if (task.isCompletedAbnormally()) { System.out.printf("Main: An exception has ocurred "); System.out.printf("Main: %s ", task.getException()); } System.out.printf("Main: Result: %d", task.join()); } } class MyTask extends RecursiveTask<Integer> { private int array[]; private int start, end; public MyTask(int array[], int start, int end) { this.array = array; this.start = start; this.end = end; } @Override protected Integer compute() { System.out.printf("Task: Start from %d to %d ", start, end); if (end - start < 10) { if ((3 > start) && (3 < end)) { throw new RuntimeException("This task throws an" + "Exception: Task from " + start + " to " + end); } try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } } else { int mid = (end + start) / 2; MyTask task1 = new MyTask(array, start, mid); MyTask task2 = new MyTask(array, mid, end); invokeAll(task1, task2); } System.out.printf("Task: End form %d to %d ", start, end); return 0; } }
5.取消任务
当你在一个ForkJoinPool类中执行ForkJoinTask对象,在它们开始执行之前,你可以取消执行它们。ForkJoinTask类提供cancel()方法用于这个目的。当你想要取消一个任务时,有一些点你必须考虑一下,这些点如下:
- ForkJoinPool类并没有提供任何方法来取消正在池中运行或等待的所有任务。
- 当你取消一个任务时,你不能取消一个已经执行的任务。
public final class MyForkJoinPool5 { public static void main(String[] args) { ArrayGenerator generator = new ArrayGenerator(); int array[] = generator.generateArray(1000); TaskManager manager = new TaskManager(); ForkJoinPool pool = new ForkJoinPool(); SearchNumberTask task = new SearchNumberTask(array, 0, 1000, 5, manager); pool.execute(task); pool.shutdown(); try { pool.awaitTermination(1, TimeUnit.DAYS); } catch (InterruptedException e) { e.printStackTrace(); } System.out.printf("Main: The program has finished "); } } class ArrayGenerator { public int[] generateArray(int size) { int array[] = new int[size]; Random random = new Random(); for (int i = 0; i < size; i++) { array[i] = random.nextInt(10); } return array; } } class TaskManager { private List<ForkJoinTask<Integer>> tasks; public TaskManager() { tasks = new ArrayList<>(); } public void addTask(ForkJoinTask<Integer> task) { tasks.add(task); } public void cancelTasks(ForkJoinTask<Integer> cancelTask) { for (ForkJoinTask<Integer> task : tasks) { if (task != cancelTask) { task.cancel(true); ((SearchNumberTask) task).writeCancelMessage(); } } } } class SearchNumberTask extends RecursiveTask<Integer> { private int numbers[]; private int start, end; private int number; private TaskManager manager; private final static int NOT_FOUND = -1; public SearchNumberTask(int numbers[], int start, int end, int number, TaskManager manager) { this.numbers = numbers; this.start = start; this.end = end; this.number = number; this.manager = manager; } @Override protected Integer compute() { System.out.println("Task: " + start + ":" + end); int ret; if (end - start > 10) { ret = launchTasks(); } else { ret = lookForNumber(); } return ret; } private int lookForNumber() { for (int i = start; i < end; i++) { if (numbers[i] == number) { System.out.printf("Task: Number %d found in position %d ", number, i); manager.cancelTasks(this); return i; } try { TimeUnit.SECONDS.sleep(1); } catch (InterruptedException e) { e.printStackTrace(); } } return NOT_FOUND; } private int launchTasks() { int mid = (start + end) / 2; SearchNumberTask task1 = new SearchNumberTask(numbers, start, mid, number, manager); SearchNumberTask task2 = new SearchNumberTask(numbers, mid, end, number, manager); manager.addTask(task1); manager.addTask(task2); task1.fork(); task2.fork(); int returnValue; returnValue = task1.join(); if (returnValue != -1) { return returnValue; } returnValue = task2.join(); return returnValue; } public void writeCancelMessage() { System.out.printf("Task: Canceled task from %d to %d", start, end); } }