Skip to content

Commit

Permalink
Merge pull request #389 from miurla/feat/enabled-load-chat
Browse files Browse the repository at this point in the history
Save and load annotation data
  • Loading branch information
miurla authored Jan 9, 2025
2 parents 49475ac + e01f811 commit e5c3056
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 48 deletions.
41 changes: 20 additions & 21 deletions app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import {
streamText,
createDataStreamResponse,
Message,
convertToCoreMessages,
generateId,
JSONValue,
ToolInvocation
} from 'ai'
import { researcher } from '@/lib/agents/researcher'
import { generateRelatedQuestions } from '@/lib/agents/generate-related-questions'
import { cookies } from 'next/headers'
import { getChat, saveChat } from '@/lib/actions/chat'
import { ExtendedCoreMessage } from '@/lib/types'
import { convertToExtendedCoreMessages } from '@/lib/utils'

export const maxDuration = 30

Expand All @@ -19,7 +19,11 @@ const DEFAULT_MODEL = 'openai:gpt-4o-mini'
export async function POST(req: Request) {
const { messages, id: chatId } = await req.json()

// streamText requires core messages
const coreMessages = convertToCoreMessages(messages)
// convertToExtendedCoreMessages for saving annotations
const extendedCoreMessages = convertToExtendedCoreMessages(messages)

const cookieStore = await cookies()
const modelFromCookie = cookieStore.get('selected-model')?.value
const model = modelFromCookie || DEFAULT_MODEL
Expand All @@ -31,24 +35,16 @@ export async function POST(req: Request) {
model
})

