Skip to content

Commit

Permalink
some final updates
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterXY89 committed Jan 16, 2024
1 parent ab4fc49 commit 289aa31
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 27 deletions.
10 changes: 8 additions & 2 deletions chat_doc/app/static/js/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@ function setICDResults(icd_matches) {

const bestMatch = icd_matches[0];

const icd11URL = "https://icd.who.int/browse11/l-m/en#/http%3A%2F%2Fid.who.int%2Ficd%2Fentity%2F" + bestMatch.metadata.id;
console.log(icd11URL);

resultsDiv.innerHTML += `
<div class="overflow-x-none">
<div class="overflow-x-none pb-3">
<a href="${icd11URL}" target="_blank">
<div class="badge badge-md badge-primary badge-outline">ICD-ID: ${bestMatch.metadata.id}</div>
</a>
<div class="badge badge-md badge-accent badge-outline">Match ${bestMatch.score.toFixed(3)}</div>
</div>
<p>
${bestMatch.id}
${bestMatch.text}
</p>
`;
// ${bestMatch.text.slice(0, 100)} ...
Expand Down
2 changes: 1 addition & 1 deletion chat_doc/app/templates/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ <h2 class="card-title font-bold">Help</h2>
<form method="post" class="w-full">
<input type="hidden" name="csrf_token" value="{{ csrf_token() }}" />
<!-- Rest of your form -->
<textarea class="textarea w-full textarea-md w-full border-none outline-0"
<textarea class="textarea w-full textarea-md w-full border-none outline-0 text-base-content"
placeholder="Message Dr. Chad..." rows="1" contenteditable
onkeydown="autoScaleTextArea(this);" id="newMSG"></textarea>
</form>
Expand Down
16 changes: 10 additions & 6 deletions chat_doc/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ def generate_chat_id(req: request):

def _make_hf_request(payload):
# API_URL_V1 = "https://chdgdfk63z6o9xd8.eu-west-1.aws.endpoints.huggingface.cloud"
API_URL_V2 = "https://pxei8lam5mc67ngq.eu-west-1.aws.endpoints.huggingface.cloud"
API_URL = "https://pxei8lam5mc67ngq.eu-west-1.aws.endpoints.huggingface.cloud"
headers = {
"Authorization": f"Bearer {config['credentials']['hf_token']}",
"Accept": "application/json",
"Authorization": "Bearer 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",
"Content-Type": "application/json",
}

response = requests.post(API_URL_V2, headers=headers, json=payload)
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()


def hf_postprocess(prediction):
print("prediction", prediction)
try:
prediction = (
prediction.split("<</SYS>>")[1]
Expand All @@ -59,8 +61,9 @@ def hf_inference(question: str, history: str, icd_match: str):
print("final_prompt", final_prompt)

payload = chat._payload(final_prompt, qa=False)
print(payload)
result = _make_hf_request(payload)
result = "test"
# result = "test"
print("result", result)

try:
Expand All @@ -72,8 +75,9 @@ def hf_inference(question: str, history: str, icd_match: str):
# return "An error occurred while trying to fetch your answer. Please try again:)"

except Exception as e:
logger.error(f"An error occurred during HF inference: {e}")
return "An error occurred while trying to fetch your answer. Please try again:)" # or an appropriate fallback response
raise e
# logger.error(f"An error occurred during HF inference: {e}")
# return "An error occurred while trying to fetch your answer. Please try again:)" # or an appropriate fallback response


def update_chat_history(chat_id: str, question: str, answer: str):
Expand Down
2 changes: 1 addition & 1 deletion chat_doc/inference/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _payload(self, prompt: str, qa: bool) -> dict:
"max_new_tokens": 512,
"repetition_penalty": 1.2,
"length_penalty": 0.3,
"stop": "<</SYS>>",
# "stop": ["<</SYS>>"],
},
}
# override parameters for qa --> single-choice questions
Expand Down
48 changes: 33 additions & 15 deletions chat_doc/rag/document_processing.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,54 @@
from pathlib import Path

import pandas as pd
from llama_index import Document
from llama_index.node_parser.text import SentenceSplitter
from llama_index.schema import TextNode
from tqdm import tqdm

from chat_doc.config import BASE_DIR, logger


class DocumentProcessor:
def __init__(self, loader, text_parser=SentenceSplitter(chunk_size=1024)):
def __init__(self, loader, text_parser=SentenceSplitter()):
self.loader = loader
self.text_parser = text_parser

def load_documents(self, file_path):
return self.loader.load_data(file=Path(file_path))

def process_documents(self, documents=None):
if not documents:
documents = self.load_documents(file_path=Path(BASE_DIR + "/data/icd11.csv"))
def process_documents(self, documents_df=None):
if not documents_df:
documents_df = pd.read_csv(Path(BASE_DIR + "/data/icd11.csv"))

text_chunks = []
doc_idxs = []
for doc_idx, doc in enumerate(documents):
cur_text_chunks = self.text_parser.split_text(doc.text)
text_chunks.extend(cur_text_chunks)
doc_idxs.extend([doc_idx] * len(cur_text_chunks))
def build_node(row):
node = TextNode(text=row["definition"])
node.metadata = {
"id": row["id"],
"name": row["name"],
}
return node

nodes = []
for idx, text_chunk in enumerate(text_chunks):
node = TextNode(text=text_chunk)
src_doc = documents[doc_idxs[idx]]
node.metadata = src_doc.metadata
nodes.append(node)

for idx, row in tqdm(documents_df.iterrows(), total=len(documents_df)):
nodes.append(build_node(row))

# text_chunks = []
# doc_idxs = []
# for doc_idx, doc in enumerate(documents):
# cur_text_chunks = self.text_parser.split_text(doc.text)
# text_chunks.extend(cur_text_chunks)
# doc_idxs.extend([doc_idx] * len(cur_text_chunks))

# nodes = self.text_parser.split_text(list(map(lambda d: d.text[0], documents)))

# nodes = []
# for raw_node in enumerate(nodes):
# node = TextNode(text=raw_node)
# # src_doc = documents[doc_idxs[idx]]
# # node.metadata = src_doc.metadata
# nodes.append(node)

self.nodes = nodes
return nodes
5 changes: 3 additions & 2 deletions chat_doc/rag/parse_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ def parse_data():
_df = pd.read_json("chat_doc/data/pinglab-ICD11-data.json")
_df = _df.query("definition != 'Key Not found'")
_df.reset_index(inplace=True)
_df = _df["name", "definition"]
_df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"]

_df = _df[["name", "definition", "id"]]
# _df["text"] = "Name:" + _df["name"] + "\nDefinition: " + _df["definition"]

return _df

Expand Down

0 comments on commit 289aa31

Please sign in to comment.