Skip to content

Commit

Permalink
some cleanup and convenience for chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Jan 27, 2025
1 parent ffecfce commit b2a9a46
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 31 deletions.
10 changes: 4 additions & 6 deletions python/semantic_kernel/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,14 @@ async def reduce_history(self, history: "ChatHistory") -> bool:

if self.history_reducer is history:
logger.info("You're reducing the same object, you can call `history.reduce()` instead.")
initial_len = len(history.messages)
await history.reduce()
return len(history.messages) < initial_len
initial_len = len(self.history_reducer)
await self.history_reducer.reduce()
return len(self.history_reducer) < initial_len

self.history_reducer.messages = history.messages

new_messages = await self.history_reducer.reduce()
if new_messages is not None:
history.messages.clear()
history.messages.extend(new_messages)
history.replace(new_messages)
return True

return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,8 @@ async def invoke_stream(self, history: ChatHistory) -> AsyncIterable[StreamingCh

def _setup_agent_chat_history(self, history: ChatHistory) -> ChatHistory:
"""Setup the agent chat history."""
chat = []

if self.instructions is not None:
chat.append(ChatMessageContent(role=AuthorRole.SYSTEM, content=self.instructions, name=self.name))

chat.extend(history.messages if history.messages else [])

return ChatHistory(messages=chat)
return (
ChatHistory(messages=history.messages)
if self.instructions is None
else ChatHistory(system_message=self.instructions, messages=history.messages)
)
61 changes: 44 additions & 17 deletions python/semantic_kernel/contents/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from collections.abc import Generator
from collections.abc import Generator, Iterable, MutableSequence
from functools import singledispatchmethod
from html import unescape
from typing import Any
from typing import Any, TypeVar
from xml.etree.ElementTree import Element, tostring # nosec

from defusedxml.ElementTree import XML, ParseError
Expand All @@ -19,6 +19,8 @@

logger = logging.getLogger(__name__)

_T = TypeVar("_T", bound="ChatHistory")


class ChatHistory(KernelBaseModel):
"""This class holds the history of chat messages from a chat conversation.
Expand All @@ -28,10 +30,10 @@ class ChatHistory(KernelBaseModel):
as a keyword argument, but not be part of the class definition.
Attributes:
messages (List[ChatMessageContent]): The list of chat messages in the history.
messages: The list of chat messages in the history.
"""

messages: list[ChatMessageContent]
messages: MutableSequence[ChatMessageContent]

def __init__(self, **data: Any):
"""Initializes a new instance of the ChatHistory class.
Expand All @@ -52,8 +54,8 @@ def __init__(self, **data: Any):
Args:
**data: Arbitrary keyword arguments.
The constructor looks for two optional keys:
- 'messages': Optional[List[ChatMessageContent]], a list of chat messages to include in the history.
- 'system_message' Optional[str]: An optional string representing a system-generated message to be
- 'messages': List[ChatMessageContent], a list of chat messages to include in the history.
- 'system_message' str: An optional string representing a system-generated message to be
included at the start of the chat history.
"""
Expand All @@ -72,7 +74,7 @@ def __init__(self, **data: Any):

@field_validator("messages", mode="before")
@classmethod
def _validate_messages(cls, messages: list[ChatMessageContent]) -> list[ChatMessageContent]:
def _validate_messages(cls, messages: MutableSequence[ChatMessageContent]) -> MutableSequence[ChatMessageContent]:
if not messages:
return messages
out_msgs: list[ChatMessageContent] = []
Expand All @@ -89,12 +91,12 @@ def add_system_message(self, content: str | list[KernelContent], **kwargs) -> No
raise NotImplementedError

@add_system_message.register
def add_system_message_str(self, content: str, **kwargs: Any) -> None:
def _(self, content: str, **kwargs: Any) -> None:
"""Add a system message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.SYSTEM, content=content, **kwargs))

@add_system_message.register(list)
def add_system_message_list(self, content: list[KernelContent], **kwargs: Any) -> None:
def _(self, content: list[KernelContent], **kwargs: Any) -> None:
"""Add a system message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.SYSTEM, items=content, **kwargs))

