Skip to content

Commit

Permalink
Renders populated form as markdown (microsoft#235)
Browse files Browse the repository at this point in the history
The populated form is available in the state inspector. When completed,
it is also returned as a message.
  • Loading branch information
markwaddle authored Nov 9, 2024
1 parent 300d3b7 commit b574e2f
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,30 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]:
),
)

async with state.agent_state(context) as agent_state:
async with state.extension_state(context) as agent_state:
while True:
logger.info("form-fill-agent execute loop; mode: %s", agent_state.mode)

match agent_state.mode:
case state.FormFillAgentMode.acquire_form_step:
case state.FormFillExtensionMode.acquire_form_step:
result = await acquire_form_step.execute(
step_context=build_step_context(config.acquire_form_config),
)

match result:
case acquire_form_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)
await _send_message(context, result.message, result.debug)

agent_state.form_filename = result.filename
agent_state.mode = state.FormFillAgentMode.extract_form_fields
agent_state.mode = state.FormFillExtensionMode.extract_form_fields

continue

case _:
await _handle_incomplete_result(context, result)
return

case state.FormFillAgentMode.extract_form_fields:
case state.FormFillExtensionMode.extract_form_fields:
file_content = await get_attachment_content(agent_state.form_filename)
result = await extract_form_fields_step.execute(
step_context=build_step_context(config.extract_form_fields_config),
Expand All @@ -85,41 +85,44 @@ def build_step_context(config: ConfigT) -> Context[ConfigT]:

match result:
case extract_form_fields_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)
await _send_message(context, result.message, result.debug)

agent_state.extracted_form_title = result.extracted_form_title
agent_state.extracted_form_fields = result.extracted_form_fields
agent_state.mode = state.FormFillAgentMode.fill_form_step
agent_state.mode = state.FormFillExtensionMode.fill_form_step

continue

case _:
await _handle_incomplete_result(context, result)
return

case state.FormFillAgentMode.fill_form_step:
case state.FormFillExtensionMode.fill_form_step:
result = await fill_form_step.execute(
step_context=build_step_context(config.fill_form_config),
form_filename=agent_state.form_filename,
form_title=agent_state.extracted_form_title,
form_fields=agent_state.extracted_form_fields,
)

match result:
case fill_form_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)
await _send_message(context, result.message, result.debug)

agent_state.populated_form_markdown = result.populated_form_markdown
agent_state.fill_form_gc_artifact = result.artifact
agent_state.mode = state.FormFillAgentMode.generate_filled_form_step
agent_state.mode = state.FormFillExtensionMode.conversation_over

continue

case _:
await _handle_incomplete_result(context, result)
return

case state.FormFillAgentMode.generate_filled_form_step:
case state.FormFillExtensionMode.conversation_over:
await _send_message(
context,
"I'd love to generate the filled-out form now, but it's not yet implemented. :)",
"The form is now complete! Create a new conversation to work with another form.",
{},
)
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import json
from enum import StrEnum
from hashlib import md5
from pathlib import Path
from typing import Callable
Expand All @@ -12,6 +13,17 @@
)


class StateProjection(StrEnum):
"""
The projection to use when displaying the state.
"""

original_content = "original_content"
"""Return the state as string content."""
json_to_yaml = "json_to_yaml"
"""Return the state as a yaml code block."""


