Skip to content

Commit

Permalink
refactor: 🎨 update model alias handling logic (#952)
Browse files Browse the repository at this point in the history
* refactor: 🎨 update model alias handling logic

* feat: add model specification to agents 🤖

* docs: ✏️ update model aliases section in docs

* refactor: ♻️ update footer message for AI content warning

* refactor: ♻️ update agent system references in script
  • Loading branch information
pelikhan authored Dec 16, 2024
1 parent 8633332 commit 66b0f26
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 83 deletions.
6 changes: 4 additions & 2 deletions docs/src/content/docs/reference/scripts/model-aliases.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ An model alias can reference another alias as long as cycles are not created.

## Builtin aliases

By default, GenAIScript supports the following model aliases:
By default, GenAIScript supports the following model aliases, and various candidates
in different LLM providers.

- `large`: `gpt-4o like` model
- `small`: `gpt-4o-mini` model or similar. A smaller, cheaper faster model
- `vision`: `gpt-4o-mini`. A model that can analyze images
- `reasoning`: `o1` or `o1-preview`.
- `reasoning-small`: `o1-mini`.

The following aliases are also set so that you can override LLMs used by GenAIScript itself.

- `reasoning`: `large`. In the future, `o1` like models.
- `agent`: `large`. Model used by the Agent LLM.
- `memory`: `small`. Moel used by the agent short term memory.
49 changes: 40 additions & 9 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { lstat, readFile, unlink, writeFile } from "node:fs/promises"
import { ensureDir, exists, existsSync, remove } from "fs-extra"
import { resolve, dirname } from "node:path"
import { glob } from "glob"
import { debug, error, info, isQuiet, warn } from "./log"
import { debug, error, info, warn } from "./log"
import { execa } from "execa"
import { join } from "node:path"
import { createNodePath } from "./nodepath"
Expand All @@ -17,8 +17,7 @@ import {
parseTokenFromEnv,
} from "../../core/src/connection"
import {
DEFAULT_MODEL,
DEFAULT_TEMPERATURE,
DEFAULT_LARGE_MODEL,
MODEL_PROVIDER_AZURE_OPENAI,
SHELL_EXEC_TIMEOUT,
MODEL_PROVIDER_OLLAMA,
Expand All @@ -33,6 +32,14 @@ import {
DEFAULT_VISION_MODEL,
LARGE_MODEL_ID,
SMALL_MODEL_ID,
DEFAULT_SMALL_MODEL_CANDIDATES,
DEFAULT_LARGE_MODEL_CANDIDATES,
DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
DEFAULT_VISION_MODEL_CANDIDATES,
DEFAULT_REASONING_MODEL,
DEFAULT_REASONING_SMALL_MODEL,
DEFAULT_REASONING_SMALL_MODEL_CANDIDATES,
DEFAULT_REASONING_MODEL_CANDIDATES,
} from "../../core/src/constants"
import { tryReadText } from "../../core/src/fs"
import {
Expand Down Expand Up @@ -71,7 +78,6 @@ import {
} from "../../core/src/azurecontentsafety"
import { resolveGlobalConfiguration } from "../../core/src/config"
import { HostConfiguration } from "../../core/src/hostconfiguration"
import { YAMLStringify } from "../../core/src/yaml"

class NodeServerManager implements ServerManager {
async start(): Promise<void> {
Expand Down Expand Up @@ -171,11 +177,36 @@ export class NodeHost implements RuntimeHost {
Omit<ModelConfigurations, "large" | "small" | "vision" | "embeddings">
> = {
default: {
large: { model: DEFAULT_MODEL, source: "default" },
small: { model: DEFAULT_SMALL_MODEL, source: "default" },
vision: { model: DEFAULT_VISION_MODEL, source: "default" },
embeddings: { model: DEFAULT_EMBEDDINGS_MODEL, source: "default" },
reasoning: { model: LARGE_MODEL_ID, source: "default" },
large: {
model: DEFAULT_LARGE_MODEL,
source: "default",
candidates: DEFAULT_LARGE_MODEL_CANDIDATES,
},
small: {
model: DEFAULT_SMALL_MODEL,
source: "default",
candidates: DEFAULT_SMALL_MODEL_CANDIDATES,
},
vision: {
model: DEFAULT_VISION_MODEL,
source: "default",
candidates: DEFAULT_VISION_MODEL_CANDIDATES,
},
embeddings: {
model: DEFAULT_EMBEDDINGS_MODEL,
source: "default",
candidates: DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
},
reasoning: {
model: DEFAULT_REASONING_MODEL,
source: "default",
candidates: DEFAULT_REASONING_MODEL_CANDIDATES,
},
["reasoning-small"]: {
model: DEFAULT_REASONING_SMALL_MODEL,
source: "default",
candidates: DEFAULT_REASONING_SMALL_MODEL_CANDIDATES,
},
agent: { model: LARGE_MODEL_ID, source: "default" },
memory: { model: SMALL_MODEL_ID, source: "default" },
},
Expand Down
4 changes: 2 additions & 2 deletions packages/cli/src/parse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { YAMLParse, YAMLStringify } from "../../core/src/yaml"
import { resolveTokenEncoder } from "../../core/src/encoders"
import {
CSV_REGEX,
DEFAULT_MODEL,
DEFAULT_LARGE_MODEL,
INI_REGEX,
JSON5_REGEX,
MD_REGEX,
Expand Down Expand Up @@ -204,7 +204,7 @@ export async function parseTokens(
filesGlobs: string[],
options: { excludedFiles: string[]; model: string }
) {
const { model = DEFAULT_MODEL } = options || {}
const { model = DEFAULT_LARGE_MODEL } = options || {}
const { encode: encoder } = await resolveTokenEncoder(model)

const files = await expandFiles(filesGlobs, options?.excludedFiles)
Expand Down
24 changes: 19 additions & 5 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ export const SMALL_MODEL_ID = "small"
export const LARGE_MODEL_ID = "large"
export const VISION_MODEL_ID = "vision"
export const DEFAULT_FENCE_FORMAT: FenceFormat = "xml"
export const DEFAULT_MODEL = "openai:gpt-4o"
export const DEFAULT_MODEL_CANDIDATES = [
export const DEFAULT_LARGE_MODEL = "openai:gpt-4o"
export const DEFAULT_LARGE_MODEL_CANDIDATES = [
"azure_serverless:gpt-4o",
DEFAULT_MODEL,
DEFAULT_LARGE_MODEL,
"google:gemini-1.5-pro-latest",
"anthropic:claude-2.1",
"mistral:mistral-large-latest",
Expand All @@ -69,7 +69,7 @@ export const DEFAULT_MODEL_CANDIDATES = [
export const DEFAULT_VISION_MODEL = "openai:gpt-4o"
export const DEFAULT_VISION_MODEL_CANDIDATES = [
"azure_serverless:gpt-4o",
DEFAULT_MODEL,
DEFAULT_VISION_MODEL,
"google:gemini-1.5-flash-latest",
"anthropic:claude-2.1",
"github:gpt-4o",
Expand All @@ -91,6 +91,20 @@ export const DEFAULT_EMBEDDINGS_MODEL_CANDIDATES = [
"github:text-embedding-3-small",
"client:text-embedding-3-small",
]
export const DEFAULT_REASONING_SMALL_MODEL = "openai:o1-mini"
export const DEFAULT_REASONING_SMALL_MODEL_CANDIDATES = [
"azure_serverless:o1-mini",
DEFAULT_REASONING_SMALL_MODEL,
"github:o1-mini",
"client:o1-mini",
]
export const DEFAULT_REASONING_MODEL = "openai:o1"
export const DEFAULT_REASONING_MODEL_CANDIDATES = [
"azure_serverless:o1-preview",
DEFAULT_REASONING_MODEL,
"github:o1-preview",
"client:o1-preview",
]
export const DEFAULT_EMBEDDINGS_MODEL = "openai:text-embedding-ada-002"
export const DEFAULT_TEMPERATURE = 0.8
export const BUILTIN_PREFIX = "_builtin/"
Expand Down Expand Up @@ -329,4 +343,4 @@ export const IMAGE_DETAIL_LOW_HEIGHT = 512

export const MIN_LINE_NUMBER_LENGTH = 10

export const VSCODE_SERVER_MAX_RETRIES = 5
export const VSCODE_SERVER_MAX_RETRIES = 5
7 changes: 1 addition & 6 deletions packages/core/src/git.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
// It includes functionality to find modified files, execute Git commands, and manage branches.

import { uniq } from "es-toolkit"
import {
DEFAULT_MODEL,
GIT_DIFF_MAX_TOKENS,
GIT_IGNORE_GENAI,
GIT_LOG_COUNT,
} from "./constants"
import { GIT_DIFF_MAX_TOKENS, GIT_IGNORE_GENAI } from "./constants"
import { llmifyDiff } from "./diff"
import { resolveFileContents } from "./file"
import { readText } from "./fs"
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/github.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ export function generatedByFooter(
info: { runUrl?: string },
code?: string
) {
return `\n\n> generated by ${link(script.id, info.runUrl)}${code ? ` \`${code}\` ` : ""}\n\n`
return `\n\n> AI-generated content ${link(script.id, info.runUrl)}${code ? ` \`${code}\` ` : ""} may be incorrect\n\n`
}

export function appendGeneratedComment(
Expand Down Expand Up @@ -544,7 +544,7 @@ export class GitHubClient implements GitHub {
auth,
ref,
refName,
issueNumber: issue
issueNumber: issue,
})
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export interface AzureTokenResolver {
export type ModelConfiguration = Readonly<
Pick<ModelOptions, "model" | "temperature"> & {
source: "cli" | "env" | "config" | "default"
candidates?: string[]
}
>

Expand Down
71 changes: 23 additions & 48 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import { uniq } from "es-toolkit"
import {
DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
DEFAULT_MODEL_CANDIDATES,
DEFAULT_SMALL_MODEL_CANDIDATES,
DEFAULT_VISION_MODEL_CANDIDATES,
LARGE_MODEL_ID,
MODEL_PROVIDER_LLAMAFILE,
MODEL_PROVIDER_OPENAI,
SMALL_MODEL_ID,
VISION_MODEL_ID,
} from "./constants"
import { errorMessage } from "./error"
import { LanguageModelConfiguration, host, runtimeHost } from "./host"
import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace"
import { arrayify, assert, logVerbose, toStringList } from "./util"
import { arrayify, assert, toStringList } from "./util"

/**
* model
Expand Down Expand Up @@ -117,55 +111,32 @@ export async function resolveModelConnectionInfo(
options?: {
model?: string
token?: boolean
candidates?: string[]
} & TraceOptions &
AbortSignalOptions
): Promise<{
info: ModelConnectionInfo
configuration?: LanguageModelConfiguration
}> {
const { trace, token: askToken, signal } = options || {}
const hint = options?.model || conn.model || ""
let candidates = options?.candidates
let m = hint
if (m === SMALL_MODEL_ID) {
m = undefined
candidates ??= [
runtimeHost.modelAliases.small.model,
...DEFAULT_SMALL_MODEL_CANDIDATES,
]
} else if (m === VISION_MODEL_ID) {
m = undefined
candidates ??= [
runtimeHost.modelAliases.vision.model,
...DEFAULT_VISION_MODEL_CANDIDATES,
]
} else if (m === LARGE_MODEL_ID) {
m = undefined
candidates ??= [
runtimeHost.modelAliases.large.model,
...DEFAULT_MODEL_CANDIDATES,
]
}
candidates ??= [
runtimeHost.modelAliases.large.model,
...DEFAULT_MODEL_CANDIDATES,
]

const { modelAliases } = runtimeHost
const hint = options?.model || conn.model
// supports candidate if no model hint or hint is a model alias
const supportsCandidates = !hint || !!modelAliases[hint]
let modelId = hint || LARGE_MODEL_ID
let candidates: string[]
// recursively resolve model aliases
if (m) {
const seen = [m]
const modelAliases = runtimeHost.modelAliases
while (modelAliases[m]) {
const alias = modelAliases[m].model
if (seen.includes(alias))
{
const seen: string[] = []
while (modelAliases[modelId]) {
const { model: id, candidates: c } = modelAliases[modelId]
if (seen.includes(id))
throw new Error(
`Circular model alias: ${alias}, seen ${[...seen].join(",")}`
`Circular model alias: ${id}, seen ${[...seen].join(",")}`
)
m = alias
seen.push(m)
seen.push(modelId)
modelId = id
if (supportsCandidates) candidates = c
}
if (seen.length > 1) logVerbose(`model_aliases: ${seen.join(" -> ")}`)
}

const resolveModel = async (
Expand Down Expand Up @@ -214,10 +185,14 @@ export async function resolveModelConnectionInfo(
}
}

if (m) {
return await resolveModel(m, { withToken: askToken, reportError: true })
if (!supportsCandidates) {
return await resolveModel(modelId, {
withToken: askToken,
reportError: true,
})
} else {
for (const candidate of uniq(candidates).filter((c) => !!c)) {
candidates = uniq([modelId, ...(candidates || [])].filter((c) => !!c))
for (const candidate of candidates) {
const res = await resolveModel(candidate, {
withToken: askToken,
reportError: false,
Expand Down
6 changes: 2 additions & 4 deletions packages/core/src/testhost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import {
import { TraceOptions } from "./trace"
import {
DEFAULT_EMBEDDINGS_MODEL,
DEFAULT_MODEL,
DEFAULT_LARGE_MODEL,
DEFAULT_SMALL_MODEL,
DEFAULT_TEMPERATURE,
DEFAULT_VISION_MODEL,
} from "./constants"
import {
Expand All @@ -38,7 +37,6 @@ import {
} from "node:path"
import { LanguageModel } from "./chat"
import { NotSupportedError } from "./error"
import { HostConfiguration } from "./hostconfiguration"
import { Project } from "./server/messages"

// Function to create a frozen object representing Node.js path methods
Expand Down Expand Up @@ -73,7 +71,7 @@ export class TestHost implements RuntimeHost {

// Default options for language models
readonly modelAliases: ModelConfigurations = {
large: { model: DEFAULT_MODEL, source: "default" },
large: { model: DEFAULT_LARGE_MODEL, source: "default" },
small: { model: DEFAULT_SMALL_MODEL, source: "default" },
vision: { model: DEFAULT_VISION_MODEL, source: "default" },
embeddings: { model: DEFAULT_EMBEDDINGS_MODEL, source: "default" },
Expand Down
4 changes: 0 additions & 4 deletions packages/core/src/vectorsearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ export async function vectorSearch(
},
{
token: true,
candidates: [
runtimeHost.modelAliases.embeddings.model,
...DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
],
}
)
if (info.error) throw new Error(info.error)
Expand Down
1 change: 1 addition & 0 deletions packages/sample/genaisrc/github-agent.genai.mts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ script({
"agent_interpreter",
"agent_docs",
],
model: "reasoning",
parameters: {
jobUrl: { type: "string" }, // URL of the job
workflow: { type: "string" }, // Workflow name
Expand Down
1 change: 1 addition & 0 deletions packages/sample/genaisrc/prd-agent.genai.mts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ script({
description: "Generate a pull request description from the git diff",
tools: ["agent_fs", "agent_git"],
temperature: 0.5,
model: "reasoning-small",
})

$`You are an expert software developer and architect.
Expand Down
2 changes: 1 addition & 1 deletion packages/sample/genaisrc/samples/copilotchat.genai.mts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ script({
"system.agent_github",
"system.agent_interpreter",
"system.agent_docs",
"system.agent_vision",
"system.agent_web",
"system.vision_ask_image",
],
group: "copilot", // Group categorization for the script
parameters: {
Expand Down

0 comments on commit 66b0f26

Please sign in to comment.