diff --git a/core/cat/convo/messages.py b/core/cat/convo/messages.py index 130dbb07..9dd96207 100644 --- a/core/cat/convo/messages.py +++ b/core/cat/convo/messages.py @@ -63,15 +63,11 @@ class Message(BaseModelDict): ---------- user_id : str Unique identifier for the user associated with the message. - who : str - The name of the message author. - when : Optional[float] - The timestamp when the message was sent. - + when : float + The timestamp when the message was sent. """ user_id: str - who: str when: float = Field(default_factory=time.time) @@ -82,6 +78,12 @@ class ConversationMessage(Message): Attributes ---------- + user_id : str + Unique identifier for the user associated with the message. + when : float + The timestamp when the message was sent. Defaults to the current time. + who : str + The name of the message author. text : Optional[str], default=None The text content of the message. image : Optional[str], default=None @@ -90,10 +92,47 @@ class ConversationMessage(Message): Audio file URLs or base64 data URIs that represent audio associated with the message. """ + who: str text: Optional[str] = None image: Optional[str] = None audio: Optional[str] = None + # massage was used in the old history instead of text + # we need to keep it for backward compatibility + def __init__(self, **data): + + if "message" in data: + deprecation_warning("The `message` parameter is deprecated. Use `text` instead.") + data["text"] = data.pop("message") + + super().__init__(**data) + + @computed_field + @property + def message(self) -> str: + """ + This attribute is deprecated. Use `text` instead. + + The text content of the message. Use `text` instead. + + Returns + ------- + str + The text content of the message. + """ + deprecation_warning("The `message` attribute is deprecated. Use `text` instead.") + return self.text + + @message.setter + def message(self, value): + deprecation_warning("The `message` attribute is deprecated. Use `text` instead.") + self.text = value + + @property + def role(self) -> None: + """The role of the message author.""" + return None + class CatMessage(ConversationMessage): """ @@ -103,6 +142,18 @@ class CatMessage(ConversationMessage): ---------- type : str The type of message. Defaults to "chat". + user_id : str + Unique identifier for the user associated with the message. + when : float + The timestamp when the message was sent. Defaults to the current time. + who : str + The name of the message author. + text : Optional[str], default=None + The text content of the message. + image : Optional[str], default=None + Image file URLs or base64 data URIs that represent image associated with the message. + audio : Optional[str], default=None + Audio file URLs or base64 data URIs that represent audio associated with the message. why : Optional[MessageWhy] Additional contextual information related to the message. @@ -115,21 +166,22 @@ class CatMessage(ConversationMessage): type: str = "chat" # For now is always "chat" and is not used why: Optional[MessageWhy] = None - def langchainfy(self) -> AIMessage: - """ - Convert the internal CatMessage to a LangChain AIMessage. + def __init__( + self, + user_id: str, + who: str = "AI", + text: Optional[str] = None, + image: Optional[str] = None, + audio: Optional[str] = None, + why: Optional[MessageWhy] = None, + **kwargs, + ): + if "content" in kwargs: + deprecation_warning("The `content` parameter is deprecated. Use `text` instead.") + text = kwargs.pop("content") # Map 'content' to 'text' + + super().__init__(user_id=user_id, text=text, image=image, audio=audio, why=why, who=who, **kwargs) - Returns - ------- - AIMessage - The LangChain AIMessage converted from the internal CatMessage. - """ - - return AIMessage( - name=self.who, - content=self.text - ) - @computed_field @property def content(self) -> str: @@ -151,6 +203,26 @@ def content(self, value): deprecation_warning("The `content` attribute is deprecated. Use `text` instead.") self.text = value + @property + def role(self) -> Role: + """The role of the message author.""" + return Role.AI + + def langchainfy(self) -> AIMessage: + """ + Convert the internal CatMessage to a LangChain AIMessage. + + Returns + ------- + AIMessage + The LangChain AIMessage converted from the internal CatMessage. + """ + + return AIMessage( + name=self.who, + content=self.text + ) + class UserMessage(ConversationMessage): """ @@ -158,10 +230,30 @@ class UserMessage(ConversationMessage): This class is used to encapsulate the details of a message sent by a user, including the user's identifier, the text content of the message, and any associated multimedia content such as image or audio files. + + Attributes + ---------- + user_id : str + Unique identifier for the user associated with the message. + when : float + The timestamp when the message was sent. Defaults to the current time. + who : str + The name of the message author. + text : Optional[str], default=None + The text content of the message. + image : Optional[str], default=None + Image file URLs or base64 data URIs that represent image associated with the message. + audio : Optional[str], default=None + Audio file URLs or base64 data URIs that represent audio associated with the message. """ who: str = "Human" + @property + def role(self) -> Role: + """The role of the message author.""" + return Role.Human + def langchainfy(self) -> HumanMessage: """ Convert the internal UserMessage to a LangChain HumanMessage. @@ -188,7 +280,14 @@ def langchainfy(self) -> HumanMessage: ) def langchainfy_image(self) -> dict: - """Format an image to be sent as a data URI.""" + """ + Format an image to be sent as a data URI. + + Returns + ------- + dict + A dictionary containing the image data URI. + """ # If the image is a URL, download it and encode it as a data URI if self.image.startswith("http"):