Skip to content

Commit

Permalink
Fix bytes offset bug and duplicate readers and add uTs for derived so…
Browse files Browse the repository at this point in the history
…urce (#2494)

Fixes a bug in the derived source writer where we are reading the entire
bytes array from the bytes ref instead of just the offset+length.

Also reuses readers to prevent memory leak

Along with that, touches up the ParentChildHelper (no prod impact) and
also adds some unit tests.

Signed-off-by: John Mazanec <[email protected]>
(cherry picked from commit ab33538)
  • Loading branch information
jmazanec15 authored and github-actions[bot] committed Feb 6, 2025
1 parent c8b6637 commit 72a8d6e
Show file tree
Hide file tree
Showing 15 changed files with 419 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;

Expand Down Expand Up @@ -55,11 +56,14 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI
if (derivedVectorFields == null || derivedVectorFields.isEmpty()) {
return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext);
}

SegmentReadState segmentReadState = new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext);
DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
return new DerivedSourceStoredFieldsReader(
delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext),
derivedVectorFields,
derivedSourceReadersSupplier,
new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext)
derivedSourceReaders,
segmentReadState
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.util.IOUtils;
import org.opensearch.index.fieldvisitor.FieldsVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector;

import java.io.IOException;
import java.util.List;

@Log4j2
public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {
private final StoredFieldsReader delegate;
private final List<FieldInfo> derivedVectorFields;
private final DerivedSourceReadersSupplier derivedSourceReadersSupplier;
private final DerivedSourceReaders derivedSourceReaders;
private final SegmentReadState segmentReadState;
private final boolean shouldInject;

Expand All @@ -31,36 +33,36 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {
*
* @param delegate delegate StoredFieldsReader
* @param derivedVectorFields List of fields that are derived source fields
* @param derivedSourceReadersSupplier Supplier for the derived source readers
* @param derivedSourceReaders Derived source readers
* @param segmentReadState SegmentReadState for the segment
* @throws IOException in case of I/O error
*/
public DerivedSourceStoredFieldsReader(
StoredFieldsReader delegate,
List<FieldInfo> derivedVectorFields,
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState
) throws IOException {
this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true);
this(delegate, derivedVectorFields, derivedSourceReaders, segmentReadState, true);
}

private DerivedSourceStoredFieldsReader(
StoredFieldsReader delegate,
List<FieldInfo> derivedVectorFields,
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState,
boolean shouldInject
) throws IOException {
this.delegate = delegate;
this.derivedVectorFields = derivedVectorFields;
this.derivedSourceReadersSupplier = derivedSourceReadersSupplier;
this.derivedSourceReaders = derivedSourceReaders;
this.segmentReadState = segmentReadState;
this.shouldInject = shouldInject;
this.derivedSourceVectorInjector = createDerivedSourceVectorInjector();
}

private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException {
return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields);
private DerivedSourceVectorInjector createDerivedSourceVectorInjector() {
return new DerivedSourceVectorInjector(derivedSourceReaders, segmentReadState, derivedVectorFields);
}

@Override
Expand All @@ -86,7 +88,7 @@ public StoredFieldsReader clone() {
return new DerivedSourceStoredFieldsReader(
delegate.clone(),
derivedVectorFields,
derivedSourceReadersSupplier,
derivedSourceReaders.clone(),
segmentReadState,
shouldInject
);
Expand All @@ -102,6 +104,7 @@ public void checkIntegrity() throws IOException {

@Override
public void close() throws IOException {
log.debug("Closing derived source stored fields reader for segment: " + segmentReadState.segmentInfo.name);
IOUtils.close(delegate, derivedSourceVectorInjector);
}

Expand All @@ -117,7 +120,7 @@ private StoredFieldsReader cloneForMerge() {
return new DerivedSourceStoredFieldsReader(
delegate.getMergeInstance(),
derivedVectorFields,
derivedSourceReadersSupplier,
derivedSourceReaders.clone(),
segmentReadState,
false
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio
// Reference:
// https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)),
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)),
true,
MediaTypeRegistry.JSON
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final StoredFieldsFormat storedFieldsFormat;

private final MapperService mapperService;

Expand All @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
this.mapperService = mapperService;
this.storedFieldsFormat = getStoredFieldsFormat();
}

@Override
Expand All @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() {

@Override
public StoredFieldsFormat storedFieldsFormat() {
return storedFieldsFormat;
}

private StoredFieldsFormat getStoredFieldsFormat() {
DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> {
if (segmentReadState.fieldInfos.hasVectorValues()) {
return knnVectorsFormat().fieldsReader(segmentReadState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -23,7 +24,8 @@
*/
@RequiredArgsConstructor
@Getter
public class DerivedSourceReaders implements Closeable {
@Log4j2
public class DerivedSourceReaders implements Cloneable, Closeable {
@Nullable
private final KnnVectorsReader knnVectorsReader;
@Nullable
Expand All @@ -33,8 +35,17 @@ public class DerivedSourceReaders implements Closeable {
@Nullable
private final NormsProducer normsProducer;

private final boolean isCloned;

@Override
public void close() throws IOException {
IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer);
if (isCloned == false) {
IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer);
}
}

@Override
public DerivedSourceReaders clone() {
return new DerivedSourceReaders(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class DerivedSourceReadersSupplier {
private final DerivedSourceReaderSupplier<NormsProducer> normsProducer;

/**
* Get the readers for the segment
* Get the readers for the segment.
*
* @param state SegmentReadState
* @return DerivedSourceReaders
Expand All @@ -38,7 +38,8 @@ public DerivedSourceReaders getReaders(SegmentReadState state) throws IOExceptio
knnVectorsReaderSupplier.apply(state),
docValuesProducerSupplier.apply(state),
fieldsProducerSupplier.apply(state),
normsProducer.apply(state)
normsProducer.apply(state),
false
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ public class DerivedSourceVectorInjector implements Closeable {
/**
* Constructor for DerivedSourceVectorInjector.
*
* @param derivedSourceReadersSupplier Supplier for the derived source readers.
* @param derivedSourceReaders Derived source readers.
* @param segmentReadState Segment read state
* @param fieldsToInjectVector List of fields to inject vectors into
*/
public DerivedSourceVectorInjector(
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState,
List<FieldInfo> fieldsToInjectVector
) throws IOException {
this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
) {
this.derivedSourceReaders = derivedSourceReaders;
this.perFieldDerivedVectorInjectors = new ArrayList<>();
this.fieldNames = new HashSet<>();
for (FieldInfo fieldInfo : fieldsToInjectVector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ public class ParentChildHelper {
* this would return "parent.to".
*
* @param field nested field path
* @return parent field path without the child
* @return parent field path without the child. Null if no parent exists
*/
public static String getParentField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
Expand All @@ -30,10 +33,16 @@ public static String getParentField(String field) {
* return "child".
*
* @param field nested field path
* @return child field path without the parent path
* @return child field path without the parent path. Null if no child exists
*/
public static String getChildField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
}
return field.substring(lastDot + 1);
}

Expand All @@ -46,7 +55,11 @@ public static String getChildField(String field) {
* @return sibling field path
*/
public static String constructSiblingField(String field, String sibling) {
return getParentField(field) + "." + sibling;
String parent = getParentField(field);
if (parent == null) {
return sibling;
}
return parent + "." + sibling;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
/**
* Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator}
*/
interface VectorValueExtractorStrategy {
public interface VectorValueExtractorStrategy {

/**
* Extract a float vector from KNNVectorValuesIterator.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.StoredFieldsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import java.util.List;
import java.util.Map;

import static org.mockito.Mockito.mock;

public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase {

@SneakyThrows
public void testWriteField() {
StoredFieldsWriter delegate = mock(StoredFieldsWriter.class);
FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build();
List<String> fields = List.of("test");

DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields);

Map<String, Object> source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value");
BytesStreamOutput bStream = new BytesStreamOutput();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source);
builder.close();
byte[] originalBytes = bStream.bytes().toBytesRef().bytes;
byte[] shiftedBytes = new byte[originalBytes.length + 2];
System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length);
derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import org.apache.lucene.index.StoredFieldVisitor;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase {

public void testBinaryField() throws Exception {
StoredFieldVisitor delegate = mock(StoredFieldVisitor.class);
doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any());
DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class);
when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]);
DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor(
delegate,
0,
derivedSourceVectorInjector
);

// When field is not _source, then do not call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null);
verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any());
verify(delegate, times(1)).binaryField(any(), any());

// When field is not _source, then do call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null);
verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any());
verify(delegate, times(2)).binaryField(any(), any());
}
}
Loading

0 comments on commit 72a8d6e

Please sign in to comment.