Skip to content

Commit

Permalink
add save_as_onnx and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: faradawn <[email protected]>
  • Loading branch information
faradawn committed Nov 1, 2023
1 parent 9df65e7 commit 9315bdb
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 28 deletions.
3 changes: 2 additions & 1 deletion opensearch_py_ml/ml_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@

from .metrics_correlation.mcorr import MCorr
from .sentencetransformermodel import SentenceTransformerModel
from .question_answering_model import QuestionAnsweringModel

__all__ = ["SentenceTransformerModel", "MCorr"]
__all__ = ["SentenceTransformerModel", "MCorr", "QuestionAnsweringModel"]
151 changes: 124 additions & 27 deletions opensearch_py_ml/ml_models/question_answering_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
from transformers.convert_graph_to_onnx import convert
import transformers


from opensearch_py_ml.ml_commons.ml_common_utils import (
_generate_model_content_hash_value,
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
work directory
:type folder_path: string
:param overwrite: Optional, choose to overwrite the folder at folder path. Default as false. When training
different sentence transformer models, it's recommended to give designated folder path every time.
different question answering models, it's recommended to give designated folder path every time.
Users can choose to overwrite = True to overwrite previous runs
:type overwrite: bool
:return: no return value expected
Expand Down Expand Up @@ -115,7 +116,7 @@ def save_as_pt(
Required, for example sentences = ['today is sunny']
:type sentences: List of string [str]
:param model_id:
sentence transformer model id to download model from sentence transformers.
question answering model id to download model from question answerings.
default model_id = "distilbert-base-cased-distilled-squad"
:type model_id: string
:param model_name:
Expand Down Expand Up @@ -208,7 +209,7 @@ def save_as_pt(
strict=False
)
torch.jit.save(compiled_model, model_path)
print("Traced model is saved to ", model_path)
print("Traced torchscript model is saved to ", model_path)

# zip model file along with tokenizer.json (and license file) as output
with ZipFile(str(zip_file_path), "w") as zipObj:
Expand All @@ -227,6 +228,124 @@ def save_as_pt(
print("zip file is saved to ", zip_file_path, "\n")
return zip_file_path

def save_as_onnx(
self,
model_id="distilbert-base-cased-distilled-squad",
model_name: str = None,
save_json_folder_path: str = None,
model_output_path: str = None,
zip_file_name: str = None,
add_apache_license: bool = False,
) -> str:
"""
Download question answering model directly from huggingface, convert model to onnx format,
zip the model file and its tokenizer.json file to prepare to upload to the Open Search cluster
:param model_id:
question answering model id to download model from question answerings.
default model_id = "distilbert-base-cased-distilled-squad"
:type model_id: string
:param model_name:
Optional, model name to name the model file, e.g, "sample_model.pt". If None, default takes the
model_id and add the extension with ".pt"
:type model_name: string
:param save_json_folder_path:
Optional, path to save model json file, e.g, "home/save_pre_trained_model_json/"). If None, default as
default_folder_path from the constructor
:type save_json_folder_path: string
:param model_output_path:
Optional, path to save traced model zip file. If None, default as
default_folder_path from the constructor
:type model_output_path: string
:param zip_file_name:
Optional, file name for zip file. e.g, "sample_model.zip". If None, default takes the model_id
and add the extension with ".zip"
:type zip_file_name: string
:param add_apache_license:
Optional, whether to add a Apache-2.0 license file to model zip file
:type add_apache_license: string
:return: model zip file path. The file path where the zip file is being saved
:rtype: string
"""

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased-distilled-squad')
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad')

if model_name is None:
model_name = str(model_id.split("/")[-1] + ".onnx")

model_path = os.path.join(self.folder_path, model_name)

if save_json_folder_path is None:
save_json_folder_path = self.folder_path

if model_output_path is None:
model_output_path = self.folder_path

if zip_file_name is None:
zip_file_name = str(model_id.split("/")[-1] + ".zip")
zip_file_path = os.path.join(model_output_path, zip_file_name)

# save tokenizer.json in save_json_folder_name
tokenizer.save_pretrained(save_json_folder_path)

# Find the tokenizer.json file path in cache: /Users/faradawn/.cache/huggingface/hub/models/...
config_json_path = os.path.join(save_json_folder_path, "tokenizer_config.json")
with open(config_json_path) as f:
config_json = json.load(f)
tokenizer_file_path = config_json["tokenizer_file"]

# Open the tokenizer.json and replace the truncation field
with open(tokenizer_file_path) as user_file:
parsed_json = json.load(user_file)

if "truncation" not in parsed_json or parsed_json["truncation"] is None:
parsed_json["truncation"] = {
"direction": "Right",
"max_length": tokenizer.model_max_length,
"strategy": "LongestFirst",
"stride": 0,
}

# Save tokenizer
tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json")
with open(tokenizer_file_path, "w") as file:
json.dump(parsed_json, file, indent=2)

# load config
model_kind, model_onnx_config = transformers.onnx.FeaturesManager.check_supported_model_or_raise(model, feature="question-answering")
onnx_config = model_onnx_config(model.config)

# export
onnx_inputs, onnx_outputs = transformers.onnx.export(
preprocessor=tokenizer,
model=model,
config=onnx_config,
opset=13,
output=Path(model_path)
)

print("Traced onnx model is saved to ", model_path)

# zip model file along with tokenizer.json (and license file) as output
with ZipFile(str(zip_file_path), "w") as zipObj:
zipObj.write(
model_path,
arcname=str(model_name),
)
zipObj.write(
os.path.join(save_json_folder_path, "tokenizer.json"),
arcname="tokenizer.json",
)
if add_apache_license:
self._add_apache_license_to_model_zip_file(zip_file_path)

self.onnx_zip_file_path = zip_file_path
print("zip file is saved to ", zip_file_path, "\n")
return zip_file_path



def make_model_config_json(
self,
model_name: str = None,
Expand All @@ -247,7 +366,7 @@ def make_model_config_json(
:param model_name:
Optional, The name of the model. If None, default is model id, for example,
'sentence-transformers/msmarco-distilbert-base-tas-b'
'distilbert-base-cased-distilled-squad'
:type model_name: string
:param model_format:
Optional, the format of the model. Default is "TORCH_SCRIPT".
Expand Down Expand Up @@ -426,26 +545,4 @@ def make_model_config_json(
)

return model_config_file_path

def test_traced_model(self, model_path, question, context):
"""
Load a model from TorchScript and run inference on the question and text.
"""
traced_model = torch.jit.load(model_path)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased-distilled-squad')
inputs = tokenizer(question, context, return_tensors="pt")

# Start inference
with torch.no_grad():
outputs = traced_model(**inputs)

# Get the most likely start and end positions
answer_start_index = torch.argmax(outputs["start_logits"], dim=-1).item()
answer_end_index = torch.argmax(outputs["end_logits"], dim=-1).item()

# Extract the answer tokens and convert back to text
predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1]
answer = tokenizer.decode(predict_answer_tokens)

return answer
73 changes: 73 additions & 0 deletions tests/ml_models/test_question_answering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

from opensearch_py_ml.ml_models import QuestionAnsweringModel

# Save our model as pt or onnx
model_id = "distilbert-base-cased-distilled-squad"
folder_path = "question-model-folder"
our_pre_trained_model = QuestionAnsweringModel(model_id=model_id, folder_path=folder_path, overwrite=True)
# zip_file_path = our_pre_trained_model.save_as_pt(model_id=model_id, sentences=["for example providing a small sentence", "we can add multiple sentences"])
zip_file_path = our_pre_trained_model.save_as_onnx(model_id=model_id)

# List of questions to test
questions = ["Who was Jim Henson?", "Where do I live?", "What's my name?"]
contexts = ["Jim Henson was a nice puppet", "My name is Sarah and I live in London", "My name is Clara and I live in Berkeley."]

# Obtain pytorch's official model
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased-distilled-squad')
official_model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad')

def official_model_answer(question, context):
inputs = tokenizer(question, context, return_tensors="pt")
with torch.no_grad():
outputs = official_model(**inputs)
answer_start_index = torch.argmax(outputs.start_logits, dim=-1).item()
answer_end_index = torch.argmax(outputs.end_logits, dim=-1).item()
predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1]
official_answer = tokenizer.decode(predict_answer_tokens)
return official_answer

def test_onnx():
from transformers import AutoTokenizer
from onnxruntime import InferenceSession
import numpy as np
session = InferenceSession(f"{folder_path}/{model_id}.onnx")

for i in range(len(questions)):
question = questions[i]
context = contexts[i]
inputs = tokenizer(question, context, return_tensors="pt")
print(f"=== test {i}, question: {question}, context: {context}")

inputs = tokenizer(question, context, return_tensors="np")
outputs = session.run(output_names=["start_logits", "end_logits"], input_feed=dict(inputs))

answer_start_index = np.argmax(outputs[0], axis=-1).item()
answer_end_index = np.argmax(outputs[1], axis=-1).item()
predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1]
answer = tokenizer.decode(predict_answer_tokens)

print(f" Official answer: {official_model_answer(question, context)}")
print(f" Our answer: {answer}")

def test_pt():
traced_model = torch.jit.load(f"{folder_path}/{model_id}.pt")

for i in range(len(questions)):
question = questions[i]
context = contexts[i]
inputs = tokenizer(question, context, return_tensors="pt")
print(f"=== test {i}, question: {question}, context: {context}")

with torch.no_grad():
outputs = traced_model(**inputs)
answer_start_index = torch.argmax(outputs["start_logits"], dim=-1).item()
answer_end_index = torch.argmax(outputs["end_logits"], dim=-1).item()
predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1]
answer = tokenizer.decode(predict_answer_tokens)

print(f" Official answer: {official_model_answer(question, context)}")
print(f" Our answer: {answer}")

test_onnx()

0 comments on commit 9315bdb

Please sign in to comment.