Expand All @@ -104,12 +106,12 @@ def add_developer_message(self, content: str | list[KernelContent], **kwargs) ->
raise NotImplementedError

@add_developer_message.register
def add_developer_message_str(self, content: str, **kwargs: Any) -> None:
def _(self, content: str, **kwargs: Any) -> None:
"""Add a system message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.DEVELOPER, content=content, **kwargs))

@add_developer_message.register(list)
def add_developer_message_list(self, content: list[KernelContent], **kwargs: Any) -> None:
def _(self, content: list[KernelContent], **kwargs: Any) -> None:
"""Add a system message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.DEVELOPER, items=content, **kwargs))

Expand All @@ -119,12 +121,12 @@ def add_user_message(self, content: str | list[KernelContent], **kwargs: Any) ->
raise NotImplementedError

@add_user_message.register
def add_user_message_str(self, content: str, **kwargs: Any) -> None:
def _(self, content: str, **kwargs: Any) -> None:
"""Add a user message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.USER, content=content, **kwargs))

@add_user_message.register(list)
def add_user_message_list(self, content: list[KernelContent], **kwargs: Any) -> None:
def _(self, content: list[KernelContent], **kwargs: Any) -> None:
"""Add a user message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.USER, items=content, **kwargs))

Expand All @@ -134,12 +136,12 @@ def add_assistant_message(self, content: str | list[KernelContent], **kwargs: An
raise NotImplementedError

@add_assistant_message.register
def add_assistant_message_str(self, content: str, **kwargs: Any) -> None:
def _(self, content: str, **kwargs: Any) -> None:
"""Add an assistant message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.ASSISTANT, content=content, **kwargs))

@add_assistant_message.register(list)
def add_assistant_message_list(self, content: list[KernelContent], **kwargs: Any) -> None:
def _(self, content: list[KernelContent], **kwargs: Any) -> None:
"""Add an assistant message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.ASSISTANT, items=content, **kwargs))

Expand All @@ -149,12 +151,12 @@ def add_tool_message(self, content: str | list[KernelContent], **kwargs: Any) ->
raise NotImplementedError

@add_tool_message.register
def add_tool_message_str(self, content: str, **kwargs: Any) -> None:
def _(self, content: str, **kwargs: Any) -> None:
"""Add a tool message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.TOOL, content=content, **kwargs))

@add_tool_message.register(list)
def add_tool_message_list(self, content: list[KernelContent], **kwargs: Any) -> None:
def _(self, content: list[KernelContent], **kwargs: Any) -> None:
"""Add a tool message to the chat history."""
self.add_message(message=self._prepare_for_add(role=AuthorRole.TOOL, items=content, **kwargs))

Expand Down Expand Up @@ -245,6 +247,31 @@ def __str__(self) -> str:
chat_history_xml.append(message.to_element())
return tostring(chat_history_xml, encoding="unicode", short_empty_elements=True)

def clear(self) -> None:
"""Clear the chat history."""
self.messages.clear()

def extend(self, messages: Iterable[ChatMessageContent]) -> None:
"""Extend the chat history with a list of messages.
Args:
messages: The messages to add to the history.
Can be a list of ChatMessageContent instances or a ChatHistory itself.
"""
self.messages.extend(messages)

def replace(self, messages: Iterable[ChatMessageContent]) -> None:
"""Replace the chat history with a list of messages.
This calls clear() and then extend(messages=messages).
Args:
messages: The messages to add to the history.
Can be a list of ChatMessageContent instances or a ChatHistory itself.
"""
self.clear()
self.extend(messages=messages)

def to_prompt(self) -> str:
"""Return a string representation of the history."""
chat_history_xml = Element(CHAT_HISTORY_TAG)
Expand Down

0 comments on commit b2a9a46

Please sign in to comment.