Skip to content

Commit

Permalink
Fix bytes bug and add uTs for derived source
Browse files Browse the repository at this point in the history
Fixes a bug in the derived source writer where we are reading the entire
bytes array from the bytes ref instead of just the offset+length.

Along with that, touches up the ParentChildHelper (no prod impact) and
also adds some unit tests.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Feb 6, 2025
1 parent 48488a8 commit 99dd73c
Show file tree
Hide file tree
Showing 11 changed files with 444 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.KNN9120Codec.DerivedSourceStoredFieldsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 10.0.1
Expand All @@ -24,12 +28,15 @@ public class KNN10010Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_10_01_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final StoredFieldsFormat storedFieldsFormat;

private final MapperService mapperService;

/**
* No arg constructor that uses Lucene99 as the delegate
*/
public KNN10010Codec() {
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null);
}

/**
Expand All @@ -40,10 +47,12 @@ public KNN10010Codec() {
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) {
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
this.mapperService = mapperService;
this.storedFieldsFormat = getStoredFieldsFormat();
}

@Override
Expand All @@ -60,4 +69,36 @@ public CompoundFormat compoundFormat() {
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}

@Override
public StoredFieldsFormat storedFieldsFormat() {
return storedFieldsFormat;
}

private StoredFieldsFormat getStoredFieldsFormat() {
DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> {
if (segmentReadState.fieldInfos.hasVectorValues()) {
return knnVectorsFormat().fieldsReader(segmentReadState);
}
return null;
}, (segmentReadState) -> {
if (segmentReadState.fieldInfos.hasDocValues()) {
return docValuesFormat().fieldsProducer(segmentReadState);
}
return null;

}, (segmentReadState) -> {
if (segmentReadState.fieldInfos.hasPostings()) {
return postingsFormat().fieldsProducer(segmentReadState);
}
return null;

}, (segmentReadState -> {
if (segmentReadState.fieldInfos.hasNorms()) {
return normsFormat().normsProducer(segmentReadState);
}
return null;
}));
return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio
// Reference:
// https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)),
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)),
true,
MediaTypeRegistry.JSON
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final StoredFieldsFormat storedFieldsFormat;

private final MapperService mapperService;

Expand All @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
this.mapperService = mapperService;
this.storedFieldsFormat = getStoredFieldsFormat();
}

@Override
Expand All @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() {

@Override
public StoredFieldsFormat storedFieldsFormat() {
return storedFieldsFormat;
}

private StoredFieldsFormat getStoredFieldsFormat() {
DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> {
if (segmentReadState.fieldInfos.hasVectorValues()) {
return knnVectorsFormat().fieldsReader(segmentReadState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ public class ParentChildHelper {
* this would return "parent.to".
*
* @param field nested field path
* @return parent field path without the child
* @return parent field path without the child. Null if no parent exists
*/
public static String getParentField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
Expand All @@ -30,10 +33,16 @@ public static String getParentField(String field) {
* return "child".
*
* @param field nested field path
* @return child field path without the parent path
* @return child field path without the parent path. Null if no child exists
*/
public static String getChildField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
}
return field.substring(lastDot + 1);
}

