Skip to content

Commit

Permalink
fixing flake8 errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterXY89 committed Dec 8, 2023
1 parent faf7daa commit a699821
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 41 deletions.
3 changes: 1 addition & 2 deletions chat_doc/dataset_generation/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def build_full_dialogue_dataset(self):
logger.error(e)

return prompts


def build_full_dataset(self):
# load both datasets, if they don't exist, build them
Expand Down Expand Up @@ -69,7 +68,7 @@ def build_full_dataset(self):
logger.error(e)

return prompts

def load_full_dialogue_dataset(self):
try:
with open(self.dialogue_path, "rb") as f:
Expand Down
6 changes: 3 additions & 3 deletions chat_doc/dataset_generation/med_dialogue_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, name="med_dialogue"):

def load_data(self: str):
"""
Load Medical Dialogue data from HF dataset.
Load Medical Dialogue data from HF dataset.
"""
try:
raw_dataset = load_dataset(self.med_dialogue_hf_id, self.variant)
Expand Down Expand Up @@ -58,13 +58,13 @@ def build_prompts(self):
med_dialogue = self.dataset.copy()

prompts = []
for _, row in tqdm(med_dialogue.iterrows(), total=med_dialogue.shape[0]):
for _, row in tqdm(med_dialogue.iterrows(), total=med_dialogue.shape[0]):

prompts.append(
# inherit from ChatDataset
self.unify_prompt(
instruction=f"{row['patient']}",
context=f"",
context="",
response=f"{row['doctor']}",
)
)
Expand Down
23 changes: 11 additions & 12 deletions chat_doc/dataset_generation/pmc_patients_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,34 @@ def _str_sex(sex_str: str) -> str:
logger.info("PMC patients prompts built.")
self.prompts = prompts


def build_prompts(self):
if self._is_processed():
pmc_data = self.dataset.copy()

def _str_sex(sex_str: str) -> str:
return "male" if sex_str == "M" else "female"

# Sample prompt templates
templates = [
f"Summarize the medical history for a {age}-year-old {gender} patient based on the following summary: {patient_summary}",
f"Explain the treatment options for a patient with this profile: {patient_summary}",
f"Identify potential diagnoses for a {age}-year-old patient presenting these symptoms: {patient_summary}"
]
# templates = [
# "Summarize the medical history for a {age}-year-old {gender} patient based on the following summary: {patient_summary}",
# "Explain the treatment options for a patient with this profile: {patient_summary}",
# ]

prompts = []
for _, row in pmc_data.iterrows():
template = np.random.choice(templates)
for _, row in pmc_data.iterrows():
# template = np.random.choice(templates)
age = row.age
gender = _str_sex(row['sex'])
sex = _str_sex(row['sex'])
patient_summary = row['patient']
instruction_template = f"Identify potential diagnoses for a {age}-year-old {sex} patient presenting these symptoms: {patient_summary}"

prompts.append(
# inherit from ChatDataset
self.unify_prompt(
instruction=template.format(age=age, gender=gender, patient_summary=patient_summary),
instruction=instruction_template,
context="",
response=f"{patient_summary}",
)
)
)

return prompts
25 changes: 1 addition & 24 deletions chat_doc/inference/chat.py
Original file line number Diff line number Diff line change
@@ -1,24 +1 @@
from sagemaker.huggingface.model import HuggingFaceModel

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
model_data="s3://models/my-bert-model/model.tar.gz", # path to your trained SageMaker model
role=role, # IAM role with permissions to create an endpoint
transformers_version="4.26", # Transformers version used
pytorch_version="1.13", # PyTorch version used
py_version='py39', # Python version used
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.m5.xlarge"
)

# example request: you always need to define "inputs"
data = {
"inputs": "Camera - You are awarded a SiPix Digital Camera! call 09061221066 fromm landline. Delivery within 28 days."
}

# request
predictor.predict(data)
# see inference.ipynb for now

0 comments on commit a699821

Please sign in to comment.