Skip to content

Commit

Permalink
Fix/adjust typing in chat manager.
Browse files Browse the repository at this point in the history
Don't need the model manager; keep it simpler.
  • Loading branch information
dannon committed Nov 19, 2024
1 parent 41be831 commit d6aa14f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
16 changes: 8 additions & 8 deletions lib/galaxy/managers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
InternalServerError,
RequestParameterInvalidException,
)
from galaxy.managers import base
from galaxy.managers.context import ProvidesUserContext
from galaxy.model import ChatExchange
from galaxy.model.base import transaction
Expand All @@ -35,14 +34,12 @@
]


class ChatManager(base.ModelManager[ChatExchange]):
class ChatManager:
"""
Business logic for chat exchanges.
"""

model_class = ChatExchange

def create(self, trans: ProvidesUserContext, job_id: JobIdPathParam, response: str) -> ChatExchange:
def create(self, trans: ProvidesUserContext, job_id: JobIdPathParam, message: str) -> ChatExchange:
"""
Create a new chat exchange in the DB. Currently these are *only* job-based chat exchanges, will need to generalize down the road.
:param job_id: id of the job to associate the response with
Expand All @@ -53,13 +50,13 @@ def create(self, trans: ProvidesUserContext, job_id: JobIdPathParam, response: s
:rtype: galaxy.model.ChatExchange
:raises: InternalServerError
"""
chat_exchange = ChatExchange(user=trans.user, job_id=job_id, message=response)
chat_exchange = ChatExchange(user=trans.user, job_id=job_id, message=message)
trans.sa_session.add(chat_exchange)
with transaction(trans.sa_session):
trans.sa_session.commit()
return chat_exchange

def get(self, trans: ProvidesUserContext, job_id: JobIdPathParam) -> ChatExchange:
def get(self, trans: ProvidesUserContext, job_id: JobIdPathParam) -> ChatExchange | None:
"""
Returns the chat response from the DB based on the given job id.
:param job_id: id of the job to load a response for from the DB
Expand All @@ -70,7 +67,7 @@ def get(self, trans: ProvidesUserContext, job_id: JobIdPathParam) -> ChatExchang
"""
try:
stmt = select(ChatExchange).where(ChatExchange.job_id == job_id)
chat_response = self.session().execute(stmt).scalar_one()
chat_response = trans.sa_session.execute(stmt).scalar_one()
except MultipleResultsFound:
# TODO: Unsure about this, isn't this more applicable when we're getting the response for response.id instead of response.job_id?
raise InconsistentDatabase("Multiple chat responses found with the same job id.")
Expand Down Expand Up @@ -102,6 +99,9 @@ def set_feedback_for_job(self, trans: ProvidesUserContext, job_id: JobIdPathPara

chat_exchange = self.get(trans, job_id)

if not chat_exchange:
raise RequestParameterInvalidException("No accessible response found with the id provided.")

# There is only one message in an exchange currently, so we can set the feedback on the first message
chat_exchange.messages[0].feedback = feedback

Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,10 +2985,12 @@ class ChatExchange(Base, RepresentById):
user: Mapped["User"] = relationship()
messages: Mapped[List["ChatExchangeMessage"]] = relationship(back_populates="chat_exchange", cascade_backrefs=False)

def __init__(self, user, job_id, message, **kwargs):
def __init__(self, user, job_id=None, message=None, **kwargs):
self.user = user
self.job_id = job_id
self.messages = [ChatExchangeMessage(message=message)]
self.messages = []
if message:
self.add_message(message)

def add_message(self, message):
self.messages.append(ChatExchangeMessage(message=message))
Expand Down
8 changes: 3 additions & 5 deletions lib/galaxy/webapps/galaxy/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@ def query(
job_id: JobIdPathParam,
payload: ChatPayload,
trans: ProvidesUserContext = DependsOnTrans,
) -> str:
) -> str | None:
"""We're off to ask the wizard"""

answer = None

if job_id:
existing_response = self.chat_manager.get(trans, job_id)
# Currently job-based chat exchanges are the only ones supported,
Expand All @@ -68,7 +66,7 @@ def query(

# TODO: Maybe we need to first check if the job_id exists (in the `job` table)?
if job_id:
self.chat_manager.create(trans, job_id, answer)
self.chat_manager.create(trans.user, job_id, answer)

return answer

Expand All @@ -78,7 +76,7 @@ def feedback(
job_id: JobIdPathParam,
feedback: int,
trans: ProvidesUserContext = DependsOnTrans,
) -> int:
) -> int | None:
"""Provide feedback on the chatbot response."""
chat_response = self.chat_manager.set_feedback_for_job(trans, job_id, feedback)
return chat_response.messages[0].feedback
Expand Down

0 comments on commit d6aa14f

Please sign in to comment.