Skip to content

Commit

Permalink
adding more data & adopt train flow for additional data
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterXY89 committed Dec 8, 2023
1 parent 70c4a00 commit ae5d838
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 13 deletions.
2 changes: 2 additions & 0 deletions chat_doc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# parent dir of chat_doc (BASE_DIR)
DATA_DIR = ROOT_DIR + "/data"

SEED = 1160

# Change if you renamed your config filey
CONFIG_FILE_PATH = f"{BASE_DIR}/config.yml"
CREDENTIAL_FILE_PATH = f"{BASE_DIR}/.env"
Expand Down
62 changes: 53 additions & 9 deletions chat_doc/dataset_generation/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,76 @@
from chat_doc.config import DATA_DIR, ROOT_DIR, logger
from chat_doc.dataset_generation.icd11_dataset import ICD11Dataset
from chat_doc.dataset_generation.pmc_patients_dataset import PMCPatientsDataset
from chat_doc.dataset_generation.diagnose_dataset import DiagnoseDataset
from chat_doc.dataset_generation.med_dialogue_dataset import MedDialogueDataset


class DatasetFactory:
def __init__(self):
self.dataset = None
self.path = ROOT_DIR + "/data/full_prompts.pkl"
self.available_datasets = ["icd", "pmc", "full"]
self.full_path = ROOT_DIR + "/data/full_prompts.pkl"
self.dialogue_path = ROOT_DIR + "/data/full_dialogue.pkl"
self.available_datasets = ["icd", "pmc", "diagnose", "med-dialogue", "dialogue-full", "full"]

def build_full_dialogue_dataset(self):
dialogue_prompts = self.load_dataset("med-dialogue")
diagnose_prompts = self.load_dataset("diagnose")

if dialogue_prompts is None:
dialogue_prompts = self.build_dataset("med-dialogue")
if diagnose_prompts is None:
diagnose_prompts = self.build_dataset("diagnose")

# combine them
prompts = dialogue_prompts + diagnose_prompts

try:
with open(self.dialogue_path, "wb") as f:
pickle.dump(prompts, f)
logger.info(f"Full prompts saved to {self.dialogue_path}")
except Exception as e:
logger.error(f"Could not save full prompts to {self.dialogue_path}")
logger.error(e)

return prompts


def build_full_dataset(self):
# load both datasets, if they don't exist, build them
icd_prompts = self.load_dataset("icd")
pmc_prompts = self.load_dataset("pmc")
diagnose_prompts = self.load_dataset("diagnose")
dialogue_prompts = self.load_dataset("med-dialogue")

if icd_prompts is None:
icd_prompts = self.build_dataset("icd")
if pmc_prompts is None:
pmc_prompts = self.build_dataset("pmc")
if diagnose_prompts is None:
diagnose_prompts = self.build_dataset("diagnose")
if dialogue_prompts is None:
dialogue_prompts = self.build_dataset("med-dialogue")

# combine them
prompts = icd_prompts + pmc_prompts
prompts = icd_prompts + pmc_prompts + diagnose_prompts

try:
with open(self.path, "wb") as f:
with open(self.full_path, "wb") as f:
pickle.dump(prompts, f)
logger.info(f"Full prompts saved to {self.path}")
logger.info(f"Full prompts saved to {self.full_path}")
except Exception as e:
logger.error(f"Could not save full prompts to {self.path}")
logger.error(f"Could not save full prompts to {self.full_path}")
logger.error(e)

return prompts

def load_full_dataset(self):
try:
with open(self.path, "rb") as f:
with open(self.full_path, "rb") as f:
prompts = pickle.load(f)
logger.info(f"Full prompts loaded from {self.path}")
logger.info(f"Full prompts loaded from {self.full_path}")
except Exception as e:
logger.error(f"Could not load full prompts from {self.path}")
logger.error(f"Could not load full prompts from {self.full_path}")
logger.error(e)
prompts = None

