Skip to content

Commit

Permalink
[Security Assistant] Fix system prompts (elastic#189130)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Jul 30, 2024
1 parent 64c61e0 commit dc11f75
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { FormattedMessage } from '@kbn/i18n-react';
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/public/common';
import { noop } from 'lodash/fp';
import { PromptResponse } from '@kbn/elastic-assistant-common';
import { QueryObserverResult } from '@tanstack/react-query';
import { Conversation } from '../../../..';
import * as i18n from './translations';
import * as i18nModel from '../../../connectorland/models/model_selector/translations';
Expand All @@ -37,6 +38,7 @@ export interface ConversationSettingsEditorProps {
React.SetStateAction<ConversationsBulkActions>
>;
onSelectedConversationChange: (conversation?: Conversation) => void;
refetchConversations?: () => Promise<QueryObserverResult<Record<string, Conversation>, unknown>>;
}

/**
Expand All @@ -53,6 +55,7 @@ export const ConversationSettingsEditor: React.FC<ConversationSettingsEditorProp
conversationsSettingsBulkActions,
setConversationsSettingsBulkActions,
onSelectedConversationChange,
refetchConversations,
}) => {
const { data: connectors, isSuccess: areConnectorsFetched } = useLoadConnectors({
http,
Expand Down Expand Up @@ -276,6 +279,7 @@ export const ConversationSettingsEditor: React.FC<ConversationSettingsEditorProp
conversation={selectedConversation}
isDisabled={isDisabled}
onSystemPromptSelectionChange={handleOnSystemPromptSelectionChange}
refetchConversations={refetchConversations}
selectedPrompt={selectedSystemPrompt}
isSettingsModalVisible={true}
setIsSettingsModalVisible={noop} // noop, already in settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ const ConversationSettingsManagementComponent: React.FC<Props> = ({
conversationsSettingsBulkActions={conversationsSettingsBulkActions}
http={http}
isDisabled={isDisabled}
refetchConversations={refetchConversations}
selectedConversation={selectedConversation}
setConversationSettings={setConversationSettings}
setConversationsSettingsBulkActions={setConversationsSettingsBulkActions}
Expand Down
15 changes: 13 additions & 2 deletions x-pack/packages/kbn-elastic-assistant/impl/assistant/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ const AssistantComponent: React.FC<Props> = ({
conversations[WELCOME_CONVERSATION_TITLE] ??
getDefaultConversation({ cTitle: WELCOME_CONVERSATION_TITLE });

// updated selected system prompt
setEditingSystemPromptId(
getDefaultSystemPrompt({
allSystemPrompts,
conversation: conversationToReturn,
})?.id
);
if (
prev &&
prev.id === conversationToReturn.id &&
Expand All @@ -273,6 +280,7 @@ const AssistantComponent: React.FC<Props> = ({
});
}
}, [
allSystemPrompts,
areConnectorsFetched,
conversationTitle,
conversations,
Expand Down Expand Up @@ -647,6 +655,7 @@ const AssistantComponent: React.FC<Props> = ({
actionTypeId: (defaultConnector?.actionTypeId as string) ?? '.gen-ai',
provider: apiConfig?.apiProvider,
model: apiConfig?.defaultModel,
defaultSystemPromptId: allSystemPrompts.find((sp) => sp.isNewConversationDefault)?.id,
},
});
},
Expand All @@ -665,14 +674,14 @@ const AssistantComponent: React.FC<Props> = ({

useEffect(() => {
(async () => {
if (areConnectorsFetched && currentConversation?.id === '') {
if (areConnectorsFetched && currentConversation?.id === '' && !isLoadingPrompts) {
const conversation = await mutateAsync(currentConversation);
if (currentConversation.id === '' && conversation) {
setCurrentConversationId(conversation.id);
}
}
})();
}, [areConnectorsFetched, currentConversation, mutateAsync]);
}, [areConnectorsFetched, currentConversation, isLoadingPrompts, mutateAsync]);

const handleCreateConversation = useCallback(async () => {
const newChatExists = find(conversations, ['title', NEW_CHAT]);
Expand Down Expand Up @@ -791,6 +800,7 @@ const AssistantComponent: React.FC<Props> = ({
isSettingsModalVisible={isSettingsModalVisible}
setIsSettingsModalVisible={setIsSettingsModalVisible}
allSystemPrompts={allSystemPrompts}
refetchConversations={refetchResults}
/>
</EuiFlexItem>
<EuiFlexItem grow={false}>
Expand Down Expand Up @@ -823,6 +833,7 @@ const AssistantComponent: React.FC<Props> = ({
handleOnSystemPromptSelectionChange,
isSettingsModalVisible,
isWelcomeSetup,
refetchResults,
]);

return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,29 @@ import { getOptions, getOptionFromPrompt } from './helpers';
describe('helpers', () => {
describe('getOptionFromPrompt', () => {
it('returns an EuiSuperSelectOption with the correct value', () => {
const option = getOptionFromPrompt({ ...mockSystemPrompt });
const option = getOptionFromPrompt({ ...mockSystemPrompt, isCleared: false });

expect(option.value).toBe(mockSystemPrompt.id);
});

it('returns an EuiSuperSelectOption with the correct inputDisplay', () => {
const option = getOptionFromPrompt({ ...mockSystemPrompt });
const option = getOptionFromPrompt({ ...mockSystemPrompt, isCleared: false });

render(<>{option.inputDisplay}</>);

expect(screen.getByTestId('systemPromptText')).toHaveTextContent(mockSystemPrompt.name);
});

it('shows the expected name in the dropdownDisplay', () => {
const option = getOptionFromPrompt({ ...mockSystemPrompt });
const option = getOptionFromPrompt({ ...mockSystemPrompt, isCleared: false });

render(<TestProviders>{option.dropdownDisplay}</TestProviders>);

expect(screen.getByTestId('name')).toHaveTextContent(mockSystemPrompt.name);
});

it('shows the expected prompt content in the dropdownDisplay', () => {
const option = getOptionFromPrompt({ ...mockSystemPrompt });
const option = getOptionFromPrompt({ ...mockSystemPrompt, isCleared: false });

render(<TestProviders>{option.dropdownDisplay}</TestProviders>);

Expand All @@ -51,7 +51,7 @@ describe('helpers', () => {
const prompts = [mockSystemPrompt, mockSuperheroSystemPrompt];
const promptIds = prompts.map(({ id }) => id);

const options = getOptions({ prompts });
const options = getOptions({ prompts, isCleared: false });
const optionValues = options.map(({ value }) => value);

expect(optionValues).toEqual(promptIds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,38 @@ import styled from '@emotion/styled';
import { isEmpty } from 'lodash/fp';
import { euiThemeVars } from '@kbn/ui-theme';
import { PromptResponse } from '@kbn/elastic-assistant-common';
import { css } from '@emotion/react';
import { EMPTY_PROMPT } from './translations';

const Strong = styled.strong`
margin-right: ${euiThemeVars.euiSizeS};
`;

interface GetOptionFromPromptProps extends PromptResponse {
content: string;
id: string;
name: string;
isCleared: boolean;
}

export const getOptionFromPrompt = ({
content,
id,
isCleared,
name,
}: PromptResponse): EuiSuperSelectOption<string> => ({
}: GetOptionFromPromptProps): EuiSuperSelectOption<string> => ({
value: id,
inputDisplay: <span data-test-subj="systemPromptText">{name}</span>,
inputDisplay: (
<span
data-test-subj="systemPromptText"
// @ts-ignore
css={css`
color: ${isCleared ? euiThemeVars.euiColorLightShade : euiThemeVars.euiColorDarkestShade};
`}
>
{name}
</span>
),
dropdownDisplay: (
<>
<Strong data-test-subj="name">{name}</Strong>
Expand All @@ -41,6 +60,10 @@ export const getOptionFromPrompt = ({

interface GetOptionsProps {
prompts: PromptResponse[] | undefined;
isCleared: boolean;
}
export const getOptions = ({ prompts }: GetOptionsProps): Array<EuiSuperSelectOption<string>> =>
prompts?.map(getOptionFromPrompt) ?? [];
export const getOptions = ({
prompts,
isCleared,
}: GetOptionsProps): Array<EuiSuperSelectOption<string>> =>
prompts?.map((p) => getOptionFromPrompt({ ...p, isCleared })) ?? [];
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
* 2.0.
*/

import React, { useCallback, useMemo } from 'react';
import React, { useCallback, useMemo, useState } from 'react';
import { PromptResponse } from '@kbn/elastic-assistant-common';
import { QueryObserverResult } from '@tanstack/react-query';
import { Conversation } from '../../../..';
import { SelectSystemPrompt } from './select_system_prompt';

Expand All @@ -17,6 +18,7 @@ interface Props {
onSystemPromptSelectionChange: (systemPromptId: string | undefined) => void;
setIsSettingsModalVisible: React.Dispatch<React.SetStateAction<boolean>>;
allSystemPrompts: PromptResponse[];
refetchConversations?: () => Promise<QueryObserverResult<Record<string, Conversation>, unknown>>;
}

const SystemPromptComponent: React.FC<Props> = ({
Expand All @@ -26,20 +28,34 @@ const SystemPromptComponent: React.FC<Props> = ({
onSystemPromptSelectionChange,
setIsSettingsModalVisible,
allSystemPrompts,
refetchConversations,
}) => {
const [isCleared, setIsCleared] = useState(false);
const selectedPrompt = useMemo(() => {
if (editingSystemPromptId !== undefined) {
setIsCleared(false);
return allSystemPrompts.find((p) => p.id === editingSystemPromptId);
} else {
return allSystemPrompts.find((p) => p.id === conversation?.apiConfig?.defaultSystemPromptId);
}
}, [allSystemPrompts, conversation?.apiConfig?.defaultSystemPromptId, editingSystemPromptId]);

const handleClearSystemPrompt = useCallback(() => {
if (conversation) {
if (editingSystemPromptId === undefined) {
setIsCleared(false);
onSystemPromptSelectionChange(
allSystemPrompts.find((p) => p.id === conversation?.apiConfig?.defaultSystemPromptId)?.id
);
} else {
setIsCleared(true);
onSystemPromptSelectionChange(undefined);
}
}, [conversation, onSystemPromptSelectionChange]);
}, [
allSystemPrompts,
conversation?.apiConfig?.defaultSystemPromptId,
editingSystemPromptId,
onSystemPromptSelectionChange,
]);

return (
<SelectSystemPrompt
Expand All @@ -48,6 +64,8 @@ const SystemPromptComponent: React.FC<Props> = ({
conversation={conversation}
data-test-subj="systemPrompt"
isClearable={true}
isCleared={isCleared}
refetchConversations={refetchConversations}
isSettingsModalVisible={isSettingsModalVisible}
onSystemPromptSelectionChange={onSystemPromptSelectionChange}
selectedPrompt={selectedPrompt}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
PromptResponse,
PromptTypeEnum,
} from '@kbn/elastic-assistant-common/impl/schemas/prompts/bulk_crud_prompts_route.gen';
import { QueryObserverResult } from '@tanstack/react-query';
import { Conversation } from '../../../../..';
import { getOptions } from '../helpers';
import * as i18n from '../translations';
Expand All @@ -38,6 +39,7 @@ export interface Props {
selectedPrompt: PromptResponse | undefined;
clearSelectedSystemPrompt?: () => void;
isClearable?: boolean;
isCleared?: boolean;
isDisabled?: boolean;
isOpen?: boolean;
isSettingsModalVisible: boolean;
Expand All @@ -46,6 +48,7 @@ export interface Props {
onSelectedConversationChange?: (result: Conversation) => void;
setConversationSettings?: React.Dispatch<React.SetStateAction<Record<string, Conversation>>>;
setConversationsSettingsBulkActions?: React.Dispatch<Record<string, Conversation>>;
refetchConversations?: () => Promise<QueryObserverResult<Record<string, Conversation>, unknown>>;
}

const ADD_NEW_SYSTEM_PROMPT = 'ADD_NEW_SYSTEM_PROMPT';
Expand All @@ -57,8 +60,10 @@ const SelectSystemPromptComponent: React.FC<Props> = ({
selectedPrompt,
clearSelectedSystemPrompt,
isClearable = false,
isCleared = false,
isDisabled = false,
isOpen = false,
refetchConversations,
isSettingsModalVisible,
onSystemPromptSelectionChange,
setIsSettingsModalVisible,
Expand Down Expand Up @@ -89,10 +94,11 @@ const SelectSystemPromptComponent: React.FC<Props> = ({
defaultSystemPromptId: promptId,
},
});
await refetchConversations?.();
return result;
}
},
[conversation, setApiConfig]
[conversation, refetchConversations, setApiConfig]
);

const addNewSystemPrompt = useMemo(() => {
Expand All @@ -116,7 +122,10 @@ const SelectSystemPromptComponent: React.FC<Props> = ({
}, []);

// SuperSelect State/Actions
const options = useMemo(() => getOptions({ prompts: allSystemPrompts }), [allSystemPrompts]);
const options = useMemo(
() => getOptions({ prompts: allSystemPrompts, isCleared }),
[allSystemPrompts, isCleared]
);

const onChange = useCallback(
async (selectedSystemPromptId) => {
Expand Down Expand Up @@ -160,9 +169,8 @@ const SelectSystemPromptComponent: React.FC<Props> = ({
);

const clearSystemPrompt = useCallback(() => {
setSelectedSystemPrompt(undefined);
clearSelectedSystemPrompt?.();
}, [clearSelectedSystemPrompt, setSelectedSystemPrompt]);
}, [clearSelectedSystemPrompt]);

return (
<EuiFlexGroup
Expand Down Expand Up @@ -226,10 +234,14 @@ const SelectSystemPromptComponent: React.FC<Props> = ({
inline-size: 16px;
block-size: 16px;
border-radius: 16px;
background: ${euiThemeVars.euiColorMediumShade};
background: ${isCleared
? euiThemeVars.euiColorLightShade
: euiThemeVars.euiColorMediumShade};
:hover:not(:disabled) {
background: ${euiThemeVars.euiColorMediumShade};
background: ${isCleared
? euiThemeVars.euiColorLightShade
: euiThemeVars.euiColorMediumShade};
transform: none;
}
Expand Down
Loading

0 comments on commit dc11f75

Please sign in to comment.