-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: zhichao-aws <[email protected]>
- Loading branch information
1 parent
ec6b30d
commit a598bd1
Showing
6 changed files
with
296 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
src/main/java/org/opensearch/neuralsearch/analysis/DJLUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
import ai.djl.util.Utils; | ||
|
||
import java.io.BufferedReader; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.InputStreamReader; | ||
import java.nio.charset.StandardCharsets; | ||
import java.nio.file.Path; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedActionException; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.concurrent.Callable; | ||
|
||
public class DJLUtils { | ||
static private Path ML_CACHE_PATH; | ||
static private String ML_CACHE_DIR_NAME = "ml_cache"; | ||
static private String HUGGING_FACE_BASE_URL = "https://huggingface.co/"; | ||
static private String HUGGING_FACE_RESOLVE_PATH = "resolve/main/"; | ||
|
||
static public void buildDJLCachePath(Path opensearchDataFolder) { | ||
// the logic to build cache path is consistent with ml-commons plugin | ||
// see | ||
// https://github.com/opensearch-project/ml-commons/blob/14b971214c488aa3f4ab150d1a6cc379df1758be/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java#L53 | ||
ML_CACHE_PATH = opensearchDataFolder.resolve(ML_CACHE_DIR_NAME); | ||
} | ||
|
||
public static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException { | ||
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> { | ||
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); | ||
try { | ||
System.setProperty("java.library.path", ML_CACHE_PATH.toAbsolutePath().toString()); | ||
System.setProperty("DJL_CACHE_DIR", ML_CACHE_PATH.toAbsolutePath().toString()); | ||
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader()); | ||
|
||
return action.call(); | ||
} finally { | ||
Thread.currentThread().setContextClassLoader(contextClassLoader); | ||
} | ||
}); | ||
} | ||
|
||
public static HuggingFaceTokenizer buildHuggingFaceTokenizer(String tokenizerId) { | ||
try { | ||
return withDJLContext(() -> HuggingFaceTokenizer.newInstance(tokenizerId)); | ||
} catch (PrivilegedActionException e) { | ||
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e); | ||
} | ||
} | ||
|
||
public static Map<String, Float> parseInputStreamToTokenWeights(InputStream inputStream) { | ||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { | ||
Map<String, Float> tokenWeights = new HashMap<>(); | ||
String line; | ||
while ((line = reader.readLine()) != null) { | ||
if (line.trim().isEmpty()) { | ||
continue; | ||
} | ||
String[] parts = line.split("\t"); | ||
if (parts.length != 2) { | ||
throw new IllegalArgumentException("Invalid line in token weights file: " + line); | ||
} | ||
String token = parts[0]; | ||
float weight = Float.parseFloat(parts[1]); | ||
tokenWeights.put(token, weight); | ||
} | ||
return tokenWeights; | ||
} catch (IOException e) { | ||
throw new RuntimeException("Failed to parse token weights file. " + e); | ||
} | ||
} | ||
|
||
public static Map<String, Float> fetchTokenWeights(String tokenizerId, String fileName) { | ||
Map<String, Float> tokenWeights = new HashMap<>(); | ||
String url = HUGGING_FACE_BASE_URL + tokenizerId + "/" + HUGGING_FACE_RESOLVE_PATH + fileName; | ||
|
||
InputStream inputStream = null; | ||
try { | ||
inputStream = withDJLContext(() -> Utils.openUrl(url)); | ||
} catch (PrivilegedActionException e) { | ||
throw new RuntimeException("Failed to download file from " + url, e); | ||
} | ||
|
||
return parseInputStreamToTokenWeights(inputStream); | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import java.io.IOException; | ||
import java.nio.ByteBuffer; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
import com.google.common.io.CharStreams; | ||
import org.apache.lucene.analysis.Tokenizer; | ||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; | ||
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; | ||
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; | ||
|
||
import ai.djl.huggingface.tokenizers.Encoding; | ||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
import org.apache.lucene.util.BytesRef; | ||
|
||
public class HFTokenizer extends Tokenizer { | ||
public static final String NAME = "hf_tokenizer"; | ||
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f; | ||
|
||
private final CharTermAttribute termAtt; | ||
private final PayloadAttribute payloadAtt; | ||
private final OffsetAttribute offsetAtt; | ||
private final HuggingFaceTokenizer tokenizer; | ||
private final Map<String, Float> tokenWeights; | ||
|
||
private Encoding encoding; | ||
private int tokenIdx = 0; | ||
private int overflowingIdx = 0; | ||
|
||
public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer) { | ||
this(huggingFaceTokenizer, null); | ||
} | ||
|
||
public HFTokenizer(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, Float> weights) { | ||
termAtt = addAttribute(CharTermAttribute.class); | ||
offsetAtt = addAttribute(OffsetAttribute.class); | ||
if (Objects.nonNull(weights)) { | ||
payloadAtt = addAttribute(PayloadAttribute.class); | ||
} else { | ||
payloadAtt = null; | ||
} | ||
tokenizer = huggingFaceTokenizer; | ||
tokenWeights = weights; | ||
} | ||
|
||
@Override | ||
public void reset() throws IOException { | ||
super.reset(); | ||
tokenIdx = 0; | ||
overflowingIdx = -1; | ||
String inputStr = CharStreams.toString(input); | ||
encoding = tokenizer.encode(inputStr, false, true); | ||
} | ||
|
||
private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) { | ||
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0; | ||
} | ||
|
||
private static byte[] floatToBytes(float value) { | ||
return ByteBuffer.allocate(4).putFloat(value).array(); | ||
} | ||
|
||
private static float bytesToFloat(byte[] bytes) { | ||
return ByteBuffer.wrap(bytes).getFloat(); | ||
} | ||
|
||
@Override | ||
final public boolean incrementToken() throws IOException { | ||
clearAttributes(); | ||
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx]; | ||
|
||
while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) { | ||
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) { | ||
// reset cur segment, go to the next segment | ||
// until overflowingIdx = encoding.getOverflowing().length | ||
tokenIdx = 0; | ||
overflowingIdx++; | ||
if (overflowingIdx >= encoding.getOverflowing().length) { | ||
return false; | ||
} | ||
curEncoding = encoding.getOverflowing()[overflowingIdx]; | ||
} else { | ||
termAtt.append(curEncoding.getTokens()[tokenIdx]); | ||
offsetAtt.setOffset( | ||
curEncoding.getCharTokenSpans()[tokenIdx].getStart(), | ||
curEncoding.getCharTokenSpans()[tokenIdx].getEnd() | ||
); | ||
if (Objects.nonNull(tokenWeights)) { | ||
// for neural sparse query, write the token weight to payload field | ||
payloadAtt.setPayload( | ||
new BytesRef(floatToBytes(tokenWeights.getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT))) | ||
); | ||
} | ||
tokenIdx++; | ||
return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
} |
58 changes: 58 additions & 0 deletions
58
src/main/java/org/opensearch/neuralsearch/analysis/HFTokenizerFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.analysis; | ||
|
||
import org.apache.lucene.analysis.Tokenizer; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.env.Environment; | ||
import org.opensearch.index.IndexSettings; | ||
import org.opensearch.index.analysis.AbstractTokenizerFactory; | ||
|
||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
|
||
import java.util.Map; | ||
|
||
public class HFTokenizerFactory extends AbstractTokenizerFactory { | ||
private final HuggingFaceTokenizer tokenizer; | ||
private final Map<String, Float> tokenWeights; | ||
|
||
static private final String DEFAULT_TOKENIZER_ID = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"; | ||
static private final String DEFAULT_TOKEN_WEIGHTS_FILE = "query_token_weights.txt"; | ||
static private volatile HuggingFaceTokenizer defaultTokenizer; | ||
static private volatile Map<String, Float> defaultTokenWeights; | ||
|
||
static public Tokenizer createDefaultTokenizer() { | ||
// what if throw exception during init? | ||
if (defaultTokenizer == null) { | ||
synchronized (HFTokenizerFactory.class) { | ||
if (defaultTokenizer == null) { | ||
defaultTokenizer = DJLUtils.buildHuggingFaceTokenizer(DEFAULT_TOKENIZER_ID); | ||
defaultTokenWeights = DJLUtils.fetchTokenWeights(DEFAULT_TOKENIZER_ID, DEFAULT_TOKEN_WEIGHTS_FILE); | ||
} | ||
} | ||
} | ||
return new HFTokenizer(defaultTokenizer, defaultTokenWeights); | ||
} | ||
|
||
public HFTokenizerFactory(IndexSettings indexSettings, Environment environment, String name, Settings settings) { | ||
// For custom tokenizer, the factory is created during IndexModule.newIndexService | ||
// And can be accessed via indexService.getIndexAnalyzers() | ||
super(indexSettings, settings, name); | ||
String tokenizerId = settings.get("tokenizer_id", DEFAULT_TOKENIZER_ID); | ||
String tokenWeightsFileName = settings.get("token_weights_file", null); | ||
tokenizer = DJLUtils.buildHuggingFaceTokenizer(tokenizerId); | ||
if (tokenWeightsFileName != null) { | ||
tokenWeights = DJLUtils.fetchTokenWeights(tokenizerId, tokenWeightsFileName); | ||
} else { | ||
tokenWeights = null; | ||
} | ||
} | ||
|
||
@Override | ||
public Tokenizer create() { | ||
// the create method will be called for every single analyze request | ||
return new HFTokenizer(tokenizer, tokenWeights); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters