From a598bd11fd21544461ec4915d30c604c45e2677a Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 10 Jan 2025 15:49:47 +0800 Subject: [PATCH] impl interface Signed-off-by: zhichao-aws --- build.gradle | 2 + .../neuralsearch/analysis/DJLUtils.java | 94 +++++++++++++++ .../neuralsearch/analysis/HFTokenizer.java | 107 ++++++++++++++++++ .../analysis/HFTokenizerFactory.java | 58 ++++++++++ .../neuralsearch/plugin/NeuralSearch.java | 30 ++++- .../plugin-metadata/plugin-security.policy | 6 + 6 files changed, 296 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizer.java create mode 100644 src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizerFactory.java diff --git a/build.gradle b/build.gradle index 90b1ca1f1..f94462590 100644 --- a/build.gradle +++ b/build.gradle @@ -259,6 +259,8 @@ dependencies { api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' + implementation group: 'ai.djl', name: 'api', version: '0.28.0' + implementation group: 'ai.djl.huggingface', name: 'tokenizers', version: '0.28.0' // ml-common excluded reflection for runtime so we need to add it by ourselves. // https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9 // TODO: Remove following three lines of dependencies if ml-common include them in their jar diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java b/src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java new file mode 100644 index 000000000..2a9f6d3dc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.util.Utils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Callable; + +public class DJLUtils { + static private Path ML_CACHE_PATH; + static private String ML_CACHE_DIR_NAME = "ml_cache"; + static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/"; + static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/"; + + static public void buildDJLCachePath(Path opensearchDataFolder) { + // the logic to build cache path is consistent with ml-commons plugin + // see + // https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53 + ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME); + } + + public static T withDJLContext(Callable action) throws PrivilegedActionException { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + try { + System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString()); + System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString()); + Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader()); + + return action.call(); + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader); + } + }); + } + + public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) { + try { + return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId)); + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e); + } + } + + public static Map parseInputStreamToTokenWeights(InputStream inputStream) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { + Map tokenWeights = new HashMap<>(); + String line; + while ((line = reader.readLine()) != null) { + if (line.trim().isEmpty()) { + continue; + } + String[] parts = line.split("\t"); + if (parts.length != 2) { + throw new IllegalArgumentException("Invalid line in token weights file: " + line); + } + String token = parts[0]; + float weight = Float.parseFloat(parts[1]); + tokenWeights.put(token, weight); + } + return tokenWeights; + } catch (IOException e) { + throw new RuntimeException("Failed to parse token weights file. " + e); + } + } + + public static Map fetchTokenWeights(String tokenizerId, String fileName) { + Map tokenWeights = new HashMap<>(); + String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName; + + InputStream inputStream = null; + try { + inputStream = withDJLContext(() -> Utils.openUrl(url)); + } catch (PrivilegedActionException e) { + throw new RuntimeException("Failed to download file from " + url, e); + } + + return parseInputStreamToTokenWeights(inputStream); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizer.java b/src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizer.java new file mode 100644 index 000000000..9a0871869 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizer.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Objects; + +import com.google.common.io.CharStreams; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import org.apache.lucene.util.BytesRef; + +public class HFTokenizer extends Tokenizer { + public static final String NAME = "hf_tokenizer"; + private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f; + + private final CharTermAttribute termAtt; + private final PayloadAttribute payloadAtt; + private final OffsetAttribute offsetAtt; + private final HuggingFaceTokenizer tokenizer; + private final Map tokenWeights; + + private Encoding encoding; + private int tokenIdx = 0; + private int overflowingIdx = 0; + + public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) { + this(huggingFaceTokenizer, null); + } + + public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map weights) { + termAtt = addAttribute(CharTermAttribute.class); + offsetAtt = addAttribute(OffsetAttribute.class); + if (Objects.nonNull(weights)) { + payloadAtt = addAttribute(PayloadAttribute.class); + } else { + payloadAtt = null; + } + tokenizer = huggingFaceTokenizer; + tokenWeights = weights; + } + + @Override + public void reset() throws IOException { + super.reset(); + tokenIdx = 0; + overflowingIdx = -1; + String inputStr = CharStreams.toString(input); + encoding = tokenizer.encode(inputStr, false, true); + } + + private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) { + return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0; + } + + private static byte[] floatToBytes(float value) { + return ByteBuffer.allocate(4).putFloat(value).array(); + } + + private static float bytesToFloat(byte[] bytes) { + return ByteBuffer.wrap(bytes).getFloat(); + } + + @Override + final public boolean incrementToken() throws IOException { + clearAttributes(); + Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx]; + + while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) { + if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) { + // reset cur segment, go to the next segment + // until overflowingIdx = encoding.getOverflowing().length + tokenIdx = 0; + overflowingIdx++; + if (overflowingIdx >= encoding.getOverflowing().length) { + return false; + } + curEncoding = encoding.getOverflowing()[overflowingIdx]; + } else { + termAtt.append(curEncoding.getTokens()[tokenIdx]); + offsetAtt.setOffset( + curEncoding.getCharTokenSpans()[tokenIdx].getStart(), + curEncoding.getCharTokenSpans()[tokenIdx].getEnd() + ); + if (Objects.nonNull(tokenWeights)) { + // for neural sparse query, write the token weight to payload field + payloadAtt.setPayload( + new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT))) + ); + } + tokenIdx++; + return true; + } + } + + return false; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizerFactory.java b/src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizerFactory.java new file mode 100644 index 000000000..da96b097f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizerFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.analysis; + +import org.apache.lucene.analysis.Tokenizer; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.analysis.AbstractTokenizerFactory; + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; + +import java.util.Map; + +public class HFTokenizerFactory extends AbstractTokenizerFactory { + private final HuggingFaceTokenizer tokenizer; + private final Map tokenWeights; + + static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"; + static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt"; + static private volatile HuggingFaceTokenizer defaultTokenizer; + static private volatile Map defaultTokenWeights; + + static public Tokenizer createDefaultTokenizer() { + // what if throw exception during init? + if (defaultTokenizer == null) { + synchronized (HFTokenizerFactory.class) { + if (defaultTokenizer == null) { + defaultTokenizer = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID); + defaultTokenWeights = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE); + } + } + } + return new HFTokenizer(defaultTokenizer, defaultTokenWeights); + } + + public HFTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) { + // For custom tokenizer, the factory is created during IndexModule.newIndexService + // And can be accessed via indexService.getIndexAnalyzers() + super(indexSettings, settings, name); + String tokenizerId = settings.get("tokenizer_id", DEFAULT_TOKENIZER_ID); + String tokenWeightsFileName = settings.get("token_weights_file", null); + tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId); + if (tokenWeightsFileName != null) { + tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName); + } else { + tokenWeights = null; + } + } + + @Override + public Tokenizer create() { + // the create method will be called for every single analyze request + return new HFTokenizer(tokenizer, tokenWeights); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 1350a7963..8e57a83b0 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -7,6 +7,7 @@ import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -24,8 +25,14 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.analysis.PreConfiguredTokenizer; +import org.opensearch.index.analysis.TokenizerFactory; +import org.opensearch.indices.analysis.AnalysisModule; import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.neuralsearch.analysis.DJLUtils; +import org.opensearch.neuralsearch.analysis.HFTokenizer; +import org.opensearch.neuralsearch.analysis.HFTokenizerFactory; import org.opensearch.neuralsearch.executors.HybridQueryExecutor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; @@ -56,6 +63,7 @@ import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.AnalysisPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; @@ -77,7 +85,14 @@ * Neural Search plugin class */ @Log4j2 -public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { +public class NeuralSearch extends Plugin + implements + ActionPlugin, + SearchPlugin, + IngestPlugin, + ExtensiblePlugin, + SearchPipelinePlugin, + AnalysisPlugin { private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); @@ -103,6 +118,7 @@ public Collection createComponents( NeuralSparseQueryBuilder.initialize(clientAccessor); HybridQueryExecutor.initialize(threadPool); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); + DJLUtils.buildDJLCachePath(environment.dataFiles()[0]); return List.of(clientAccessor); } @@ -200,4 +216,16 @@ public List> getSearchExts() { ) ); } + + @Override + public Map> getTokenizers() { + return Map.of("hf_tokenizer", HFTokenizerFactory::new); + } + + @Override + public List getPreConfiguredTokenizers() { + List tokenizers = new ArrayList<>(); + tokenizers.add(PreConfiguredTokenizer.singleton(HFTokenizer.NAME, HFTokenizerFactory::createDefaultTokenizer)); + return tokenizers; + } } diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy index db2413e86..dc04de6ab 100644 --- a/src/main/plugin-metadata/plugin-security.policy +++ b/src/main/plugin-metadata/plugin-security.policy @@ -4,4 +4,10 @@ grant { permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; permission java.lang.RuntimePermission "setContextClassLoader"; + + permission java.net.SocketPermission "*", "connect,resolve"; + permission java.lang.RuntimePermission "loadLibrary.*"; + permission java.lang.RuntimePermission "setContextClassLoader"; + permission java.util.PropertyPermission "DJL_CACHE_DIR", "read,write"; + permission java.util.PropertyPermission "java.library.path", "read,write"; };