Skip to content

Commit

Permalink
Add the include/exclude optimization
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Jan 28, 2025
1 parent 31de672 commit c7db83e
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add a new build mode, `FAISS_OPT_LEVEL=avx512_spr`, which enables the use of advanced AVX-512 instructions introduced with Intel(R) Sapphire Rapids (#2404)[https://github.com/opensearch-project/k-NN/pull/2404]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
- Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345]

- Add derived source feature for vector fields (#2449)[https://github.com/opensearch-project/k-NN/pull/2449]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Setter;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.index.StoredFieldVisitor;
import org.opensearch.index.fieldvisitor.FieldsVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector;

Expand All @@ -25,7 +26,15 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {

@Override
public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException {
if (shouldInject) {
// If the visitor has explicitly indicated it does not need the fields, we should not inject them
boolean isVisitorNeedFields = true;
if (storedFieldVisitor instanceof FieldsVisitor) {
isVisitorNeedFields = derivedSourceVectorInjector.shouldInject(
((FieldsVisitor) storedFieldVisitor).includes(),
((FieldsVisitor) storedFieldVisitor).excludes()
);
}
if (shouldInject && isVisitorNeedFields) {
delegate.document(docId, new DerivedSourceStoredFieldVisitor(storedFieldVisitor, docId, derivedSourceVectorInjector));
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ public class DerivedSourceStoredFieldVisitor extends StoredFieldVisitor {

@Override
public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException {
// TODO: Add skip condition here if the delegate specifies which fields are not required for source
if (fieldInfo.name.equals(SourceFieldMapper.NAME)) {
delegate.binaryField(fieldInfo, derivedSourceVectorInjector.injectVectors(documentId, value));
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* This class is responsible for injecting vectors into the source of a document. From a high level, it uses alternative
Expand All @@ -31,6 +33,7 @@
public class DerivedSourceVectorInjector {

private final List<PerFieldDerivedVectorInjector> perFieldDerivedVectorInjectors;
private final Set<String> fieldNames;

/**
* Constructor for DerivedSourceVectorInjector.
Expand All @@ -46,10 +49,12 @@ public DerivedSourceVectorInjector(
) throws IOException {
DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
this.perFieldDerivedVectorInjectors = new ArrayList<>();
this.fieldNames = new HashSet<>();
for (FieldInfo fieldInfo : fieldsToInjectVector) {
this.perFieldDerivedVectorInjectors.add(
PerFieldDerivedVectorInjectorFactory.create(fieldInfo, derivedSourceReaders, segmentReadState)
);
this.fieldNames.add(fieldInfo.name);
}
}

Expand Down Expand Up @@ -84,4 +89,34 @@ public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOExcept
builder.close();
return BytesReference.toBytes(BytesReference.bytes(builder));
}

/**
* Whether or not to inject vectors based on what fields are explicitly required
*
* @param includes List of fields that are required to be injected
* @param excludes List of fields that are not required to be injected
* @return true if vectors should be injected, false otherwise
*/
public boolean shouldInject(String[] includes, String[] excludes) {
// If any of the vector fields are explicitly required we should inject
if (includes != null) {
for (String includedField : includes) {
if (fieldNames.contains(includedField)) {
return true;
}
}
}

// If all of the vector fields are explicitly excluded we should not inject
if (excludes != null) {
int excludedVectorFieldCount = 0;
for (String excludedField : excludes) {
if (fieldNames.contains(excludedField)) {
excludedVectorFieldCount++;
}
}
return excludedVectorFieldCount >= fieldNames.size();
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.codec.derivedsource;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.PostingsEnum;
Expand All @@ -26,22 +27,13 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

@Log4j2
@AllArgsConstructor
public class NestedPerFieldDerivedVectorInjector implements PerFieldDerivedVectorInjector {

private final FieldInfo childFieldInfo;
private final DerivedSourceReaders derivedSourceReaders;
private final SegmentReadState segmentReadState;

public NestedPerFieldDerivedVectorInjector(
FieldInfo childFieldInfo,
DerivedSourceReaders derivedSourceReaders,
SegmentReadState segmentReadState
) {
this.childFieldInfo = childFieldInfo;
this.derivedSourceReaders = derivedSourceReaders;
this.segmentReadState = segmentReadState;
}

@Override
public void inject(Integer parentDocId, Map<String, Object> sourceAsMap) throws IOException {
// Setup the iterator. Return if not-relevant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ public int firstChild() {
return previousParentDocId + 1;
}

/**
* Get the number of children for this parent.
*
* @return number of children for this parent
*/
public int numChildren() {
return children.size();
}

/**
* Get the next child for this parent
*
Expand Down

0 comments on commit c7db83e

Please sign in to comment.