Skip to content

Commit

Permalink
Fix refcounting in SharedBlobCacheService
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrx committed May 28, 2024
1 parent 8329a09 commit 666c390
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.blobcache.BlobCacheMetrics;
import org.elasticsearch.blobcache.BlobCacheUtils;
Expand Down Expand Up @@ -710,6 +711,11 @@ public void close() {
sharedBytes.decRef();
}

// used by tests
SharedBytes getSharedBytes() {
return sharedBytes;
}

private record RegionKey<KeyType>(KeyType file, int region) {
@Override
public String toString() {
Expand Down Expand Up @@ -801,7 +807,7 @@ public long physicalStartOffset() {
return ioRef == null ? -1L : (long) regionKey.region * regionSize;
}

public boolean tryIncRefEnsureOpen() {
private boolean tryIncRefEnsureOpen() {
if (tryIncRef()) {
ensureOpenOrDecRef();
return true;
Expand All @@ -810,7 +816,7 @@ public boolean tryIncRefEnsureOpen() {
return false;
}

public void incRefEnsureOpen() {
private void incRefEnsureOpen() {
incRef();
ensureOpenOrDecRef();
}
Expand Down Expand Up @@ -878,13 +884,17 @@ private static void throwAlreadyEvicted() {
*/
boolean tryRead(ByteBuffer buf, long offset) throws IOException {
SharedBytes.IO ioRef = this.io;
if (ioRef != null) {
int readBytes = ioRef.read(buf, getRegionRelativePosition(offset));
if (isEvicted()) {
buf.position(buf.position() - readBytes);
return false;
if (ioRef != null && ioRef.tryIncRef()) {
try {
int readBytes = ioRef.read(buf, getRegionRelativePosition(offset));
if (isEvicted()) {
buf.position(buf.position() - readBytes);
return false;
}
return true;
} finally {
ioRef.decRef();
}
return true;
} else {
// taken by someone else
return false;
Expand All @@ -908,29 +918,35 @@ void populate(
final Executor executor,
final ActionListener<Boolean> listener
) {
Releasable resource = null;
try {
incRefEnsureOpen();
resource = Releasables.releaseOnce(this::decRef);
final List<SparseFileTracker.Gap> gaps = tracker.waitForRange(
rangeToWrite,
rangeToWrite,
Assertions.ENABLED ? ActionListener.releaseAfter(ActionListener.running(() -> {
assert regionOwners.get(io) == this;
}), resource) : ActionListener.releasing(resource)
);
final var hasGapsToFill = gaps.size() > 0;
try (RefCountingListener refs = new RefCountingListener(listener.map(unused -> hasGapsToFill))) {
if (hasGapsToFill) {
final var cacheFileRegion = CacheFileRegion.this;
for (SparseFileTracker.Gap gap : gaps) {
var fillGapRunnable = fillGapRunnable(cacheFileRegion, writer, gap);
executor.execute(ActionRunnable.run(refs.acquire(), fillGapRunnable::run));
try (var cacheRefs = new RefCountingRunnable(this::decRef)) {
final var ioRef = io;
ioRef.incRef();
var releasable = Releasables.wrap(cacheRefs.acquire(), ioRef::decRef);
try (var ioRefs = new RefCountingRunnable(releasable::close)) {
final List<SparseFileTracker.Gap> gaps = tracker.waitForRange(
rangeToWrite,
rangeToWrite,
Assertions.ENABLED ? ActionListener.releaseAfter(ActionListener.running(() -> {
assert regionOwners.get(io) == this;
}), ioRefs.acquire()) : ioRefs.acquireListener()
);
if (gaps.isEmpty()) {
listener.onResponse(false);
return;
}
try (var gapsListener = new RefCountingListener(listener.map(unused -> true))) {
for (SparseFileTracker.Gap gap : gaps) {
executor.execute(
fillGapRunnable(gap, writer, ActionListener.releaseAfter(gapsListener.acquire(), ioRefs.acquire()))
);
}
}
}
}
} catch (Exception e) {
releaseAndFail(listener, resource, e);
listener.onFailure(e);
}
}

Expand All @@ -942,77 +958,66 @@ void populateAndRead(
final Executor executor,
final ActionListener<Integer> listener
) {
Releasable resource = null;
try {
incRefEnsureOpen();
resource = Releasables.releaseOnce(this::decRef);
final List<SparseFileTracker.Gap> gaps = tracker.waitForRange(
rangeToWrite,
rangeToRead,
ActionListener.runAfter(listener, resource::close).delegateFailureAndWrap((l, success) -> {
var ioRef = io;
assert regionOwners.get(ioRef) == this;
final int start = Math.toIntExact(rangeToRead.start());
final int read = reader.onRangeAvailable(ioRef, start, start, Math.toIntExact(rangeToRead.length()));
assert read == rangeToRead.length()
: "partial read ["
+ read
+ "] does not match the range to read ["
+ rangeToRead.end()
+ '-'
+ rangeToRead.start()
+ ']';
readCount.increment();
l.onResponse(read);
})
);
try (var cacheRefs = new RefCountingRunnable(this::decRef)) {
final var ioRef = io;
ioRef.incRef();
var releasable = Releasables.wrap(cacheRefs.acquire(), ioRef::decRef);
try (var ioRefs = new RefCountingRunnable(releasable::close)) {
final List<SparseFileTracker.Gap> gaps = tracker.waitForRange(
rangeToWrite,
rangeToRead,
ActionListener.releaseAfter(listener, ioRefs.acquire()).delegateFailureAndWrap((l, success) -> {
assert regionOwners.get(ioRef) == this;
final int start = Math.toIntExact(rangeToRead.start());
final int read = reader.onRangeAvailable(ioRef, start, start, Math.toIntExact(rangeToRead.length()));
assert read == rangeToRead.length()
: "partial read ["
+ read
+ "] does not match the range to read ["
+ rangeToRead.end()
+ '-'
+ rangeToRead.start()
+ ']';
readCount.increment();
l.onResponse(read);
})
);

if (gaps.isEmpty() == false) {
final var cacheFileRegion = CacheFileRegion.this;
for (SparseFileTracker.Gap gap : gaps) {
executor.execute(fillGapRunnable(cacheFileRegion, writer, gap));
if (gaps.isEmpty() == false) {
for (SparseFileTracker.Gap gap : gaps) {
executor.execute(fillGapRunnable(gap, writer, ioRefs.acquireListener()));
}
}
}
}
} catch (Exception e) {
releaseAndFail(listener, resource, e);
}
}

private AbstractRunnable fillGapRunnable(CacheFileRegion cacheFileRegion, RangeMissingHandler writer, SparseFileTracker.Gap gap) {
return new AbstractRunnable() {
@Override
protected void doRun() throws Exception {
if (cacheFileRegion.tryIncRefEnsureOpen() == false) {
throw new AlreadyClosedException("File chunk [" + cacheFileRegion.regionKey + "] has been released");
}
try {
final int start = Math.toIntExact(gap.start());
var ioRef = io;
assert regionOwners.get(ioRef) == cacheFileRegion;
writer.fillCacheRange(
ioRef,
start,
start,
Math.toIntExact(gap.end() - start),
progress -> gap.onProgress(start + progress)
);
writeCount.increment();
} finally {
cacheFileRegion.decRef();
}
gap.onCompletion();
}

@Override
public void onFailure(Exception e) {
gap.onFailure(e);
}
};
listener.onFailure(e);
}
}

private AbstractRunnable fillGapRunnable(SparseFileTracker.Gap gap, RangeMissingHandler writer, ActionListener<Void> listener) {
return ActionRunnable.run(listener.delegateResponse((l, e) -> failGapAndListener(gap, l, e)), () -> {
var ioRef = io;
assert regionOwners.get(ioRef) == CacheFileRegion.this;
assert CacheFileRegion.this.hasReferences() : CacheFileRegion.this;
int start = Math.toIntExact(gap.start());
writer.fillCacheRange(
ioRef,
start,
start,
Math.toIntExact(gap.end() - start),
progress -> gap.onProgress(start + progress)
);
writeCount.increment();
gap.onCompletion();
});
}

private static void releaseAndFail(ActionListener<?> listener, Releasable decrementRef, Exception e) {
private static void failGapAndListener(SparseFileTracker.Gap gap, ActionListener<?> listener, Exception e) {
try {
Releasables.close(decrementRef);
gap.onFailure(e);
} catch (Exception ex) {
e.addSuppressed(ex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.store.AlreadyClosedException;
import org.elasticsearch.blobcache.BlobCacheUtils;
import org.elasticsearch.blobcache.common.ByteBufferReference;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Streams;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.env.Environment;
Expand Down Expand Up @@ -285,7 +287,7 @@ public IO getFileChannel(int sharedBytesPos) {
return ios[sharedBytesPos];
}

public final class IO {
public final class IO implements RefCounted {

private final long pageStart;

Expand All @@ -298,8 +300,31 @@ private IO(final int sharedBytesPos, MappedByteBuffer mappedByteBuffer) {
this.mappedByteBuffer = mappedByteBuffer;
}

@Override
public boolean tryIncRef() {
return SharedBytes.this.tryIncRef();
}

@Override
public void incRef() {
if (tryIncRef() == false) {
throw new AlreadyClosedException("File channel is closed");
}
}

@Override
public boolean decRef() {
return SharedBytes.this.decRef();
}

@Override
public boolean hasReferences() {
return SharedBytes.this.hasReferences();
}

@SuppressForbidden(reason = "Use positional reads on purpose")
public int read(ByteBuffer dst, int position) throws IOException {
assert hasReferences();
int remaining = dst.remaining();
checkOffsets(position, remaining);
final int bytesRead;
Expand All @@ -316,6 +341,7 @@ public int read(ByteBuffer dst, int position) throws IOException {

@SuppressForbidden(reason = "Use positional writes on purpose")
public int write(ByteBuffer src, int position) throws IOException {
assert hasReferences();
// check if writes are page size aligned for optimal performance
assert position % PAGE_SIZE == 0;
assert src.remaining() % PAGE_SIZE == 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -1310,4 +1313,68 @@ protected int computeCacheFileRegionSize(long fileLength, int region) {
}
}
}

public void testWriteAndReadCanCompleteAfterSharedBytesCloses() throws Exception {
final long regionSize = size(1L);
Settings settings = Settings.builder()
.put(NODE_NAME_SETTING.getKey(), "node")
.put(SharedBlobCacheService.SHARED_CACHE_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(size(10)).getStringRep())
.put(SharedBlobCacheService.SHARED_CACHE_REGION_SIZE_SETTING.getKey(), ByteSizeValue.ofBytes(regionSize).getStringRep())
.put("path.home", createTempDir())
.build();

final TestThreadPool threadPool = new TestThreadPool(getTestName());
try (NodeEnvironment environment = new NodeEnvironment(settings, TestEnvironment.newEnvironment(settings))) {
var cacheService = new SharedBlobCacheService<>(
environment,
settings,
threadPool,
ThreadPool.Names.GENERIC,
BlobCacheMetrics.NOOP
);

final var cacheKey = generateCacheKey();
final var blobLength = size(12L);

var writeBlocked = new CountDownLatch(1);
var resumeWrites = new CountDownLatch(1);

var entry = cacheService.get(cacheKey, blobLength, 0);
final PlainActionFuture<Integer> future = new PlainActionFuture<>();

entry.populateAndRead(
ByteRange.of(0, regionSize),
ByteRange.of(0, regionSize),
(channel, channelPos, relativePos, length) -> Math.toIntExact(regionSize),
(channel, channelPos, relativePos, length, progressUpdater) -> {
writeBlocked.countDown();
safeAwait(resumeWrites);
SharedBytes.copyToCacheFileAligned(
channel,
new ByteArrayInputStream(randomByteArrayOfLength(SharedBytes.PAGE_SIZE)),
channelPos,
progressUpdater,
ByteBuffer.allocate(SharedBytes.PAGE_SIZE)
);
},
threadPool.generic(),
future
);

safeAwait(writeBlocked);
assertThat(cacheService.getSharedBytes().hasReferences(), is(true));

cacheService.close();
assertThat(cacheService.getSharedBytes().hasReferences(), is(true));

resumeWrites.countDown();

var written = future.get(10L, TimeUnit.SECONDS);
assertThat(written, equalTo((int) regionSize));

assertBusy(() -> assertThat(cacheService.getSharedBytes().hasReferences(), is(false)));
} finally {
assertTrue(ThreadPool.terminate(threadPool, 10L, TimeUnit.SECONDS));
}
}
}

0 comments on commit 666c390

Please sign in to comment.