diff --git a/docs/tutorials/rerank/rerank_pipeline_with_CrossEncoder_model_deployed_on_Sagemaker.md b/docs/tutorials/rerank/rerank_pipeline_with_CrossEncoder_model_deployed_on_Sagemaker.md new file mode 100644 index 0000000000..39c8b342a2 --- /dev/null +++ b/docs/tutorials/rerank/rerank_pipeline_with_CrossEncoder_model_deployed_on_Sagemaker.md @@ -0,0 +1,379 @@ +# Topic + +[Reranking pipeline](https://opensearch.org/docs/latest/search-plugins/search-relevance/reranking-search-results/) is a feature released in OpenSearch 2.12. +It can rerank search results, providing a relevance score for each document in the search results with respect to the search query. +The relevance score is calculated by a cross-encoder model. + +This tutorial explains how to use the [Huggingface cross-encoder/ms-marco-MiniLM-L-6-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2) model in a reranking pipeline. + +Note: Replace the placeholders that start with `your_` with your own values. + +# Steps + +## 0. Deploy Model on Sagemaker +Use this code to deploy model on Sagemaker. +```python +import sagemaker +import boto3 +from sagemaker.huggingface import HuggingFaceModel +sess = sagemaker.Session() +role = sagemaker.get_execution_role() + +hub = { + 'HF_MODEL_ID':'cross-encoder/ms-marco-MiniLM-L-6-v2', + 'HF_TASK':'text-classification' +} +huggingface_model = HuggingFaceModel( + transformers_version='4.37.0', + pytorch_version='2.1.0', + py_version='py310', + env=hub, + role=role, +) +predictor = huggingface_model.deploy( + initial_instance_count=1, # number of instances + instance_type='ml.m5.xlarge' # ec2 instance type +) +``` +Find the model inference endpoint and note it. We will use it to create connector in next step + +## 1. Create Connector and Model + +If you are using self-managed Opensearch, you should supply AWS credentials: +```json +POST /_plugins/_ml/connectors/_create +{ + "name": "Sagemakre cross-encoder model", + "description": "Test connector for Sagemaker cross-encoder model", + "version": 1, + "protocol": "aws_sigv4", + "credential": { + "access_key": "your_access_key", + "secret_key": "your_secret_key", + "session_token": "your_session_token" + }, + "parameters": { + "region": "your_sagemkaer_model_region_like_us-west-2", + "service_name": "sagemaker" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "your_sagemaker_model_inference_endpoint_created_in_last_step", + "headers": { + "content-type": "application/json" + }, + "request_body": "{ \"inputs\": ${parameters.inputs} }", + "pre_process_function": "\n String escape(def input) { \n if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n return input;\n }\n\n String query = params.query_text;\n StringBuilder builder = new StringBuilder('[');\n \n for (int i=0; i