Skip to content

Commit

Permalink
impl interface
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Jan 10, 2025
1 parent ec6b30d commit a598bd1
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 1 deletion.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ dependencies {
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
implementation group: 'ai.djl', name: 'api', version: '0.28.0'
implementation group: 'ai.djl.huggingface', name: 'tokenizers', version: '0.28.0'
// ml-common excluded reflection for runtime so we need to add it by ourselves.
// https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9
// TODO: Remove following three lines of dependencies if ml-common include them in their jar
Expand Down
94 changes: 94 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.util.Utils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Callable;

public class DJLUtils {
static private Path ML_CACHE_PATH;
static private String ML_CACHE_DIR_NAME = "ml_cache";
static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/";
static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/";

static public void buildDJLCachePath(Path opensearchDataFolder) {
// the logic to build cache path is consistent with ml-commons plugin
// see
// https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53
ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME);
}

public static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException {
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> {
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
try {
System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString());
System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString());
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader());

return action.call();
} finally {
Thread.currentThread().setContextClassLoader(contextClassLoader);
}
});
}

public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) {
try {
return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId));
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e);
}
}

public static Map<String, Float> parseInputStreamToTokenWeights(InputStream inputStream) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
Map<String, Float> tokenWeights = new HashMap<>();
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().isEmpty()) {
continue;
}
String[] parts = line.split("\t");
if (parts.length != 2) {
throw new IllegalArgumentException("Invalid line in token weights file: " + line);
}
String token = parts[0];
float weight = Float.parseFloat(parts[1]);
tokenWeights.put(token, weight);
}
return tokenWeights;
} catch (IOException e) {
throw new RuntimeException("Failed to parse token weights file. " + e);
}
}

public static Map<String, Float> fetchTokenWeights(String tokenizerId, String fileName) {
Map<String, Float> tokenWeights = new HashMap<>();
String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName;

InputStream inputStream = null;
try {
inputStream = withDJLContext(() -> Utils.openUrl(url));
} catch (PrivilegedActionException e) {
throw new RuntimeException("Failed to download file from " + url, e);
}

return parseInputStreamToTokenWeights(inputStream);
}
}
107 changes: 107 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Objects;

import com.google.common.io.CharStreams;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import org.apache.lucene.util.BytesRef;

public class HFTokenizer extends Tokenizer {
public static final String NAME = "hf_tokenizer";
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f;

private final CharTermAttribute termAtt;
private final PayloadAttribute payloadAtt;
private final OffsetAttribute offsetAtt;
private final HuggingFaceTokenizer tokenizer;
private final Map<String, Float> tokenWeights;

private Encoding encoding;
private int tokenIdx = 0;
private int overflowingIdx = 0;

public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) {
this(huggingFaceTokenizer, null);
}

public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, Float> weights) {
termAtt = addAttribute(CharTermAttribute.class);
offsetAtt = addAttribute(OffsetAttribute.class);
if (Objects.nonNull(weights)) {
payloadAtt = addAttribute(PayloadAttribute.class);
} else {
payloadAtt = null;
}
tokenizer = huggingFaceTokenizer;
tokenWeights = weights;
}

@Override
public void reset() throws IOException {
super.reset();
tokenIdx = 0;
overflowingIdx = -1;
String inputStr = CharStreams.toString(input);
encoding = tokenizer.encode(inputStr, false, true);
}

private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) {
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0;
}

private static byte[] floatToBytes(float value) {
return ByteBuffer.allocate(4).putFloat(value).array();
}

private static float bytesToFloat(byte[] bytes) {
return ByteBuffer.wrap(bytes).getFloat();
}

@Override
final public boolean incrementToken() throws IOException {
clearAttributes();
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx];

while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) {
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) {
// reset cur segment, go to the next segment
// until overflowingIdx = encoding.getOverflowing().length
tokenIdx = 0;
overflowingIdx++;
if (overflowingIdx >= encoding.getOverflowing().length) {
return false;
}
curEncoding = encoding.getOverflowing()[overflowingIdx];
} else {
termAtt.append(curEncoding.getTokens()[tokenIdx]);
offsetAtt.setOffset(
curEncoding.getCharTokenSpans()[tokenIdx].getStart(),
curEncoding.getCharTokenSpans()[tokenIdx].getEnd()
);
if (Objects.nonNull(tokenWeights)) {
// for neural sparse query, write the token weight to payload field
payloadAtt.setPayload(
new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT)))
);
}
tokenIdx++;
return true;
}
}