let toolResults: ToolInvocation[] = []
const result = streamText({
...researcherConfig,
onStepFinish(event) {
// onFinish's event.toolResults is empty. Use onStepFinish to get the tool results.
if (event.stepType === 'initial') {
toolResults = event.toolResults
}
},
onFinish: async event => {
const responseMessages = event.response.messages

let annotation: JSONValue = {
type: 'related-questions',
data: {
items: []
},
status: 'loading'
}
}

// Notify related questions loading
Expand All @@ -63,21 +59,22 @@ export async function POST(req: Request) {
// Update the annotation with the related questions
annotation = {
...annotation,
data: relatedQuestions.object,
status: 'done'
data: relatedQuestions.object
}

// Send related questions to client
dataStream.writeMessageAnnotation(annotation)

// Create the message to save
const generatedMessage: Message = {
role: 'assistant',
content: event.text,
toolInvocations: toolResults,
annotations: [annotation],
id: generateId()
}
const generatedMessages = [
...extendedCoreMessages,
...responseMessages.slice(0, -1),
{
role: 'data',
content: annotation
},
responseMessages[responseMessages.length - 1]
] as ExtendedCoreMessage[]

// Get the chat from the database if it exists, otherwise create a new one
const savedChat = (await getChat(chatId)) ?? {
Expand All @@ -89,10 +86,12 @@ export async function POST(req: Request) {
id: chatId
}

console.log('generatedMessages', generatedMessages)

// Save chat with complete response and related questions
await saveChat({
...savedChat,
messages: [...savedChat.messages, generatedMessage]
messages: generatedMessages
})
}
})
Expand Down
8 changes: 7 additions & 1 deletion app/search/[id]/page.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { notFound, redirect } from 'next/navigation'
import { Chat } from '@/components/chat'
import { getChat } from '@/lib/actions/chat'
import { convertToCoreMessages } from 'ai'
import { convertToUIMessages } from '@/lib/utils'

export const maxDuration = 60

Expand All @@ -20,6 +22,8 @@ export default async function SearchPage(props: {
const userId = 'anonymous'
const { id } = await props.params
const chat = await getChat(id, userId)
// convertToUIMessages for useChat hook
const messages = convertToUIMessages(chat?.messages || [])

if (!chat) {
redirect('/')
Expand All @@ -29,5 +33,7 @@ export default async function SearchPage(props: {
notFound()
}

return <Chat id={id} />
console.log(chat)

return <Chat id={id} savedMessages={messages} />
}
1 change: 1 addition & 0 deletions components/answer-section.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export function AnswerSection({
header={header}
isOpen={isOpen}
onOpenChange={onOpenChange}
showBorder={false}
>
{content ? <BotMessage message={content} /> : <DefaultSkeleton />}
</CollapsibleMessage>
Expand Down
14 changes: 11 additions & 3 deletions components/chat.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
'use client'

import { useChat } from 'ai/react'
import { Message, useChat } from 'ai/react'
import { ChatMessages } from './chat-messages'
import { ChatPanel } from './chat-panel'

export function Chat({ id }: { id: string }) {
export function Chat({
id,
savedMessages = []
}: {
id: string
savedMessages?: Message[]
}) {
const {
messages,
input,
Expand All @@ -15,11 +21,13 @@ export function Chat({ id }: { id: string }) {
stop,
append
} = useChat({
initialMessages: savedMessages,
id: 'chat',
body: {
id
},
onFinish: () => {
// window.history.replaceState({}, '', `/search/${id}`)
window.history.replaceState({}, '', `/search/${id}`)
}
})

Expand Down
20 changes: 15 additions & 5 deletions components/collapsible-message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
} from './ui/collapsible'
import { Separator } from './ui/separator'
import { cn } from '@/lib/utils'
import { IconLogo } from './ui/icons'

interface CollapsibleMessageProps {
children: React.ReactNode
Expand All @@ -14,6 +15,7 @@ interface CollapsibleMessageProps {
isOpen?: boolean
header?: React.ReactNode
onOpenChange?: (open: boolean) => void
showBorder?: boolean
}

export function CollapsibleMessage({
Expand All @@ -22,22 +24,30 @@ export function CollapsibleMessage({
isCollapsible = false,
isOpen = true,
header,
onOpenChange
onOpenChange,
showBorder = true
}: CollapsibleMessageProps) {
const content = <div className="py-2 flex-1">{children}</div>

return (
<div className="flex gap-3">
<div className="relative flex flex-col items-center">
<div className={cn('mt-[10px]', role === 'assistant' && 'pl-4')}>
{role === 'user' && (
<div className={cn('mt-[10px]', role === 'assistant' && 'mt-4')}>
{role === 'user' ? (
<UserCircle2 size={20} className="text-muted-foreground" />
) : (
<IconLogo className="size-5" />
)}
</div>
</div>

{isCollapsible ? (
<div className="flex-1 border border-border/50 rounded-2xl p-4">
<div
className={cn(
'flex-1 rounded-2xl p-4',
showBorder && 'border border-border/50'
)}
>
<Collapsible
open={isOpen}
onOpenChange={onOpenChange}
Expand All @@ -50,7 +60,7 @@ export function CollapsibleMessage({
</div>
</CollapsibleTrigger>
<CollapsibleContent className="data-[state=closed]:animate-collapse-up data-[state=open]:animate-collapse-down">
<Separator className="my-4" />
<Separator className="my-4 border-border/50" />
{content}
</CollapsibleContent>
</Collapsible>
Expand Down
23 changes: 13 additions & 10 deletions components/related-questions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { ArrowRight, Repeat2 } from 'lucide-react'
import { Skeleton } from './ui/skeleton'
import { JSONValue } from 'ai'
import { CollapsibleMessage } from './collapsible-message'
import { useChat } from 'ai/react'

export interface RelatedQuestionsProps {
annotations: JSONValue[]
Expand All @@ -19,7 +20,6 @@ interface RelatedQuestionsAnnotation extends Record<string, JSONValue> {
data: {
items: Array<{ query: string }>
}
status: 'loading' | 'done'
}

export const RelatedQuestions: React.FC<RelatedQuestionsProps> = ({
Expand All @@ -28,18 +28,17 @@ export const RelatedQuestions: React.FC<RelatedQuestionsProps> = ({
isOpen,
onOpenChange
}) => {
const { isLoading } = useChat({
id: 'chat'
})

if (!annotations) {
return null
}

const lastRelatedQuestionsAnnotation = annotations.find(
(a): a is RelatedQuestionsAnnotation =>
a !== null &&
typeof a === 'object' &&
'type' in a &&
a.type === 'related-questions' &&
a.status === 'done'
)
const lastRelatedQuestionsAnnotation = annotations[
annotations.length - 1
] as RelatedQuestionsAnnotation

const header = (
<div className="flex items-center gap-1">
Expand All @@ -49,7 +48,11 @@ export const RelatedQuestions: React.FC<RelatedQuestionsProps> = ({
)

const relatedQuestions = lastRelatedQuestionsAnnotation?.data
if (!relatedQuestions) {
if (!relatedQuestions && !isLoading) {
return null
}

if (!relatedQuestions || isLoading) {
return (
<CollapsibleMessage
role="assistant"
Expand Down
14 changes: 9 additions & 5 deletions components/search-section.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { Section, ToolArgsSection } from './section'
import type { SearchResults as TypeSearchResults } from '@/lib/types'
import { ToolInvocation } from 'ai'
import { CollapsibleMessage } from './collapsible-message'
import { useChat } from 'ai/react'

interface SearchSectionProps {
tool: ToolInvocation
Expand All @@ -19,7 +20,10 @@ export function SearchSection({
isOpen,
onOpenChange
}: SearchSectionProps) {
const isLoading = tool.state === 'call'
const { isLoading } = useChat({
id: 'chat'
})
const isToolLoading = tool.state === 'call'
const searchResults: TypeSearchResults =
tool.state === 'result' ? tool.result : undefined
const query = tool.args.query as string | undefined
Expand Down Expand Up @@ -50,13 +54,13 @@ export function SearchSection({
/>
</Section>
)}
{!isLoading && searchResults.results ? (
{isLoading && isToolLoading ? (
<DefaultSkeleton />
) : searchResults?.results ? (
<Section title="Sources">
<SearchResults results={searchResults.results} />
</Section>
) : (
<DefaultSkeleton />
)}
) : null}
</CollapsibleMessage>
)
}
10 changes: 8 additions & 2 deletions lib/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Message } from 'ai'
import { CoreMessage, JSONValue, Message } from 'ai'

export type SearchResults = {
images: SearchResultImage[]
Expand Down Expand Up @@ -63,10 +63,16 @@ export interface Chat extends Record<string, any> {
createdAt: Date
userId: string
path: string
messages: Message[] // Note: Changed from AIMessage to Message
messages: ExtendedCoreMessage[] // Note: Changed from AIMessage to ExtendedCoreMessage
sharePath?: string
}

// ExtendedCoreMessage for saveing annotations
export type ExtendedCoreMessage = Omit<CoreMessage, 'role' | 'content'> & {
role: CoreMessage['role'] | 'data'
content: CoreMessage['content'] | JSONValue
}

export type AIMessage = {
role: 'user' | 'assistant' | 'system' | 'function' | 'data' | 'tool'
content: string
Expand Down
Loading

0 comments on commit e5c3056

Please sign in to comment.