diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 19944c11b..c2a4bb9b7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.processor; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -27,8 +28,10 @@ import com.google.common.collect.ImmutableMap; /** - * The abstract class for text processing use cases. Users provide a field name map and a model id. - * During ingestion, the processor will use the corresponding model to inference the input texts, + * The abstract class for text processing use cases. Users provide a field name + * map and a model id. + * During ingestion, the processor will use the corresponding model to inference + * the input texts, * and set the target fields according to the field name map. */ @Log4j2 @@ -39,7 +42,8 @@ public abstract class InferenceProcessor extends AbstractProcessor { private final String type; - // This field is used for nested knn_vector/rank_features field. The value of the field will be used as the + // This field is used for nested knn_vector/rank_features field. The value of + // the field will be used as the // default key for the nested object. private final String listTypeNestedMapKey; @@ -52,18 +56,18 @@ public abstract class InferenceProcessor extends AbstractProcessor { private final Environment environment; public InferenceProcessor( - String tag, - String description, - String type, - String listTypeNestedMapKey, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment - ) { + String tag, + String description, + String type, + String listTypeNestedMapKey, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment) { super(tag, description); this.type = type; - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); + if (StringUtils.isBlank(modelId)) + throw new IllegalArgumentException("model_id is null or empty, cannot process it"); validateEmbeddingConfiguration(fieldMap); this.listTypeNestedMapKey = listTypeNestedMapKey; @@ -75,22 +79,23 @@ public InferenceProcessor( private void validateEmbeddingConfiguration(Map fieldMap) { if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() - .stream() - .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) - )) { + || fieldMap.size() == 0 + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.startsWith(x.getKey(), ".") || StringUtils.endsWith(x.getKey(), ".") + || Arrays.stream(x.getKey().split("\\.")).anyMatch(y -> StringUtils.isBlank(y)) + || Objects.isNull(x.getValue()) + || StringUtils.isBlank(x.getValue().toString()))) { throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); } } public abstract void doExecute( - IngestDocument ingestDocument, - Map ProcessMap, - List inferenceList, - BiConsumer handler - ); + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler); @Override public IngestDocument execute(IngestDocument ingestDocument) throws Exception { @@ -98,10 +103,14 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { } /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * This method will be invoked by PipelineService to make async inference and + * then delegate the handler to * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + * + * @param ingestDocument {@link IngestDocument} which is the document passed to + * processor. + * @param handler {@link BiConsumer} which is the handler which can be + * used after the inference task is done. */ @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { @@ -142,7 +151,8 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List } else if (sourceValue instanceof List) { texts.addAll(((List) sourceValue)); } else { - if (sourceValue == null) return; + if (sourceValue == null) + return; texts.add(sourceValue.toString()); } } @@ -154,9 +164,20 @@ Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { String originalKey = fieldMapEntry.getKey(); Object targetKey = fieldMapEntry.getValue(); + + int nestedDotIndex = originalKey.indexOf('.'); + if (nestedDotIndex != -1) { + Map newTargetKey = new LinkedHashMap<>(); + newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey); + targetKey = newTargetKey; + + originalKey = originalKey.substring(0, nestedDotIndex); + } + if (targetKey instanceof Map) { Map treeRes = new LinkedHashMap<>(); - buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, + treeRes); mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); } else { mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); @@ -166,21 +187,20 @@ Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge } private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (processorKey == null || sourceAndMetadataMap == null) return; + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes) { + if (processorKey == null || sourceAndMetadataMap == null) + return; if (processorKey instanceof Map) { Map next = new LinkedHashMap<>(); for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next); } treeRes.put(parentKey, next); } else { @@ -199,9 +219,11 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { validateNestedTypeValue(sourceKey, sourceValue, () -> 1); } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + throw new IllegalArgumentException( + "field [" + sourceKey + "] is neither string nor nested type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + throw new IllegalArgumentException( + "field [" + sourceKey + "] has empty string value, cannot process it"); } } } @@ -211,18 +233,21 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + throw new IllegalArgumentException( + "map type field [" + sourceKey + "] reached max depth limit, cannot process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { validateListTypeValue(sourceKey, sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + throw new IllegalArgumentException( + "map type field [" + sourceKey + "] has non-string type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + throw new IllegalArgumentException( + "map type field [" + sourceKey + "] has empty string, cannot process it"); } } @@ -232,14 +257,17 @@ private void validateListTypeValue(String sourceKey, Object sourceValue) { if (value == null) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + throw new IllegalArgumentException( + "list type field [" + sourceKey + "] has non string value, cannot process it"); } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); + throw new IllegalArgumentException( + "list type field [" + sourceKey + "] has empty string, cannot process it"); } } } - protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { + protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, + List results) { Objects.requireNonNull(results, "embedding failed, inference returns null result!"); log.debug("Model inference result fetched, starting build vector output!"); Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); @@ -248,7 +276,8 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { + Map buildNLPResult(Map processorMap, List results, + Map sourceAndMetadataMap) { IndexWrapper indexWrapper = new IndexWrapper(0); Map result = new LinkedHashMap<>(); for (Map.Entry knnMapEntry : processorMap.entrySet()) { @@ -267,34 +296,36 @@ Map buildNLPResult(Map processorMap, List res @SuppressWarnings({ "unchecked" }) private void putNLPResultToSourceMapForMapType( - String processorKey, - Object sourceValue, - List results, - IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + String processorKey, + Object sourceValue, + List results, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap) { + if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) + return; if (sourceValue instanceof Map) { for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { putNLPResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - results, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) - ); + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey)); } } else if (sourceValue instanceof String) { sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + sourceAndMetadataMap.put(processorKey, + buildNLPResultForListType((List) sourceValue, results, indexWrapper)); } } - private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { + private List> buildNLPResultForListType(List sourceValue, List results, + IndexWrapper indexWrapper) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + .forEachOrdered( + x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); return keyToResult; } @@ -304,10 +335,14 @@ public String getType() { } /** - * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, - * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order - * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the - * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * Since we need to build a {@link List} as the input for text + * embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. + * For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested + * map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the + * recursive, so here the IndexWrapper is to store and increase * the index pointer during the recursive. * index: the index pointer of the text embedding result. */