Expand All @@ -55,6 +87,12 @@ def build_dataset(self, name):
self.dataset = ICD11Dataset()
elif name == "pmc":
self.dataset = PMCPatientsDataset()
elif name == "diagnose":
self.dataset = DiagnoseDataset()
elif name == "med-dialogue":
self.dataset = MedDialogueDataset()
elif name == "dialogue-full":
self.build_full_dialogue_dataset()
elif name == "full":
self.build_full_dataset()
else:
Expand Down Expand Up @@ -82,6 +120,12 @@ def load_dataset(self, name, is_prompts=True):
self.dataset = ICD11Dataset()
elif name == "pmc":
self.dataset = PMCPatientsDataset()
elif name == "diagnose":
self.dataset = DiagnoseDataset()
elif name == "med-dialogue":
self.dataset = MedDialogueDataset()
elif name == "dialogue-full":
return self.load_full_dialogue_dataset()
elif name == "full":
return self.load_full_dataset()
else:
Expand Down
74 changes: 74 additions & 0 deletions chat_doc/dataset_generation/diagnose_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Load and process ICD-11 data to generate a dataset for training and testing.
"""

import numpy as np
import pandas as pd
from tqdm import tqdm

from chat_doc.config import BASE_DIR, logger, SEED
from chat_doc.dataset_generation.chat_dataset import ChatDataset

np.random.seed(SEED)


class DiagnoseDataset(ChatDataset):
def __init__(self, name="diagnose"):
super().__init__(name)
self.diagnose_me_path = BASE_DIR + "/data/en_medical_dialog.json"

def load_data(self: str):
"""
Load Diagnose-Me data from json file.
"""
try:
self.dataset = pd.read_json(self.diagnose_me_path)
logger.info("Diagnose-Me data loaded from json file.")
except Exception as e:
logger.error(f"Error loading dataset: {e}")
raise e

def process_data(self):
"""
Here we reproduce the data processing steps shown usefull from our exploration (data_exploration.ipynb)
"""
if self._is_loaded():
# copy data to avoid changing the original in case of errors
diagnose_data = self.dataset.copy()

# drop irrelevant columns (id) --> is the same as pd index
diagnose_data.drop(columns=["id"], inplace=True)
diagnose_data.columns = ["desc", "doctor", "patient"]

for col in diagnose_data.columns:
# remove all urls from text
diagnose_data[col] = diagnose_data[col].str.replace(r'\s*https?://\S+(\s+|$)', ' ').str.strip()
# remove all html tags from text
diagnose_data[col] = diagnose_data[col].str.replace(r'<[^<]+?>', ' ').str.strip()


logger.info("Diagnose-Me data processed.")
self.processed = True
self.dataset = diagnose_data

def build_prompts(self):
if self._is_processed():
diagnose_data = self.dataset.copy()

# sample 7% of the data --> approx. 15.000 prompts
diagnose_data = diagnose_data.sample(frac=0.06).reset_index(drop=True)

prompts = []
for _, row in tqdm(diagnose_data.iterrows(), total=diagnose_data.shape[0]):

prompts.append(
# inherit from ChatDataset
self.unify_prompt(
instruction=f"{row['desc']}",
context=f"Patient: {row['patient']}",
response=f"{row['doctor']}",
)
)

logger.info("Diagnose-Me prompts built.")
self.prompts = prompts
73 changes: 73 additions & 0 deletions chat_doc/dataset_generation/med_dialogue_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Load and process ICD-11 data to generate a dataset for training and testing.
"""

import numpy as np
import pandas as pd
from tqdm import tqdm

from datasets import load_dataset

from chat_doc.config import BASE_DIR, logger
from chat_doc.dataset_generation.chat_dataset import ChatDataset


class MedDialogueDataset(ChatDataset):
def __init__(self, name="med_dialogue"):
super().__init__(name)
self.variant = "processed.en"
self.med_dialogue_hf_id = "medical_dialog"

def load_data(self: str):
"""
Load Medical Dialogue data from HF dataset.
"""
try:
raw_dataset = load_dataset(self.med_dialogue_hf_id, self.variant)
self.dataset = raw_dataset["train"]
# = raw_dataset.to_pandas()
logger.info("Medical Dialogue data loaded from HF.")
except Exception as e:
logger.error(f"Error loading dataset: {e}")
raise e

