Skip to content

Commit

Permalink
8347274: Gatherers.mapConcurrent exhibits undesired behavior under va…
Browse files Browse the repository at this point in the history
…riable delays, interruption, and finishing

Reviewed-by: alanb
  • Loading branch information
Viktor Klang committed Jan 13, 2025
1 parent 82e2a79 commit 450636a
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 55 deletions.
124 changes: 71 additions & 53 deletions src/java.base/share/classes/java/util/stream/Gatherers.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -30,6 +30,7 @@
import java.util.ArrayDeque;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
Expand Down Expand Up @@ -350,86 +351,103 @@ public static <T, R> Gatherer<T,?,R> mapConcurrent(
final int maxConcurrency,
final Function<? super T, ? extends R> mapper) {
if (maxConcurrency < 1)
throw new IllegalArgumentException(
"'maxConcurrency' must be greater than 0");
throw new IllegalArgumentException("'maxConcurrency' must be greater than 0");

Objects.requireNonNull(mapper, "'mapper' must not be null");

class State {
// ArrayDeque default initial size is 16
final ArrayDeque<Future<R>> window =
new ArrayDeque<>(Math.min(maxConcurrency, 16));
final Semaphore windowLock = new Semaphore(maxConcurrency);

final boolean integrate(T element,
Downstream<? super R> downstream) {
if (!downstream.isRejecting())
createTaskFor(element);
return flush(0, downstream);
final class MapConcurrentTask extends FutureTask<R> {
final Thread thread;
private MapConcurrentTask(Callable<R> callable) {
super(callable);
this.thread = Thread.ofVirtual().unstarted(this);
}
}

final void createTaskFor(T element) {
windowLock.acquireUninterruptibly();
final class State {
private final ArrayDeque<MapConcurrentTask> wip =
new ArrayDeque<>(Math.min(maxConcurrency, 16));

var task = new FutureTask<R>(() -> {
try {
return mapper.apply(element);
} finally {
windowLock.release();
}
});
boolean integrate(T element, Downstream<? super R> downstream) {
// Prepare the next task and add it to the work-in-progress
final var task = new MapConcurrentTask(() -> mapper.apply(element));
wip.addLast(task);

assert wip.peekLast() == task;
assert wip.size() <= maxConcurrency;

var wasAddedToWindow = window.add(task);
assert wasAddedToWindow;
// Start the next task
task.thread.start();

Thread.startVirtualThread(task);
// Flush at least 1 element if we're at capacity
return flush(wip.size() < maxConcurrency ? 0 : 1, downstream);
}

final boolean flush(long atLeastN,
Downstream<? super R> downstream) {
boolean proceed = !downstream.isRejecting();
boolean interrupted = false;
boolean flush(long atLeastN, Downstream<? super R> downstream) {
boolean success = false, interrupted = false;
try {
Future<R> current;
while (proceed
&& (current = window.peek()) != null
&& (current.isDone() || atLeastN > 0)) {
proceed &= downstream.push(current.get());
boolean proceed = !downstream.isRejecting();
MapConcurrentTask current;
while (
proceed
&& (current = wip.peekFirst()) != null
&& (current.isDone() || atLeastN > 0)
) {
R result;

// Ensure that the task is done before proceeding
for (;;) {
try {
result = current.get();
break;
} catch (InterruptedException ie) {
interrupted = true; // ignore for now, and restore later
}
}

proceed &= downstream.push(result);
atLeastN -= 1;

var correctRemoval = window.pop() == current;
final var correctRemoval = wip.pollFirst() == current;
assert correctRemoval;
}
} catch(InterruptedException ie) {
proceed = false;
interrupted = true;
return (success = proceed); // Ensure that cleanup occurs if needed
} catch (ExecutionException e) {
proceed = false; // Ensure cleanup
final var cause = e.getCause();
throw (cause instanceof RuntimeException re)
? re
: new RuntimeException(cause == null ? e : cause);
} finally {
// Clean up
if (!proceed) {
Future<R> next;
while ((next = window.pollFirst()) != null) {
next.cancel(true);
// Clean up work-in-progress
if (!success && !wip.isEmpty()) {
// First signal cancellation for all tasks in progress
for (var task : wip)
task.cancel(true);

// Then wait for all in progress task Threads to exit
MapConcurrentTask next;
while ((next = wip.pollFirst()) != null) {
while (next.thread.isAlive()) {
try {
next.thread.join();
} catch (InterruptedException ie) {
interrupted = true; // ignore, for now, and restore later
}
}
}
}
}

if (interrupted)
Thread.currentThread().interrupt();

return proceed;
// integrate(..) could be called from different threads each time
// so we need to restore the interrupt on the calling thread
if (interrupted)
Thread.currentThread().interrupt();
}
}
}

return Gatherer.ofSequential(
State::new,
Integrator.<State, T, R>ofGreedy(State::integrate),
(state, downstream) -> state.flush(Long.MAX_VALUE, downstream)
State::new,
Integrator.<State, T, R>ofGreedy(State::integrate),
(state, downstream) -> state.flush(Long.MAX_VALUE, downstream)
);
}

Expand Down
63 changes: 61 additions & 2 deletions test/jdk/java/util/stream/GatherersMapConcurrentTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -24,6 +24,9 @@
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import java.util.function.Function;
import java.util.stream.Gatherer;
import java.util.stream.Gatherers;
import java.util.stream.Stream;
Expand Down Expand Up @@ -298,7 +301,7 @@ public void behavesAsExpected(ConcurrencyConfig cc) {

@ParameterizedTest
@MethodSource("concurrencyConfigurations")
public void behavesAsExpectedWhenShortCircuited(ConcurrencyConfig cc) {
public void shortCircuits(ConcurrencyConfig cc) {
final var limitTo = Math.max(cc.config().streamSize() / 2, 1);

final var expectedResult = cc.config().stream()
Expand All @@ -313,4 +316,60 @@ public void behavesAsExpectedWhenShortCircuited(ConcurrencyConfig cc) {

assertEquals(expectedResult, result);
}

@ParameterizedTest
@MethodSource("concurrencyConfigurations")
public void ignoresAndRestoresCallingThreadInterruption(ConcurrencyConfig cc) {
final var limitTo = Math.max(cc.config().streamSize() / 2, 1);

final var expectedResult = cc.config().stream()
.map(x -> x * x)
.limit(limitTo)
.toList();

// Ensure calling thread is interrupted
Thread.currentThread().interrupt();

final var result = cc.config().stream()
.gather(Gatherers.mapConcurrent(cc.concurrencyLevel(), x -> {
LockSupport.parkNanos(10000); // 10 us
return x * x;
}))
.limit(limitTo)
.toList();

// Ensure calling thread remains interrupted
assertEquals(true, Thread.interrupted());

assertEquals(expectedResult, result);
}

@ParameterizedTest
@MethodSource("concurrencyConfigurations")
public void limitsWorkInProgressToMaxConcurrency(ConcurrencyConfig cc) {
final var elementNum = new AtomicLong(0);
final var wipCount = new AtomicLong(0);
final var limitTo = Math.max(cc.config().streamSize() / 2, 1);

final var expectedResult = cc.config().stream()
.map(x -> x * x)
.limit(limitTo)
.toList();

Function<Integer, Integer> fun = x -> {
if (wipCount.incrementAndGet() > cc.concurrencyLevel)
throw new IllegalStateException("Too much wip!");
if (elementNum.getAndIncrement() == 0)
LockSupport.parkNanos(500_000_000); // 500 ms
return x * x;
};

final var result = cc.config().stream()
.gather(Gatherers.mapConcurrent(cc.concurrencyLevel(), fun))
.gather(Gatherer.of((v, e, d) -> wipCount.decrementAndGet() >= 0 && d.push(e)))
.limit(limitTo)
.toList();

assertEquals(expectedResult, result);
}
}

0 comments on commit 450636a

Please sign in to comment.