Skip to content

Commit

Permalink
revert using XML for CoT
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruviyer committed Jan 14, 2025
1 parent 0685768 commit 2d7df7a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
36 changes: 11 additions & 25 deletions lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,18 @@ def cot_postprocessor(llm_answers: list[str]):
outputs: list[str | None] = []
explanations: list[str | None] = []
for llm_answer in llm_answers:
import xml.etree.ElementTree as ET
reasoning_idx = llm_answer.find("Reasoning:\n")
if reasoning_idx == -1:
reasoning_idx = 0
else:
reasoning_idx += len("Reasoning:\n")

try:
root = ET.fromstring(f"<root>{llm_answer}</root>")
reasoning = root.find(".//Reasoning")
answer = root.find(".//Answer")

if answer is not None and answer.text:
answer_str = answer.text.strip()
else:
lotus.logger.error(f"\t Failed to parse answer from: {llm_answer}")
answer_str = ""

if reasoning is not None and reasoning.text:
reasoning_str= reasoning.text.strip()
else:
lotus.logger.debug(f"\t Unable to extract reasoning from: {llm_answer}. Was CoT used?")
reasoning_str = None

explanations.append(reasoning_str)
outputs.append(answer_str)

except ET.ParseError:
lotus.logger.debug(f"\t XML error parsing: {llm_answer}")
explanations.append(None)
outputs.append("")
answer_idx = llm_answer.find("Answer:")
reasoning = llm_answer[reasoning_idx:answer_idx].rstrip("\n").lstrip("\n")
answer = llm_answer[answer_idx + len("Answer:") :]

explanations.append(reasoning)
outputs.append(answer)

return outputs, explanations

Expand Down
15 changes: 6 additions & 9 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,26 @@


def cot_formatter(reasoning, answer):
return f"""<Reasoning>{reasoning}</Reasoning><Answer>{answer}</Answer>"""
return f"""Reasoning:\n{reasoning}\n\nAnswer: {answer}"""


def answer_only_formatter(answer):
return f"""<Answer>{answer}</Answer>"""
return f"""Answer: {answer}"""


def cot_prompt_formatter(reasoning_instructions: str = "", answer_instructions: str = "") -> str:
reasoning_instructions = f"Provide your reasoning here. {reasoning_instructions}"
answer_instructions = f"Provide your answer here. {answer_instructions}"
reasoning_instructions = f"<Your reasoning here. {reasoning_instructions}>"
answer_instructions = f"<Your answer here. {answer_instructions}>"
return f"""Let's think step by step. Use the following format to provide your answer:
{cot_formatter(reasoning_instructions, answer_instructions)}
Your response must be valid XML format.
"""


def non_cot_prompt_formatter(answer_instructions: str = "") -> str:
answer_instructions = f"Provide your answer here. {answer_instructions}"
answer_instructions = f"<Your answer here. {answer_instructions}>"
return f"""Use the following format to provide your answer:
{answer_only_formatter(answer_instructions)}
Your response must be valid XML format."""
"""


def context_formatter(
Expand Down

0 comments on commit 2d7df7a

Please sign in to comment.