Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix up conversation initialization #6430

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions frontend/__tests__/initial-query.test.tsx
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import { describe, it, expect } from "vitest";
import store from "../src/store";
import {
setInitialQuery,
clearInitialQuery,
setInitialPrompt,
clearInitialPrompt,
} from "../src/state/initial-query-slice";

describe("Initial Query Behavior", () => {
it("should clear initial query when clearInitialQuery is dispatched", () => {
it("should clear initial query when clearInitialPrompt is dispatched", () => {
// Set up initial query in the store
store.dispatch(setInitialQuery("test query"));
expect(store.getState().initialQuery.initialQuery).toBe("test query");
store.dispatch(setInitialPrompt("test query"));
expect(store.getState().initialQuery.initialPrompt).toBe("test query");

// Clear the initial query
store.dispatch(clearInitialQuery());
store.dispatch(clearInitialPrompt());

// Verify initial query is cleared
expect(store.getState().initialQuery.initialQuery).toBeNull();
expect(store.getState().initialQuery.initialPrompt).toBeNull();
});
});
4 changes: 4 additions & 0 deletions frontend/src/api/open-hands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,14 @@ class OpenHands {
static async createConversation(
githubToken?: string,
selectedRepository?: string,
initialUserMsg?: string,
imageUrls?: string[],
): Promise<Conversation> {
const body = {
github_token: githubToken,
selected_repository: selectedRepository,
initial_user_msg: initialUserMsg,
image_urls: imageUrls,
};

const { data } = await openHands.post<Conversation>(
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/agent-status-map.constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export const AGENT_STATUS_MAP: {
},
[AgentState.AWAITING_USER_INPUT]: {
message: I18nKey.CHAT_INTERFACE$AGENT_AWAITING_USER_INPUT_MESSAGE,
indicator: IndicatorColor.ORANGE,
indicator: IndicatorColor.BLUE,
},
[AgentState.PAUSED]: {
message: I18nKey.CHAT_INTERFACE$AGENT_PAUSED_MESSAGE,
Expand Down
9 changes: 6 additions & 3 deletions frontend/src/hooks/mutation/use-create-conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { useNavigate } from "react-router";
import posthog from "posthog-js";
import { useDispatch, useSelector } from "react-redux";
import OpenHands from "#/api/open-hands";
import { setInitialQuery } from "#/state/initial-query-slice";
import { setInitialPrompt } from "#/state/initial-query-slice";
import { RootState } from "#/store";
import { useAuth } from "#/context/auth-context";

Expand All @@ -18,7 +18,7 @@ export const useCreateConversation = () => {
);

return useMutation({
mutationFn: (variables: { q?: string }) => {
mutationFn: async (variables: { q?: string }) => {
if (
!variables.q?.trim() &&
!selectedRepository &&
Expand All @@ -28,10 +28,13 @@ export const useCreateConversation = () => {
throw new Error("No query provided");
}

if (variables.q) dispatch(setInitialQuery(variables.q));
if (variables.q) dispatch(setInitialPrompt(variables.q));

return OpenHands.createConversation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we marked this function async, shouldn't this now be:
return await OpenHands.createConversation(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Otherwise the result will be a promise wrapped in a promise?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this works just fine--promises are pretty intelligent about nesting IIRC.

gitHubToken || undefined,
selectedRepository || undefined,
variables.q,
files,
);
},
onSuccess: async ({ conversation_id: conversationId }, { q }) => {
Expand Down
2 changes: 0 additions & 2 deletions frontend/src/routes/_oh.app/event-handler.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import React from "react";
import { useWSStatusChange } from "./hooks/use-ws-status-change";
import { useHandleWSEvents } from "./hooks/use-handle-ws-events";
import { useHandleRuntimeActive } from "./hooks/use-handle-runtime-active";

export function EventHandler({ children }: React.PropsWithChildren) {
useWSStatusChange();
useHandleWSEvents();
useHandleRuntimeActive();

Expand Down
68 changes: 0 additions & 68 deletions frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts

This file was deleted.

21 changes: 19 additions & 2 deletions frontend/src/routes/_oh.app/route.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { useDisclosure } from "@nextui-org/react";
import React from "react";
import { Outlet } from "react-router";
import { useDispatch } from "react-redux";
import { useDispatch, useSelector } from "react-redux";
import { FaServer } from "react-icons/fa";
import toast from "react-hot-toast";
import { useTranslation } from "react-i18next";
Expand All @@ -11,7 +11,7 @@ import {
useConversation,
} from "#/context/conversation-context";
import { Controls } from "#/components/features/controls/controls";
import { clearMessages } from "#/state/chat-slice";
import { clearMessages, addUserMessage } from "#/state/chat-slice";
import { clearTerminal } from "#/state/command-slice";
import { useEffectOnce } from "#/hooks/use-effect-once";
import CodeIcon from "#/icons/code.svg?react";
Expand All @@ -36,6 +36,8 @@ import { ServedAppLabel } from "#/components/layout/served-app-label";
import { TerminalStatusLabel } from "#/components/features/terminal/terminal-status-label";
import { useSettings } from "#/hooks/query/use-settings";
import { MULTI_CONVERSATION_UI } from "#/utils/feature-flags";
import { clearFiles, clearInitialPrompt } from "#/state/initial-query-slice";
import { RootState } from "#/store";

function AppContent() {
useConversationConfig();
Expand All @@ -46,6 +48,9 @@ function AppContent() {
const { data: conversation, isFetched } = useUserConversation(
conversationId || null,
);
const { initialPrompt, files } = useSelector(
(state: RootState) => state.initialQuery,
);
const dispatch = useDispatch();
const endSession = useEndSession();

Expand Down Expand Up @@ -74,6 +79,18 @@ function AppContent() {
dispatch(clearMessages());
dispatch(clearTerminal());
dispatch(clearJupyter());
if (conversationId && (initialPrompt || files.length > 0)) {
dispatch(
addUserMessage({
content: initialPrompt || "",
imageUrls: files || [],
timestamp: new Date().toISOString(),
pending: true,
}),
);
dispatch(clearInitialPrompt());
dispatch(clearFiles());
}
}, [conversationId]);

useEffectOnce(() => {
Expand Down
16 changes: 8 additions & 8 deletions frontend/src/state/initial-query-slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";

type SliceState = {
files: string[]; // base64 encoded images
initialQuery: string | null;
initialPrompt: string | null;
selectedRepository: string | null;
importedProjectZip: string | null; // base64 encoded zip
};

const initialState: SliceState = {
files: [],
initialQuery: null,
initialPrompt: null,
selectedRepository: null,
importedProjectZip: null,
};
Expand All @@ -27,11 +27,11 @@ export const selectedFilesSlice = createSlice({
clearFiles(state) {
state.files = [];
},
setInitialQuery(state, action: PayloadAction<string>) {
state.initialQuery = action.payload;
setInitialPrompt(state, action: PayloadAction<string>) {
state.initialPrompt = action.payload;
},
clearInitialQuery(state) {
state.initialQuery = null;
clearInitialPrompt(state) {
state.initialPrompt = null;
},
setSelectedRepository(state, action: PayloadAction<string | null>) {
state.selectedRepository = action.payload;
Expand All @@ -49,8 +49,8 @@ export const {
addFile,
removeFile,
clearFiles,
setInitialQuery,
clearInitialQuery,
setInitialPrompt,
clearInitialPrompt,
setSelectedRepository,
clearSelectedRepository,
setImportedProjectZip,
Expand Down
4 changes: 0 additions & 4 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,6 @@ async def set_agent_state_to(self, new_state: AgentState) -> None:
EventSource.ENVIRONMENT,
)

if new_state == AgentState.INIT and self.state.resume_state:
await self.set_agent_state_to(self.state.resume_state)
self.state.resume_state = None

def get_agent_state(self) -> AgentState:
"""Returns the current state of the agent.
Expand Down
4 changes: 0 additions & 4 deletions openhands/core/schema/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@


class ActionTypeSchema(BaseModel):
INIT: str = Field(default='initialize')
"""Initializes the agent. Only sent by client.
"""

MESSAGE: str = Field(default='message')
"""Represents a message.
"""
Expand Down
4 changes: 0 additions & 4 deletions openhands/core/schema/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ class AgentState(str, Enum):
"""The agent is loading.
"""

INIT = 'init'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think INIT exists in the event stream, at least as value of an agent change observation. It might be worth to test restoring a session, to see if we can recreate the events.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh good point

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works!

"""The agent is initialized.
"""

RUNNING = 'running'
"""The agent is running.
"""
Expand Down
3 changes: 0 additions & 3 deletions openhands/server/listen_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from socketio.exceptions import ConnectionRefusedError

from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import (
NullAction,
)
Expand Down Expand Up @@ -86,8 +85,6 @@ async def connect(connection_id: str, environ, auth):
):
continue
elif isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.INIT:
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
agent_state_changed = event
else:
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
Expand Down
5 changes: 0 additions & 5 deletions openhands/server/mock/listen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from fastapi import FastAPI, WebSocket

from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import ActionType
from openhands.utils.shutdown_listener import should_continue

app = FastAPI()
Expand All @@ -11,10 +10,6 @@
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
# send message to mock connection
await websocket.send_json(
{'action': ActionType.INIT, 'message': 'Control loop started.'}
)

try:
while should_continue():
Expand Down
19 changes: 17 additions & 2 deletions openhands/server/routes/manage_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel

from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.events.stream import EventStreamSubscriber
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_user_id
Expand Down Expand Up @@ -34,13 +35,15 @@ class InitSessionRequest(BaseModel):
github_token: str | None = None
selected_repository: str | None = None
initial_user_msg: str | None = None
image_urls: list[str] | None = None


async def _create_new_conversation(
user_id: str | None,
token: str | None,
selected_repository: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
):
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
Expand Down Expand Up @@ -94,8 +97,14 @@ async def _create_new_conversation(
)

logger.info(f'Starting agent loop for conversation {conversation_id}')
initial_message_action = None
if initial_user_msg or image_urls:
initial_message_action = MessageAction(
content=initial_user_msg or '',
image_urls=image_urls or [],
)
event_stream = await session_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id, initial_user_msg
conversation_id, conversation_init_data, user_id, initial_message_action
)
try:
event_stream.subscribe(
Expand All @@ -121,10 +130,16 @@ async def new_conversation(request: Request, data: InitSessionRequest):
github_token = getattr(request.state, 'github_token', '') or data.github_token
selected_repository = data.selected_repository
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []

try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id, github_token, selected_repository, initial_user_msg
user_id,
github_token,
selected_repository,
initial_user_msg,
image_urls,
)

return JSONResponse(
Expand Down
Loading
Loading