From b574e2fb10f3b53e1bb62e0a1d68f8509f203c2d Mon Sep 17 00:00:00 2001 From: Mark Waddle Date: Fri, 8 Nov 2024 16:04:11 -0800 Subject: [PATCH] Renders populated form as markdown (#235) The populated form is available in the state inspector. When completed, it is also returned as a message. --- .../agents/form_fill_extension/extension.py | 27 ++-- .../agents/form_fill_extension/inspector.py | 31 +++- .../agents/form_fill_extension/state.py | 41 +++-- .../steps/_guided_conversation.py | 2 +- .../steps/acquire_form_step.py | 4 +- .../steps/extract_form_fields_step.py | 16 +- .../steps/fill_form_step.py | 144 ++++++++++++++---- 7 files changed, 204 insertions(+), 61 deletions(-) diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/extension.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/extension.py index 9c283258..78a1952b 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/extension.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/extension.py @@ -53,22 +53,22 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]: ), ) - async with state.agent_state(context) as agent_state: + async with state.extension_state(context) as agent_state: while True: logger.info("form-fill-agent execute loop; mode: %s", agent_state.mode) match agent_state.mode: - case state.FormFillAgentMode.acquire_form_step: + case state.FormFillExtensionMode.acquire_form_step: result = await acquire_form_step.execute( step_context=build_step_context(config.acquire_form_config), ) match result: case acquire_form_step.CompleteResult(): - await _send_message(context, result.ai_message, result.debug) + await _send_message(context, result.message, result.debug) agent_state.form_filename = result.filename - agent_state.mode = state.FormFillAgentMode.extract_form_fields + agent_state.mode = state.FormFillExtensionMode.extract_form_fields continue @@ -76,7 +76,7 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]: await _handle_incomplete_result(context, result) return - case state.FormFillAgentMode.extract_form_fields: + case state.FormFillExtensionMode.extract_form_fields: file_content = await get_attachment_content(agent_state.form_filename) result = await extract_form_fields_step.execute( step_context=build_step_context(config.extract_form_fields_config), @@ -85,10 +85,11 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]: match result: case extract_form_fields_step.CompleteResult(): - await _send_message(context, result.ai_message, result.debug) + await _send_message(context, result.message, result.debug) + agent_state.extracted_form_title = result.extracted_form_title agent_state.extracted_form_fields = result.extracted_form_fields - agent_state.mode = state.FormFillAgentMode.fill_form_step + agent_state.mode = state.FormFillExtensionMode.fill_form_step continue @@ -96,19 +97,21 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]: await _handle_incomplete_result(context, result) return - case state.FormFillAgentMode.fill_form_step: + case state.FormFillExtensionMode.fill_form_step: result = await fill_form_step.execute( step_context=build_step_context(config.fill_form_config), form_filename=agent_state.form_filename, + form_title=agent_state.extracted_form_title, form_fields=agent_state.extracted_form_fields, ) match result: case fill_form_step.CompleteResult(): - await _send_message(context, result.ai_message, result.debug) + await _send_message(context, result.message, result.debug) + agent_state.populated_form_markdown = result.populated_form_markdown agent_state.fill_form_gc_artifact = result.artifact - agent_state.mode = state.FormFillAgentMode.generate_filled_form_step + agent_state.mode = state.FormFillExtensionMode.conversation_over continue @@ -116,10 +119,10 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]: await _handle_incomplete_result(context, result) return - case state.FormFillAgentMode.generate_filled_form_step: + case state.FormFillExtensionMode.conversation_over: await _send_message( context, - "I'd love to generate the filled-out form now, but it's not yet implemented. :)", + "The form is now complete! Create a new conversation to work with another form.", {}, ) return diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/inspector.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/inspector.py index a8b0192f..ea4688e1 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/inspector.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/inspector.py @@ -1,5 +1,6 @@ import contextlib import json +from enum import StrEnum from hashlib import md5 from pathlib import Path from typing import Callable @@ -12,6 +13,17 @@ ) +class StateProjection(StrEnum): + """ + The projection to use when displaying the state. + """ + + original_content = "original_content" + """Return the state as string content.""" + json_to_yaml = "json_to_yaml" + """Return the state as a yaml code block.""" + + class FileStateInspector(ReadOnlyAssistantConversationInspectorStateProvider): """ A conversation inspector state provider that reads the state from a file and displays it as a yaml code block. @@ -22,6 +34,8 @@ def __init__( display_name: str, file_path_source: Callable[[ConversationContext], Path], description: str = "", + projection: StateProjection = StateProjection.json_to_yaml, + select_field: str = "", ) -> None: self._state_id = md5( (type(self).__name__ + "_" + display_name).encode("utf-8"), usedforsecurity=False @@ -29,6 +43,8 @@ def __init__( self._display_name = display_name self._file_path_source = file_path_source self._description = description + self._projection = projection + self._select_field = select_field @property def state_id(self) -> str: @@ -50,7 +66,14 @@ def read_state(path: Path) -> dict: state = read_state(self._file_path_source(context)) - # return the state as a yaml code block, as it is easier to read than json - return AssistantConversationInspectorStateDataModel( - data={"content": f"```yaml\n{yaml.dump(state, sort_keys=False)}\n```"}, - ) + selected = state.get(self._select_field) if self._select_field else state + + match self._projection: + case StateProjection.original_content: + return AssistantConversationInspectorStateDataModel(data={"content": selected}) + case StateProjection.json_to_yaml: + state_as_yaml = yaml.dump(selected, sort_keys=False) + # return the state as a yaml code block, as it is easier to read than json + return AssistantConversationInspectorStateDataModel( + data={"content": f"```yaml\n{state_as_yaml}\n```"}, + ) diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/state.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/state.py index 3d49eae1..121f6414 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/state.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/state.py @@ -2,7 +2,7 @@ from contextvars import ContextVar from enum import StrEnum from pathlib import Path -from typing import AsyncIterator, Literal +from typing import AsyncIterator from pydantic import BaseModel, Field from semantic_workbench_assistant.assistant_app.context import ConversationContext, storage_directory_for_context @@ -11,26 +11,47 @@ from .inspector import FileStateInspector +class FieldType(StrEnum): + text = "text" + date = "date" + signature = "signature" + multiple_choice = "multiple_choice" + + +class AllowedOptionSelections(StrEnum): + one = "one" + """One of the options can be selected.""" + many = "many" + """One or more of the options can be selected.""" + + class FormField(BaseModel): id: str = Field(description="The descriptive, unique identifier of the field as a snake_case_english_string.") name: str = Field(description="The name of the field.") description: str = Field(description="The description of the field.") - type: Literal["string", "bool", "multiple_choice"] = Field(description="The type of the field.") + type: FieldType = Field(description="The type of the field.") options: list[str] = Field(description="The options for multiple choice fields.") - required: bool = Field(description="Whether the field is required or not.") + option_selections_allowed: AllowedOptionSelections | None = Field( + description="The number of options that can be selected for multiple choice fields." + ) + required: bool = Field( + description="Whether the field is required or not. False indicates the field is optional and can be left blank." + ) -class FormFillAgentMode(StrEnum): +class FormFillExtensionMode(StrEnum): acquire_form_step = "acquire_form" extract_form_fields = "extract_form_fields" fill_form_step = "fill_form" - generate_filled_form_step = "generate_filled_form" + conversation_over = "conversation_over" -class FormFillAgentState(BaseModel): - mode: FormFillAgentMode = FormFillAgentMode.acquire_form_step +class FormFillExtensionState(BaseModel): + mode: FormFillExtensionMode = FormFillExtensionMode.acquire_form_step form_filename: str = "" + extracted_form_title: str = "" extracted_form_fields: list[FormField] = [] + populated_form_markdown: str = "" fill_form_gc_artifact: dict | None = None @@ -38,11 +59,11 @@ def path_for_state(context: ConversationContext) -> Path: return storage_directory_for_context(context) / "state.json" -current_state = ContextVar[FormFillAgentState | None]("current_state", default=None) +current_state = ContextVar[FormFillExtensionState | None]("current_state", default=None) @asynccontextmanager -async def agent_state(context: ConversationContext) -> AsyncIterator[FormFillAgentState]: +async def extension_state(context: ConversationContext) -> AsyncIterator[FormFillExtensionState]: """ Context manager that provides the agent state, reading it from disk, and saving back to disk after the context manager block is executed. @@ -53,7 +74,7 @@ async def agent_state(context: ConversationContext) -> AsyncIterator[FormFillAge return async with context.state_updated_event_after(inspector.state_id): - state = read_model(path_for_state(context), FormFillAgentState) or FormFillAgentState() + state = read_model(path_for_state(context), FormFillExtensionState) or FormFillExtensionState() current_state.set(state) yield state write_model(path_for_state(context), state) diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/_guided_conversation.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/_guided_conversation.py index 08dbfbfc..61baecb5 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/_guided_conversation.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/_guided_conversation.py @@ -40,7 +40,7 @@ async def engine( given state file. """ - async with _state_locks[state_file_path], context.state_updated_event_after(state_id, focus_event=True): + async with _state_locks[state_file_path], context.state_updated_event_after(state_id): kernel, service_id = _build_kernel_with_service(openai_client, openai_model) state: dict | None = None diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/acquire_form_step.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/acquire_form_step.py index c2c2fbc7..516bcd97 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/acquire_form_step.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/acquire_form_step.py @@ -61,7 +61,7 @@ class AcquireFormConfig(BaseModel): @dataclass class CompleteResult(Result): - ai_message: str + message: str filename: str @@ -103,7 +103,7 @@ async def execute( if form_filename and form_filename != "Unanswered": return CompleteResult( - ai_message=result.ai_message or "", + message=result.ai_message or "", filename=form_filename, debug=debug, ) diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/extract_form_fields_step.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/extract_form_fields_step.py index cfe8f156..8a4e7e5b 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/extract_form_fields_step.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/extract_form_fields_step.py @@ -19,15 +19,17 @@ class ExtractFormFieldsConfig(BaseModel): Field(title="Instruction", description="The instruction for extracting form fields from the file content."), UISchema(widget="textarea"), ] = ( - "Extract the form fields from the provided form attachment. Any type of form is allowed, including for example" + "Read the user provided form attachment and determine what fields are in the form. Any type of form is allowed, including" " tax forms, address forms, surveys, and other official or unofficial form-types. If the content is not a form," - " or the fields cannot be determined, then set the error_message." + " or the fields cannot be determined, then explain the reason why in the error_message. If the fields can be determined," + " leave the error_message empty." ) @dataclass class CompleteResult(Result): - ai_message: str + message: str + extracted_form_title: str extracted_form_fields: list[state.FormField] @@ -62,14 +64,18 @@ async def execute( ) return CompleteResult( - ai_message="", + message="", + extracted_form_title=extracted_form_fields.title, extracted_form_fields=extracted_form_fields.fields, debug=metadata, ) class FormFields(BaseModel): - error_message: str = Field(description="The error message in the case that the form fields could not be extracted.") + error_message: str = Field( + description="The error message in the case that the form fields could not be determined." + ) + title: str = Field(description="The title of the form.") fields: list[state.FormField] = Field(description="The fields in the form.") diff --git a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/fill_form_step.py b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/fill_form_step.py index 1ff4fa2b..412bf9e7 100644 --- a/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/fill_form_step.py +++ b/assistants/prospector-assistant/assistant/agents/form_fill_extension/steps/fill_form_step.py @@ -1,18 +1,19 @@ import logging +from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path from textwrap import dedent -from typing import Annotated, Any, Literal +from typing import Annotated, Any, AsyncIterator, Literal, Optional from guided_conversation.utils.resources import ResourceConstraintMode, ResourceConstraintUnit from openai.types.chat import ChatCompletionMessageParam -from pydantic import BaseModel, Field, create_model -from semantic_workbench_assistant.assistant_app.context import ConversationContext +from pydantic import BaseModel, ConfigDict, Field, create_model +from semantic_workbench_assistant.assistant_app.context import ConversationContext, storage_directory_for_context from semantic_workbench_assistant.assistant_app.protocol import AssistantAppProtocol from semantic_workbench_assistant.config import UISchema from .. import state -from ..inspector import FileStateInspector +from ..inspector import FileStateInspector, StateProjection from . import _guided_conversation, _llm from .types import ( Context, @@ -28,7 +29,8 @@ def extend(app: AssistantAppProtocol) -> None: - app.add_inspector_state_provider(_inspector.state_id, _inspector) + app.add_inspector_state_provider(_guided_conversation_inspector.state_id, _guided_conversation_inspector) + app.add_inspector_state_provider(_populated_form_state_inspector.state_id, _populated_form_state_inspector) definition = GuidedConversationDefinition( @@ -101,15 +103,21 @@ class FieldValueCandidatesFromDocument(BaseModel): candidates: FieldValueCandidates +class FillFormState(BaseModel): + populated_form_markdown: str = "(The form has not yet been provided)" + + @dataclass class CompleteResult(Result): - ai_message: str + message: str artifact: dict + populated_form_markdown: str async def execute( step_context: Context[FillFormConfig], form_filename: str, + form_title: str, form_fields: list[state.FormField], ) -> IncompleteResult | IncompleteErrorResult | CompleteResult: """ @@ -130,11 +138,11 @@ async def execute( async with _guided_conversation.engine( definition=definition, artifact_type=artifact_type, - state_file_path=_get_state_file_path(step_context.context), + state_file_path=_get_guided_conversation_state_file_path(step_context.context), openai_client=step_context.llm_config.openai_client_factory(), openai_model=step_context.llm_config.openai_model, context=step_context.context, - state_id=_inspector.state_id, + state_id=_guided_conversation_inspector.state_id, ) as gce: try: result = await gce.step_conversation(message) @@ -152,10 +160,20 @@ async def execute( fill_form_gc_artifact = gce.artifact.artifact.model_dump(mode="json") logger.info("guided-conversation artifact: %s", gce.artifact) + populated_form_markdown = _generate_populated_form( + form_title=form_title, + form_fields=form_fields, + populated_fields=fill_form_gc_artifact, + ) + + async with step_state(step_context.context) as state: + state.populated_form_markdown = populated_form_markdown + if result.is_conversation_over: return CompleteResult( - ai_message="", + message=populated_form_markdown, artifact=fill_form_gc_artifact, + populated_form_markdown=populated_form_markdown, debug=debug, ) @@ -216,22 +234,55 @@ def _form_fields_to_artifact_basemodel(form_fields: list[state.FormField]): required_fields.append(field.id) match field.type: - case "string": - field_definitions[field.id] = (str, Field(title=field.name, description=field.description)) + case state.FieldType.text | state.FieldType.signature | state.FieldType.date: + field_type = str + + case state.FieldType.multiple_choice: + match field.option_selections_allowed: + case state.AllowedOptionSelections.one: + field_type = Literal[tuple(field.options)] + + case state.AllowedOptionSelections.many: + field_type = list[Literal[tuple(field.options)]] + + case _: + raise ValueError(f"Unsupported option_selections_allowed: {field.option_selections_allowed}") - case "bool": - field_definitions[field.id] = (bool, Field(title=field.name, description=field.description)) + case _: + raise ValueError(f"Unsupported field type: {field.type}") - case "multiple_choice": - field_definitions[field.id] = ( - Literal[tuple(field.options)], - Field(title=field.name, description=field.description), - ) + if not field.required: + field_type = Optional[field_type] + + field_definitions[field.id] = (field_type, Field(title=field.name, description=field.description)) return create_model( "FilledFormArtifact", + __config__=ConfigDict(json_schema_extra={"required": required_fields}), **field_definitions, # type: ignore - ) # type: ignore + ) + + +def _get_guided_conversation_state_file_path(context: ConversationContext) -> Path: + return _guided_conversation.path_for_state(context, "fill_form") + + +_guided_conversation_inspector = FileStateInspector( + display_name="Fill-Form Guided-Conversation", + file_path_source=_get_guided_conversation_state_file_path, +) + + +def _get_step_state_file_path(context: ConversationContext) -> Path: + return storage_directory_for_context(context, "fill_form_state.json") + + +_populated_form_state_inspector = FileStateInspector( + display_name="Populated Form", + file_path_source=_get_step_state_file_path, + projection=StateProjection.original_content, + select_field="populated_form_markdown", +) async def _extract( @@ -263,11 +314,50 @@ class _SerializationModel(BaseModel): ) -def _get_state_file_path(context: ConversationContext) -> Path: - return _guided_conversation.path_for_state(context, "fill_form") - - -_inspector = FileStateInspector( - display_name="Fill-Form Guided-Conversation", - file_path_source=_get_state_file_path, -) +def _generate_populated_form( + form_title: str, + form_fields: list[state.FormField], + populated_fields: dict, +) -> str: + def field_value(field_id: str) -> str: + value = populated_fields.get(field_id) or "" + if value == "Unanswered": + return "_" * 20 + if value == "null": + return "" + return value + + markdown_fields: list[str] = [] + for field in form_fields: + value = field_value(field.id) + match field.type: + case state.FieldType.text | state.FieldType.signature | state.FieldType.date: + markdown_fields.append(f"*{field.name}:*\n\n{value}") + + case state.FieldType.multiple_choice: + markdown_fields.append(f"*{field.name}:*\n") + for option in field.options: + if option in value: + markdown_fields.append(f"- [x] {option}\n") + continue + markdown_fields.append(f"- [ ] {option}\n") + + case _: + raise ValueError(f"Unsupported field type: {field.type}") + + all_fields = "\n\n".join(markdown_fields) + return "\n".join(( + "```markdown", + f"## {form_title}", + "", + all_fields, + "```", + )) + + +@asynccontextmanager +async def step_state(context: ConversationContext) -> AsyncIterator[FillFormState]: + step_state = state.read_model(_get_step_state_file_path(context), FillFormState) or FillFormState() + async with context.state_updated_event_after(_populated_form_state_inspector.state_id, focus_event=True): + yield step_state + state.write_model(_get_step_state_file_path(context), step_state)