Skip to content

Commit

Permalink
Implement thread deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
pamella committed Jun 12, 2024
1 parent 9ec3d96 commit 8633b88
Show file tree
Hide file tree
Showing 24 changed files with 1,484 additions and 169 deletions.
1 change: 1 addition & 0 deletions django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULTS = {
"CAN_CREATE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_DELETE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_CREATE_MESSAGE_FN": "django_ai_assistant.permissions.allow_all",
"CAN_RUN_ASSISTANT": "django_ai_assistant.permissions.allow_all",
}
Expand Down
19 changes: 18 additions & 1 deletion django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
AIUserNotAllowedError,
)
from django_ai_assistant.models import Thread
from django_ai_assistant.permissions import can_create_message, can_create_thread, can_run_assistant
from django_ai_assistant.permissions import (
can_create_message,
can_create_thread,
can_delete_thread,
can_run_assistant,
)
from django_ai_assistant.tools import Tool
from django_ai_assistant.tools import tool as tool_decorator

Expand Down Expand Up @@ -372,6 +377,18 @@ def get_threads(
return list(Thread.objects.filter(created_by=user))


def delete_thread(
thread: Thread,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request, view=view):
raise AIUserNotAllowedError("User is not allowed to delete this thread")

return thread.delete()


def get_thread_messages(
thread_id: str,
user: Any,
Expand Down
14 changes: 14 additions & 0 deletions django_ai_assistant/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ def can_create_thread(
)


def can_delete_thread(
thread,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
**kwargs,
) -> bool:
return app_settings.call_fn(
"CAN_DELETE_THREAD_FN",
**_get_default_kwargs(user, request, view),
thread=thread,
)


def can_create_message(
thread,
user: Any,
Expand Down
20 changes: 16 additions & 4 deletions django_ai_assistant/views.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import List

from django.shortcuts import get_object_or_404

from langchain_core.messages import message_to_dict
from ninja import NinjaAPI

from django_ai_assistant import __package__, __version__

from .exceptions import AIUserNotAllowedError
from .helpers import assistants
from .helpers.assistants import (
create_message,
get_assistants_info,
get_thread_messages,
get_threads,
)
from .helpers.assistants import (
create_thread as ai_create_thread,
)
from .models import Thread
from .schemas import (
AssistantSchema,
Expand Down Expand Up @@ -50,7 +50,19 @@ def list_threads(request):
@api.post("threads/", response=ThreadSchema, url_name="threads_list_create")
def create_thread(request, payload: ThreadSchemaIn):
name = payload.name
return ai_create_thread(name=name, user=request.user, request=request, view=None)
return assistants.create_thread(name=name, user=request.user, request=request, view=None)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="threads_delete")
def delete_thread(request, thread_id: str):
thread = get_object_or_404(Thread, id=thread_id)
assistants.delete_thread(
thread=thread,
user=request.user,
request=request,
view=None,
)
return 204, None


@api.get(
Expand Down
6 changes: 4 additions & 2 deletions example/assets/js/Chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import classes from "./Chat.module.css";
import { useCallback, useEffect, useRef, useState } from "react";
import { IconSend2 } from "@tabler/icons-react";
import { getHotkeyHandler } from "@mantine/hooks";
import Markdown from "react-markdown";

import {
ThreadMessagesSchemaOut,
Expand All @@ -28,7 +29,7 @@ function ChatMessage({ message }: { message: ThreadMessagesSchemaOut }) {
return (
<Box mb="md">
<Text fw={700}>{message.type === "ai" ? "AI" : "User"}</Text>
<Text>{message.content}</Text>
<Markdown className={classes.mdMessage}>{message.content}</Markdown>
</Box>
);
}
Expand Down Expand Up @@ -58,7 +59,7 @@ export function Chat() {
const [inputValue, setInputValue] = useState<string>("");

const { fetchAssistants, assistants } = useAssistant();
const { fetchThreads, threads, createThread } = useThread();
const { fetchThreads, threads, createThread, deleteThread } = useThread();
const {
fetchMessages,
messages,
Expand Down Expand Up @@ -132,6 +133,7 @@ export function Chat() {
selectedThreadId={activeThread?.id}
selectThread={setActiveThread}
createThread={createThread}
deleteThread={deleteThread}
/>
<main className={classes.main}>
<Container className={classes.chatContainer}>
Expand Down
26 changes: 0 additions & 26 deletions example/assets/js/Chat/ThreadsNav.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,6 @@
margin-bottom: rem(5px);
}

.threadLink {
display: block;
padding: rem(8px) var(--mantine-spacing-xs);
text-decoration: none;
border-radius: var(--mantine-radius-sm);
font-size: var(--mantine-font-size-sm);
color: light-dark(var(--mantine-color-gray-7), var(--mantine-color-dark-0));
line-height: 1;
font-weight: 500;

&:hover {
background-color: light-dark(
var(--mantine-color-gray-0),
var(--mantine-color-dark-6)
);
color: light-dark(var(--mantine-color-gray-7), var(--mantine-color-dark-0));
}
}

.threadLinkSelected {
background-color: light-dark(
var(--mantine-color-gray-0),
var(--mantine-color-dark-4)
);
}

.threadLinkInfo {
padding: rem(8px) var(--mantine-spacing-xs);
font-size: var(--mantine-font-size-sm);
Expand Down
79 changes: 55 additions & 24 deletions example/assets/js/Chat/ThreadsNav.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import { Text, Group, ActionIcon, Tooltip, rem, Loader } from "@mantine/core";
import { IconPlus } from "@tabler/icons-react";
import {
Text,
Group,
ActionIcon,
Tooltip,
Loader,
NavLink,
} from "@mantine/core";
import { useHover } from "@mantine/hooks";
import { IconPlus, IconTrash } from "@tabler/icons-react";
import classes from "./ThreadsNav.module.css";

import { ThreadSchema } from "django-ai-assistant-client";
Expand All @@ -9,31 +17,57 @@ export function ThreadsNav({
selectedThreadId,
selectThread,
createThread,
deleteThread,
}: {
threads: ThreadSchema[] | null;
selectedThreadId: number | null | undefined;
selectThread: (thread: ThreadSchema | null) => void;
createThread: () => Promise<ThreadSchema>;
deleteThread: ({ threadId }: { threadId: string }) => Promise<void>;
}) {
const threadLinks = threads?.map((thread) => {
const isThreadSelected = selectedThreadId && selectedThreadId === thread.id;
const ThreadNavLink = ({ thread }: { thread: ThreadSchema }) => {
const { hovered, ref } = useHover();

return (
<a
href="#"
onClick={(event) => {
selectThread(thread);
event.preventDefault();
}}
key={thread.id}
className={
classes.threadLink +
` ${isThreadSelected ? classes.threadLinkSelected : ""}`
}
>
{thread.name}
</a>
<div ref={ref} key={thread.id}>
<NavLink
href="#"
onClick={(event) => {
selectThread(thread);
event.preventDefault();
}}
label={thread.name}
active={selectedThreadId === thread.id}
variant="filled"
rightSection={
hovered ? (
<Tooltip label="Delete thread" withArrow>
<ActionIcon
variant="light"
color="red"
size="sm"
onClick={async () => {
await deleteThread({ threadId: thread.id });
window.location.reload();
}}
aria-label="Delete thread"
>
<IconTrash
style={{ width: "70%", height: "70%" }}
stroke={1.5}
/>
</ActionIcon>
</Tooltip>
) : null
}
/>
</div>
);
});
};

const threadLinks = threads?.map((thread) => (
<ThreadNavLink key={thread.id} thread={thread} />
));

return (
<nav className={classes.navbar}>
Expand All @@ -45,17 +79,14 @@ export function ThreadsNav({
<Tooltip label="Create thread" withArrow position="right">
<ActionIcon
variant="default"
size={18}
size="sm"
onClick={async (e) => {
const thread = await createThread();
selectThread(thread);
e.preventDefault();
}}
>
<IconPlus
style={{ width: rem(12), height: rem(12) }}
stroke={1.5}
/>
<IconPlus style={{ width: "70%", height: "70%" }} stroke={1.5} />
</ActionIcon>
</Tooltip>
</Group>
Expand Down
1 change: 1 addition & 0 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_DELETE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_CREATE_MESSAGE_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all"

Expand Down
Loading

0 comments on commit 8633b88

Please sign in to comment.