From 450636ae28b84ded083b6861c6cba85fbf87e16e Mon Sep 17 00:00:00 2001 From: Viktor Klang Date: Mon, 13 Jan 2025 10:38:02 +0000 Subject: [PATCH] 8347274: Gatherers.mapConcurrent exhibits undesired behavior under variable delays, interruption, and finishing Reviewed-by: alanb --- .../classes/java/util/stream/Gatherers.java | 124 ++++++++++-------- .../stream/GatherersMapConcurrentTest.java | 63 ++++++++- 2 files changed, 132 insertions(+), 55 deletions(-) diff --git a/src/java.base/share/classes/java/util/stream/Gatherers.java b/src/java.base/share/classes/java/util/stream/Gatherers.java index b394f6fc7d86e..0a98d7d5033e6 100644 --- a/src/java.base/share/classes/java/util/stream/Gatherers.java +++ b/src/java.base/share/classes/java/util/stream/Gatherers.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,6 +30,7 @@ import java.util.ArrayDeque; import java.util.List; import java.util.Objects; +import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; @@ -350,86 +351,103 @@ public static Gatherer mapConcurrent( final int maxConcurrency, final Function mapper) { if (maxConcurrency < 1) - throw new IllegalArgumentException( - "'maxConcurrency' must be greater than 0"); + throw new IllegalArgumentException("'maxConcurrency' must be greater than 0"); Objects.requireNonNull(mapper, "'mapper' must not be null"); - class State { - // ArrayDeque default initial size is 16 - final ArrayDeque> window = - new ArrayDeque<>(Math.min(maxConcurrency, 16)); - final Semaphore windowLock = new Semaphore(maxConcurrency); - - final boolean integrate(T element, - Downstream downstream) { - if (!downstream.isRejecting()) - createTaskFor(element); - return flush(0, downstream); + final class MapConcurrentTask extends FutureTask { + final Thread thread; + private MapConcurrentTask(Callable callable) { + super(callable); + this.thread = Thread.ofVirtual().unstarted(this); } + } - final void createTaskFor(T element) { - windowLock.acquireUninterruptibly(); + final class State { + private final ArrayDeque wip = + new ArrayDeque<>(Math.min(maxConcurrency, 16)); - var task = new FutureTask(() -> { - try { - return mapper.apply(element); - } finally { - windowLock.release(); - } - }); + boolean integrate(T element, Downstream downstream) { + // Prepare the next task and add it to the work-in-progress + final var task = new MapConcurrentTask(() -> mapper.apply(element)); + wip.addLast(task); + + assert wip.peekLast() == task; + assert wip.size() <= maxConcurrency; - var wasAddedToWindow = window.add(task); - assert wasAddedToWindow; + // Start the next task + task.thread.start(); - Thread.startVirtualThread(task); + // Flush at least 1 element if we're at capacity + return flush(wip.size() < maxConcurrency ? 0 : 1, downstream); } - final boolean flush(long atLeastN, - Downstream downstream) { - boolean proceed = !downstream.isRejecting(); - boolean interrupted = false; + boolean flush(long atLeastN, Downstream downstream) { + boolean success = false, interrupted = false; try { - Future current; - while (proceed - && (current = window.peek()) != null - && (current.isDone() || atLeastN > 0)) { - proceed &= downstream.push(current.get()); + boolean proceed = !downstream.isRejecting(); + MapConcurrentTask current; + while ( + proceed + && (current = wip.peekFirst()) != null + && (current.isDone() || atLeastN > 0) + ) { + R result; + + // Ensure that the task is done before proceeding + for (;;) { + try { + result = current.get(); + break; + } catch (InterruptedException ie) { + interrupted = true; // ignore for now, and restore later + } + } + + proceed &= downstream.push(result); atLeastN -= 1; - var correctRemoval = window.pop() == current; + final var correctRemoval = wip.pollFirst() == current; assert correctRemoval; } - } catch(InterruptedException ie) { - proceed = false; - interrupted = true; + return (success = proceed); // Ensure that cleanup occurs if needed } catch (ExecutionException e) { - proceed = false; // Ensure cleanup final var cause = e.getCause(); throw (cause instanceof RuntimeException re) ? re : new RuntimeException(cause == null ? e : cause); } finally { - // Clean up - if (!proceed) { - Future next; - while ((next = window.pollFirst()) != null) { - next.cancel(true); + // Clean up work-in-progress + if (!success && !wip.isEmpty()) { + // First signal cancellation for all tasks in progress + for (var task : wip) + task.cancel(true); + + // Then wait for all in progress task Threads to exit + MapConcurrentTask next; + while ((next = wip.pollFirst()) != null) { + while (next.thread.isAlive()) { + try { + next.thread.join(); + } catch (InterruptedException ie) { + interrupted = true; // ignore, for now, and restore later + } + } } } - } - - if (interrupted) - Thread.currentThread().interrupt(); - return proceed; + // integrate(..) could be called from different threads each time + // so we need to restore the interrupt on the calling thread + if (interrupted) + Thread.currentThread().interrupt(); + } } } return Gatherer.ofSequential( - State::new, - Integrator.ofGreedy(State::integrate), - (state, downstream) -> state.flush(Long.MAX_VALUE, downstream) + State::new, + Integrator.ofGreedy(State::integrate), + (state, downstream) -> state.flush(Long.MAX_VALUE, downstream) ); } diff --git a/test/jdk/java/util/stream/GatherersMapConcurrentTest.java b/test/jdk/java/util/stream/GatherersMapConcurrentTest.java index 557598de3eece..970a550400aa6 100644 --- a/test/jdk/java/util/stream/GatherersMapConcurrentTest.java +++ b/test/jdk/java/util/stream/GatherersMapConcurrentTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -24,6 +24,9 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.LockSupport; +import java.util.function.Function; import java.util.stream.Gatherer; import java.util.stream.Gatherers; import java.util.stream.Stream; @@ -298,7 +301,7 @@ public void behavesAsExpected(ConcurrencyConfig cc) { @ParameterizedTest @MethodSource("concurrencyConfigurations") - public void behavesAsExpectedWhenShortCircuited(ConcurrencyConfig cc) { + public void shortCircuits(ConcurrencyConfig cc) { final var limitTo = Math.max(cc.config().streamSize() / 2, 1); final var expectedResult = cc.config().stream() @@ -313,4 +316,60 @@ public void behavesAsExpectedWhenShortCircuited(ConcurrencyConfig cc) { assertEquals(expectedResult, result); } + + @ParameterizedTest + @MethodSource("concurrencyConfigurations") + public void ignoresAndRestoresCallingThreadInterruption(ConcurrencyConfig cc) { + final var limitTo = Math.max(cc.config().streamSize() / 2, 1); + + final var expectedResult = cc.config().stream() + .map(x -> x * x) + .limit(limitTo) + .toList(); + + // Ensure calling thread is interrupted + Thread.currentThread().interrupt(); + + final var result = cc.config().stream() + .gather(Gatherers.mapConcurrent(cc.concurrencyLevel(), x -> { + LockSupport.parkNanos(10000); // 10 us + return x * x; + })) + .limit(limitTo) + .toList(); + + // Ensure calling thread remains interrupted + assertEquals(true, Thread.interrupted()); + + assertEquals(expectedResult, result); + } + + @ParameterizedTest + @MethodSource("concurrencyConfigurations") + public void limitsWorkInProgressToMaxConcurrency(ConcurrencyConfig cc) { + final var elementNum = new AtomicLong(0); + final var wipCount = new AtomicLong(0); + final var limitTo = Math.max(cc.config().streamSize() / 2, 1); + + final var expectedResult = cc.config().stream() + .map(x -> x * x) + .limit(limitTo) + .toList(); + + Function fun = x -> { + if (wipCount.incrementAndGet() > cc.concurrencyLevel) + throw new IllegalStateException("Too much wip!"); + if (elementNum.getAndIncrement() == 0) + LockSupport.parkNanos(500_000_000); // 500 ms + return x * x; + }; + + final var result = cc.config().stream() + .gather(Gatherers.mapConcurrent(cc.concurrencyLevel(), fun)) + .gather(Gatherer.of((v, e, d) -> wipCount.decrementAndGet() >= 0 && d.push(e))) + .limit(limitTo) + .toList(); + + assertEquals(expectedResult, result); + } }