class FileStateInspector(ReadOnlyAssistantConversationInspectorStateProvider):
"""
A conversation inspector state provider that reads the state from a file and displays it as a yaml code block.
Expand All @@ -22,13 +34,17 @@ def __init__(
display_name: str,
file_path_source: Callable[[ConversationContext], Path],
description: str = "",
projection: StateProjection = StateProjection.json_to_yaml,
select_field: str = "",
) -> None:
self._state_id = md5(
(type(self).__name__ + "_" + display_name).encode("utf-8"), usedforsecurity=False
).hexdigest()
self._display_name = display_name
self._file_path_source = file_path_source
self._description = description
self._projection = projection
self._select_field = select_field

@property
def state_id(self) -> str:
Expand All @@ -50,7 +66,14 @@ def read_state(path: Path) -> dict:

state = read_state(self._file_path_source(context))

# return the state as a yaml code block, as it is easier to read than json
return AssistantConversationInspectorStateDataModel(
data={"content": f"```yaml\n{yaml.dump(state, sort_keys=False)}\n```"},
)
selected = state.get(self._select_field) if self._select_field else state

match self._projection:
case StateProjection.original_content:
return AssistantConversationInspectorStateDataModel(data={"content": selected})
case StateProjection.json_to_yaml:
state_as_yaml = yaml.dump(selected, sort_keys=False)
# return the state as a yaml code block, as it is easier to read than json
return AssistantConversationInspectorStateDataModel(
data={"content": f"```yaml\n{state_as_yaml}\n```"},
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextvars import ContextVar
from enum import StrEnum
from pathlib import Path
from typing import AsyncIterator, Literal
from typing import AsyncIterator

from pydantic import BaseModel, Field
from semantic_workbench_assistant.assistant_app.context import ConversationContext, storage_directory_for_context
Expand All @@ -11,38 +11,59 @@
from .inspector import FileStateInspector


class FieldType(StrEnum):
text = "text"
date = "date"
signature = "signature"
multiple_choice = "multiple_choice"


class AllowedOptionSelections(StrEnum):
one = "one"
"""One of the options can be selected."""
many = "many"
"""One or more of the options can be selected."""


class FormField(BaseModel):
id: str = Field(description="The descriptive, unique identifier of the field as a snake_case_english_string.")
name: str = Field(description="The name of the field.")
description: str = Field(description="The description of the field.")
type: Literal["string", "bool", "multiple_choice"] = Field(description="The type of the field.")
type: FieldType = Field(description="The type of the field.")
options: list[str] = Field(description="The options for multiple choice fields.")
required: bool = Field(description="Whether the field is required or not.")
option_selections_allowed: AllowedOptionSelections | None = Field(
description="The number of options that can be selected for multiple choice fields."
)
required: bool = Field(
description="Whether the field is required or not. False indicates the field is optional and can be left blank."
)


class FormFillAgentMode(StrEnum):
class FormFillExtensionMode(StrEnum):
acquire_form_step = "acquire_form"
extract_form_fields = "extract_form_fields"
fill_form_step = "fill_form"
generate_filled_form_step = "generate_filled_form"
conversation_over = "conversation_over"


class FormFillAgentState(BaseModel):
mode: FormFillAgentMode = FormFillAgentMode.acquire_form_step
class FormFillExtensionState(BaseModel):
mode: FormFillExtensionMode = FormFillExtensionMode.acquire_form_step
form_filename: str = ""
extracted_form_title: str = ""
extracted_form_fields: list[FormField] = []
populated_form_markdown: str = ""
fill_form_gc_artifact: dict | None = None


def path_for_state(context: ConversationContext) -> Path:
return storage_directory_for_context(context) / "state.json"


current_state = ContextVar[FormFillAgentState | None]("current_state", default=None)
current_state = ContextVar[FormFillExtensionState | None]("current_state", default=None)


@asynccontextmanager
async def agent_state(context: ConversationContext) -> AsyncIterator[FormFillAgentState]:
async def extension_state(context: ConversationContext) -> AsyncIterator[FormFillExtensionState]:
"""
Context manager that provides the agent state, reading it from disk, and saving back
to disk after the context manager block is executed.
Expand All @@ -53,7 +74,7 @@ async def agent_state(context: ConversationContext) -> AsyncIterator[FormFillAge
return

async with context.state_updated_event_after(inspector.state_id):
state = read_model(path_for_state(context), FormFillAgentState) or FormFillAgentState()
state = read_model(path_for_state(context), FormFillExtensionState) or FormFillExtensionState()
current_state.set(state)
yield state
write_model(path_for_state(context), state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def engine(
given state file.
"""

async with _state_locks[state_file_path], context.state_updated_event_after(state_id, focus_event=True):
async with _state_locks[state_file_path], context.state_updated_event_after(state_id):
kernel, service_id = _build_kernel_with_service(openai_client, openai_model)

state: dict | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AcquireFormConfig(BaseModel):

@dataclass
class CompleteResult(Result):
ai_message: str
message: str
filename: str


Expand Down Expand Up @@ -103,7 +103,7 @@ async def execute(

if form_filename and form_filename != "Unanswered":
return CompleteResult(
ai_message=result.ai_message or "",
message=result.ai_message or "",
filename=form_filename,
debug=debug,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ class ExtractFormFieldsConfig(BaseModel):
Field(title="Instruction", description="The instruction for extracting form fields from the file content."),
UISchema(widget="textarea"),
] = (
"Extract the form fields from the provided form attachment. Any type of form is allowed, including for example"
"Read the user provided form attachment and determine what fields are in the form. Any type of form is allowed, including"
" tax forms, address forms, surveys, and other official or unofficial form-types. If the content is not a form,"
" or the fields cannot be determined, then set the error_message."
" or the fields cannot be determined, then explain the reason why in the error_message. If the fields can be determined,"
" leave the error_message empty."
)


@dataclass
class CompleteResult(Result):
ai_message: str
message: str
extracted_form_title: str
extracted_form_fields: list[state.FormField]


Expand Down Expand Up @@ -62,14 +64,18 @@ async def execute(
)

return CompleteResult(
ai_message="",
message="",
extracted_form_title=extracted_form_fields.title,
extracted_form_fields=extracted_form_fields.fields,
debug=metadata,
)


class FormFields(BaseModel):
error_message: str = Field(description="The error message in the case that the form fields could not be extracted.")
error_message: str = Field(
description="The error message in the case that the form fields could not be determined."
)
title: str = Field(description="The title of the form.")
fields: list[state.FormField] = Field(description="The fields in the form.")


Expand Down
Loading

0 comments on commit b574e2f

Please sign in to comment.