From 666c390747456b7440ffdb2a96fd8bfbb980d592 Mon Sep 17 00:00:00 2001 From: Tanguy Leroux Date: Tue, 28 May 2024 15:32:24 +0200 Subject: [PATCH] Fix refcounting in SharedBlobCacheService --- .../shared/SharedBlobCacheService.java | 179 +++++++++--------- .../blobcache/shared/SharedBytes.java | 28 ++- .../shared/SharedBlobCacheServiceTests.java | 67 +++++++ 3 files changed, 186 insertions(+), 88 deletions(-) diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index c5ef1d7c2bf1d..c93286802244c 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.blobcache.BlobCacheMetrics; import org.elasticsearch.blobcache.BlobCacheUtils; @@ -710,6 +711,11 @@ public void close() { sharedBytes.decRef(); } + // used by tests + SharedBytes getSharedBytes() { + return sharedBytes; + } + private record RegionKey(KeyType file, int region) { @Override public String toString() { @@ -801,7 +807,7 @@ public long physicalStartOffset() { return ioRef == null ? -1L : (long) regionKey.region * regionSize; } - public boolean tryIncRefEnsureOpen() { + private boolean tryIncRefEnsureOpen() { if (tryIncRef()) { ensureOpenOrDecRef(); return true; @@ -810,7 +816,7 @@ public boolean tryIncRefEnsureOpen() { return false; } - public void incRefEnsureOpen() { + private void incRefEnsureOpen() { incRef(); ensureOpenOrDecRef(); } @@ -878,13 +884,17 @@ private static void throwAlreadyEvicted() { */ boolean tryRead(ByteBuffer buf, long offset) throws IOException { SharedBytes.IO ioRef = this.io; - if (ioRef != null) { - int readBytes = ioRef.read(buf, getRegionRelativePosition(offset)); - if (isEvicted()) { - buf.position(buf.position() - readBytes); - return false; + if (ioRef != null && ioRef.tryIncRef()) { + try { + int readBytes = ioRef.read(buf, getRegionRelativePosition(offset)); + if (isEvicted()) { + buf.position(buf.position() - readBytes); + return false; + } + return true; + } finally { + ioRef.decRef(); } - return true; } else { // taken by someone else return false; @@ -908,29 +918,35 @@ void populate( final Executor executor, final ActionListener listener ) { - Releasable resource = null; try { incRefEnsureOpen(); - resource = Releasables.releaseOnce(this::decRef); - final List gaps = tracker.waitForRange( - rangeToWrite, - rangeToWrite, - Assertions.ENABLED ? ActionListener.releaseAfter(ActionListener.running(() -> { - assert regionOwners.get(io) == this; - }), resource) : ActionListener.releasing(resource) - ); - final var hasGapsToFill = gaps.size() > 0; - try (RefCountingListener refs = new RefCountingListener(listener.map(unused -> hasGapsToFill))) { - if (hasGapsToFill) { - final var cacheFileRegion = CacheFileRegion.this; - for (SparseFileTracker.Gap gap : gaps) { - var fillGapRunnable = fillGapRunnable(cacheFileRegion, writer, gap); - executor.execute(ActionRunnable.run(refs.acquire(), fillGapRunnable::run)); + try (var cacheRefs = new RefCountingRunnable(this::decRef)) { + final var ioRef = io; + ioRef.incRef(); + var releasable = Releasables.wrap(cacheRefs.acquire(), ioRef::decRef); + try (var ioRefs = new RefCountingRunnable(releasable::close)) { + final List gaps = tracker.waitForRange( + rangeToWrite, + rangeToWrite, + Assertions.ENABLED ? ActionListener.releaseAfter(ActionListener.running(() -> { + assert regionOwners.get(io) == this; + }), ioRefs.acquire()) : ioRefs.acquireListener() + ); + if (gaps.isEmpty()) { + listener.onResponse(false); + return; + } + try (var gapsListener = new RefCountingListener(listener.map(unused -> true))) { + for (SparseFileTracker.Gap gap : gaps) { + executor.execute( + fillGapRunnable(gap, writer, ActionListener.releaseAfter(gapsListener.acquire(), ioRefs.acquire())) + ); + } } } } } catch (Exception e) { - releaseAndFail(listener, resource, e); + listener.onFailure(e); } } @@ -942,77 +958,66 @@ void populateAndRead( final Executor executor, final ActionListener listener ) { - Releasable resource = null; try { incRefEnsureOpen(); - resource = Releasables.releaseOnce(this::decRef); - final List gaps = tracker.waitForRange( - rangeToWrite, - rangeToRead, - ActionListener.runAfter(listener, resource::close).delegateFailureAndWrap((l, success) -> { - var ioRef = io; - assert regionOwners.get(ioRef) == this; - final int start = Math.toIntExact(rangeToRead.start()); - final int read = reader.onRangeAvailable(ioRef, start, start, Math.toIntExact(rangeToRead.length())); - assert read == rangeToRead.length() - : "partial read [" - + read - + "] does not match the range to read [" - + rangeToRead.end() - + '-' - + rangeToRead.start() - + ']'; - readCount.increment(); - l.onResponse(read); - }) - ); + try (var cacheRefs = new RefCountingRunnable(this::decRef)) { + final var ioRef = io; + ioRef.incRef(); + var releasable = Releasables.wrap(cacheRefs.acquire(), ioRef::decRef); + try (var ioRefs = new RefCountingRunnable(releasable::close)) { + final List gaps = tracker.waitForRange( + rangeToWrite, + rangeToRead, + ActionListener.releaseAfter(listener, ioRefs.acquire()).delegateFailureAndWrap((l, success) -> { + assert regionOwners.get(ioRef) == this; + final int start = Math.toIntExact(rangeToRead.start()); + final int read = reader.onRangeAvailable(ioRef, start, start, Math.toIntExact(rangeToRead.length())); + assert read == rangeToRead.length() + : "partial read [" + + read + + "] does not match the range to read [" + + rangeToRead.end() + + '-' + + rangeToRead.start() + + ']'; + readCount.increment(); + l.onResponse(read); + }) + ); - if (gaps.isEmpty() == false) { - final var cacheFileRegion = CacheFileRegion.this; - for (SparseFileTracker.Gap gap : gaps) { - executor.execute(fillGapRunnable(cacheFileRegion, writer, gap)); + if (gaps.isEmpty() == false) { + for (SparseFileTracker.Gap gap : gaps) { + executor.execute(fillGapRunnable(gap, writer, ioRefs.acquireListener())); + } + } } } } catch (Exception e) { - releaseAndFail(listener, resource, e); - } - } - - private AbstractRunnable fillGapRunnable(CacheFileRegion cacheFileRegion, RangeMissingHandler writer, SparseFileTracker.Gap gap) { - return new AbstractRunnable() { - @Override - protected void doRun() throws Exception { - if (cacheFileRegion.tryIncRefEnsureOpen() == false) { - throw new AlreadyClosedException("File chunk [" + cacheFileRegion.regionKey + "] has been released"); - } - try { - final int start = Math.toIntExact(gap.start()); - var ioRef = io; - assert regionOwners.get(ioRef) == cacheFileRegion; - writer.fillCacheRange( - ioRef, - start, - start, - Math.toIntExact(gap.end() - start), - progress -> gap.onProgress(start + progress) - ); - writeCount.increment(); - } finally { - cacheFileRegion.decRef(); - } - gap.onCompletion(); - } - - @Override - public void onFailure(Exception e) { - gap.onFailure(e); - } - }; + listener.onFailure(e); + } + } + + private AbstractRunnable fillGapRunnable(SparseFileTracker.Gap gap, RangeMissingHandler writer, ActionListener listener) { + return ActionRunnable.run(listener.delegateResponse((l, e) -> failGapAndListener(gap, l, e)), () -> { + var ioRef = io; + assert regionOwners.get(ioRef) == CacheFileRegion.this; + assert CacheFileRegion.this.hasReferences() : CacheFileRegion.this; + int start = Math.toIntExact(gap.start()); + writer.fillCacheRange( + ioRef, + start, + start, + Math.toIntExact(gap.end() - start), + progress -> gap.onProgress(start + progress) + ); + writeCount.increment(); + gap.onCompletion(); + }); } - private static void releaseAndFail(ActionListener listener, Releasable decrementRef, Exception e) { + private static void failGapAndListener(SparseFileTracker.Gap gap, ActionListener listener, Exception e) { try { - Releasables.close(decrementRef); + gap.onFailure(e); } catch (Exception ex) { e.addSuppressed(ex); } diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java index 051dfab1cdaa0..46e7b2c7d3889 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java @@ -9,11 +9,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.blobcache.BlobCacheUtils; import org.elasticsearch.blobcache.common.ByteBufferReference; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Streams; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.env.Environment; @@ -285,7 +287,7 @@ public IO getFileChannel(int sharedBytesPos) { return ios[sharedBytesPos]; } - public final class IO { + public final class IO implements RefCounted { private final long pageStart; @@ -298,8 +300,31 @@ private IO(final int sharedBytesPos, MappedByteBuffer mappedByteBuffer) { this.mappedByteBuffer = mappedByteBuffer; } + @Override + public boolean tryIncRef() { + return SharedBytes.this.tryIncRef(); + } + + @Override + public void incRef() { + if (tryIncRef() == false) { + throw new AlreadyClosedException("File channel is closed"); + } + } + + @Override + public boolean decRef() { + return SharedBytes.this.decRef(); + } + + @Override + public boolean hasReferences() { + return SharedBytes.this.hasReferences(); + } + @SuppressForbidden(reason = "Use positional reads on purpose") public int read(ByteBuffer dst, int position) throws IOException { + assert hasReferences(); int remaining = dst.remaining(); checkOffsets(position, remaining); final int bytesRead; @@ -316,6 +341,7 @@ public int read(ByteBuffer dst, int position) throws IOException { @SuppressForbidden(reason = "Use positional writes on purpose") public int write(ByteBuffer src, int position) throws IOException { + assert hasReferences(); // check if writes are page size aligned for optimal performance assert position % PAGE_SIZE == 0; assert src.remaining() % PAGE_SIZE == 0; diff --git a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java index edeed9a16034a..c71004c5c37f3 100644 --- a/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java +++ b/x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java @@ -33,7 +33,9 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; @@ -41,6 +43,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -1310,4 +1313,68 @@ protected int computeCacheFileRegionSize(long fileLength, int region) { } } } + + public void testWriteAndReadCanCompleteAfterSharedBytesCloses() throws Exception { + final long regionSize = size(1L); + Settings settings = Settings.builder() + .put(NODE_NAME_SETTING.getKey(), "node") + .put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(10)).getStringRep()) + .put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep()) + .put("path.home", createTempDir()) + .build(); + + final TestThreadPool threadPool = new TestThreadPool(getTestName()); + try (NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings))) { + var cacheService = new SharedBlobCacheService<>( + environment, + settings, + threadPool, + ThreadPool.Names.GENERIC, + BlobCacheMetrics.NOOP + ); + + final var cacheKey = generateCacheKey(); + final var blobLength = size(12L); + + var writeBlocked = new CountDownLatch(1); + var resumeWrites = new CountDownLatch(1); + + var entry = cacheService.get(cacheKey, blobLength, 0); + final PlainActionFuture future = new PlainActionFuture<>(); + + entry.populateAndRead( + ByteRange.of(0, regionSize), + ByteRange.of(0, regionSize), + (channel, channelPos, relativePos, length) -> Math.toIntExact(regionSize), + (channel, channelPos, relativePos, length, progressUpdater) -> { + writeBlocked.countDown(); + safeAwait(resumeWrites); + SharedBytes.copyToCacheFileAligned( + channel, + new ByteArrayInputStream(randomByteArrayOfLength(SharedBytes.PAGE_SIZE)), + channelPos, + progressUpdater, + ByteBuffer.allocate(SharedBytes.PAGE_SIZE) + ); + }, + threadPool.generic(), + future + ); + + safeAwait(writeBlocked); + assertThat(cacheService.getSharedBytes().hasReferences(), is(true)); + + cacheService.close(); + assertThat(cacheService.getSharedBytes().hasReferences(), is(true)); + + resumeWrites.countDown(); + + var written = future.get(10L, TimeUnit.SECONDS); + assertThat(written, equalTo((int) regionSize)); + + assertBusy(() -> assertThat(cacheService.getSharedBytes().hasReferences(), is(false))); + } finally { + assertTrue(ThreadPool.terminate(threadPool, 10L, TimeUnit.SECONDS)); + } + } }