From 289aa31a166443e97464ce4d56b4b07e929b16fc Mon Sep 17 00:00:00 2001 From: Tilman Kerl Date: Tue, 16 Jan 2024 23:12:11 +0100 Subject: [PATCH] some final updates --- chat_doc/app/static/js/chat.js | 10 ++++-- chat_doc/app/templates/chat.html | 2 +- chat_doc/app/utils.py | 16 ++++++---- chat_doc/inference/chat.py | 2 +- chat_doc/rag/document_processing.py | 48 ++++++++++++++++++++--------- chat_doc/rag/parse_data.py | 5 +-- 6 files changed, 56 insertions(+), 27 deletions(-) diff --git a/chat_doc/app/static/js/chat.js b/chat_doc/app/static/js/chat.js index 1d96e31..9a80238 100644 --- a/chat_doc/app/static/js/chat.js +++ b/chat_doc/app/static/js/chat.js @@ -72,12 +72,18 @@ function setICDResults(icd_matches) { const bestMatch = icd_matches[0]; + const icd11URL = "https://icd.who.int/browse11/l-m/en#/http%3A%2F%2Fid.who.int%2Ficd%2Fentity%2F" + bestMatch.metadata.id; + console.log(icd11URL); + resultsDiv.innerHTML += ` -
+
+ +
ICD-ID: ${bestMatch.metadata.id}
+
Match ${bestMatch.score.toFixed(3)}

- ${bestMatch.id} + ${bestMatch.text}

`; // ${bestMatch.text.slice(0, 100)} ... diff --git a/chat_doc/app/templates/chat.html b/chat_doc/app/templates/chat.html index 9e57c41..087cb5f 100644 --- a/chat_doc/app/templates/chat.html +++ b/chat_doc/app/templates/chat.html @@ -68,7 +68,7 @@

Help

-
diff --git a/chat_doc/app/utils.py b/chat_doc/app/utils.py index d0c68de..46e006d 100644 --- a/chat_doc/app/utils.py +++ b/chat_doc/app/utils.py @@ -24,17 +24,19 @@ def generate_chat_id(req: request): def _make_hf_request(payload): # API_URL_V1 = "https://chdgdfk63z6o9xd8.eu-west-1.aws.endpoints.huggingface.cloud" - API_URL_V2 = "https://pxei8lam5mc67ngq.eu-west-1.aws.endpoints.huggingface.cloud" + API_URL = "https://pxei8lam5mc67ngq.eu-west-1.aws.endpoints.huggingface.cloud" headers = { - "Authorization": f"Bearer {config['credentials']['hf_token']}", + "Accept": "application/json", + "Authorization": "Bearer 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", "Content-Type": "application/json", } - response = requests.post(API_URL_V2, headers=headers, json=payload) + response = requests.post(API_URL, headers=headers, json=payload) return response.json() def hf_postprocess(prediction): + print("prediction", prediction) try: prediction = ( prediction.split("<>")[1] @@ -59,8 +61,9 @@ def hf_inference(question: str, history: str, icd_match: str): print("final_prompt", final_prompt) payload = chat._payload(final_prompt, qa=False) + print(payload) result = _make_hf_request(payload) - result = "test" + # result = "test" print("result", result) try: @@ -72,8 +75,9 @@ def hf_inference(question: str, history: str, icd_match: str): # return "An error occurred while trying to fetch your answer. Please try again:)" except Exception as e: - logger.error(f"An error occurred during HF inference: {e}") - return "An error occurred while trying to fetch your answer. Please try again:)" # or an appropriate fallback response + raise e + # logger.error(f"An error occurred during HF inference: {e}") + # return "An error occurred while trying to fetch your answer. Please try again:)" # or an appropriate fallback response def update_chat_history(chat_id: str, question: str, answer: str): diff --git a/chat_doc/inference/chat.py b/chat_doc/inference/chat.py index 8600f5c..047fee0 100644 --- a/chat_doc/inference/chat.py +++ b/chat_doc/inference/chat.py @@ -97,7 +97,7 @@ def _payload(self, prompt: str, qa: bool) -> dict: "max_new_tokens": 512, "repetition_penalty": 1.2, "length_penalty": 0.3, - "stop": "<>", + # "stop": ["<>"], }, } # override parameters for qa --> single-choice questions diff --git a/chat_doc/rag/document_processing.py b/chat_doc/rag/document_processing.py index 526e721..8545d40 100644 --- a/chat_doc/rag/document_processing.py +++ b/chat_doc/rag/document_processing.py @@ -1,36 +1,54 @@ from pathlib import Path +import pandas as pd +from llama_index import Document from llama_index.node_parser.text import SentenceSplitter from llama_index.schema import TextNode +from tqdm import tqdm from chat_doc.config import BASE_DIR, logger class DocumentProcessor: - def __init__(self, loader, text_parser=SentenceSplitter(chunk_size=1024)): + def __init__(self, loader, text_parser=SentenceSplitter()): self.loader = loader self.text_parser = text_parser def load_documents(self, file_path): return self.loader.load_data(file=Path(file_path)) - def process_documents(self, documents=None): - if not documents: - documents = self.load_documents(file_path=Path(BASE_DIR + "/data/icd11.csv")) + def process_documents(self, documents_df=None): + if not documents_df: + documents_df = pd.read_csv(Path(BASE_DIR + "/data/icd11.csv")) - text_chunks = [] - doc_idxs = [] - for doc_idx, doc in enumerate(documents): - cur_text_chunks = self.text_parser.split_text(doc.text) - text_chunks.extend(cur_text_chunks) - doc_idxs.extend([doc_idx] * len(cur_text_chunks)) + def build_node(row): + node = TextNode(text=row["definition"]) + node.metadata = { + "id": row["id"], + "name": row["name"], + } + return node nodes = [] - for idx, text_chunk in enumerate(text_chunks): - node = TextNode(text=text_chunk) - src_doc = documents[doc_idxs[idx]] - node.metadata = src_doc.metadata - nodes.append(node) + + for idx, row in tqdm(documents_df.iterrows(), total=len(documents_df)): + nodes.append(build_node(row)) + + # text_chunks = [] + # doc_idxs = [] + # for doc_idx, doc in enumerate(documents): + # cur_text_chunks = self.text_parser.split_text(doc.text) + # text_chunks.extend(cur_text_chunks) + # doc_idxs.extend([doc_idx] * len(cur_text_chunks)) + + # nodes = self.text_parser.split_text(list(map(lambda d: d.text[0], documents))) + + # nodes = [] + # for raw_node in enumerate(nodes): + # node = TextNode(text=raw_node) + # # src_doc = documents[doc_idxs[idx]] + # # node.metadata = src_doc.metadata + # nodes.append(node) self.nodes = nodes return nodes diff --git a/chat_doc/rag/parse_data.py b/chat_doc/rag/parse_data.py index ff94f23..3f0b2c7 100644 --- a/chat_doc/rag/parse_data.py +++ b/chat_doc/rag/parse_data.py @@ -5,8 +5,9 @@ def parse_data(): _df = pd.read_json("chat_doc/data/pinglab-ICD11-data.json") _df = _df.query("definition != 'Key Not found'") _df.reset_index(inplace=True) - _df = _df["name", "definition"] - _df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"] + + _df = _df[["name", "definition", "id"]] + # _df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"] return _df