Path: blob/master/src/java.base/share/classes/java/util/ArrayPrefixHelpers.java
41152 views
/*1* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.2*3* This code is free software; you can redistribute it and/or modify it4* under the terms of the GNU General Public License version 2 only, as5* published by the Free Software Foundation. Oracle designates this6* particular file as subject to the "Classpath" exception as provided7* by Oracle in the LICENSE file that accompanied this code.8*9* This code is distributed in the hope that it will be useful, but WITHOUT10* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or11* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License12* version 2 for more details (a copy is included in the LICENSE file that13* accompanied this code).14*15* You should have received a copy of the GNU General Public License version16* 2 along with this work; if not, write to the Free Software Foundation,17* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.18*19* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA20* or visit www.oracle.com if you need additional information or have any21* questions.22*/2324/*25* This file is available under and governed by the GNU General Public26* License version 2 only, as published by the Free Software Foundation.27* However, the following notice accompanied the original version of this28* file:29*30* Written by Doug Lea with assistance from members of JCP JSR-16631* Expert Group and released to the public domain, as explained at32* http://creativecommons.org/publicdomain/zero/1.0/33*/3435package java.util;3637import java.util.concurrent.CountedCompleter;38import java.util.concurrent.ForkJoinPool;39import java.util.function.BinaryOperator;40import java.util.function.DoubleBinaryOperator;41import java.util.function.IntBinaryOperator;42import java.util.function.LongBinaryOperator;4344/**45* ForkJoin tasks to perform Arrays.parallelPrefix operations.46*47* @author Doug Lea48* @since 1.849*/50class ArrayPrefixHelpers {51private ArrayPrefixHelpers() {} // non-instantiable5253/*54* Parallel prefix (aka cumulate, scan) task classes55* are based loosely on Guy Blelloch's original56* algorithm (http://www.cs.cmu.edu/~scandal/alg/scan.html):57* Keep dividing by two to threshold segment size, and then:58* Pass 1: Create tree of partial sums for each segment59* Pass 2: For each segment, cumulate with offset of left sibling60*61* This version improves performance within FJ framework mainly by62* allowing the second pass of ready left-hand sides to proceed63* even if some right-hand side first passes are still executing.64* It also combines first and second pass for leftmost segment,65* and skips the first pass for rightmost segment (whose result is66* not needed for second pass). It similarly manages to avoid67* requiring that users supply an identity basis for accumulations68* by tracking those segments/subtasks for which the first69* existing element is used as base.70*71* Managing this relies on ORing some bits in the pendingCount for72* phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the73* main phase bit. When false, segments compute only their sum.74* When true, they cumulate array elements. CUMULATE is set at75* root at beginning of second pass and then propagated down. But76* it may also be set earlier for subtrees with lo==0 (the left77* spine of tree). SUMMED is a one bit join count. For leafs, it78* is set when summed. For internal nodes, it becomes true when79* one child is summed. When the second child finishes summing,80* we then moves up tree to trigger the cumulate phase. FINISHED81* is also a one bit join count. For leafs, it is set when82* cumulated. For internal nodes, it becomes true when one child83* is cumulated. When the second child finishes cumulating, it84* then moves up tree, completing at the root.85*86* To better exploit locality and reduce overhead, the compute87* method loops starting with the current task, moving if possible88* to one of its subtasks rather than forking.89*90* As usual for this sort of utility, there are 4 versions, that91* are simple copy/paste/adapt variants of each other. (The92* double and int versions differ from long version solely by93* replacing "long" (with case-matching)).94*/9596// see above97static final int CUMULATE = 1;98static final int SUMMED = 2;99static final int FINISHED = 4;100101/** The smallest subtask array partition size to use as threshold */102static final int MIN_PARTITION = 16;103104static final class CumulateTask<T> extends CountedCompleter<Void> {105@SuppressWarnings("serial") // Not statically typed as Serializable106final T[] array;107@SuppressWarnings("serial") // Not statically typed as Serializable108final BinaryOperator<T> function;109CumulateTask<T> left, right;110@SuppressWarnings("serial") // Not statically typed as Serializable111T in;112@SuppressWarnings("serial") // Not statically typed as Serializable113T out;114final int lo, hi, origin, fence, threshold;115116/** Root task constructor */117public CumulateTask(CumulateTask<T> parent,118BinaryOperator<T> function,119T[] array, int lo, int hi) {120super(parent);121this.function = function; this.array = array;122this.lo = this.origin = lo; this.hi = this.fence = hi;123int p;124this.threshold =125(p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3))126<= MIN_PARTITION ? MIN_PARTITION : p;127}128129/** Subtask constructor */130CumulateTask(CumulateTask<T> parent, BinaryOperator<T> function,131T[] array, int origin, int fence, int threshold,132int lo, int hi) {133super(parent);134this.function = function; this.array = array;135this.origin = origin; this.fence = fence;136this.threshold = threshold;137this.lo = lo; this.hi = hi;138}139140public final void compute() {141final BinaryOperator<T> fn;142final T[] a;143if ((fn = this.function) == null || (a = this.array) == null)144throw new NullPointerException(); // hoist checks145int th = threshold, org = origin, fnc = fence, l, h;146CumulateTask<T> t = this;147outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {148if (h - l > th) {149CumulateTask<T> lt = t.left, rt = t.right, f;150if (lt == null) { // first pass151int mid = (l + h) >>> 1;152f = rt = t.right =153new CumulateTask<T>(t, fn, a, org, fnc, th, mid, h);154t = lt = t.left =155new CumulateTask<T>(t, fn, a, org, fnc, th, l, mid);156}157else { // possibly refork158T pin = t.in;159lt.in = pin;160f = t = null;161if (rt != null) {162T lout = lt.out;163rt.in = (l == org ? lout :164fn.apply(pin, lout));165for (int c;;) {166if (((c = rt.getPendingCount()) & CUMULATE) != 0)167break;168if (rt.compareAndSetPendingCount(c, c|CUMULATE)){169t = rt;170break;171}172}173}174for (int c;;) {175if (((c = lt.getPendingCount()) & CUMULATE) != 0)176break;177if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {178if (t != null)179f = t;180t = lt;181break;182}183}184if (t == null)185break;186}187if (f != null)188f.fork();189}190else {191int state; // Transition to sum, cumulate, or both192for (int b;;) {193if (((b = t.getPendingCount()) & FINISHED) != 0)194break outer; // already done195state = ((b & CUMULATE) != 0 ? FINISHED :196(l > org) ? SUMMED : (SUMMED|FINISHED));197if (t.compareAndSetPendingCount(b, b|state))198break;199}200201T sum;202if (state != SUMMED) {203int first;204if (l == org) { // leftmost; no in205sum = a[org];206first = org + 1;207}208else {209sum = t.in;210first = l;211}212for (int i = first; i < h; ++i) // cumulate213a[i] = sum = fn.apply(sum, a[i]);214}215else if (h < fnc) { // skip rightmost216sum = a[l];217for (int i = l + 1; i < h; ++i) // sum only218sum = fn.apply(sum, a[i]);219}220else221sum = t.in;222t.out = sum;223for (CumulateTask<T> par;;) { // propagate224@SuppressWarnings("unchecked") CumulateTask<T> partmp225= (CumulateTask<T>)t.getCompleter();226if ((par = partmp) == null) {227if ((state & FINISHED) != 0) // enable join228t.quietlyComplete();229break outer;230}231int b = par.getPendingCount();232if ((b & state & FINISHED) != 0)233t = par; // both done234else if ((b & state & SUMMED) != 0) { // both summed235int nextState; CumulateTask<T> lt, rt;236if ((lt = par.left) != null &&237(rt = par.right) != null) {238T lout = lt.out;239par.out = (rt.hi == fnc ? lout :240fn.apply(lout, rt.out));241}242int refork = (((b & CUMULATE) == 0 &&243par.lo == org) ? CUMULATE : 0);244if ((nextState = b|state|refork) == b ||245par.compareAndSetPendingCount(b, nextState)) {246state = SUMMED; // drop finished247t = par;248if (refork != 0)249par.fork();250}251}252else if (par.compareAndSetPendingCount(b, b|state))253break outer; // sib not ready254}255}256}257}258@java.io.Serial259private static final long serialVersionUID = 5293554502939613543L;260}261262static final class LongCumulateTask extends CountedCompleter<Void> {263final long[] array;264@SuppressWarnings("serial") // Not statically typed as Serializable265final LongBinaryOperator function;266LongCumulateTask left, right;267long in, out;268final int lo, hi, origin, fence, threshold;269270/** Root task constructor */271public LongCumulateTask(LongCumulateTask parent,272LongBinaryOperator function,273long[] array, int lo, int hi) {274super(parent);275this.function = function; this.array = array;276this.lo = this.origin = lo; this.hi = this.fence = hi;277int p;278this.threshold =279(p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3))280<= MIN_PARTITION ? MIN_PARTITION : p;281}282283/** Subtask constructor */284LongCumulateTask(LongCumulateTask parent, LongBinaryOperator function,285long[] array, int origin, int fence, int threshold,286int lo, int hi) {287super(parent);288this.function = function; this.array = array;289this.origin = origin; this.fence = fence;290this.threshold = threshold;291this.lo = lo; this.hi = hi;292}293294public final void compute() {295final LongBinaryOperator fn;296final long[] a;297if ((fn = this.function) == null || (a = this.array) == null)298throw new NullPointerException(); // hoist checks299int th = threshold, org = origin, fnc = fence, l, h;300LongCumulateTask t = this;301outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {302if (h - l > th) {303LongCumulateTask lt = t.left, rt = t.right, f;304if (lt == null) { // first pass305int mid = (l + h) >>> 1;306f = rt = t.right =307new LongCumulateTask(t, fn, a, org, fnc, th, mid, h);308t = lt = t.left =309new LongCumulateTask(t, fn, a, org, fnc, th, l, mid);310}311else { // possibly refork312long pin = t.in;313lt.in = pin;314f = t = null;315if (rt != null) {316long lout = lt.out;317rt.in = (l == org ? lout :318fn.applyAsLong(pin, lout));319for (int c;;) {320if (((c = rt.getPendingCount()) & CUMULATE) != 0)321break;322if (rt.compareAndSetPendingCount(c, c|CUMULATE)){323t = rt;324break;325}326}327}328for (int c;;) {329if (((c = lt.getPendingCount()) & CUMULATE) != 0)330break;331if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {332if (t != null)333f = t;334t = lt;335break;336}337}338if (t == null)339break;340}341if (f != null)342f.fork();343}344else {345int state; // Transition to sum, cumulate, or both346for (int b;;) {347if (((b = t.getPendingCount()) & FINISHED) != 0)348break outer; // already done349state = ((b & CUMULATE) != 0 ? FINISHED :350(l > org) ? SUMMED : (SUMMED|FINISHED));351if (t.compareAndSetPendingCount(b, b|state))352break;353}354355long sum;356if (state != SUMMED) {357int first;358if (l == org) { // leftmost; no in359sum = a[org];360first = org + 1;361}362else {363sum = t.in;364first = l;365}366for (int i = first; i < h; ++i) // cumulate367a[i] = sum = fn.applyAsLong(sum, a[i]);368}369else if (h < fnc) { // skip rightmost370sum = a[l];371for (int i = l + 1; i < h; ++i) // sum only372sum = fn.applyAsLong(sum, a[i]);373}374else375sum = t.in;376t.out = sum;377for (LongCumulateTask par;;) { // propagate378if ((par = (LongCumulateTask)t.getCompleter()) == null) {379if ((state & FINISHED) != 0) // enable join380t.quietlyComplete();381break outer;382}383int b = par.getPendingCount();384if ((b & state & FINISHED) != 0)385t = par; // both done386else if ((b & state & SUMMED) != 0) { // both summed387int nextState; LongCumulateTask lt, rt;388if ((lt = par.left) != null &&389(rt = par.right) != null) {390long lout = lt.out;391par.out = (rt.hi == fnc ? lout :392fn.applyAsLong(lout, rt.out));393}394int refork = (((b & CUMULATE) == 0 &&395par.lo == org) ? CUMULATE : 0);396if ((nextState = b|state|refork) == b ||397par.compareAndSetPendingCount(b, nextState)) {398state = SUMMED; // drop finished399t = par;400if (refork != 0)401par.fork();402}403}404else if (par.compareAndSetPendingCount(b, b|state))405break outer; // sib not ready406}407}408}409}410@java.io.Serial411private static final long serialVersionUID = -5074099945909284273L;412}413414static final class DoubleCumulateTask extends CountedCompleter<Void> {415final double[] array;416@SuppressWarnings("serial") // Not statically typed as Serializable417final DoubleBinaryOperator function;418DoubleCumulateTask left, right;419double in, out;420final int lo, hi, origin, fence, threshold;421422/** Root task constructor */423public DoubleCumulateTask(DoubleCumulateTask parent,424DoubleBinaryOperator function,425double[] array, int lo, int hi) {426super(parent);427this.function = function; this.array = array;428this.lo = this.origin = lo; this.hi = this.fence = hi;429int p;430this.threshold =431(p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3))432<= MIN_PARTITION ? MIN_PARTITION : p;433}434435/** Subtask constructor */436DoubleCumulateTask(DoubleCumulateTask parent, DoubleBinaryOperator function,437double[] array, int origin, int fence, int threshold,438int lo, int hi) {439super(parent);440this.function = function; this.array = array;441this.origin = origin; this.fence = fence;442this.threshold = threshold;443this.lo = lo; this.hi = hi;444}445446public final void compute() {447final DoubleBinaryOperator fn;448final double[] a;449if ((fn = this.function) == null || (a = this.array) == null)450throw new NullPointerException(); // hoist checks451int th = threshold, org = origin, fnc = fence, l, h;452DoubleCumulateTask t = this;453outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {454if (h - l > th) {455DoubleCumulateTask lt = t.left, rt = t.right, f;456if (lt == null) { // first pass457int mid = (l + h) >>> 1;458f = rt = t.right =459new DoubleCumulateTask(t, fn, a, org, fnc, th, mid, h);460t = lt = t.left =461new DoubleCumulateTask(t, fn, a, org, fnc, th, l, mid);462}463else { // possibly refork464double pin = t.in;465lt.in = pin;466f = t = null;467if (rt != null) {468double lout = lt.out;469rt.in = (l == org ? lout :470fn.applyAsDouble(pin, lout));471for (int c;;) {472if (((c = rt.getPendingCount()) & CUMULATE) != 0)473break;474if (rt.compareAndSetPendingCount(c, c|CUMULATE)){475t = rt;476break;477}478}479}480for (int c;;) {481if (((c = lt.getPendingCount()) & CUMULATE) != 0)482break;483if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {484if (t != null)485f = t;486t = lt;487break;488}489}490if (t == null)491break;492}493if (f != null)494f.fork();495}496else {497int state; // Transition to sum, cumulate, or both498for (int b;;) {499if (((b = t.getPendingCount()) & FINISHED) != 0)500break outer; // already done501state = ((b & CUMULATE) != 0 ? FINISHED :502(l > org) ? SUMMED : (SUMMED|FINISHED));503if (t.compareAndSetPendingCount(b, b|state))504break;505}506507double sum;508if (state != SUMMED) {509int first;510if (l == org) { // leftmost; no in511sum = a[org];512first = org + 1;513}514else {515sum = t.in;516first = l;517}518for (int i = first; i < h; ++i) // cumulate519a[i] = sum = fn.applyAsDouble(sum, a[i]);520}521else if (h < fnc) { // skip rightmost522sum = a[l];523for (int i = l + 1; i < h; ++i) // sum only524sum = fn.applyAsDouble(sum, a[i]);525}526else527sum = t.in;528t.out = sum;529for (DoubleCumulateTask par;;) { // propagate530if ((par = (DoubleCumulateTask)t.getCompleter()) == null) {531if ((state & FINISHED) != 0) // enable join532t.quietlyComplete();533break outer;534}535int b = par.getPendingCount();536if ((b & state & FINISHED) != 0)537t = par; // both done538else if ((b & state & SUMMED) != 0) { // both summed539int nextState; DoubleCumulateTask lt, rt;540if ((lt = par.left) != null &&541(rt = par.right) != null) {542double lout = lt.out;543par.out = (rt.hi == fnc ? lout :544fn.applyAsDouble(lout, rt.out));545}546int refork = (((b & CUMULATE) == 0 &&547par.lo == org) ? CUMULATE : 0);548if ((nextState = b|state|refork) == b ||549par.compareAndSetPendingCount(b, nextState)) {550state = SUMMED; // drop finished551t = par;552if (refork != 0)553par.fork();554}555}556else if (par.compareAndSetPendingCount(b, b|state))557break outer; // sib not ready558}559}560}561}562@java.io.Serial563private static final long serialVersionUID = -586947823794232033L;564}565566static final class IntCumulateTask extends CountedCompleter<Void> {567final int[] array;568@SuppressWarnings("serial") // Not statically typed as Serializable569final IntBinaryOperator function;570IntCumulateTask left, right;571int in, out;572final int lo, hi, origin, fence, threshold;573574/** Root task constructor */575public IntCumulateTask(IntCumulateTask parent,576IntBinaryOperator function,577int[] array, int lo, int hi) {578super(parent);579this.function = function; this.array = array;580this.lo = this.origin = lo; this.hi = this.fence = hi;581int p;582this.threshold =583(p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3))584<= MIN_PARTITION ? MIN_PARTITION : p;585}586587/** Subtask constructor */588IntCumulateTask(IntCumulateTask parent, IntBinaryOperator function,589int[] array, int origin, int fence, int threshold,590int lo, int hi) {591super(parent);592this.function = function; this.array = array;593this.origin = origin; this.fence = fence;594this.threshold = threshold;595this.lo = lo; this.hi = hi;596}597598public final void compute() {599final IntBinaryOperator fn;600final int[] a;601if ((fn = this.function) == null || (a = this.array) == null)602throw new NullPointerException(); // hoist checks603int th = threshold, org = origin, fnc = fence, l, h;604IntCumulateTask t = this;605outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {606if (h - l > th) {607IntCumulateTask lt = t.left, rt = t.right, f;608if (lt == null) { // first pass609int mid = (l + h) >>> 1;610f = rt = t.right =611new IntCumulateTask(t, fn, a, org, fnc, th, mid, h);612t = lt = t.left =613new IntCumulateTask(t, fn, a, org, fnc, th, l, mid);614}615else { // possibly refork616int pin = t.in;617lt.in = pin;618f = t = null;619if (rt != null) {620int lout = lt.out;621rt.in = (l == org ? lout :622fn.applyAsInt(pin, lout));623for (int c;;) {624if (((c = rt.getPendingCount()) & CUMULATE) != 0)625break;626if (rt.compareAndSetPendingCount(c, c|CUMULATE)){627t = rt;628break;629}630}631}632for (int c;;) {633if (((c = lt.getPendingCount()) & CUMULATE) != 0)634break;635if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {636if (t != null)637f = t;638t = lt;639break;640}641}642if (t == null)643break;644}645if (f != null)646f.fork();647}648else {649int state; // Transition to sum, cumulate, or both650for (int b;;) {651if (((b = t.getPendingCount()) & FINISHED) != 0)652break outer; // already done653state = ((b & CUMULATE) != 0 ? FINISHED :654(l > org) ? SUMMED : (SUMMED|FINISHED));655if (t.compareAndSetPendingCount(b, b|state))656break;657}658659int sum;660if (state != SUMMED) {661int first;662if (l == org) { // leftmost; no in663sum = a[org];664first = org + 1;665}666else {667sum = t.in;668first = l;669}670for (int i = first; i < h; ++i) // cumulate671a[i] = sum = fn.applyAsInt(sum, a[i]);672}673else if (h < fnc) { // skip rightmost674sum = a[l];675for (int i = l + 1; i < h; ++i) // sum only676sum = fn.applyAsInt(sum, a[i]);677}678else679sum = t.in;680t.out = sum;681for (IntCumulateTask par;;) { // propagate682if ((par = (IntCumulateTask)t.getCompleter()) == null) {683if ((state & FINISHED) != 0) // enable join684t.quietlyComplete();685break outer;686}687int b = par.getPendingCount();688if ((b & state & FINISHED) != 0)689t = par; // both done690else if ((b & state & SUMMED) != 0) { // both summed691int nextState; IntCumulateTask lt, rt;692if ((lt = par.left) != null &&693(rt = par.right) != null) {694int lout = lt.out;695par.out = (rt.hi == fnc ? lout :696fn.applyAsInt(lout, rt.out));697}698int refork = (((b & CUMULATE) == 0 &&699par.lo == org) ? CUMULATE : 0);700if ((nextState = b|state|refork) == b ||701par.compareAndSetPendingCount(b, nextState)) {702state = SUMMED; // drop finished703t = par;704if (refork != 0)705par.fork();706}707}708else if (par.compareAndSetPendingCount(b, b|state))709break outer; // sib not ready710}711}712}713}714@java.io.Serial715private static final long serialVersionUID = 3731755594596840961L;716}717}718719720