Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composable scaffolding for CoT #75

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
80 changes: 32 additions & 48 deletions lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@
)


def cot_postprocessor(llm_answers: list[str]):
outputs: list[str | None] = []
explanations: list[str | None] = []
for llm_answer in llm_answers:
reasoning_idx = llm_answer.find("Reasoning:\n")
if reasoning_idx == -1:
reasoning_idx = 0
else:
reasoning_idx += len("Reasoning:\n")

answer_idx = llm_answer.find("Answer:")
reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
answer = llm_answer[answer_idx + len("Answer:") :]

explanations.append(reasoning)
outputs.append(answer)

return outputs, explanations


def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
"""
Postprocess the output of the map operator with CoT reasoning.
Expand Down Expand Up @@ -80,48 +100,9 @@ def extract_postprocess(llm_answers: list[str]) -> SemanticExtractPostprocessOut
return SemanticExtractPostprocessOutput(raw_outputs=llm_answers, outputs=extract_data)


def filter_postprocess_cot(llm_answers: list[str], default: bool) -> SemanticFilterPostprocessOutput:
"""
Postprocess the output of the filter operator with CoT reasoning.

Args:
llm_answers (list[str]): The list of llm answers.
default (bool): The default value to use if we fail to parse the answer.

Returns:
SemanticFilterPostprocessOutput
"""
outputs: list[bool] = []
explanations: list[str | None] = []

for llm_answer in llm_answers:
reasoning_idx = llm_answer.find("Reasoning:\n")
if reasoning_idx == -1:
reasoning_idx = 0
else:
reasoning_idx += len("Reasoning:\n")

answer_idx = llm_answer.find("Answer:")
reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
answer = llm_answer[answer_idx + len("Answer:") :]

explanations.append(reasoning)

if "True" in answer:
outputs.append(True)
elif "False" in answer:
outputs.append(False)
else:
lotus.logger.info(f"\t Failed to parse: defaulting to {default}")
outputs.append(default)

return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def filter_postprocess(
llm_answers: list[str],
default: bool = True,
cot_reasoning: bool = False,
) -> SemanticFilterPostprocessOutput:
"""
Postprocess the output of the filter operator.
Expand All @@ -134,18 +115,21 @@ def filter_postprocess(
Returns:
SemanticFilterPostprocessOutput
"""
if cot_reasoning:
return filter_postprocess_cot(llm_answers, default)
outputs, explanations = cot_postprocessor(llm_answers)

def process_outputs(answer):
if answer is None:
lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}")
return default

outputs: list[bool] = []
explanations: list[str | None] = [None] * len(llm_answers)
for answer in llm_answers:
if "True" in answer:
outputs.append(True)
return True
elif "False" in answer:
outputs.append(False)
return False
else:
lotus.logger.info(f"\t Failed to parse: defaulting to {default}")
outputs.append(default)
lotus.logger.info(f"\t Failed to parse {answer}: defaulting to {default}")
return default

outputs = [process_outputs(answer) for answer in outputs]

return SemanticFilterPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)
22 changes: 18 additions & 4 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def sem_filter(
safe_mode: bool = False,
show_progress_bar: bool = True,
progress_bar_desc: str = "Filtering",
additional_cot_instructions: str = "",
) -> SemanticFilterOutput:
"""
Filters a list of documents based on a given user instruction using a language model.
Expand All @@ -40,14 +41,21 @@ def sem_filter(
examples_answers (list[bool] | None): The answers for examples. Defaults to None.
cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None.
logprobs (bool): Whether to return log probabilities. Defaults to False.
additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "".

Returns:
SemanticFilterOutput: The True/False outputs, raw outputs, and explanations, and log probabilities.
"""
inputs = []
for doc in docs:
prompt = lotus.templates.task_instructions.filter_formatter(
doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy
doc,
user_instruction,
examples_multimodal_data,
examples_answers,
cot_reasoning,
strategy,
reasoning_instructions=additional_cot_instructions,
)
lotus.logger.debug(f"input to model: {prompt}")
inputs.append(prompt)
Expand All @@ -62,9 +70,7 @@ def sem_filter(
inputs, show_progress_bar=show_progress_bar, progress_bar_desc=progress_bar_desc, **kwargs
)

postprocess_output = filter_postprocess(
lm_output.outputs, default=default, cot_reasoning=strategy in ["cot", "zs-cot"]
)
postprocess_output = filter_postprocess(lm_output.outputs, default=default)
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"raw_outputs: {postprocess_output.raw_outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")
Expand All @@ -87,6 +93,7 @@ def learn_filter_cascade_thresholds(
examples_answers: list[bool] | None = None,
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
additional_cot_instructions: str = "",
) -> tuple[float, float]:
"""Automatically learns the cascade thresholds for a cascade
filter given a sample of data and doing a search across threshold
Expand All @@ -104,6 +111,7 @@ def learn_filter_cascade_thresholds(
strategy=strategy,
safe_mode=False,
progress_bar_desc="Running oracle for threshold learning",
additional_cot_instructions=additional_cot_instructions,
).outputs

best_combination, _ = learn_cascade_thresholds(
Expand Down Expand Up @@ -150,6 +158,7 @@ def __call__(
return_stats: bool = False,
safe_mode: bool = False,
progress_bar_desc: str = "Filtering",
additional_cot_instructions: str = "",
) -> pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]:
"""
Applies semantic filter over a dataframe.
Expand All @@ -168,6 +177,7 @@ def __call__(
sampling_percentage (float): The percentage of the data to sample when cascading. Defaults to 0.1.
failure_probability (float): The failure probability when cascading. Defaults to 0.2.
return_stats (bool): Whether to return statistics. Defaults to False.
additional_cot_instructions (str): Additional instructions for the CoT. Defaults to "".

Returns:
pd.DataFrame | tuple[pd.DataFrame, dict[str, Any]]: The filtered dataframe or a tuple containing the filtered dataframe and statistics.
Expand Down Expand Up @@ -247,6 +257,7 @@ def __call__(
safe_mode=safe_mode,
show_progress_bar=True,
progress_bar_desc="Running helper LM",
additional_cot_instructions=additional_cot_instructions,
)
helper_outputs, helper_logprobs = helper_output.outputs, helper_output.logprobs
assert helper_logprobs is not None
Expand All @@ -273,6 +284,7 @@ def __call__(
examples_answers=examples_answers,
cot_reasoning=cot_reasoning,
strategy=strategy,
additional_cot_instructions=additional_cot_instructions,
)

stats["pos_cascade_threshold"] = pos_cascade_threshold
Expand Down Expand Up @@ -327,6 +339,7 @@ def __call__(
strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc="Running predicate evals with oracle LM",
additional_cot_instructions=additional_cot_instructions,
)

for idx, large_idx in enumerate(low_conf_idxs):
Expand All @@ -350,6 +363,7 @@ def __call__(
safe_mode=safe_mode,
show_progress_bar=True,
progress_bar_desc=progress_bar_desc,
additional_cot_instructions=additional_cot_instructions,
)
outputs = output.outputs
raw_outputs = output.raw_outputs
Expand Down
127 changes: 61 additions & 66 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,29 @@
from lotus.types import SerializationFormat


def cot_formatter(reasoning, answer):
return f"""Reasoning:\n{reasoning}\n\nAnswer: {answer}"""


def answer_only_formatter(answer):
return f"""Answer: {answer}"""


def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str:
reasoning_instructions = f"<Your reasoning here. {reasoning_instructions}>"
answer_instructions = f"<Your answer here. {answer_instructions}>"
return f"""Let's think step by step. Use the following format to provide your answer:
{cot_formatter(reasoning_instructions, answer_instructions)}
"""


def non_cot_prompt_formatter(answer_instructions: str = "") -> str:
answer_instructions = f"<Your answer here. {answer_instructions}>"
return f"""Use the following format to provide your answer:
{answer_only_formatter(answer_instructions)}
"""


def context_formatter(
multimodal_data: dict[str, Any] | str,
) -> tuple[str, list[dict[str, str]]]:
Expand Down Expand Up @@ -55,92 +78,64 @@ def user_message_formatter(
}


def filter_formatter_cot(
multimodal_data: dict[str, Any],
user_instruction: str,
examples_multimodal_data: list[dict[str, Any]],
examples_answer: list[bool],
cot_reasoning: list[str],
) -> list[dict[str, str]]:
sys_instruction = (
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'First give your reasoning. Then you MUST end your output with "Answer: True or False"'
)
messages = [
{"role": "system", "content": sys_instruction},
]

for idx in range(len(examples_multimodal_data)):
ex_multimodal_data = examples_multimodal_data[idx]
ex_ans = examples_answer[idx]
cot = cot_reasoning[idx]
messages.extend(
[
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
{
"role": "assistant",
"content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}",
},
]
)

messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages


def filter_formatter_zs_cot(
multimodal_data: dict[str, Any],
user_instruction: str,
) -> list[dict[str, str]]:
sys_instruction = (
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'First give your reasoning. Then you MUST end your output with "Answer: True or False"'
)
messages = [
{"role": "system", "content": sys_instruction},
]

messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}"))
return messages


def filter_formatter(
multimodal_data: dict[str, Any],
user_instruction: str,
examples_multimodal_data: list[dict[str, Any]] | None = None,
examples_answer: list[bool] | None = None,
cot_reasoning: list[str] | None = None,
strategy: str | None = None,
reasoning_instructions: str = "",
) -> list[dict[str, str]]:
if cot_reasoning:
assert examples_multimodal_data is not None and examples_answer is not None
return filter_formatter_cot(
multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning
answer_instructions = "The answer should be either True or False"

sys_instruction = """The user will provide a claim and some relevant context.
Your job is to determine whether the claim is true for the given context.
"""

if strategy == "cot":
sys_instruction += cot_prompt_formatter(
reasoning_instructions=reasoning_instructions, answer_instructions=answer_instructions
)
elif strategy == "zs-cot":
return filter_formatter_zs_cot(multimodal_data, user_instruction)
else:
sys_instruction += non_cot_prompt_formatter(answer_instructions=answer_instructions)

sys_instruction = (
"The user will provide a claim and some relevant context.\n"
"Your job is to determine whether the claim is true for the given context.\n"
'You must answer with a single word, "True" or "False".'
)
messages = [
{"role": "system", "content": sys_instruction},
]

if examples_multimodal_data:
assert examples_answer is not None
assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list)
for i in range(len(examples_multimodal_data)):
ex_multimodal_data = examples_multimodal_data[i]
ex_ans = examples_answer[i]
assert len(examples_multimodal_data) == len(examples_answer)

if cot_reasoning:
# users don't have to provide cot reasoning examples
# but if they do, the number of examples must match
assert isinstance(cot_reasoning, list)
assert len(examples_multimodal_data) == len(examples_answer) == len(cot_reasoning)

for idx in range(len(examples_multimodal_data)):
ex_multimodal_data = examples_multimodal_data[idx]
ex_ans = examples_answer[idx]
content = ""

# if cot reasoning is provided, use it. Otherwise, supply a default
# reasoning as filler if the user wants cot reasoning
if cot_reasoning:
content = cot_formatter(cot_reasoning[idx], str(ex_ans))
elif strategy == "cot":
content = cot_formatter("Reasoning omitted", str(ex_ans))
else:
content = answer_only_formatter(str(ex_ans))

messages.extend(
[
user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"),
{"role": "assistant", "content": str(ex_ans)},
{
"role": "assistant",
"content": content,
},
]
)

Expand Down
Loading