return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.analysis;

import org.apache.lucene.analysis.Tokenizer;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.analysis.AbstractTokenizerFactory;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;

import java.util.Map;

public class HFTokenizerFactory extends AbstractTokenizerFactory {
private final HuggingFaceTokenizer tokenizer;
private final Map<String, Float> tokenWeights;

static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill";
static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt";
static private volatile HuggingFaceTokenizer defaultTokenizer;
static private volatile Map<String, Float> defaultTokenWeights;

static public Tokenizer createDefaultTokenizer() {
// what if throw exception during init?
if (defaultTokenizer == null) {
synchronized (HFTokenizerFactory.class) {
if (defaultTokenizer == null) {
defaultTokenizer = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID);
defaultTokenWeights = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE);
}
}
}
return new HFTokenizer(defaultTokenizer, defaultTokenWeights);
}

public HFTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
// For custom tokenizer, the factory is created during IndexModule.newIndexService
// And can be accessed via indexService.getIndexAnalyzers()
super(indexSettings, settings, name);
String tokenizerId = settings.get("tokenizer_id", DEFAULT_TOKENIZER_ID);
String tokenWeightsFileName = settings.get("token_weights_file", null);
tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId);
if (tokenWeightsFileName != null) {
tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName);
} else {
tokenWeights = null;
}
}

@Override
public Tokenizer create() {
// the create method will be called for every single analyze request
return new HFTokenizer(tokenizer, tokenWeights);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED;
import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.RERANKER_MAX_DOC_FIELDS;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
Expand All @@ -24,8 +25,14 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.index.analysis.PreConfiguredTokenizer;
import org.opensearch.index.analysis.TokenizerFactory;
import org.opensearch.indices.analysis.AnalysisModule;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.analysis.DJLUtils;
import org.opensearch.neuralsearch.analysis.HFTokenizer;
import org.opensearch.neuralsearch.analysis.HFTokenizerFactory;
import org.opensearch.neuralsearch.executors.HybridQueryExecutor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
Expand Down Expand Up @@ -56,6 +63,7 @@
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.AnalysisPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
Expand All @@ -77,7 +85,14 @@
* Neural Search plugin class
*/
@Log4j2
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin {
public class NeuralSearch extends Plugin
implements
ActionPlugin,
SearchPlugin,
IngestPlugin,
ExtensiblePlugin,
SearchPipelinePlugin,
AnalysisPlugin {
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
Expand All @@ -103,6 +118,7 @@ public Collection<Object> createComponents(
NeuralSparseQueryBuilder.initialize(clientAccessor);
HybridQueryExecutor.initialize(threadPool);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
DJLUtils.buildDJLCachePath(environment.dataFiles()[0]);
return List.of(clientAccessor);
}

Expand Down Expand Up @@ -200,4 +216,16 @@ public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
)
);
}

@Override
public Map<String, AnalysisModule.AnalysisProvider<TokenizerFactory>> getTokenizers() {
return Map.of("hf_tokenizer", HFTokenizerFactory::new);
}

@Override
public List<PreConfiguredTokenizer> getPreConfiguredTokenizers() {
List<PreConfiguredTokenizer> tokenizers = new ArrayList<>();
tokenizers.add(PreConfiguredTokenizer.singleton(HFTokenizer.NAME, HFTokenizerFactory::createDefaultTokenizer));
return tokenizers;
}
}
6 changes: 6 additions & 0 deletions src/main/plugin-metadata/plugin-security.policy
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,10 @@ grant {
permission java.lang.RuntimePermission "accessDeclaredMembers";
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
permission java.lang.RuntimePermission "setContextClassLoader";

permission java.net.SocketPermission "*", "connect,resolve";
permission java.lang.RuntimePermission "loadLibrary.*";
permission java.lang.RuntimePermission "setContextClassLoader";
permission java.util.PropertyPermission "DJL_CACHE_DIR", "read,write";
permission java.util.PropertyPermission "java.library.path", "read,write";
};

0 comments on commit a598bd1

Please sign in to comment.