-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding more data & adopt train flow for additional data
- Loading branch information
1 parent
70c4a00
commit ae5d838
Showing
7 changed files
with
249 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Load and process ICD-11 data to generate a dataset for training and testing. | ||
""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
from chat_doc.config import BASE_DIR, logger, SEED | ||
from chat_doc.dataset_generation.chat_dataset import ChatDataset | ||
|
||
np.random.seed(SEED) | ||
|
||
|
||
class DiagnoseDataset(ChatDataset): | ||
def __init__(self, name="diagnose"): | ||
super().__init__(name) | ||
self.diagnose_me_path = BASE_DIR + "/data/en_medical_dialog.json" | ||
|
||
def load_data(self: str): | ||
""" | ||
Load Diagnose-Me data from json file. | ||
""" | ||
try: | ||
self.dataset = pd.read_json(self.diagnose_me_path) | ||
logger.info("Diagnose-Me data loaded from json file.") | ||
except Exception as e: | ||
logger.error(f"Error loading dataset: {e}") | ||
raise e | ||
|
||
def process_data(self): | ||
""" | ||
Here we reproduce the data processing steps shown usefull from our exploration (data_exploration.ipynb) | ||
""" | ||
if self._is_loaded(): | ||
# copy data to avoid changing the original in case of errors | ||
diagnose_data = self.dataset.copy() | ||
|
||
# drop irrelevant columns (id) --> is the same as pd index | ||
diagnose_data.drop(columns=["id"], inplace=True) | ||
diagnose_data.columns = ["desc", "doctor", "patient"] | ||
|
||
for col in diagnose_data.columns: | ||
# remove all urls from text | ||
diagnose_data[col] = diagnose_data[col].str.replace(r'\s*https?://\S+(\s+|$)', ' ').str.strip() | ||
# remove all html tags from text | ||
diagnose_data[col] = diagnose_data[col].str.replace(r'<[^<]+?>', ' ').str.strip() | ||
|
||
|
||
logger.info("Diagnose-Me data processed.") | ||
self.processed = True | ||
self.dataset = diagnose_data | ||
|
||
def build_prompts(self): | ||
if self._is_processed(): | ||
diagnose_data = self.dataset.copy() | ||
|
||
# sample 7% of the data --> approx. 15.000 prompts | ||
diagnose_data = diagnose_data.sample(frac=0.06).reset_index(drop=True) | ||
|
||
prompts = [] | ||
for _, row in tqdm(diagnose_data.iterrows(), total=diagnose_data.shape[0]): | ||
|
||
prompts.append( | ||
# inherit from ChatDataset | ||
self.unify_prompt( | ||
instruction=f"{row['desc']}", | ||
context=f"Patient: {row['patient']}", | ||
response=f"{row['doctor']}", | ||
) | ||
) | ||
|
||
logger.info("Diagnose-Me prompts built.") | ||
self.prompts = prompts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
Load and process ICD-11 data to generate a dataset for training and testing. | ||
""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
from datasets import load_dataset | ||
|
||
from chat_doc.config import BASE_DIR, logger | ||
from chat_doc.dataset_generation.chat_dataset import ChatDataset | ||
|
||
|
||
class MedDialogueDataset(ChatDataset): | ||
def __init__(self, name="med_dialogue"): | ||
super().__init__(name) | ||
self.variant = "processed.en" | ||
self.med_dialogue_hf_id = "medical_dialog" | ||
|
||
def load_data(self: str): | ||
""" | ||
Load Medical Dialogue data from HF dataset. | ||
""" | ||
try: | ||
raw_dataset = load_dataset(self.med_dialogue_hf_id, self.variant) | ||
self.dataset = raw_dataset["train"] | ||
# = raw_dataset.to_pandas() | ||
logger.info("Medical Dialogue data loaded from HF.") | ||
except Exception as e: | ||
logger.error(f"Error loading dataset: {e}") | ||
raise e | ||
|
||
def process_data(self): | ||
""" | ||
Here we reproduce the data processing steps shown usefull from our exploration (data_exploration.ipynb) | ||
""" | ||
if self._is_loaded(): | ||
med_dialogue = self.dataset | ||
|
||
diag_list = [] | ||
for record in med_dialogue: | ||
utt = record["utterances"] | ||
diag_list.append({ | ||
"patient": utt[0].replace("patient: ", ""), | ||
"doctor": utt[1].replace("doctor: ", "") | ||
}) | ||
|
||
med_dialogue = pd.DataFrame(diag_list) | ||
|
||
|
||
logger.info("Medical Dialogue data processed.") | ||
self.processed = True | ||
self.dataset = med_dialogue | ||
|
||
def build_prompts(self): | ||
if self._is_processed(): | ||
med_dialogue = self.dataset.copy() | ||
|
||
prompts = [] | ||
for _, row in tqdm(med_dialogue.iterrows(), total=med_dialogue.shape[0]): | ||
|
||
prompts.append( | ||
# inherit from ChatDataset | ||
self.unify_prompt( | ||
instruction=f"{row['patient']}", | ||
context=f"", | ||
response=f"{row['doctor']}", | ||
) | ||
) | ||
|
||
logger.info("Medical Dialogue prompts built.") | ||
self.prompts = prompts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters