From a77a580e3f0cc1140cf5dfd5b457c9f022d20f14 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 19:35:46 -0800 Subject: [PATCH 01/10] cot and zs-cot support for semantic filter --- examples/op_examples/filter.py | 1 + examples/op_examples/filter_cascade.py | 1 + lotus/sem_ops/postprocessors.py | 88 ++++++++++----------- lotus/sem_ops/sem_filter.py | 4 +- lotus/templates/task_instructions.py | 101 ++++++++----------------- 5 files changed, 75 insertions(+), 120 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index a1acc00d..f20aa593 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,6 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) + data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index 104c8410..a1b94f4d 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -8,6 +8,7 @@ gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) + data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index d531099c..e11dbcf2 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -7,6 +7,33 @@ SemanticMapPostprocessOutput, ) +def cot_postprocessor(llm_answers: list[str]): + outputs: list[str | None] = [] + explanations: list[str | None] = [] + for llm_answer in llm_answers: + import xml.etree.ElementTree as ET + try: + root = ET.fromstring(f"{llm_answer}") + reasoning = root.find('Reasoning') + answer = root.find('Answer') + + if reasoning is None or answer is None: + raise ValueError("Failed to parse reasoning or answer") + + reasoning = reasoning.text.strip() if reasoning.text else None + answer = answer.text.strip() if answer.text else "" + + explanations.append(reasoning) + outputs.append(answer) + + lotus.logger.debug(f"{llm_answer}") + + except (ET.ParseError, ValueError): + lotus.logger.debug(f"\t Failed to parse reasoning and answer from: {llm_answer}") + explanations.append(None) + outputs.append("") + + return outputs, explanations def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: """ @@ -79,49 +106,9 @@ def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOut return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data) - -def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput: - """ - Postprocess the output of the filter operator with CoT reasoning. - - Args: - llm_answers (list[str]): The list of llm answers. - default (bool): The default value to use if we fail to parse the answer. - - Returns: - SemanticFilterPostprocessOutput - """ - outputs: list[bool] = [] - explanations: list[str | None] = [] - - for llm_answer in llm_answers: - reasoning_idx = llm_answer.find("Reasoning:\n") - if reasoning_idx == -1: - reasoning_idx = 0 - else: - reasoning_idx += len("Reasoning:\n") - - answer_idx = llm_answer.find("Answer:") - reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") - answer = llm_answer[answer_idx + len("Answer:") :] - - explanations.append(reasoning) - - if "True" in answer: - outputs.append(True) - elif "False" in answer: - outputs.append(False) - else: - lotus.logger.info(f"\t Failed to parse: defaulting to {default}") - outputs.append(default) - - return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) - - def filter_postprocess( llm_answers: list[str], default: bool = True, - cot_reasoning: bool = False, ) -> SemanticFilterPostprocessOutput: """ Postprocess the output of the filter operator. @@ -134,18 +121,21 @@ def filter_postprocess( Returns: SemanticFilterPostprocessOutput """ - if cot_reasoning: - return filter_postprocess_cot(llm_answers, default) + outputs, explanations = cot_postprocessor(llm_answers) + + def process_outputs(answer): + if answer is None: + lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") + return default - outputs: list[bool] = [] - explanations: list[str | None] = [None] * len(llm_answers) - for answer in llm_answers: if "True" in answer: - outputs.append(True) + return True elif "False" in answer: - outputs.append(False) + return False else: - lotus.logger.info(f"\t Failed to parse: defaulting to {default}") - outputs.append(default) + lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") + return default + + outputs = [process_outputs(answer) for answer in outputs] return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index d6253b8d..fba85c5e 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -47,7 +47,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -63,7 +63,7 @@ def sem_filter( ) postprocess_output = filter_postprocess( - lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"] + lm_output.outputs, default=default ) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fc30efd9..53393027 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -8,6 +8,16 @@ from lotus.types import SerializationFormat +def cot_formatter(reasoning, answer): + return f"""{reasoning}{answer}""" + +def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: + reasoning_instructions = f"Provide your reasoning here.{reasoning_instructions}" + answer_instructions = f"Provide your answer here. {answer_instructions}" + return f"""Let's think step by step. Use the following format to provide your answer: + {cot_formatter(reasoning_instructions, answer_instructions)} + """ + def context_formatter( multimodal_data: dict[str, Any] | str, ) -> tuple[str, list[dict[str, str]]]: @@ -54,79 +64,22 @@ def user_message_formatter( "content": content, } - -def filter_formatter_cot( - multimodal_data: dict[str, Any], - user_instruction: str, - examples_multimodal_data: list[dict[str, Any]], - examples_answer: list[bool], - cot_reasoning: list[str], -) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide a claim and some relevant context.\n" - "Your job is to determine whether the claim is true for the given context.\n" - 'First give your reasoning. Then you MUST end your output with "Answer: True or False"' - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] - - for idx in range(len(examples_multimodal_data)): - ex_multimodal_data = examples_multimodal_data[idx] - ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] - messages.extend( - [ - user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), - { - "role": "assistant", - "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", - }, - ] - ) - - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) - return messages - - -def filter_formatter_zs_cot( - multimodal_data: dict[str, Any], - user_instruction: str, -) -> list[dict[str, str]]: - sys_instruction = ( - "The user will provide a claim and some relevant context.\n" - "Your job is to determine whether the claim is true for the given context.\n" - 'First give your reasoning. Then you MUST end your output with "Answer: True or False"' - ) - messages = [ - {"role": "system", "content": sys_instruction}, - ] - - messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) - return messages - - def filter_formatter( multimodal_data: dict[str, Any], user_instruction: str, examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, - cot_reasoning: list[str] | None = None, - strategy: str | None = None, + cot_reasoning: list[str] | None = None ) -> list[dict[str, str]]: - if cot_reasoning: - assert examples_multimodal_data is not None and examples_answer is not None - return filter_formatter_cot( - multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning - ) - elif strategy == "zs-cot": - return filter_formatter_zs_cot(multimodal_data, user_instruction) - + sys_instruction = ( - "The user will provide a claim and some relevant context.\n" - "Your job is to determine whether the claim is true for the given context.\n" - 'You must answer with a single word, "True" or "False".' + f"""The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. + + {cot_prompt_formatter(answer_instructions="The answer should be either True or False")} + """ ) + messages = [ {"role": "system", "content": sys_instruction}, ] @@ -134,13 +87,23 @@ def filter_formatter( if examples_multimodal_data: assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) - for i in range(len(examples_multimodal_data)): - ex_multimodal_data = examples_multimodal_data[i] - ex_ans = examples_answer[i] + assert len(examples_multimodal_data) == len(examples_answer) + + if cot_reasoning: + assert isinstance(cot_reasoning, list) + assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) + + for idx in range(len(examples_multimodal_data)): + ex_multimodal_data = examples_multimodal_data[idx] + ex_ans = examples_answer[idx] + cot = cot_reasoning[idx] if cot_reasoning else "" messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), - {"role": "assistant", "content": str(ex_ans)}, + { + "role": "assistant", + "content": f"""{cot_formatter(cot, ex_ans)}""", + }, ] ) From 2999d584ae64a013d0887896b27d3d349f739cdd Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 23:24:08 -0800 Subject: [PATCH 02/10] made cot optional --- examples/op_examples/filter.py | 4 ++-- lotus/sem_ops/postprocessors.py | 24 +++++++++++-------- lotus/sem_ops/sem_filter.py | 2 +- lotus/templates/task_instructions.py | 35 ++++++++++++++++++++-------- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..69c19a15 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - +lotus.logger.setLevel("DEBUG") data = { "Course Name": [ "Probability and Random Processes", @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction) +df = df.sem_filter(user_instruction, strategy="cot") print(df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index e11dbcf2..0e377cb8 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -14,21 +14,25 @@ def cot_postprocessor(llm_answers: list[str]): import xml.etree.ElementTree as ET try: root = ET.fromstring(f"{llm_answer}") - reasoning = root.find('Reasoning') - answer = root.find('Answer') + reasoning = root.find('.//Reasoning') # Use XPath to find nested tags + answer = root.find('.//Answer') # Use XPath to find nested tags - if reasoning is None or answer is None: - raise ValueError("Failed to parse reasoning or answer") - - reasoning = reasoning.text.strip() if reasoning.text else None - answer = answer.text.strip() if answer.text else "" + if answer is not None and answer.text: + answer = answer.text.strip() + else: + lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") + answer = "" + + if reasoning is not None and reasoning.text: + reasoning = reasoning.text.strip() + else: + lotus.logger.debug(f"\t Failed to parse reasoning from: {llm_answer}") + reasoning = None explanations.append(reasoning) outputs.append(answer) - - lotus.logger.debug(f"{llm_answer}") - except (ET.ParseError, ValueError): + except (ET.ParseError): lotus.logger.debug(f"\t Failed to parse reasoning and answer from: {llm_answer}") explanations.append(None) outputs.append("") diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index fba85c5e..e4ec5939 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -47,7 +47,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 53393027..94454299 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -7,7 +7,6 @@ from lotus.dtype_extensions import ImageDtype from lotus.types import SerializationFormat - def cot_formatter(reasoning, answer): return f"""{reasoning}{answer}""" @@ -17,6 +16,9 @@ def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} """ +def non_cot_prompt_formatter(answer_instructions: str = "") -> str: + answer_instructions = f"Provide your answer here. {answer_instructions}" + return f"""{answer_instructions}""" def context_formatter( multimodal_data: dict[str, Any] | str, @@ -45,7 +47,6 @@ def context_formatter( raise ValueError("multimodal_data must be a dictionary or a string") return text, image_inputs - def user_message_formatter( multimodal_data: dict[str, Any] | str, user_instruction_with_tag: str | None = None, @@ -69,16 +70,30 @@ def filter_formatter( user_instruction: str, examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, - cot_reasoning: list[str] | None = None + cot_reasoning: list[str] | None = None, + strategy: str | None = None, + reasoning_instructions: str = "", ) -> list[dict[str, str]]: + answer_instructions="The answer should be either True or False" - sys_instruction = ( - f"""The user will provide a claim and some relevant context. - Your job is to determine whether the claim is true for the given context. + if strategy == "cot": + sys_instruction = ( + f"""The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. + + {cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, + answer_instructions=answer_instructions)} + """ + ) + else: + sys_instruction = ( + f"""The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. - {cot_prompt_formatter(answer_instructions="The answer should be either True or False")} - """ - ) + {non_cot_prompt_formatter(answer_instructions=answer_instructions)} + """ + ) messages = [ {"role": "system", "content": sys_instruction}, @@ -96,7 +111,7 @@ def filter_formatter( for idx in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] if cot_reasoning else "" + cot = cot_reasoning[idx] if cot_reasoning else "Reasoning for this example has not been provided" messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), From 1ef7446a2b9a6fa4add6173a20df8eeb0257f812 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 23:43:32 -0800 Subject: [PATCH 03/10] linting and formatting --- examples/op_examples/filter.py | 2 +- lotus/sem_ops/postprocessors.py | 22 +++++++----- lotus/sem_ops/sem_filter.py | 4 +-- lotus/templates/task_instructions.py | 51 ++++++++++++++++------------ 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index 69c19a15..ebc17088 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction, strategy="cot") +df = df.sem_filter(user_instruction, strategy="") print(df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 0e377cb8..8bec4180 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -7,18 +7,20 @@ SemanticMapPostprocessOutput, ) + def cot_postprocessor(llm_answers: list[str]): outputs: list[str | None] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: import xml.etree.ElementTree as ET + try: root = ET.fromstring(f"{llm_answer}") - reasoning = root.find('.//Reasoning') # Use XPath to find nested tags - answer = root.find('.//Answer') # Use XPath to find nested tags + reasoning = root.find(".//Reasoning") # Use XPath to find nested tags + answer = root.find(".//Answer") # Use XPath to find nested tags if answer is not None and answer.text: - answer = answer.text.strip() + answer = answer.text.strip() else: lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") answer = "" @@ -26,19 +28,20 @@ def cot_postprocessor(llm_answers: list[str]): if reasoning is not None and reasoning.text: reasoning = reasoning.text.strip() else: - lotus.logger.debug(f"\t Failed to parse reasoning from: {llm_answer}") + lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?") reasoning = None explanations.append(reasoning) outputs.append(answer) - - except (ET.ParseError): - lotus.logger.debug(f"\t Failed to parse reasoning and answer from: {llm_answer}") + + except ET.ParseError: + lotus.logger.debug(f"\t XML error parsing: {llm_answer}") explanations.append(None) outputs.append("") - + return outputs, explanations + def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput: """ Postprocess the output of the map operator with CoT reasoning. @@ -110,6 +113,7 @@ def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOut return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data) + def filter_postprocess( llm_answers: list[str], default: bool = True, @@ -139,7 +143,7 @@ def process_outputs(answer): else: lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}") return default - + outputs = [process_outputs(answer) for answer in outputs] return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations) diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index e4ec5939..ea8605cf 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -62,9 +62,7 @@ def sem_filter( inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs ) - postprocess_output = filter_postprocess( - lm_output.outputs, default=default - ) + postprocess_output = filter_postprocess(lm_output.outputs, default=default) lotus.logger.debug(f"outputs: {postprocess_output.outputs}") lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}") lotus.logger.debug(f"explanations: {postprocess_output.explanations}") diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 94454299..c177a48d 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -7,18 +7,29 @@ from lotus.dtype_extensions import ImageDtype from lotus.types import SerializationFormat + def cot_formatter(reasoning, answer): return f"""{reasoning}{answer}""" + +def answer_only_formatter(answer): + return f"""{answer}""" + + def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: - reasoning_instructions = f"Provide your reasoning here.{reasoning_instructions}" + reasoning_instructions = f"Provide your reasoning here. {reasoning_instructions}" answer_instructions = f"Provide your answer here. {answer_instructions}" return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} """ + + def non_cot_prompt_formatter(answer_instructions: str = "") -> str: answer_instructions = f"Provide your answer here. {answer_instructions}" - return f"""{answer_instructions}""" + return f"""Use the following format to provide your answer: + {answer_only_formatter(answer_instructions)} + """ + def context_formatter( multimodal_data: dict[str, Any] | str, @@ -47,6 +58,7 @@ def context_formatter( raise ValueError("multimodal_data must be a dictionary or a string") return text, image_inputs + def user_message_formatter( multimodal_data: dict[str, Any] | str, user_instruction_with_tag: str | None = None, @@ -65,6 +77,7 @@ def user_message_formatter( "content": content, } + def filter_formatter( multimodal_data: dict[str, Any], user_instruction: str, @@ -74,26 +87,18 @@ def filter_formatter( strategy: str | None = None, reasoning_instructions: str = "", ) -> list[dict[str, str]]: - answer_instructions="The answer should be either True or False" - - if strategy == "cot": - sys_instruction = ( - f"""The user will provide a claim and some relevant context. - Your job is to determine whether the claim is true for the given context. + answer_instructions = "The answer should be either True or False" - {cot_prompt_formatter( - reasoning_instructions=reasoning_instructions, - answer_instructions=answer_instructions)} - """ - ) - else: - sys_instruction = ( - f"""The user will provide a claim and some relevant context. - Your job is to determine whether the claim is true for the given context. + sys_instruction = """The user will provide a claim and some relevant context. + Your job is to determine whether the claim is true for the given context. + """ - {non_cot_prompt_formatter(answer_instructions=answer_instructions)} - """ + if strategy == "cot": + sys_instruction += cot_prompt_formatter( + reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions ) + else: + sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions) messages = [ {"role": "system", "content": sys_instruction}, @@ -107,17 +112,19 @@ def filter_formatter( if cot_reasoning: assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) - + for idx in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] - cot = cot_reasoning[idx] if cot_reasoning else "Reasoning for this example has not been provided" + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), { "role": "assistant", - "content": f"""{cot_formatter(cot, ex_ans)}""", + "content": cot_formatter(cot_reasoning[idx], str(ex_ans)) + if cot_reasoning + else answer_only_formatter(str(ex_ans)), }, ] ) From ced50e076ffd7d4cf02dfa4a21e4bab1643d65cd Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 9 Jan 2025 23:49:18 -0800 Subject: [PATCH 04/10] cleaning up for code review --- examples/op_examples/filter.py | 4 ++-- lotus/sem_ops/postprocessors.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index ebc17088..f20aa593 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.logger.setLevel("DEBUG") + data = { "Course Name": [ "Probability and Random Processes", @@ -17,5 +17,5 @@ } df = pd.DataFrame(data) user_instruction = "{Course Name} requires a lot of math" -df = df.sem_filter(user_instruction, strategy="") +df = df.sem_filter(user_instruction) print(df) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 8bec4180..70baedfd 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -16,8 +16,8 @@ def cot_postprocessor(llm_answers: list[str]): try: root = ET.fromstring(f"{llm_answer}") - reasoning = root.find(".//Reasoning") # Use XPath to find nested tags - answer = root.find(".//Answer") # Use XPath to find nested tags + reasoning = root.find(".//Reasoning") + answer = root.find(".//Answer") if answer is not None and answer.text: answer = answer.text.strip() From fb6bdf3bcf702c75d656c679336575570c6e7fc4 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 00:14:14 -0800 Subject: [PATCH 05/10] exposed ability to add custom reasoning instructions and disaggregated providing examples and requiring the model to use CoT --- examples/op_examples/filter.py | 1 + lotus/sem_ops/sem_filter.py | 12 +++++++++++- lotus/templates/task_instructions.py | 20 +++++++++++++++----- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..1580c7f4 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,6 +6,7 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) +lotus.logger.setLevel("DEBUG") data = { "Course Name": [ diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index ea8605cf..dfb3f05f 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -27,6 +27,7 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", + additional_cot_instructions: str = "" ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -40,6 +41,7 @@ def sem_filter( examples_answers (list[bool] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. logprobs (bool): Whether to return log probabilities. Defaults to False. + additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: SemanticFilterOutput: The True/False outputs, raw outputs, and explanations, and log probabilities. @@ -47,7 +49,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy, reasoning_instructions=additional_cot_instructions ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -85,6 +87,7 @@ def learn_filter_cascade_thresholds( examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, + additional_cot_instructions: str = "", ) -> tuple[float, float]: """Automatically learns the cascade thresholds for a cascade filter given a sample of data and doing a search across threshold @@ -102,6 +105,7 @@ def learn_filter_cascade_thresholds( strategy=strategy, safe_mode=False, progress_bar_desc="Running oracle for threshold learning", + additional_cot_instructions=additional_cot_instructions, ).outputs best_combination, _ = learn_cascade_thresholds( @@ -148,6 +152,7 @@ def __call__( return_stats: bool = False, safe_mode: bool = False, progress_bar_desc: str = "Filtering", + additional_cot_instructions: str = "", ) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: """ Applies semantic filter over a dataframe. @@ -166,6 +171,7 @@ def __call__( sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1. failure_probability (float): The failure probability when cascading. Defaults to 0.2. return_stats (bool): Whether to return statistics. Defaults to False. + additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "". Returns: pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: The filtered dataframe or a tuple containing the filtered dataframe and statistics. @@ -245,6 +251,7 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc="Running helper LM", + additional_cot_instructions=additional_cot_instructions, ) helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs assert helper_logprobs is not None @@ -271,6 +278,7 @@ def __call__( examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, + additional_cot_instructions=additional_cot_instructions, ) stats["pos_cascade_threshold"] = pos_cascade_threshold @@ -325,6 +333,7 @@ def __call__( strategy=strategy, safe_mode=safe_mode, progress_bar_desc="Running predicate evals with oracle LM", + additional_cot_instructions=additional_cot_instructions, ) for idx, large_idx in enumerate(low_conf_idxs): @@ -348,6 +357,7 @@ def __call__( safe_mode=safe_mode, show_progress_bar=True, progress_bar_desc=progress_bar_desc, + additional_cot_instructions=additional_cot_instructions, ) outputs = output.outputs raw_outputs = output.raw_outputs diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index c177a48d..680da4ba 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -108,23 +108,33 @@ def filter_formatter( assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) assert len(examples_multimodal_data) == len(examples_answer) - + if cot_reasoning: + # users don't have to provide cot reasoning examples + # but if they do, the number of examples must match assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) for idx in range(len(examples_multimodal_data)): ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] - + content = "" + + # if cot reasoning is provided, use it. Otherwise, supply a default + # reasoning as filler if the user wants cot reasoning + if cot_reasoning: + content = cot_formatter(cot_reasoning[idx], str(ex_ans)) + elif strategy == "cot": + content = cot_formatter("Reasoning omitted", str(ex_ans)) + else: + content = answer_only_formatter(str(ex_ans)) + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), { "role": "assistant", - "content": cot_formatter(cot_reasoning[idx], str(ex_ans)) - if cot_reasoning - else answer_only_formatter(str(ex_ans)), + "content": content, }, ] ) From 436882e96e6f07cfb0111121d825f17a87eedd18 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 00:28:39 -0800 Subject: [PATCH 06/10] remove debug level in filter example --- examples/op_examples/filter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index 1580c7f4..f20aa593 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.logger.setLevel("DEBUG") data = { "Course Name": [ From e0640e9a00e6602b88298d5bbfe34b6acf680a11 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 00:35:30 -0800 Subject: [PATCH 07/10] fix mypy errors --- lotus/sem_ops/postprocessors.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 70baedfd..0472a010 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -20,19 +20,19 @@ def cot_postprocessor(llm_answers: list[str]): answer = root.find(".//Answer") if answer is not None and answer.text: - answer = answer.text.strip() + answer_str = answer.text.strip() else: lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") - answer = "" + answer_str = "" if reasoning is not None and reasoning.text: - reasoning = reasoning.text.strip() + reasoning_str= reasoning.text.strip() else: lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?") - reasoning = None + reasoning_str = None - explanations.append(reasoning) - outputs.append(answer) + explanations.append(reasoning_str) + outputs.append(answer_str) except ET.ParseError: lotus.logger.debug(f"\t XML error parsing: {llm_answer}") From 06857687d9f048795bf6700a0d1a80379f22b702 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Fri, 10 Jan 2025 17:53:09 -0800 Subject: [PATCH 08/10] prompt LLM to generate valid XML --- lotus/templates/task_instructions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 680da4ba..9351866f 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -21,6 +21,8 @@ def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: answer_instructions = f"Provide your answer here. {answer_instructions}" return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} + + Your response must be valid XML format. """ @@ -28,7 +30,8 @@ def non_cot_prompt_formatter(answer_instructions: str = "") -> str: answer_instructions = f"Provide your answer here. {answer_instructions}" return f"""Use the following format to provide your answer: {answer_only_formatter(answer_instructions)} - """ + + Your response must be valid XML format.""" def context_formatter( From 2d7df7a97a38eec7ad41b4bde1f77b800566c453 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Mon, 13 Jan 2025 17:34:39 -0800 Subject: [PATCH 09/10] revert using XML for CoT --- lotus/sem_ops/postprocessors.py | 36 +++++++++------------------- lotus/templates/task_instructions.py | 15 +++++------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/lotus/sem_ops/postprocessors.py b/lotus/sem_ops/postprocessors.py index 0472a010..a361a166 100644 --- a/lotus/sem_ops/postprocessors.py +++ b/lotus/sem_ops/postprocessors.py @@ -12,32 +12,18 @@ def cot_postprocessor(llm_answers: list[str]): outputs: list[str | None] = [] explanations: list[str | None] = [] for llm_answer in llm_answers: - import xml.etree.ElementTree as ET + reasoning_idx = llm_answer.find("Reasoning:\n") + if reasoning_idx == -1: + reasoning_idx = 0 + else: + reasoning_idx += len("Reasoning:\n") - try: - root = ET.fromstring(f"{llm_answer}") - reasoning = root.find(".//Reasoning") - answer = root.find(".//Answer") - - if answer is not None and answer.text: - answer_str = answer.text.strip() - else: - lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}") - answer_str = "" - - if reasoning is not None and reasoning.text: - reasoning_str= reasoning.text.strip() - else: - lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?") - reasoning_str = None - - explanations.append(reasoning_str) - outputs.append(answer_str) - - except ET.ParseError: - lotus.logger.debug(f"\t XML error parsing: {llm_answer}") - explanations.append(None) - outputs.append("") + answer_idx = llm_answer.find("Answer:") + reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n") + answer = llm_answer[answer_idx + len("Answer:") :] + + explanations.append(reasoning) + outputs.append(answer) return outputs, explanations diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 9351866f..fbef1ea2 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -9,29 +9,26 @@ def cot_formatter(reasoning, answer): - return f"""{reasoning}{answer}""" + return f"""Reasoning:\n{reasoning}\n\nAnswer: {answer}""" def answer_only_formatter(answer): - return f"""{answer}""" + return f"""Answer: {answer}""" def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str: - reasoning_instructions = f"Provide your reasoning here. {reasoning_instructions}" - answer_instructions = f"Provide your answer here. {answer_instructions}" + reasoning_instructions = f"" + answer_instructions = f"" return f"""Let's think step by step. Use the following format to provide your answer: {cot_formatter(reasoning_instructions, answer_instructions)} - - Your response must be valid XML format. """ def non_cot_prompt_formatter(answer_instructions: str = "") -> str: - answer_instructions = f"Provide your answer here. {answer_instructions}" + answer_instructions = f"" return f"""Use the following format to provide your answer: {answer_only_formatter(answer_instructions)} - - Your response must be valid XML format.""" + """ def context_formatter( From 13d01618ec3f0e3c55730e64af225f24899c5304 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Mon, 13 Jan 2025 18:03:54 -0800 Subject: [PATCH 10/10] ruff format and removed excesss changes to mnimize PR --- examples/op_examples/filter.py | 1 - examples/op_examples/filter_cascade.py | 1 - lotus/sem_ops/sem_filter.py | 10 ++++++++-- lotus/templates/task_instructions.py | 8 ++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/op_examples/filter.py b/examples/op_examples/filter.py index f20aa593..a1acc00d 100644 --- a/examples/op_examples/filter.py +++ b/examples/op_examples/filter.py @@ -6,7 +6,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) - data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/filter_cascade.py b/examples/op_examples/filter_cascade.py index a1b94f4d..104c8410 100644 --- a/examples/op_examples/filter_cascade.py +++ b/examples/op_examples/filter_cascade.py @@ -8,7 +8,6 @@ gpt_4o = LM("gpt-4o") lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini) - data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index dfb3f05f..7a8cf4b0 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -27,7 +27,7 @@ def sem_filter( safe_mode: bool = False, show_progress_bar: bool = True, progress_bar_desc: str = "Filtering", - additional_cot_instructions: str = "" + additional_cot_instructions: str = "", ) -> SemanticFilterOutput: """ Filters a list of documents based on a given user instruction using a language model. @@ -49,7 +49,13 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy, reasoning_instructions=additional_cot_instructions + doc, + user_instruction, + examples_multimodal_data, + examples_answers, + cot_reasoning, + strategy, + reasoning_instructions=additional_cot_instructions, ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index fbef1ea2..a71acd8c 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -108,10 +108,10 @@ def filter_formatter( assert examples_answer is not None assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) assert len(examples_multimodal_data) == len(examples_answer) - + if cot_reasoning: - # users don't have to provide cot reasoning examples - # but if they do, the number of examples must match + # users don't have to provide cot reasoning examples + # but if they do, the number of examples must match assert isinstance(cot_reasoning, list) assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning) @@ -128,7 +128,7 @@ def filter_formatter( content = cot_formatter("Reasoning omitted", str(ex_ans)) else: content = answer_only_formatter(str(ex_ans)) - + messages.extend( [ user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),