From 9315bdbf09f62730e36ac91ced5023865a6a7641 Mon Sep 17 00:00:00 2001 From: faradawn Date: Tue, 31 Oct 2023 23:17:00 -0500 Subject: [PATCH] add save_as_onnx and add tests Signed-off-by: faradawn --- opensearch_py_ml/ml_models/__init__.py | 3 +- .../ml_models/question_answering_model.py | 151 ++++++++++++++---- tests/ml_models/test_question_answering.py | 73 +++++++++ 3 files changed, 199 insertions(+), 28 deletions(-) create mode 100644 tests/ml_models/test_question_answering.py diff --git a/opensearch_py_ml/ml_models/__init__.py b/opensearch_py_ml/ml_models/__init__.py index 3ec96ebd5..eacea92f0 100644 --- a/opensearch_py_ml/ml_models/__init__.py +++ b/opensearch_py_ml/ml_models/__init__.py @@ -7,5 +7,6 @@ from .metrics_correlation.mcorr import MCorr from .sentencetransformermodel import SentenceTransformerModel +from .question_answering_model import QuestionAnsweringModel -__all__ = ["SentenceTransformerModel", "MCorr"] +__all__ = ["SentenceTransformerModel", "MCorr", "QuestionAnsweringModel"] diff --git a/opensearch_py_ml/ml_models/question_answering_model.py b/opensearch_py_ml/ml_models/question_answering_model.py index f227916cf..ab472a43f 100644 --- a/opensearch_py_ml/ml_models/question_answering_model.py +++ b/opensearch_py_ml/ml_models/question_answering_model.py @@ -31,7 +31,8 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering -from transformers.convert_graph_to_onnx import convert +import transformers + from opensearch_py_ml.ml_commons.ml_common_utils import ( _generate_model_content_hash_value, @@ -67,7 +68,7 @@ def __init__( work directory :type folder_path: string :param overwrite: Optional, choose to overwrite the folder at folder path. Default as false. When training - different sentence transformer models, it's recommended to give designated folder path every time. + different question answering models, it's recommended to give designated folder path every time. Users can choose to overwrite = True to overwrite previous runs :type overwrite: bool :return: no return value expected @@ -115,7 +116,7 @@ def save_as_pt( Required, for example sentences = ['today is sunny'] :type sentences: List of string [str] :param model_id: - sentence transformer model id to download model from sentence transformers. + question answering model id to download model from question answerings. default model_id = "distilbert-base-cased-distilled-squad" :type model_id: string :param model_name: @@ -208,7 +209,7 @@ def save_as_pt( strict=False ) torch.jit.save(compiled_model, model_path) - print("Traced model is saved to ", model_path) + print("Traced torchscript model is saved to ", model_path) # zip model file along with tokenizer.json (and license file) as output with ZipFile(str(zip_file_path), "w") as zipObj: @@ -227,6 +228,124 @@ def save_as_pt( print("zip file is saved to ", zip_file_path, "\n") return zip_file_path + def save_as_onnx( + self, + model_id="distilbert-base-cased-distilled-squad", + model_name: str = None, + save_json_folder_path: str = None, + model_output_path: str = None, + zip_file_name: str = None, + add_apache_license: bool = False, + ) -> str: + """ + Download question answering model directly from huggingface, convert model to onnx format, + zip the model file and its tokenizer.json file to prepare to upload to the Open Search cluster + + :param model_id: + question answering model id to download model from question answerings. + default model_id = "distilbert-base-cased-distilled-squad" + :type model_id: string + :param model_name: + Optional, model name to name the model file, e.g, "sample_model.pt". If None, default takes the + model_id and add the extension with ".pt" + :type model_name: string + :param save_json_folder_path: + Optional, path to save model json file, e.g, "home/save_pre_trained_model_json/"). If None, default as + default_folder_path from the constructor + :type save_json_folder_path: string + :param model_output_path: + Optional, path to save traced model zip file. If None, default as + default_folder_path from the constructor + :type model_output_path: string + :param zip_file_name: + Optional, file name for zip file. e.g, "sample_model.zip". If None, default takes the model_id + and add the extension with ".zip" + :type zip_file_name: string + :param add_apache_license: + Optional, whether to add a Apache-2.0 license file to model zip file + :type add_apache_license: string + :return: model zip file path. The file path where the zip file is being saved + :rtype: string + """ + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased-distilled-squad') + model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad') + + if model_name is None: + model_name = str(model_id.split("/")[-1] + ".onnx") + + model_path = os.path.join(self.folder_path, model_name) + + if save_json_folder_path is None: + save_json_folder_path = self.folder_path + + if model_output_path is None: + model_output_path = self.folder_path + + if zip_file_name is None: + zip_file_name = str(model_id.split("/")[-1] + ".zip") + zip_file_path = os.path.join(model_output_path, zip_file_name) + + # save tokenizer.json in save_json_folder_name + tokenizer.save_pretrained(save_json_folder_path) + + # Find the tokenizer.json file path in cache: /Users/faradawn/.cache/huggingface/hub/models/... + config_json_path = os.path.join(save_json_folder_path, "tokenizer_config.json") + with open(config_json_path) as f: + config_json = json.load(f) + tokenizer_file_path = config_json["tokenizer_file"] + + # Open the tokenizer.json and replace the truncation field + with open(tokenizer_file_path) as user_file: + parsed_json = json.load(user_file) + + if "truncation" not in parsed_json or parsed_json["truncation"] is None: + parsed_json["truncation"] = { + "direction": "Right", + "max_length": tokenizer.model_max_length, + "strategy": "LongestFirst", + "stride": 0, + } + + # Save tokenizer + tokenizer_file_path = os.path.join(save_json_folder_path, "tokenizer.json") + with open(tokenizer_file_path, "w") as file: + json.dump(parsed_json, file, indent=2) + + # load config + model_kind, model_onnx_config = transformers.onnx.FeaturesManager.check_supported_model_or_raise(model, feature="question-answering") + onnx_config = model_onnx_config(model.config) + + # export + onnx_inputs, onnx_outputs = transformers.onnx.export( + preprocessor=tokenizer, + model=model, + config=onnx_config, + opset=13, + output=Path(model_path) + ) + + print("Traced onnx model is saved to ", model_path) + + # zip model file along with tokenizer.json (and license file) as output + with ZipFile(str(zip_file_path), "w") as zipObj: + zipObj.write( + model_path, + arcname=str(model_name), + ) + zipObj.write( + os.path.join(save_json_folder_path, "tokenizer.json"), + arcname="tokenizer.json", + ) + if add_apache_license: + self._add_apache_license_to_model_zip_file(zip_file_path) + + self.onnx_zip_file_path = zip_file_path + print("zip file is saved to ", zip_file_path, "\n") + return zip_file_path + + + def make_model_config_json( self, model_name: str = None, @@ -247,7 +366,7 @@ def make_model_config_json( :param model_name: Optional, The name of the model. If None, default is model id, for example, - 'sentence-transformers/msmarco-distilbert-base-tas-b' + 'distilbert-base-cased-distilled-squad' :type model_name: string :param model_format: Optional, the format of the model. Default is "TORCH_SCRIPT". @@ -426,26 +545,4 @@ def make_model_config_json( ) return model_config_file_path - - def test_traced_model(self, model_path, question, context): - """ - Load a model from TorchScript and run inference on the question and text. - - """ - traced_model = torch.jit.load(model_path) - tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased-distilled-squad') - inputs = tokenizer(question, context, return_tensors="pt") - - # Start inference - with torch.no_grad(): - outputs = traced_model(**inputs) - - # Get the most likely start and end positions - answer_start_index = torch.argmax(outputs["start_logits"], dim=-1).item() - answer_end_index = torch.argmax(outputs["end_logits"], dim=-1).item() - - # Extract the answer tokens and convert back to text - predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1] - answer = tokenizer.decode(predict_answer_tokens) - return answer diff --git a/tests/ml_models/test_question_answering.py b/tests/ml_models/test_question_answering.py new file mode 100644 index 000000000..3399b292e --- /dev/null +++ b/tests/ml_models/test_question_answering.py @@ -0,0 +1,73 @@ + +from opensearch_py_ml.ml_models import QuestionAnsweringModel + +# Save our model as pt or onnx +model_id = "distilbert-base-cased-distilled-squad" +folder_path = "question-model-folder" +our_pre_trained_model = QuestionAnsweringModel(model_id=model_id, folder_path=folder_path, overwrite=True) +# zip_file_path = our_pre_trained_model.save_as_pt(model_id=model_id, sentences=["for example providing a small sentence", "we can add multiple sentences"]) +zip_file_path = our_pre_trained_model.save_as_onnx(model_id=model_id) + +# List of questions to test +questions = ["Who was Jim Henson?", "Where do I live?", "What's my name?"] +contexts = ["Jim Henson was a nice puppet", "My name is Sarah and I live in London", "My name is Clara and I live in Berkeley."] + +# Obtain pytorch's official model +from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering +import torch +tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased-distilled-squad') +official_model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad') + +def official_model_answer(question, context): + inputs = tokenizer(question, context, return_tensors="pt") + with torch.no_grad(): + outputs = official_model(**inputs) + answer_start_index = torch.argmax(outputs.start_logits, dim=-1).item() + answer_end_index = torch.argmax(outputs.end_logits, dim=-1).item() + predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1] + official_answer = tokenizer.decode(predict_answer_tokens) + return official_answer + +def test_onnx(): + from transformers import AutoTokenizer + from onnxruntime import InferenceSession + import numpy as np + session = InferenceSession(f"{folder_path}/{model_id}.onnx") + + for i in range(len(questions)): + question = questions[i] + context = contexts[i] + inputs = tokenizer(question, context, return_tensors="pt") + print(f"=== test {i}, question: {question}, context: {context}") + + inputs = tokenizer(question, context, return_tensors="np") + outputs = session.run(output_names=["start_logits", "end_logits"], input_feed=dict(inputs)) + + answer_start_index = np.argmax(outputs[0], axis=-1).item() + answer_end_index = np.argmax(outputs[1], axis=-1).item() + predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1] + answer = tokenizer.decode(predict_answer_tokens) + + print(f" Official answer: {official_model_answer(question, context)}") + print(f" Our answer: {answer}") + +def test_pt(): + traced_model = torch.jit.load(f"{folder_path}/{model_id}.pt") + + for i in range(len(questions)): + question = questions[i] + context = contexts[i] + inputs = tokenizer(question, context, return_tensors="pt") + print(f"=== test {i}, question: {question}, context: {context}") + + with torch.no_grad(): + outputs = traced_model(**inputs) + answer_start_index = torch.argmax(outputs["start_logits"], dim=-1).item() + answer_end_index = torch.argmax(outputs["end_logits"], dim=-1).item() + predict_answer_tokens = inputs['input_ids'][0, answer_start_index : answer_end_index + 1] + answer = tokenizer.decode(predict_answer_tokens) + + print(f" Official answer: {official_model_answer(question, context)}") + print(f" Our answer: {answer}") + +test_onnx() \ No newline at end of file