Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js/ai/prompt): added prepare fn option to definePrompt #1779

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ export {
isExecutablePrompt,
loadPromptFolder,
prompt,
type DocsResolver,
type ExecutablePrompt,
type MessagesResolver,
type PartsResolver,
type PromptAction,
type PromptConfig,
type PromptGenerateOptions,
type PromptPrepare,
} from './prompt.js';
export {
rerank,
Expand Down
43 changes: 42 additions & 1 deletion js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ export interface PromptConfig<
toolChoice?: ToolChoice;
use?: ModelMiddleware[];
context?: ActionContext;
prepare?: PromptPrepare<I, O, CustomOptions>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should mark this with @beta in docs and add an assertUnstable to the code that handles it. The main concern I have about marking this stable is that I wonder if it should be allowed to manipulate the input, not just context, and that would cause weirdness with types.

}

/**
Expand Down Expand Up @@ -208,6 +209,34 @@ export type DocsResolver<I, S = any> = (
}
) => DocumentData[] | Promise<DocumentData[]>;

/**
* A function that can be passes to the prompt that can produce additional propmpt render options
* like context, docs, or any other propmpt option really:
Comment on lines +213 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* A function that can be passes to the prompt that can produce additional propmpt render options
* like context, docs, or any other propmpt option really:
* A function that can validate, fetch, and populate additional information necessary for the prompt to execute.
* The `prepare` function can be used to retrieve documents, add personalized user information to the context,
* or generally perform any other work necessary to configure the prompt based on runtime input.

*
* ```ts
* const ragPrompt = ai.definePrompt({
* prepare: async (input, {context}) => ({
* context: {...context, userInfo: await fetchUserInfo(context.auth.uid)},
* docs: await myRetriever({query: input.query}),
* })
* })
* ```
*/
export type PromptPrepare<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
S = any,
> = (
input: I,
options: {
state?: S;
context: ActionContext;
}
) =>
| PromptGenerateOptions<O, CustomOptions>
| PromiseLike<PromptGenerateOptions<O, CustomOptions>>;

interface PromptCache {
userPrompt?: PromptFunction;
system?: PromptFunction;
Expand Down Expand Up @@ -250,10 +279,22 @@ function definePromptAsync<
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
): Promise<GenerateOptions> => {
const messages: MessageData[] = [];
renderOptions = { ...renderOptions }; // make a copy, we will be trimming
const session = getCurrentSession(registry);
const resolvedOptions = await optionsPromise;

// if prepare option is set, invoke it and merge with renderOptions
const preparedOptions = resolvedOptions?.prepare
? await resolvedOptions?.prepare(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
})
: {};
// make a copy of renderOptions, we will be trimming
renderOptions = {
...renderOptions,
...preparedOptions,
};

// order of these matters:
await renderSystemPrompt(
registry,
Expand Down
42 changes: 42 additions & 0 deletions js/ai/tests/prompt/prompt_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,48 @@ describe('prompt', () => {
tools: ['toolA'],
},
},
{
name: 'invoked prepare fn and merges with render options',
prompt: {
model: 'echoModel',
name: 'prompt1',
config: { banana: 'ripe' },
input: { schema: z.object({ name: z.string() }) },
prompt: 'hello {{@foo}} {{@baz}} ({{@state.name}})',
tools: ['toolA'],
prepare: async (input, { context }) => ({
context: { ...context, foo: 'bar' },
docs: [Document.fromText('doc txt')],
}),
},
input: { name: 'foo' },
state: { name: 'bar' },
inputOptions: { config: { temperature: 11 }, context: { baz: 'aux' } },
wantTextOutput:
'Echo: hello bar aux (bar),\n' +
'\n' +
'Use the following information to complete your task:\n' +
'\n' +
'- [0]: doc txt\n' +
'\n' +
'; config: {"banana":"ripe","temperature":11}',
wantRendered: {
context: {
baz: 'aux',
foo: 'bar',
},
docs: [Document.fromText('doc txt')],
config: {
banana: 'ripe',
temperature: 11,
},
messages: [
{ content: [{ text: 'hello bar aux (bar)' }], role: 'user' },
],
model: 'echoModel',
tools: ['toolA'],
},
},
];

basicTests = basicTests.find((t) => t.only)
Expand Down
4 changes: 4 additions & 0 deletions js/genkit/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export {
indexerRef,
rerankerRef,
retrieverRef,
type DocsResolver,
type DocumentData,
type EmbedderAction,
type EmbedderArgument,
Expand Down Expand Up @@ -79,13 +80,16 @@ export {
type LlmStats,
type MediaPart,
type MessageData,
type MessagesResolver,
type ModelArgument,
type ModelReference,
type ModelRequest,
type ModelResponseData,
type Part,
type PartsResolver,
type PromptAction,
type PromptConfig,
type PromptPrepare,
type RankedDocument,
type RerankerAction,
type RerankerArgument,
Expand Down
Loading