Expand All @@ -46,7 +55,11 @@ public static String getChildField(String field) {
* @return sibling field path
*/
public static String constructSiblingField(String field, String sibling) {
return getParentField(field) + "." + sibling;
String parent = getParentField(field);
if (parent == null) {
return sibling;
}
return parent + "." + sibling;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
/**
* Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator}
*/
interface VectorValueExtractorStrategy {
public interface VectorValueExtractorStrategy {

/**
* Extract a float vector from KNNVectorValuesIterator.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.StoredFieldsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import java.util.List;
import java.util.Map;

import static org.mockito.Mockito.mock;

public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase {

@SneakyThrows
public void testWriteField() {
StoredFieldsWriter delegate = mock(StoredFieldsWriter.class);
FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build();
List<String> fields = List.of("test");

DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields);

Map<String, Object> source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value");
BytesStreamOutput bStream = new BytesStreamOutput();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source);
builder.close();
byte[] originalBytes = bStream.bytes().toBytesRef().bytes;
byte[] shiftedBytes = new byte[originalBytes.length + 2];
System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length);
derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import org.apache.lucene.index.StoredFieldVisitor;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase {

public void testBinaryField() throws Exception {
StoredFieldVisitor delegate = mock(StoredFieldVisitor.class);
doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any());
DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class);
when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]);
DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor(
delegate,
0,
derivedSourceVectorInjector
);

// When field is not _source, then do not call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null);
verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any());
verify(delegate, times(1)).binaryField(any(), any());

// When field is not _source, then do call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null);
verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any());
verify(delegate, times(2)).binaryField(any(), any());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import lombok.SneakyThrows;
import org.apache.lucene.index.FieldInfo;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.mockito.ArgumentMatchers.any;

public class DerivedSourceVectorInjectorTests extends KNNTestCase {

@SneakyThrows
@SuppressWarnings("unchecked")
public void testInjectVectors() {
List<FieldInfo> fields = List.of(
KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(),
KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(),
KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build()
);

Map<String, float[]> fieldToVector = Collections.unmodifiableMap(new HashMap<>() {
{
put("test1", new float[] { 1.0f, 2.0f, 3.0f });
put("test2", new float[] { 4.0f, 5.0f, 6.0f, 7.0f });
put("test3", new float[] { 7.0f, 8.0f, 9.0f, 1.0f, 3.0f, 4.0f });
put("test4", null);
}
});

try (MockedStatic<PerFieldDerivedVectorInjectorFactory> factory = Mockito.mockStatic(PerFieldDerivedVectorInjectorFactory.class)) {
factory.when(() -> PerFieldDerivedVectorInjectorFactory.create(any(), any(), any())).thenAnswer(invocation -> {
FieldInfo fieldInfo = invocation.getArgument(0);
return (PerFieldDerivedVectorInjector) (docId, sourceAsMap) -> {
float[] vector = fieldToVector.get(fieldInfo.name);
if (vector != null) {
sourceAsMap.put(fieldInfo.name, vector);
}
};
});

DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector(
new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null),
null,
fields
);

int docId = 2;
String existingFieldKey = "existingField";
String existingFieldValue = "existingField";
Map<String, Object> source = Map.of(existingFieldKey, existingFieldValue);
byte[] originalSourceBytes = mapToBytes(source);
byte[] modifiedSourceByttes = derivedSourceVectorInjector.injectVectors(docId, originalSourceBytes);
Map<String, Object> modifiedSource = bytesToMap(modifiedSourceByttes);

assertEquals(existingFieldValue, modifiedSource.get(existingFieldKey));

assertArrayEquals(fieldToVector.get("test1"), toFloatArray((List<Double>) modifiedSource.get("test1")), 0.000001f);
assertArrayEquals(fieldToVector.get("test2"), toFloatArray((List<Double>) modifiedSource.get("test2")), 0.000001f);
assertArrayEquals(fieldToVector.get("test3"), toFloatArray((List<Double>) modifiedSource.get("test3")), 0.000001f);
assertFalse(modifiedSource.containsKey("test4"));
}
}

@SneakyThrows
private byte[] mapToBytes(Map<String, Object> map) {

BytesStreamOutput bStream = new BytesStreamOutput(1024);
XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(map);
builder.close();
return BytesReference.toBytes(BytesReference.bytes(builder));
}

private float[] toFloatArray(List<Double> list) {
float[] array = new float[list.size()];
for (int i = 0; i < list.size(); i++) {
array[i] = list.get(i).floatValue();
}
return array;
}

private Map<String, Object> bytesToMap(byte[] bytes) {
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytes)),
true,
MediaTypeRegistry.getDefaultMediaType()
);

return mapTuple.v2();
}

@SneakyThrows
public void testShouldInject() {

List<FieldInfo> fields = List.of(
KNNCodecTestUtil.FieldInfoBuilder.builder("test1").build(),
KNNCodecTestUtil.FieldInfoBuilder.builder("test2").build(),
KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build()
);

try (
DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector(
new DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null),
null,
fields
)
) {
assertTrue(vectorInjector.shouldInject(null, null));
assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null));
assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null));
assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" }));
assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" }));
assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" }));
assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" }));
}
}
}
Loading

0 comments on commit 99dd73c

Please sign in to comment.