def process_data(self):
"""
Here we reproduce the data processing steps shown usefull from our exploration (data_exploration.ipynb)
"""
if self._is_loaded():
med_dialogue = self.dataset

diag_list = []
for record in med_dialogue:
utt = record["utterances"]
diag_list.append({
"patient": utt[0].replace("patient: ", ""),
"doctor": utt[1].replace("doctor: ", "")
})

med_dialogue = pd.DataFrame(diag_list)


logger.info("Medical Dialogue data processed.")
self.processed = True
self.dataset = med_dialogue

def build_prompts(self):
if self._is_processed():
med_dialogue = self.dataset.copy()

prompts = []
for _, row in tqdm(med_dialogue.iterrows(), total=med_dialogue.shape[0]):

prompts.append(
# inherit from ChatDataset
self.unify_prompt(
instruction=f"{row['patient']}",
context=f"",
response=f"{row['doctor']}",
)
)

logger.info("Medical Dialogue prompts built.")
self.prompts = prompts
37 changes: 36 additions & 1 deletion chat_doc/dataset_generation/pmc_patients_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from datasets import load_dataset
from tqdm import tqdm

from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer

from chat_doc.config import logger
from chat_doc.dataset_generation.chat_dataset import ChatDataset

Expand Down Expand Up @@ -52,7 +54,7 @@ def process_data(self):
self.processed = True
self.dataset = pmc_data

def build_prompts(self):
def __build_prompts(self):
if self._is_processed():
pmc_data = self.dataset.copy()

Expand All @@ -78,3 +80,36 @@ def _str_sex(sex_str: str) -> str:

logger.info("PMC patients prompts built.")
self.prompts = prompts


def build_prompts(self):
if self._is_processed():
pmc_data = self.dataset.copy()

def _str_sex(sex_str: str) -> str:
return "male" if sex_str == "M" else "female"

# Sample prompt templates
templates = [
f"Summarize the medical history for a {age}-year-old {gender} patient based on the following summary: {patient_summary}",
f"Explain the treatment options for a patient with this profile: {patient_summary}",
f"Identify potential diagnoses for a {age}-year-old patient presenting these symptoms: {patient_summary}"
]

prompts = []
for _, row in pmc_data.iterrows():
template = np.random.choice(templates)
age = row.age
gender = _str_sex(row['sex'])
patient_summary = row['patient']

prompts.append(
# inherit from ChatDataset
self.unify_prompt(
instruction=template.format(age=age, gender=gender, patient_summary=patient_summary),
context="",
response=f"{patient_summary}",
)
)

return prompts
4 changes: 2 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# "generate" subcommand
generate_parser = subparsers.add_parser("generate", help="Generate data")
generate_parser.add_argument(
"--dataset", choices=["pmc", "icd", "full"], required=True, help="Dataset to generate"
"--dataset", choices=["pmc", "icd", "diagnose", "med-dialogue", "dialogue-full", "full"], required=True, help="Dataset to generate"
)
# generate_parser.add_argument(
# "--output_path", default="./data", help="Output path (default: ./data)"
Expand All @@ -32,7 +32,7 @@
# "train" subcommand
train_parser = subparsers.add_parser("train", help="Train the model")
train_parser.add_argument(
"--dataset", choices=["pmc", "icd", "full"], required=True, help="Dataset to train on"
"--dataset", choices=["pmc", "icd", "diagnose", "med-dialogue", "dialogue-full", "full"], required=True, help="Dataset to train on"
)
train_parser.add_argument(
"--base_model",
Expand Down
10 changes: 9 additions & 1 deletion sage_maker_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(dataset_name=\"icd\")\n",
"trainer = Trainer(\n",
" dataset_name=\"meta-llama/Llama-2-7b-hf\",\n",
" base_model=\"llama7\",\n",
" hyperparams= {\n",
" \"epochs\": 2,\n",
" \"per_device_train_batch_size\": 3,\n",
" }\n",
"\n",
")\n",
"trainer._initialize()\n",
"trainer._build_training_job()"
]
Expand Down

0 comments on commit ae5d838

Please sign in to comment.