Skip to content

Commit

Permalink
Add the include/exclude optimization
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Jan 28, 2025
1 parent 1a41d51 commit 16c255d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.Setter;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.index.StoredFieldVisitor;
import org.opensearch.index.fieldvisitor.FieldsVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector;

Expand All @@ -25,7 +26,15 @@ public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {

@Override
public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException {
if (shouldInject) {
// If the visitor has explicitly indicated it does not need the fields, we should not inject them
boolean isVisitorNeedFields = true;
if (storedFieldVisitor instanceof FieldsVisitor) {
isVisitorNeedFields = derivedSourceVectorInjector.shouldInject(
((FieldsVisitor) storedFieldVisitor).includes(),
((FieldsVisitor) storedFieldVisitor).excludes()
);
}
if (shouldInject && isVisitorNeedFields) {
delegate.document(docId, new DerivedSourceStoredFieldVisitor(storedFieldVisitor, docId, derivedSourceVectorInjector));
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ public class DerivedSourceStoredFieldVisitor extends StoredFieldVisitor {

@Override
public void binaryField(FieldInfo fieldInfo, byte[] value) throws IOException {
// TODO: Add skip condition here if the delegate specifies which fields are not required for source
if (fieldInfo.name.equals(SourceFieldMapper.NAME)) {
delegate.binaryField(fieldInfo, derivedSourceVectorInjector.injectVectors(documentId, value));
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* This class is responsible for injecting vectors into the source of a document. From a high level, it uses alternative
Expand All @@ -31,6 +33,7 @@
public class DerivedSourceVectorInjector {

private final List<PerFieldDerivedVectorInjector> perFieldDerivedVectorInjectors;
private final Set<String> fieldNames;

/**
* Constructor for DerivedSourceVectorInjector.
Expand All @@ -46,10 +49,12 @@ public DerivedSourceVectorInjector(
) throws IOException {
DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
this.perFieldDerivedVectorInjectors = new ArrayList<>();
this.fieldNames = new HashSet<>();
for (FieldInfo fieldInfo : fieldsToInjectVector) {
this.perFieldDerivedVectorInjectors.add(
PerFieldDerivedVectorInjectorFactory.create(fieldInfo, derivedSourceReaders, segmentReadState)
);
this.fieldNames.add(fieldInfo.name);
}
}

Expand Down Expand Up @@ -84,4 +89,34 @@ public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOExcept
builder.close();
return BytesReference.toBytes(BytesReference.bytes(builder));
}

/**
* Whether or not to inject vectors based on what fields are explicitly required
*
* @param includes List of fields that are required to be injected
* @param excludes List of fields that are not required to be injected
* @return true if vectors should be injected, false otherwise
*/
public boolean shouldInject(String[] includes, String[] excludes) {
// If any of the vector fields are explicitly required we should inject
if (includes != null) {
for (String includedField : includes) {
if (fieldNames.contains(includedField)) {
return true;
}
}
}

// If all of the vector fields are explicitly excluded we should not inject
if (excludes != null) {
int excludedVectorFieldCount = 0;
for (String excludedField : excludes) {
if (fieldNames.contains(excludedField)) {
excludedVectorFieldCount++;
}
}
return excludedVectorFieldCount >= fieldNames.size();
}
return true;
}
}

0 comments on commit 16c255d

Please sign in to comment.