/*
* Copyright (c) 2012, 2013, Oracle and/or its affiliates. All rights reserved.
* ORACLE PROPRIETARY/CONFIDENTIAL. Use is subject to license terms.
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*
*/
package java.util;
/*
* Written by Doug Lea with assistance from members of JCP JSR-166
* Expert Group and released to the public domain, as explained at
* http://creativecommons.org/publicdomain/zero/1.0/
*/
import java.util.concurrent.
ForkJoinPool;
import java.util.concurrent.
CountedCompleter;
import java.util.function.
BinaryOperator;
import java.util.function.
IntBinaryOperator;
import java.util.function.
LongBinaryOperator;
import java.util.function.
DoubleBinaryOperator;
/**
* ForkJoin tasks to perform Arrays.parallelPrefix operations.
*
* @author Doug Lea
* @since 1.8
*/
class
ArrayPrefixHelpers {
private
ArrayPrefixHelpers() {}; // non-instantiable
/*
* Parallel prefix (aka cumulate, scan) task classes
* are based loosely on Guy Blelloch's original
* algorithm (http://www.cs.cmu.edu/~scandal/alg/scan.html):
* Keep dividing by two to threshold segment size, and then:
* Pass 1: Create tree of partial sums for each segment
* Pass 2: For each segment, cumulate with offset of left sibling
*
* This version improves performance within FJ framework mainly by
* allowing the second pass of ready left-hand sides to proceed
* even if some right-hand side first passes are still executing.
* It also combines first and second pass for leftmost segment,
* and skips the first pass for rightmost segment (whose result is
* not needed for second pass). It similarly manages to avoid
* requiring that users supply an identity basis for accumulations
* by tracking those segments/subtasks for which the first
* existing element is used as base.
*
* Managing this relies on ORing some bits in the pendingCount for
* phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the
* main phase bit. When false, segments compute only their sum.
* When true, they cumulate array elements. CUMULATE is set at
* root at beginning of second pass and then propagated down. But
* it may also be set earlier for subtrees with lo==0 (the left
* spine of tree). SUMMED is a one bit join count. For leafs, it
* is set when summed. For internal nodes, it becomes true when
* one child is summed. When the second child finishes summing,
* we then moves up tree to trigger the cumulate phase. FINISHED
* is also a one bit join count. For leafs, it is set when
* cumulated. For internal nodes, it becomes true when one child
* is cumulated. When the second child finishes cumulating, it
* then moves up tree, completing at the root.
*
* To better exploit locality and reduce overhead, the compute
* method loops starting with the current task, moving if possible
* to one of its subtasks rather than forking.
*
* As usual for this sort of utility, there are 4 versions, that
* are simple copy/paste/adapt variants of each other. (The
* double and int versions differ from long version soley by
* replacing "long" (with case-matching)).
*/
// see above
static final int
CUMULATE = 1;
static final int
SUMMED = 2;
static final int
FINISHED = 4;
/** The smallest subtask array partition size to use as threshold */
static final int
MIN_PARTITION = 16;
static final class
CumulateTask<T> extends
CountedCompleter<
Void> {
final T[]
array;
final
BinaryOperator<T>
function;
CumulateTask<T>
left,
right;
T
in,
out;
final int
lo,
hi,
origin,
fence,
threshold;
/** Root task constructor */
public
CumulateTask(
CumulateTask<T>
parent,
BinaryOperator<T>
function,
T[]
array, int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
lo = this.
origin =
lo; this.
hi = this.
fence =
hi;
int
p;
this.
threshold =
(
p = (
hi -
lo) / (
ForkJoinPool.
getCommonPoolParallelism() << 3))
<=
MIN_PARTITION ?
MIN_PARTITION :
p;
}
/** Subtask constructor */
CumulateTask(
CumulateTask<T>
parent,
BinaryOperator<T>
function,
T[]
array, int
origin, int
fence, int
threshold,
int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
origin =
origin; this.
fence =
fence;
this.
threshold =
threshold;
this.
lo =
lo; this.
hi =
hi;
}
@
SuppressWarnings("unchecked")
public final void
compute() {
final
BinaryOperator<T>
fn;
final T[]
a;
if ((
fn = this.
function) == null || (
a = this.
array) == null)
throw new
NullPointerException(); // hoist checks
int
th =
threshold,
org =
origin,
fnc =
fence,
l,
h;
CumulateTask<T>
t = this;
outer: while ((
l =
t.
lo) >= 0 && (
h =
t.
hi) <=
a.length) {
if (
h -
l >
th) {
CumulateTask<T>
lt =
t.
left,
rt =
t.
right,
f;
if (
lt == null) { // first pass
int
mid = (
l +
h) >>> 1;
f =
rt =
t.
right =
new
CumulateTask<T>(
t,
fn,
a,
org,
fnc,
th,
mid,
h);
t =
lt =
t.
left =
new
CumulateTask<T>(
t,
fn,
a,
org,
fnc,
th,
l,
mid);
}
else { // possibly refork
T
pin =
t.
in;
lt.
in =
pin;
f =
t = null;
if (
rt != null) {
T
lout =
lt.
out;
rt.
in = (
l ==
org ?
lout :
fn.
apply(
pin,
lout));
for (int
c;;) {
if (((
c =
rt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
rt.
compareAndSetPendingCount(
c,
c|
CUMULATE)){
t =
rt;
break;
}
}
}
for (int
c;;) {
if (((
c =
lt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
lt.
compareAndSetPendingCount(
c,
c|
CUMULATE)) {
if (
t != null)
f =
t;
t =
lt;
break;
}
}
if (
t == null)
break;
}
if (
f != null)
f.
fork();
}
else {
int
state; // Transition to sum, cumulate, or both
for (int
b;;) {
if (((
b =
t.
getPendingCount()) &
FINISHED) != 0)
break
outer; // already done
state = ((
b &
CUMULATE) != 0?
FINISHED :
(
l >
org) ?
SUMMED : (
SUMMED|
FINISHED));
if (
t.
compareAndSetPendingCount(
b,
b|
state))
break;
}
T
sum;
if (
state !=
SUMMED) {
int
first;
if (
l ==
org) { // leftmost; no in
sum =
a[
org];
first =
org + 1;
}
else {
sum =
t.
in;
first =
l;
}
for (int
i =
first;
i <
h; ++
i) // cumulate
a[
i] =
sum =
fn.
apply(
sum,
a[
i]);
}
else if (
h <
fnc) { // skip rightmost
sum =
a[
l];
for (int
i =
l + 1;
i <
h; ++
i) // sum only
sum =
fn.
apply(
sum,
a[
i]);
}
else
sum =
t.
in;
t.
out =
sum;
for (
CumulateTask<T>
par;;) { // propagate
if ((
par = (
CumulateTask<T>)
t.
getCompleter()) == null) {
if ((
state &
FINISHED) != 0) // enable join
t.
quietlyComplete();
break
outer;
}
int
b =
par.
getPendingCount();
if ((
b &
state &
FINISHED) != 0)
t =
par; // both done
else if ((
b &
state &
SUMMED) != 0) { // both summed
int
nextState;
CumulateTask<T>
lt,
rt;
if ((
lt =
par.
left) != null &&
(
rt =
par.
right) != null) {
T
lout =
lt.
out;
par.
out = (
rt.
hi ==
fnc ?
lout :
fn.
apply(
lout,
rt.
out));
}
int
refork = (((
b &
CUMULATE) == 0 &&
par.
lo ==
org) ?
CUMULATE : 0);
if ((
nextState =
b|
state|
refork) ==
b ||
par.
compareAndSetPendingCount(
b,
nextState)) {
state =
SUMMED; // drop finished
t =
par;
if (
refork != 0)
par.
fork();
}
}
else if (
par.
compareAndSetPendingCount(
b,
b|
state))
break
outer; // sib not ready
}
}
}
}
}
static final class
LongCumulateTask extends
CountedCompleter<
Void> {
final long[]
array;
final
LongBinaryOperator function;
LongCumulateTask left,
right;
long
in,
out;
final int
lo,
hi,
origin,
fence,
threshold;
/** Root task constructor */
public
LongCumulateTask(
LongCumulateTask parent,
LongBinaryOperator function,
long[]
array, int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
lo = this.
origin =
lo; this.
hi = this.
fence =
hi;
int
p;
this.
threshold =
(
p = (
hi -
lo) / (
ForkJoinPool.
getCommonPoolParallelism() << 3))
<=
MIN_PARTITION ?
MIN_PARTITION :
p;
}
/** Subtask constructor */
LongCumulateTask(
LongCumulateTask parent,
LongBinaryOperator function,
long[]
array, int
origin, int
fence, int
threshold,
int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
origin =
origin; this.
fence =
fence;
this.
threshold =
threshold;
this.
lo =
lo; this.
hi =
hi;
}
public final void
compute() {
final
LongBinaryOperator fn;
final long[]
a;
if ((
fn = this.
function) == null || (
a = this.
array) == null)
throw new
NullPointerException(); // hoist checks
int
th =
threshold,
org =
origin,
fnc =
fence,
l,
h;
LongCumulateTask t = this;
outer: while ((
l =
t.
lo) >= 0 && (
h =
t.
hi) <=
a.length) {
if (
h -
l >
th) {
LongCumulateTask lt =
t.
left,
rt =
t.
right,
f;
if (
lt == null) { // first pass
int
mid = (
l +
h) >>> 1;
f =
rt =
t.
right =
new
LongCumulateTask(
t,
fn,
a,
org,
fnc,
th,
mid,
h);
t =
lt =
t.
left =
new
LongCumulateTask(
t,
fn,
a,
org,
fnc,
th,
l,
mid);
}
else { // possibly refork
long
pin =
t.
in;
lt.
in =
pin;
f =
t = null;
if (
rt != null) {
long
lout =
lt.
out;
rt.
in = (
l ==
org ?
lout :
fn.
applyAsLong(
pin,
lout));
for (int
c;;) {
if (((
c =
rt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
rt.
compareAndSetPendingCount(
c,
c|
CUMULATE)){
t =
rt;
break;
}
}
}
for (int
c;;) {
if (((
c =
lt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
lt.
compareAndSetPendingCount(
c,
c|
CUMULATE)) {
if (
t != null)
f =
t;
t =
lt;
break;
}
}
if (
t == null)
break;
}
if (
f != null)
f.
fork();
}
else {
int
state; // Transition to sum, cumulate, or both
for (int
b;;) {
if (((
b =
t.
getPendingCount()) &
FINISHED) != 0)
break
outer; // already done
state = ((
b &
CUMULATE) != 0?
FINISHED :
(
l >
org) ?
SUMMED : (
SUMMED|
FINISHED));
if (
t.
compareAndSetPendingCount(
b,
b|
state))
break;
}
long
sum;
if (
state !=
SUMMED) {
int
first;
if (
l ==
org) { // leftmost; no in
sum =
a[
org];
first =
org + 1;
}
else {
sum =
t.
in;
first =
l;
}
for (int
i =
first;
i <
h; ++
i) // cumulate
a[
i] =
sum =
fn.
applyAsLong(
sum,
a[
i]);
}
else if (
h <
fnc) { // skip rightmost
sum =
a[
l];
for (int
i =
l + 1;
i <
h; ++
i) // sum only
sum =
fn.
applyAsLong(
sum,
a[
i]);
}
else
sum =
t.
in;
t.
out =
sum;
for (
LongCumulateTask par;;) { // propagate
if ((
par = (
LongCumulateTask)
t.
getCompleter()) == null) {
if ((
state &
FINISHED) != 0) // enable join
t.
quietlyComplete();
break
outer;
}
int
b =
par.
getPendingCount();
if ((
b &
state &
FINISHED) != 0)
t =
par; // both done
else if ((
b &
state &
SUMMED) != 0) { // both summed
int
nextState;
LongCumulateTask lt,
rt;
if ((
lt =
par.
left) != null &&
(
rt =
par.
right) != null) {
long
lout =
lt.
out;
par.
out = (
rt.
hi ==
fnc ?
lout :
fn.
applyAsLong(
lout,
rt.
out));
}
int
refork = (((
b &
CUMULATE) == 0 &&
par.
lo ==
org) ?
CUMULATE : 0);
if ((
nextState =
b|
state|
refork) ==
b ||
par.
compareAndSetPendingCount(
b,
nextState)) {
state =
SUMMED; // drop finished
t =
par;
if (
refork != 0)
par.
fork();
}
}
else if (
par.
compareAndSetPendingCount(
b,
b|
state))
break
outer; // sib not ready
}
}
}
}
}
static final class
DoubleCumulateTask extends
CountedCompleter<
Void> {
final double[]
array;
final
DoubleBinaryOperator function;
DoubleCumulateTask left,
right;
double
in,
out;
final int
lo,
hi,
origin,
fence,
threshold;
/** Root task constructor */
public
DoubleCumulateTask(
DoubleCumulateTask parent,
DoubleBinaryOperator function,
double[]
array, int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
lo = this.
origin =
lo; this.
hi = this.
fence =
hi;
int
p;
this.
threshold =
(
p = (
hi -
lo) / (
ForkJoinPool.
getCommonPoolParallelism() << 3))
<=
MIN_PARTITION ?
MIN_PARTITION :
p;
}
/** Subtask constructor */
DoubleCumulateTask(
DoubleCumulateTask parent,
DoubleBinaryOperator function,
double[]
array, int
origin, int
fence, int
threshold,
int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
origin =
origin; this.
fence =
fence;
this.
threshold =
threshold;
this.
lo =
lo; this.
hi =
hi;
}
public final void
compute() {
final
DoubleBinaryOperator fn;
final double[]
a;
if ((
fn = this.
function) == null || (
a = this.
array) == null)
throw new
NullPointerException(); // hoist checks
int
th =
threshold,
org =
origin,
fnc =
fence,
l,
h;
DoubleCumulateTask t = this;
outer: while ((
l =
t.
lo) >= 0 && (
h =
t.
hi) <=
a.length) {
if (
h -
l >
th) {
DoubleCumulateTask lt =
t.
left,
rt =
t.
right,
f;
if (
lt == null) { // first pass
int
mid = (
l +
h) >>> 1;
f =
rt =
t.
right =
new
DoubleCumulateTask(
t,
fn,
a,
org,
fnc,
th,
mid,
h);
t =
lt =
t.
left =
new
DoubleCumulateTask(
t,
fn,
a,
org,
fnc,
th,
l,
mid);
}
else { // possibly refork
double
pin =
t.
in;
lt.
in =
pin;
f =
t = null;
if (
rt != null) {
double
lout =
lt.
out;
rt.
in = (
l ==
org ?
lout :
fn.
applyAsDouble(
pin,
lout));
for (int
c;;) {
if (((
c =
rt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
rt.
compareAndSetPendingCount(
c,
c|
CUMULATE)){
t =
rt;
break;
}
}
}
for (int
c;;) {
if (((
c =
lt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
lt.
compareAndSetPendingCount(
c,
c|
CUMULATE)) {
if (
t != null)
f =
t;
t =
lt;
break;
}
}
if (
t == null)
break;
}
if (
f != null)
f.
fork();
}
else {
int
state; // Transition to sum, cumulate, or both
for (int
b;;) {
if (((
b =
t.
getPendingCount()) &
FINISHED) != 0)
break
outer; // already done
state = ((
b &
CUMULATE) != 0?
FINISHED :
(
l >
org) ?
SUMMED : (
SUMMED|
FINISHED));
if (
t.
compareAndSetPendingCount(
b,
b|
state))
break;
}
double
sum;
if (
state !=
SUMMED) {
int
first;
if (
l ==
org) { // leftmost; no in
sum =
a[
org];
first =
org + 1;
}
else {
sum =
t.
in;
first =
l;
}
for (int
i =
first;
i <
h; ++
i) // cumulate
a[
i] =
sum =
fn.
applyAsDouble(
sum,
a[
i]);
}
else if (
h <
fnc) { // skip rightmost
sum =
a[
l];
for (int
i =
l + 1;
i <
h; ++
i) // sum only
sum =
fn.
applyAsDouble(
sum,
a[
i]);
}
else
sum =
t.
in;
t.
out =
sum;
for (
DoubleCumulateTask par;;) { // propagate
if ((
par = (
DoubleCumulateTask)
t.
getCompleter()) == null) {
if ((
state &
FINISHED) != 0) // enable join
t.
quietlyComplete();
break
outer;
}
int
b =
par.
getPendingCount();
if ((
b &
state &
FINISHED) != 0)
t =
par; // both done
else if ((
b &
state &
SUMMED) != 0) { // both summed
int
nextState;
DoubleCumulateTask lt,
rt;
if ((
lt =
par.
left) != null &&
(
rt =
par.
right) != null) {
double
lout =
lt.
out;
par.
out = (
rt.
hi ==
fnc ?
lout :
fn.
applyAsDouble(
lout,
rt.
out));
}
int
refork = (((
b &
CUMULATE) == 0 &&
par.
lo ==
org) ?
CUMULATE : 0);
if ((
nextState =
b|
state|
refork) ==
b ||
par.
compareAndSetPendingCount(
b,
nextState)) {
state =
SUMMED; // drop finished
t =
par;
if (
refork != 0)
par.
fork();
}
}
else if (
par.
compareAndSetPendingCount(
b,
b|
state))
break
outer; // sib not ready
}
}
}
}
}
static final class
IntCumulateTask extends
CountedCompleter<
Void> {
final int[]
array;
final
IntBinaryOperator function;
IntCumulateTask left,
right;
int
in,
out;
final int
lo,
hi,
origin,
fence,
threshold;
/** Root task constructor */
public
IntCumulateTask(
IntCumulateTask parent,
IntBinaryOperator function,
int[]
array, int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
lo = this.
origin =
lo; this.
hi = this.
fence =
hi;
int
p;
this.
threshold =
(
p = (
hi -
lo) / (
ForkJoinPool.
getCommonPoolParallelism() << 3))
<=
MIN_PARTITION ?
MIN_PARTITION :
p;
}
/** Subtask constructor */
IntCumulateTask(
IntCumulateTask parent,
IntBinaryOperator function,
int[]
array, int
origin, int
fence, int
threshold,
int
lo, int
hi) {
super(
parent);
this.
function =
function; this.
array =
array;
this.
origin =
origin; this.
fence =
fence;
this.
threshold =
threshold;
this.
lo =
lo; this.
hi =
hi;
}
public final void
compute() {
final
IntBinaryOperator fn;
final int[]
a;
if ((
fn = this.
function) == null || (
a = this.
array) == null)
throw new
NullPointerException(); // hoist checks
int
th =
threshold,
org =
origin,
fnc =
fence,
l,
h;
IntCumulateTask t = this;
outer: while ((
l =
t.
lo) >= 0 && (
h =
t.
hi) <=
a.length) {
if (
h -
l >
th) {
IntCumulateTask lt =
t.
left,
rt =
t.
right,
f;
if (
lt == null) { // first pass
int
mid = (
l +
h) >>> 1;
f =
rt =
t.
right =
new
IntCumulateTask(
t,
fn,
a,
org,
fnc,
th,
mid,
h);
t =
lt =
t.
left =
new
IntCumulateTask(
t,
fn,
a,
org,
fnc,
th,
l,
mid);
}
else { // possibly refork
int
pin =
t.
in;
lt.
in =
pin;
f =
t = null;
if (
rt != null) {
int
lout =
lt.
out;
rt.
in = (
l ==
org ?
lout :
fn.
applyAsInt(
pin,
lout));
for (int
c;;) {
if (((
c =
rt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
rt.
compareAndSetPendingCount(
c,
c|
CUMULATE)){
t =
rt;
break;
}
}
}
for (int
c;;) {
if (((
c =
lt.
getPendingCount()) &
CUMULATE) != 0)
break;
if (
lt.
compareAndSetPendingCount(
c,
c|
CUMULATE)) {
if (
t != null)
f =
t;
t =
lt;
break;
}
}
if (
t == null)
break;
}
if (
f != null)
f.
fork();
}
else {
int
state; // Transition to sum, cumulate, or both
for (int
b;;) {
if (((
b =
t.
getPendingCount()) &
FINISHED) != 0)
break
outer; // already done
state = ((
b &
CUMULATE) != 0?
FINISHED :
(
l >
org) ?
SUMMED : (
SUMMED|
FINISHED));
if (
t.
compareAndSetPendingCount(
b,
b|
state))
break;
}
int
sum;
if (
state !=
SUMMED) {
int
first;
if (
l ==
org) { // leftmost; no in
sum =
a[
org];
first =
org + 1;
}
else {
sum =
t.
in;
first =
l;
}
for (int
i =
first;
i <
h; ++
i) // cumulate
a[
i] =
sum =
fn.
applyAsInt(
sum,
a[
i]);
}
else if (
h <
fnc) { // skip rightmost
sum =
a[
l];
for (int
i =
l + 1;
i <
h; ++
i) // sum only
sum =
fn.
applyAsInt(
sum,
a[
i]);
}
else
sum =
t.
in;
t.
out =
sum;
for (
IntCumulateTask par;;) { // propagate
if ((
par = (
IntCumulateTask)
t.
getCompleter()) == null) {
if ((
state &
FINISHED) != 0) // enable join
t.
quietlyComplete();
break
outer;
}
int
b =
par.
getPendingCount();
if ((
b &
state &
FINISHED) != 0)
t =
par; // both done
else if ((
b &
state &
SUMMED) != 0) { // both summed
int
nextState;
IntCumulateTask lt,
rt;
if ((
lt =
par.
left) != null &&
(
rt =
par.
right) != null) {
int
lout =
lt.
out;
par.
out = (
rt.
hi ==
fnc ?
lout :
fn.
applyAsInt(
lout,
rt.
out));
}
int
refork = (((
b &
CUMULATE) == 0 &&
par.
lo ==
org) ?
CUMULATE : 0);
if ((
nextState =
b|
state|
refork) ==
b ||
par.
compareAndSetPendingCount(
b,
nextState)) {
state =
SUMMED; // drop finished
t =
par;
if (
refork != 0)
par.
fork();
}
}
else if (
par.
compareAndSetPendingCount(
b,
b|
state))
break
outer; // sib not ready
}
}
}
}
}
}