Skip to content

Commit

Permalink
enhance connector helper notebook to support 2.9 (#2202)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Mar 13, 2024
1 parent 4af7e78 commit c233356
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
56 changes: 51 additions & 5 deletions docs/tutorials/aws/AIConnectorHelper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"from requests_aws4auth import AWS4Auth\n",
"import time\n",
"\n",
"# This python code works for AWS OpenSearch 2.11\n",
"# This Python code is compatible with AWS OpenSearch versions 2.9 and higher.\n",
"class AIConnectorHelper:\n",
" \n",
" def __init__(self, region, opensearch_domain_name, opensearch_domain_username, opensearch_domain_password, aws_user_name):\n",
Expand Down Expand Up @@ -265,24 +265,70 @@
" connector_id = json.loads(r.text)['connector_id']\n",
" return connector_id\n",
" \n",
" def search_model_group(self, model_group_name):\n",
" payload = {\n",
" \"query\": {\n",
" \"term\": {\n",
" \"name.keyword\": {\n",
" \"value\": model_group_name\n",
" }\n",
" }\n",
" }\n",
" }\n",
" headers = {\"Content-Type\": \"application/json\"}\n",
" r = requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/model_groups/_search',\n",
" auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password),\n",
" json=payload,\n",
" headers=headers)\n",
" #print(r.text)\n",
" response = json.loads(r.text)\n",
" return response\n",
" \n",
" def create_model_group(self, model_group_name, description):\n",
" search_model_group_response = self.search_model_group(model_group_name)\n",
" if search_model_group_response['hits']['total']['value'] > 0:\n",
" return search_model_group_response['hits']['hits'][0]['_id']\n",
" payload = {\n",
" \"name\": model_group_name,\n",
" \"description\": description\n",
" }\n",
" headers = {\"Content-Type\": \"application/json\"}\n",
" r = requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/model_groups/_register',\n",
" auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password),\n",
" json=payload,\n",
" headers=headers)\n",
" print(r.text)\n",
" response = json.loads(r.text)\n",
" return response['model_group_id']\n",
" \n",
" def get_task(self, task_id):\n",
" return requests.get(f'{self.opensearch_domain_url}/_plugins/_ml/tasks/{task_id}',\n",
" auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password))\n",
" \n",
" def create_model(self, model_name, description, connector_id, deploy=True):\n",
" model_group_id = self.create_model_group(model_name, description)\n",
" payload = {\n",
" \"name\": model_name,\n",
" \"function_name\": \"remote\",\n",
" \"description\": description,\n",
" \"model_group_id\": model_group_id,\n",
" \"connector_id\": connector_id\n",
" }\n",
"\n",
" headers = {\"Content-Type\": \"application/json\"}\n",
"\n",
" deploy_str = str(deploy).lower()\n",
" r = requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/models/_register?deploy={deploy_str}',\n",
" auth=HTTPBasicAuth(self.opensearch_domain_username, self.opensearch_domain_opensearch_domain_password),\n",
" json=payload,\n",
" headers=headers)\n",
" print(r.text)\n",
" model_id = json.loads(r.text)['model_id']\n",
" return model_id\n",
" response = json.loads(r.text)\n",
" if 'model_id' in response:\n",
" return response['model_id']\n",
" else:\n",
" time.sleep(2) # sleep two seconds for task complete\n",
" r = self.get_task(response['task_id'])\n",
" print(r.text)\n",
" return json.loads(r.text)['model_id']\n",
" \n",
" def deploy_model(self, model_id):\n",
" return requests.post(f'{self.opensearch_domain_url}/_plugins/_ml/models/{model_id}/_deploy',\n",
Expand Down

0 comments on commit c233356

Please sign in to comment.