diff --git a/bun.lockb b/bun.lockb index 45369b4..6fd8d44 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/docs/developers/01_getting_started.md b/docs/developers/01_getting_started.md index f5cd725..e732354 100644 --- a/docs/developers/01_getting_started.md +++ b/docs/developers/01_getting_started.md @@ -22,9 +22,9 @@ - [`packages/storage`](#packagesstorage) - [`packages/ai`](#packagesai) - [`packages/ai-provider`](#packagesai-provider) - - [`examples/cli`](#samplescli) - - [`examples/web`](#samplesweb) - - [`examples/ngraph`](#samplesngraph) + - [`examples/cli`](#examplescli) + - [`examples/web`](#examplesweb) + - [`examples/ngraph`](#examplesngraph) # Developer Getting Started @@ -51,13 +51,13 @@ After this, plese read [Architecture](02_architecture.md) before attempting to [ ```ts import { TaskGraphBuilder } from "ellmers-core"; -import { registerHuggingfaceLocalTasksInMemory } from "ellmers-ai-provider/hf-transformers/server"; +import { registerHuggingfaceLocalTasksInMemory } from "ellmers-test"; // config and start up registerHuggingfaceLocalTasksInMemory(); const builder = new TaskGraphBuilder(); builder - .DownloadModel({ model: "Xenova/LaMini-Flan-T5-783M" }) + .DownloadModel({ model: "ONNX Xenova/LaMini-Flan-T5-783M q8" }) .TextRewriter({ text: "The quick brown fox jumps over the lazy dog.", prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"], @@ -79,15 +79,17 @@ import { DataFlow, TaskGraph, TaskGraphRunner, - registerHuggingfaceLocalTasksInMemory, } from "ellmers-core"; +import { registerHuggingfaceLocalTasksInMemory } from "ellmers-test"; // config and start up registerHuggingfaceLocalTasksInMemory(); // build and run graph const graph = new TaskGraph(); -graph.addTask(new DownloadModel({ id: "1", input: { model: "Xenova/LaMini-Flan-T5-783M" } })); +graph.addTask( + new DownloadModel({ id: "1", input: { model: "ONNX Xenova/LaMini-Flan-T5-783M q8" } }) +); graph.addTask( new TextRewriterCompoundTask({ id: "2", @@ -284,7 +286,7 @@ There is a JSONTask that can be used to build a graph. This is useful for saving "id": "1", "type": "DownloadModelCompoundTask", "input": { - "model": ["Xenova/LaMini-Flan-T5-783M", "Xenova/m2m100_418M"] + "model": ["ONNX Xenova/LaMini-Flan-T5-783M q8", "ONNX Xenova/m2m100_418M q8"] } }, { @@ -305,7 +307,7 @@ There is a JSONTask that can be used to build a graph. This is useful for saving "id": "3", "type": "TextTranslationCompoundTask", "input": { - "model": "Xenova/m2m100_418M", + "model": "ONNX Xenova/m2m100_418M q8", "source": "en", "target": "es" }, diff --git a/docs/developers/02_architecture.md b/docs/developers/02_architecture.md index 167e5f0..57c0994 100644 --- a/docs/developers/02_architecture.md +++ b/docs/developers/02_architecture.md @@ -123,7 +123,7 @@ classDiagram static TaskOutputDefinition[] outputs$ static readonly sideeffects = false$ run() TaskOutput - runSyncOnly() TaskOutput + runReactive() TaskOutput } <> TaskBase style TaskBase type:abstract,stroke-dasharray: 5 5 diff --git a/docs/developers/03_extending.md b/docs/developers/03_extending.md index ad7ec18..128563a 100644 --- a/docs/developers/03_extending.md +++ b/docs/developers/03_extending.md @@ -147,4 +147,4 @@ Compound Tasks are not cached (though any or all of their children may be). ## Reactive Task UIs -Tasks can be reactive at a certain level. This means that they can be triggered by changes in the data they depend on, without "running" the expensive job based task runs. This is useful for a UI node editor. For example, you change a color in one task and it is propagated downstream without incurring costs for re-running the entire graph. It is like a spreadsheet where changing a cell can trigger a recalculation of other cells. This is implemented via a `runSyncOnly()` method that is called when the data changes. Typically, the `run()` will call `runSyncOnly()` on itself at the end of the method. +Tasks can be reactive at a certain level. This means that they can be triggered by changes in the data they depend on, without "running" the expensive job based task runs. This is useful for a UI node editor. For example, you change a color in one task and it is propagated downstream without incurring costs for re-running the entire graph. It is like a spreadsheet where changing a cell can trigger a recalculation of other cells. This is implemented via a `runReactive()` method that is called when the data changes. Typically, the `run()` will call `runReactive()` on itself at the end of the method. diff --git a/examples/cli/src/TaskCLI.ts b/examples/cli/src/TaskCLI.ts index d29cb28..0a2b692 100644 --- a/examples/cli/src/TaskCLI.ts +++ b/examples/cli/src/TaskCLI.ts @@ -9,38 +9,23 @@ import { Command } from "commander"; import { runTask } from "./TaskStreamToListr2"; import "@huggingface/transformers"; import { TaskGraph, JsonTask, TaskGraphBuilder, JsonTaskItem } from "ellmers-core"; - -import { - DownloadModelTask, - DownloadModelCompoundTask, - findAllModels, - findModelByName, - findModelByUseCase, - ModelUseCaseEnum, -} from "ellmers-ai"; +import { DownloadModelTask, getGlobalModelRepository } from "ellmers-ai"; import "ellmers-task"; export function AddBaseCommands(program: Command) { program .command("download") .description("download models") - .option("--model ", "model to download") + .requiredOption("--model ", "model to download") .action(async (options) => { - const models = findAllModels(); const graph = new TaskGraph(); if (options.model) { - const model = findModelByName(options.model); + const model = await getGlobalModelRepository().findByName(options.model); if (model) { graph.addTask(new DownloadModelTask({ input: { model: model.name } })); } else { program.error(`Unknown model ${options.model}`); } - } else { - graph.addTask( - new DownloadModelCompoundTask({ - input: { model: models.map((m) => m.name) }, - }) - ); } await runTask(graph); }); @@ -52,8 +37,11 @@ export function AddBaseCommands(program: Command) { .option("--model ", "model to use") .action(async (text: string, options) => { const model = options.model - ? findModelByName(options.model)?.name - : findModelByUseCase(ModelUseCaseEnum.TEXT_EMBEDDING).map((m) => m.name); + ? (await getGlobalModelRepository().findByName(options.model))?.name + : (await getGlobalModelRepository().findModelsByTask("TextEmbeddingTask"))?.map( + (m) => m.name + ); + if (!model) { program.error(`Unknown model ${options.model}`); } else { @@ -70,8 +58,10 @@ export function AddBaseCommands(program: Command) { .option("--model ", "model to use") .action(async (text, options) => { const model = options.model - ? findModelByName(options.model)?.name - : findModelByUseCase(ModelUseCaseEnum.TEXT_SUMMARIZATION).map((m) => m.name); + ? (await getGlobalModelRepository().findByName(options.model))?.name + : (await getGlobalModelRepository().findModelsByTask("TextSummaryTask"))?.map( + (m) => m.name + ); if (!model) { program.error(`Unknown model ${options.model}`); } else { @@ -89,8 +79,10 @@ export function AddBaseCommands(program: Command) { .option("--model ", "model to use") .action(async (text, options) => { const model = options.model - ? findModelByName(options.model)?.name - : findModelByUseCase(ModelUseCaseEnum.TEXT_REWRITING).map((m) => m.name); + ? (await getGlobalModelRepository().findByName(options.model))?.name + : (await getGlobalModelRepository().findModelsByTask("TextRewriterTask"))?.map( + (m) => m.name + ); if (!model) { program.error(`Unknown model ${options.model}`); } else { @@ -111,7 +103,7 @@ export function AddBaseCommands(program: Command) { id: "1", type: "DownloadModelTask", input: { - model: "Xenova/LaMini-Flan-T5-783M", + model: "ONNX Xenova/LaMini-Flan-T5-783M q8", }, }, { diff --git a/examples/cli/src/TaskHelper.ts b/examples/cli/src/TaskHelper.ts index 0700220..8a3d426 100644 --- a/examples/cli/src/TaskHelper.ts +++ b/examples/cli/src/TaskHelper.ts @@ -8,6 +8,18 @@ import chalk from "chalk"; import { ListrTaskWrapper } from "listr2"; +/** + * TaskHelper provides CLI progress visualization utilities. + * + * Features: + * - Unicode-based progress bar generation + * - Customizable bar length and progress indication + * - Color-coded output using chalk + * + * Used to create visual feedback for long-running tasks in the CLI interface, + * with smooth progress transitions and clear visual indicators. + */ + export function createBar(progress: number, length: number): string { let distance = progress * length; let bar = ""; @@ -47,11 +59,7 @@ export function createBar(progress: number, length: number): string { // Extend empty bar bar += "\u258F".repeat(length > bar.length ? length - bar.length : 0); - return chalk.rgb( - 70, - 70, - 240 - )("\u2595" + chalk.bgRgb(20, 20, 70)(bar) + "\u258F"); + return chalk.rgb(70, 70, 240)("\u2595" + chalk.bgRgb(20, 20, 70)(bar) + "\u258F"); } export class TaskHelper { diff --git a/examples/cli/src/ellmers.ts b/examples/cli/src/ellmers.ts index b16a53a..98e0e7f 100755 --- a/examples/cli/src/ellmers.ts +++ b/examples/cli/src/ellmers.ts @@ -4,13 +4,21 @@ import { program } from "commander"; import { argv } from "process"; import { AddBaseCommands } from "./TaskCLI"; import { getProviderRegistry } from "ellmers-ai"; -import { registerHuggingfaceLocalTasksInMemory } from "ellmers-ai-provider/hf-transformers/server"; -import { registerMediaPipeTfJsLocalInMemory } from "ellmers-ai-provider/tf-mediapipe/server"; +import { + registerHuggingfaceLocalModels, + registerHuggingfaceLocalTasksInMemory, + registerMediaPipeTfJsLocalInMemory, + registerMediaPipeTfJsLocalModels, +} from "ellmers-test"; +import "ellmers-test"; program.version("1.0.0").description("A CLI to run Ellmers."); AddBaseCommands(program); +await registerHuggingfaceLocalModels(); +await registerMediaPipeTfJsLocalModels(); + registerHuggingfaceLocalTasksInMemory(); registerMediaPipeTfJsLocalInMemory(); diff --git a/examples/web/package.json b/examples/web/package.json index 0d1fa09..738d4f7 100644 --- a/examples/web/package.json +++ b/examples/web/package.json @@ -3,13 +3,13 @@ "version": "0.0.0", "type": "module", "scripts": { - "dev": "concurrently --kill-others -c 'auto' -n app,types 'bunx --bun vite' 'tsc -w --noEmit'", + "dev": "concurrently --kill-others -c 'auto' -n app,types 'bunx --bun vite' 'tsc -w --noEmit --preserveWatchOutput'", "build": "vite build && tsc --noEmit", "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", "preview": "vite preview" }, "dependencies": { - "@xyflow/react": "^12.3.6", + "@xyflow/react": "^12.4.1", "react": "^19.0.0", "react-dom": "^19.0.0", "@uiw/react-codemirror": "^4.23.7", @@ -24,20 +24,21 @@ "ellmers-core": "workspace:packages/core", "ellmers-storage": "workspace:packages/storage", "ellmers-ai-provider": "workspace:packages/ai-provider", - "ellmers-ai": "workspace:packages/ai" + "ellmers-ai": "workspace:packages/ai", + "ellmers-test": "workspace:packages/test" }, "devDependencies": { - "@types/react": "^19.0.4", - "@types/react-dom": "^19.0.2", - "@typescript-eslint/eslint-plugin": "^8.19.1", - "@typescript-eslint/parser": "^8.19.1", + "@types/react": "^19.0.7", + "@types/react-dom": "^19.0.3", + "@typescript-eslint/eslint-plugin": "^8.20.0", + "@typescript-eslint/parser": "^8.20.0", "@vitejs/plugin-react": "^4.3.4", - "eslint": "^9.17.0", + "eslint": "^9.18.0", "eslint-plugin-react-hooks": "^5.1.0", - "eslint-plugin-react-refresh": "^0.4.16", + "eslint-plugin-react-refresh": "^0.4.18", "vite": "^6.0.7", "tailwindcss": "3.4.17", - "postcss": "8.4.49", + "postcss": "8.5.1", "autoprefixer": "10.4.20" }, "engines": { diff --git a/examples/web/src/App.tsx b/examples/web/src/App.tsx index f84fc50..3c65522 100644 --- a/examples/web/src/App.tsx +++ b/examples/web/src/App.tsx @@ -2,7 +2,15 @@ import React, { useCallback, useEffect, useState } from "react"; import { ReactFlowProvider } from "@xyflow/react"; import { RunGraphFlow } from "./RunGraphFlow"; import { JsonEditor } from "./JsonEditor"; -import { JsonTask, JsonTaskItem, TaskGraph, TaskGraphBuilder } from "ellmers-core"; +import { + ConcurrencyLimiter, + JsonTask, + JsonTaskItem, + TaskGraph, + TaskGraphBuilder, + TaskInput, + TaskOutput, +} from "ellmers-core"; import { IndexedDbTaskGraphRepository, IndexedDbTaskOutputRepository, @@ -11,10 +19,37 @@ import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from "./Resize"; import { QueuesStatus } from "./QueueSatus"; import { OutputRepositoryStatus } from "./OutputRepositoryStatus"; import { GraphStoreStatus } from "./GraphStoreStatus"; -import { registerHuggingfaceLocalTasksInMemory } from "ellmers-ai-provider/hf-transformers/browser"; +import { InMemoryJobQueue } from "ellmers-storage/inmemory"; +import { getProviderRegistry } from "ellmers-ai"; +import { + LOCAL_ONNX_TRANSFORMERJS, + registerHuggingfaceLocalTasks, +} from "ellmers-ai-provider/hf-transformers/browser"; +import { + MEDIA_PIPE_TFJS_MODEL, + registerMediaPipeTfJsLocalTasks, +} from "ellmers-ai-provider/tf-mediapipe/browser"; import "ellmers-task"; +import "ellmers-test"; +import { registerMediaPipeTfJsLocalModels } from "ellmers-test"; +import { registerHuggingfaceLocalModels } from "ellmers-test"; + +const ProviderRegistry = getProviderRegistry(); -registerHuggingfaceLocalTasksInMemory(); +registerHuggingfaceLocalTasks(); +ProviderRegistry.registerQueue( + LOCAL_ONNX_TRANSFORMERJS, + new InMemoryJobQueue("local_hft", new ConcurrencyLimiter(1, 10), 10) +); + +registerMediaPipeTfJsLocalTasks(); +ProviderRegistry.registerQueue( + MEDIA_PIPE_TFJS_MODEL, + new InMemoryJobQueue("local_mp", new ConcurrencyLimiter(1, 10), 10) +); + +ProviderRegistry.clearQueues(); +ProviderRegistry.startQueues(); const taskOutputCache = new IndexedDbTaskOutputRepository(); const builder = new TaskGraphBuilder(taskOutputCache); @@ -31,13 +66,13 @@ const graph = await taskGraphRepo.getTaskGraph("default"); const resetGraph = () => { builder .reset() - .DownloadModel({ model: ["Xenova/LaMini-Flan-T5-783M", "Xenova/m2m100_418M"] }) + .DownloadModel({ model: ["ONNX Xenova/LaMini-Flan-T5-783M q8", "ONNX Xenova/m2m100_418M q8"] }) .TextRewriter({ text: "The quick brown fox jumps over the lazy dog.", prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"], }) .TextTranslation({ - model: "Xenova/m2m100_418M", + model: "ONNX Xenova/m2m100_418M q8", source: "en", target: "es", }) @@ -76,6 +111,12 @@ export const App = () => { // changes coming from builder in console useEffect(() => { + async function init() { + await registerHuggingfaceLocalModels(); + await registerMediaPipeTfJsLocalModels(); + } + init(); + function listen() { setJsonData(JSON.stringify(builder.toDependencyJSON(), null, 2)); setGraph(builder.graph); diff --git a/examples/web/src/QueueSatus.tsx b/examples/web/src/QueueSatus.tsx index b12c738..1af21b2 100644 --- a/examples/web/src/QueueSatus.tsx +++ b/examples/web/src/QueueSatus.tsx @@ -1,8 +1,8 @@ import { JobStatus } from "ellmers-core"; -import { ModelProcessorEnum, getProviderRegistry } from "ellmers-ai"; +import { getProviderRegistry } from "ellmers-ai"; import { useCallback, useEffect, useState } from "react"; -export function QueueStatus({ queueType }: { queueType: ModelProcessorEnum }) { +export function QueueStatus({ queueType }: { queueType: string }) { const queue = getProviderRegistry().getQueue(queueType); const [pending, setPending] = useState(0); const [processing, setProcessing] = useState(0); diff --git a/examples/web/src/RunGraphFlow.tsx b/examples/web/src/RunGraphFlow.tsx index a87d9d6..cd20101 100644 --- a/examples/web/src/RunGraphFlow.tsx +++ b/examples/web/src/RunGraphFlow.tsx @@ -14,16 +14,11 @@ import { TurboNodeData, SingleNode, CompoundNode } from "./TurboNode"; import TurboEdge from "./TurboEdge"; import { FiFileText, FiClipboard, FiDownload, FiUpload } from "react-icons/fi"; import { Task, TaskGraph } from "ellmers-core"; -import { registerHuggingfaceLocalTasksInMemory } from "ellmers-ai-provider/hf-transformers/browser"; -import { registerMediaPipeTfJsLocalInMemory } from "ellmers-ai-provider/tf-mediapipe/browser"; import { GraphPipelineCenteredLayout, GraphPipelineLayout, computeLayout } from "./layout"; import "@xyflow/react/dist/base.css"; import "./RunGraphFlow.css"; -registerHuggingfaceLocalTasksInMemory(); -registerMediaPipeTfJsLocalInMemory(); - const categoryIcons = { "Text Model": , Input: , diff --git a/examples/web/src/main.tsx b/examples/web/src/main.tsx index 4f3c534..0e3a945 100644 --- a/examples/web/src/main.tsx +++ b/examples/web/src/main.tsx @@ -1,6 +1,6 @@ import ReactDOM from "react-dom/client"; import { App } from "./App"; -import { TaskGraphBuilder } from "ellmers-core"; +import { TaskGraphBuilder, TaskRegistry } from "ellmers-core"; import "./main.css"; import { TaskConsoleFormatter, @@ -8,6 +8,7 @@ import { TaskGraphBuilderHelperConsoleFormatter, isDarkMode, } from "./ConsoleFormatters"; +import { getGlobalModelRepository } from "ellmers-ai"; ReactDOM.createRoot(document.getElementById("root")!).render( // @@ -39,7 +40,7 @@ console.log( ` %cbuilder.%creset%c(); - builder.%cDownloadModel%c({ %cmodel%c: [%c'Xenova/LaMini-Flan-T5-783M']%c }); + builder.%cDownloadModel%c({ %cmodel%c: [%c'ONNX Xenova/LaMini-Flan-T5-783M q8']%c }); builder.%cTextRewriter%c({ %ctext%c: %c'The quick brown fox jumps over the lazy dog.'%c, %cprompt%c: [%c'Rewrite the following text in reverse:'%c, %c'Rewrite this to sound like a pirate:'%c] }); builder.%crename%c(%c'text'%c, %c'message'%c); builder.%cDebugLog%c({ %clevel%c: %c'info'%c }); @@ -85,3 +86,8 @@ console.log( `color: ${grey}; font-weight: normal;` ); console.log(window["builder"]); + +console.log( + "Tasks Available: ", + Array.from(TaskRegistry.all.entries()).map(([name]) => name) +); diff --git a/examples/web/tsconfig.json b/examples/web/tsconfig.json index 3201c0a..f8d7029 100644 --- a/examples/web/tsconfig.json +++ b/examples/web/tsconfig.json @@ -22,7 +22,8 @@ "ellmers-core": ["../../packages/core/src"], "ellmers-ai-provider": ["../../packages/ai-provider/src"], "ellmers-storage": ["../../packages/storage/src"], - "ellmers-task": ["../../packages/task/src"] + "ellmers-task": ["../../packages/task/src"], + "ellmers-test": ["../../packages/test/src"] } }, "include": ["src"], @@ -31,6 +32,7 @@ { "path": "../../packages/core" }, { "path": "../../packages/task" }, { "path": "../../packages/ai-provider" }, - { "path": "../../packages/storage" } + { "path": "../../packages/storage" }, + { "path": "../../packages/test" } ] } diff --git a/package.json b/package.json index 21ce6be..521ceb8 100644 --- a/package.json +++ b/package.json @@ -9,30 +9,31 @@ ], "scripts": { "build": "bun run build:packages && bun run build:examples", - "build:packages": "bun run build:core && bun run build:storage && bun run build:task && bun run build:ai && bun run build:ai-provider", + "build:packages": "bun run build:core && bun run build:ai && bun run build:storage && bun run build:task && bun run build:ai-provider && bun run build:test", "build:core": "cd packages/core && bun run build", "build:ai": "cd packages/ai && bun run build", "build:ai-provider": "cd packages/ai-provider && bun run build", "build:storage": "cd packages/storage && bun run build", "build:task": "cd packages/task && bun run build", + "build:test": "cd packages/test && bun run build", "build:examples": "bun run bun run build:cli && bun run build:web", "build:cli": "cd examples/cli && bun run build", "build:web": "cd examples/web && bun run build", "clean": "rm -rf node_modules packages/*/node_modules packages/*/dist packages/*/src/**/*\\.d\\.ts packages/*/src/**/*\\.map examples/*/node_modules examples/*/dist examples/*/src/**/*\\.d\\.ts examples/*/src/**/*\\.map", - "watch:packages": "concurrently --kill-others -c 'auto' -n core,task,storage,ai,provider 'cd packages/core && bun run watch' 'sleep 3 && cd packages/task && bun run watch' 'sleep 3 && cd packages/storage && bun run watch' 'sleep 3 && cd packages/ai && bun run watch' 'sleep 6 && cd packages/ai-provider && bun run watch'", + "watch:packages": "concurrently --kill-others -c 'auto' -n core,task,storage,ai,provider,test 'cd packages/core && bun run watch' 'sleep 3 && cd packages/task && bun run watch' 'sleep 3 && cd packages/storage && bun run watch' 'sleep 3 && cd packages/ai && bun run watch' 'sleep 6 && cd packages/ai-provider && bun run watch' 'sleep 10 && cd packages/test && bun run watch'", "docs": "typedoc", "format": "eslint \"{packages|examples}/*/src/**/*.{js,ts,tsx,json}\" --fix && prettier \"{packages|examples}/*/src/**/*.{js,ts,tsx,json}\" --check --write", "release": "bun run build && bun publish", "test": "jest" }, "dependencies": { + "@huggingface/transformers": "^3.3.1", "@mediapipe/tasks-text": "^0.10.20", - "@huggingface/transformers": "^3.2.4", "@sroussey/typescript-graph": "^0.3.14", "@types/better-sqlite3": "^7.6.12", "@types/pg": "^8.11.10", "better-sqlite3": "^9.4.3", - "bun-types": "^1.1.43", + "bun-types": "^1.1.45", "chalk": "^5.4.1", "commander": "=11.1.0", "eventemitter3": "^5.0.1", @@ -40,8 +41,9 @@ "nanoid": "^5.0.9", "pg": "^8.13.1", "postgres": "^3.4.5", + "reflect-metadata": "^0.2.2", "rxjs": "^7.8.1", - "storybook": "^8.4.7", + "storybook": "^8.5.0", "uuid": "^9.0.1" }, "devDependencies": { @@ -50,7 +52,7 @@ "typescript": "^5.7.3" }, "engines": { - "bun": "^1.1.43" + "bun": "^1.1.45" }, "trustedDependencies": [ "better-sqlite3", diff --git a/packages/ai-provider/src/ggml/model/GgmlLocalModel.ts b/packages/ai-provider/src/ggml/model/GgmlLocalModel.ts index e24729f..bf5fd30 100644 --- a/packages/ai-provider/src/ggml/model/GgmlLocalModel.ts +++ b/packages/ai-provider/src/ggml/model/GgmlLocalModel.ts @@ -5,16 +5,4 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { - Model, - ModelOptions, - ModelProcessorEnum, - ModelUseCaseEnum, -} from "../../../../ai/src/model/Model"; - -export class GgmlLocalModel extends Model { - constructor(name: string, useCase: ModelUseCaseEnum[], options?: ModelOptions) { - super(name, useCase, options); - } - readonly type = ModelProcessorEnum.LOCAL_LLAMACPP; -} +export const LOCAL_LLAMACPP = "LOCAL_LLAMACPP"; diff --git a/packages/ai-provider/src/hf-transformers/bindings/all_inmemory.ts b/packages/ai-provider/src/hf-transformers/bindings/all_inmemory.ts deleted file mode 100644 index fec0069..0000000 --- a/packages/ai-provider/src/hf-transformers/bindings/all_inmemory.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { ConcurrencyLimiter, TaskInput, TaskOutput } from "ellmers-core"; -import { getProviderRegistry, ModelProcessorEnum } from "ellmers-ai"; -import { InMemoryJobQueue } from "ellmers-storage/inmemory"; -import { registerHuggingfaceLocalTasks } from "./local_hf"; -import "../model/ONNXModelSamples"; - -export async function registerHuggingfaceLocalTasksInMemory() { - registerHuggingfaceLocalTasks(); - const ProviderRegistry = getProviderRegistry(); - const jobQueue = new InMemoryJobQueue( - "local_hf", - new ConcurrencyLimiter(1, 10), - 10 - ); - ProviderRegistry.registerQueue(ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, jobQueue); - jobQueue.start(); -} diff --git a/packages/ai-provider/src/hf-transformers/bindings/all_sqlite.ts b/packages/ai-provider/src/hf-transformers/bindings/all_sqlite.ts deleted file mode 100644 index 3535d37..0000000 --- a/packages/ai-provider/src/hf-transformers/bindings/all_sqlite.ts +++ /dev/null @@ -1,36 +0,0 @@ -// import { registerHuggingfaceLocalTasks } from "./local_hf"; -// import { registerMediaPipeTfJsLocalTasks } from "./local_mp"; -// import { getProviderRegistry } from "../provider/ProviderRegistry"; -// import { ModelProcessorEnum } from "../model/Model"; -// import { ConcurrencyLimiter } from "../job/ConcurrencyLimiter"; -// import { SqliteJobQueue } from "../job/SqliteJobQueue"; -// import { getDatabase } from "../util/db_sqlite"; -// import { TaskInput, TaskOutput } from "../task/base/Task"; -// import { mkdirSync } from "node:fs"; - -// mkdirSync("./.cache", { recursive: true }); -// const db = getDatabase("./.cache/local.db"); - -// export async function registerHuggingfaceLocalTasksSqlite() { -// registerHuggingfaceLocalTasks(); -// const ProviderRegistry = getProviderRegistry(); -// const jobQueue = new SqliteJobQueue( -// db, -// "local_hf", -// new ConcurrencyLimiter(1, 10) -// ); -// ProviderRegistry.registerQueue(ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, jobQueue); -// jobQueue.start(); -// } - -// export async function registerMediaPipeTfJsLocalSqlite() { -// registerMediaPipeTfJsLocalTasks(); -// const ProviderRegistry = getProviderRegistry(); -// const jobQueue = new SqliteJobQueue( -// db, -// "local_media_pipe", -// new ConcurrencyLimiter(1, 10) -// ); -// ProviderRegistry.registerQueue(ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, jobQueue); -// jobQueue.start(); -// } diff --git a/packages/ai-provider/src/hf-transformers/bindings/local_hf.ts b/packages/ai-provider/src/hf-transformers/bindings/registerTasks.ts similarity index 78% rename from packages/ai-provider/src/hf-transformers/bindings/local_hf.ts rename to packages/ai-provider/src/hf-transformers/bindings/registerTasks.ts index 9054c7a..77571c5 100644 --- a/packages/ai-provider/src/hf-transformers/bindings/local_hf.ts +++ b/packages/ai-provider/src/hf-transformers/bindings/registerTasks.ts @@ -1,5 +1,4 @@ import { - ModelProcessorEnum, getProviderRegistry, DownloadModelTask, TextEmbeddingTask, @@ -18,49 +17,50 @@ import { HuggingFaceLocal_TextSummaryRun, HuggingFaceLocal_TextTranslationRun, } from "../provider/HuggingFaceLocal_TaskRun"; +import { LOCAL_ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel"; export async function registerHuggingfaceLocalTasks() { const ProviderRegistry = getProviderRegistry(); ProviderRegistry.registerRunFn( DownloadModelTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_DownloadRun ); ProviderRegistry.registerRunFn( TextEmbeddingTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_EmbeddingRun ); ProviderRegistry.registerRunFn( TextGenerationTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_TextGenerationRun ); ProviderRegistry.registerRunFn( TextTranslationTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_TextTranslationRun ); ProviderRegistry.registerRunFn( TextRewriterTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_TextRewriterRun ); ProviderRegistry.registerRunFn( TextSummaryTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_TextSummaryRun ); ProviderRegistry.registerRunFn( TextQuestionAnswerTask.type, - ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + LOCAL_ONNX_TRANSFORMERJS, HuggingFaceLocal_TextQuestionAnswerRun ); } diff --git a/packages/ai-provider/src/hf-transformers/browser.ts b/packages/ai-provider/src/hf-transformers/browser.ts index b9657be..caa4d0b 100644 --- a/packages/ai-provider/src/hf-transformers/browser.ts +++ b/packages/ai-provider/src/hf-transformers/browser.ts @@ -7,5 +7,4 @@ export * from "./provider/HuggingFaceLocal_TaskRun"; export * from "./model/ONNXTransformerJsModel"; -export * from "./bindings/local_hf"; -export * from "./bindings/all_inmemory"; +export * from "./bindings/registerTasks"; diff --git a/packages/ai-provider/src/hf-transformers/model/ONNXModelSamples.ts b/packages/ai-provider/src/hf-transformers/model/ONNXModelSamples.ts deleted file mode 100644 index f57ccf0..0000000 --- a/packages/ai-provider/src/hf-transformers/model/ONNXModelSamples.ts +++ /dev/null @@ -1,120 +0,0 @@ -import { DATA_TYPES, ONNXTransformerJsModel } from "./ONNXTransformerJsModel"; -import { ModelUseCaseEnum } from "ellmers-ai"; - -export const supabaseGteSmall = new ONNXTransformerJsModel( - "Supabase/gte-small", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 384 } -); - -export const baaiBgeBaseEnV15 = new ONNXTransformerJsModel( - "Xenova/bge-base-en-v1.5", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 768 } -); - -export const xenovaMiniL6v2 = new ONNXTransformerJsModel( - "Xenova/all-MiniLM-L6-v2", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 384 } -); - -export const whereIsAIUAELargeV1 = new ONNXTransformerJsModel( - "WhereIsAI/UAE-Large-V1", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 1024 } -); - -export const baaiBgeSmallEnV15 = new ONNXTransformerJsModel( - "Xenova/bge-small-en-v1.5", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 384 } -); - -export const xenovaDistilbert = new ONNXTransformerJsModel( - "Xenova/distilbert-base-uncased-distilled-squad", - [ModelUseCaseEnum.TEXT_QUESTION_ANSWERING], - "question-answering" -); - -export const xenovaDistilbertMnli = new ONNXTransformerJsModel( - "Xenova/distilbert-base-uncased-mnli", - [ModelUseCaseEnum.TEXT_CLASSIFICATION], - "zero-shot-classification" -); - -export const stentancetransformerMultiQaMpnetBaseDotV1 = new ONNXTransformerJsModel( - "Xenova/multi-qa-mpnet-base-dot-v1", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 768 } -); - -export const gpt2 = new ONNXTransformerJsModel( - "Xenova/gpt2", - [ModelUseCaseEnum.TEXT_GENERATION], - "text-generation" -); - -export const distillgpt2 = new ONNXTransformerJsModel( - "Xenova/distilgpt2", - [ModelUseCaseEnum.TEXT_GENERATION], - "text-generation" -); - -export const flanT5small = new ONNXTransformerJsModel( - "Xenova/flan-t5-small", - [ModelUseCaseEnum.TEXT_GENERATION], - "text2text-generation" -); - -export const flanT5p786m = new ONNXTransformerJsModel( - "Xenova/LaMini-Flan-T5-783M", - [ModelUseCaseEnum.TEXT_GENERATION, ModelUseCaseEnum.TEXT_REWRITING], - "text2text-generation" -); - -export const text_summarization = new ONNXTransformerJsModel( - "Falconsai/text_summarization", - [ModelUseCaseEnum.TEXT_SUMMARIZATION], - "summarization", - { dtype: DATA_TYPES.fp32 } -); - -// export const distilbartCnn = new ONNXTransformerJsModel( -// "Xenova/distilbart-cnn-6-6", -// [ModelUseCaseEnum.TEXT_SUMMARIZATION], -// "summarization" -// ); - -// export const bartLargeCnn = new ONNXTransformerJsModel( -// "Xenova/bart-large-cnn", -// [ModelUseCaseEnum.TEXT_SUMMARIZATION], -// "summarization" -// ); - -export const nllb200distilled600m = new ONNXTransformerJsModel( - "Xenova/nllb-200-distilled-600M", - [ModelUseCaseEnum.TEXT_TRANSLATION], - "translation", - { languageStyle: "FLORES-200" } -); - -export const m2m100_418M = new ONNXTransformerJsModel( - "Xenova/m2m100_418M", - [ModelUseCaseEnum.TEXT_TRANSLATION], - "translation", - { languageStyle: "ISO-639" } -); - -export const mbartLarge50many2manyMmt = new ONNXTransformerJsModel( - "Xenova/mbart-large-50-many-to-many-mmt", - [ModelUseCaseEnum.TEXT_TRANSLATION], - "translation", - { languageStyle: "ISO-639_ISO-3166-1-alpha-2" } -); diff --git a/packages/ai-provider/src/hf-transformers/model/ONNXTransformerJsModel.ts b/packages/ai-provider/src/hf-transformers/model/ONNXTransformerJsModel.ts index ec3e154..1d21430 100644 --- a/packages/ai-provider/src/hf-transformers/model/ONNXTransformerJsModel.ts +++ b/packages/ai-provider/src/hf-transformers/model/ONNXTransformerJsModel.ts @@ -5,9 +5,9 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { Model, ModelProcessorEnum, ModelUseCaseEnum, type ModelOptions } from "ellmers-ai"; +export const LOCAL_ONNX_TRANSFORMERJS = "LOCAL_ONNX_TRANSFORMERJS"; -export enum DATA_TYPES { +export enum QUANTIZATION_DATA_TYPES { auto = "auto", // Auto-detect based on environment fp32 = "fp32", fp16 = "fp16", @@ -18,24 +18,3 @@ export enum DATA_TYPES { bnb4 = "bnb4", q4f16 = "q4f16", // fp16 model with int4 block weight quantization } - -export interface ONNXTransformerJsModelOptions extends ModelOptions { - dtype?: DATA_TYPES | { [key: string]: DATA_TYPES }; -} - -export class ONNXTransformerJsModel extends Model implements ONNXTransformerJsModelOptions { - constructor( - name: string, - useCase: ModelUseCaseEnum[], - public pipeline: string, - options?: Pick< - ONNXTransformerJsModelOptions, - "dimensions" | "parameters" | "languageStyle" | "dtype" - > - ) { - super(name, useCase, options); - this.dtype = options?.dtype ?? DATA_TYPES.q8; - } - readonly type = ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS; - dtype?: DATA_TYPES | { [key: string]: DATA_TYPES } | undefined; -} diff --git a/packages/ai-provider/src/hf-transformers/provider/HuggingFaceLocal_TaskRun.ts b/packages/ai-provider/src/hf-transformers/provider/HuggingFaceLocal_TaskRun.ts index c0bde04..b0c45ec 100644 --- a/packages/ai-provider/src/hf-transformers/provider/HuggingFaceLocal_TaskRun.ts +++ b/packages/ai-provider/src/hf-transformers/provider/HuggingFaceLocal_TaskRun.ts @@ -21,8 +21,8 @@ import { TextStreamer, } from "@huggingface/transformers"; import { ElVector } from "ellmers-core"; -import { ONNXTransformerJsModel } from "../model/ONNXTransformerJsModel"; -import { findModelByName } from "ellmers-ai"; + +import { getGlobalModelRepository } from "ellmers-ai"; import type { JobQueueLlmTask, DownloadModelTask, @@ -46,7 +46,9 @@ import type { TextTranslationTask, TextTranslationTaskInput, TextTranslationTaskOutput, + Model, } from "ellmers-ai"; +import { QUANTIZATION_DATA_TYPES } from "../browser"; env.cacheDir = "./.cache"; @@ -83,7 +85,7 @@ type StatusFile = StatusFileBookends | StatusFileProgress; type StatusRun = StatusRunReady | StatusRunUpdate | StatusRunComplete; export type CallbackStatus = StatusFile | StatusRun; -const pipelines = new Map(); +const pipelines = new Map(); /** * @@ -94,16 +96,12 @@ const pipelines = new Map(); * @param model * @param options */ -const getPipeline = async ( - task: JobQueueLlmTask, - model: ONNXTransformerJsModel, - options: any = {} -) => { +const getPipeline = async (task: JobQueueLlmTask, model: Model, options: any = {}) => { if (!pipelines.has(model)) { pipelines.set( model, - await pipeline(model.pipeline as PipelineType, model.name, { - dtype: model.dtype || "q8", + await pipeline(model.pipeline as PipelineType, model.url, { + dtype: (model.quantization as QUANTIZATION_DATA_TYPES) || "q8", session_options: options?.session_options, progress_callback: downloadProgressCallback(task), }) @@ -114,10 +112,10 @@ const getPipeline = async ( function downloadProgressCallback(task: JobQueueLlmTask) { return (status: CallbackStatus) => { - const progess = status.status === "progress" ? Math.round(status.progress) : 0; + const progress = status.status === "progress" ? Math.round(status.progress) : 0; if (status.status === "progress") { - task.progress = progess; - task.emit("progress", progess, status.file); + task.progress = progress; + task.emit("progress", progress, status.file); } }; } @@ -142,9 +140,9 @@ export async function HuggingFaceLocal_DownloadRun( task: DownloadModelTask, runInputData: DownloadModelTaskInput ): Promise> { - const model = findModelByName(runInputData.model)! as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; await getPipeline(task, model); - return { model: model.name, dimensions: model.dimensions || 0, normalize: model.normalize }; + return { model: model.name, dimensions: model.nativeDimensions || 0, normalize: model.normalize }; } /** @@ -156,7 +154,7 @@ export async function HuggingFaceLocal_EmbeddingRun( task: TextEmbeddingTask, runInputData: TextEmbeddingTaskInput ): Promise { - const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; const generateEmbedding: FeatureExtractionPipeline = await getPipeline(task, model); const hfVector = await generateEmbedding(runInputData.text, { @@ -164,15 +162,15 @@ export async function HuggingFaceLocal_EmbeddingRun( normalize: model.normalize, }); - if (hfVector.size !== model.dimensions) { + if (hfVector.size !== model.nativeDimensions) { console.warn( - `HuggingFaceLocal Embedding vector length does not match model dimensions v${hfVector.size} != m${model.dimensions}`, + `HuggingFaceLocal Embedding vector length does not match model dimensions v${hfVector.size} != m${model.nativeDimensions}`, runInputData, hfVector ); - throw `HuggingFaceLocal Embedding vector length does not match model dimensions v${hfVector.size} != m${model.dimensions}`; + throw `HuggingFaceLocal Embedding vector length does not match model dimensions v${hfVector.size} != m${model.nativeDimensions}`; } - const vector = new ElVector(hfVector.data, model.normalize); + const vector = new ElVector(hfVector.data, model.normalize ?? true); return { vector }; } @@ -185,7 +183,7 @@ export async function HuggingFaceLocal_TextGenerationRun( task: TextGenerationTask, runInputData: TextGenerationTaskInput ): Promise { - const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; const generateText: TextGenerationPipeline = await getPipeline(task, model); @@ -221,7 +219,7 @@ export async function HuggingFaceLocal_TextTranslationRun( task: TextTranslationTask, runInputData: TextTranslationTaskInput ): Promise> { - const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; const translate: TranslationPipeline = await getPipeline(task, model); @@ -253,7 +251,7 @@ export async function HuggingFaceLocal_TextRewriterRun( task: TextRewriterTask, runInputData: TextRewriterTaskInput ): Promise { - const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; const generateText: TextGenerationPipeline = await getPipeline(task, model); const streamer = new TextStreamer(generateText.tokenizer, { @@ -292,7 +290,7 @@ export async function HuggingFaceLocal_TextSummaryRun( task: TextSummaryTask, runInputData: TextSummaryTaskInput ): Promise { - const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; const generateSummary: SummarizationPipeline = await getPipeline(task, model); const streamer = new TextStreamer(generateSummary.tokenizer, { @@ -321,7 +319,7 @@ export async function HuggingFaceLocal_TextQuestionAnswerRun( task: TextQuestionAnswerTask, runInputData: TextQuestionAnswerTaskInput ): Promise { - const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; const generateAnswer: QuestionAnsweringPipeline = await getPipeline(task, model); const streamer = new TextStreamer(generateAnswer.tokenizer, { diff --git a/packages/ai-provider/src/hf-transformers/test/HFTransformersBinding.test.ts b/packages/ai-provider/src/hf-transformers/test/HFTransformersBinding.test.ts new file mode 100644 index 0000000..23e242a --- /dev/null +++ b/packages/ai-provider/src/hf-transformers/test/HFTransformersBinding.test.ts @@ -0,0 +1,108 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it } from "bun:test"; +import { ConcurrencyLimiter, TaskGraphBuilder, TaskInput, TaskOutput } from "ellmers-core"; +import { + getProviderRegistry, + getGlobalModelRepository, + setGlobalModelRepository, +} from "ellmers-ai"; +import { InMemoryJobQueue, InMemoryModelRepository } from "ellmers-storage/inmemory"; +import { getDatabase, SqliteJobQueue } from "ellmers-storage/bun/sqlite"; +import { registerHuggingfaceLocalTasks } from "../bindings/registerTasks"; +import { sleep } from "bun"; +import { LOCAL_ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel"; + +const HFQUEUE = "local_hf"; + +describe("HFTransformersBinding", () => { + describe("InMemoryJobQueue", () => { + it("Should have an item queued", async () => { + const providerRegistry = getProviderRegistry(); + const jobQueue = new InMemoryJobQueue(HFQUEUE, new ConcurrencyLimiter(1, 10), 10); + providerRegistry.registerQueue(LOCAL_ONNX_TRANSFORMERJS, jobQueue); + + registerHuggingfaceLocalTasks(); + setGlobalModelRepository(new InMemoryModelRepository()); + await getGlobalModelRepository().addModel({ + name: "ONNX Xenova/LaMini-Flan-T5-783M q8", + url: "Xenova/LaMini-Flan-T5-783M", + availableOnBrowser: true, + availableOnServer: true, + provider: LOCAL_ONNX_TRANSFORMERJS, + pipeline: "text2text-generation", + }); + await getGlobalModelRepository().connectTaskToModel( + "TextGenerationTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + await getGlobalModelRepository().connectTaskToModel( + "TextRewritingTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + + const queue = providerRegistry.getQueue(LOCAL_ONNX_TRANSFORMERJS); + expect(queue).toBeDefined(); + expect(queue?.queue).toEqual(HFQUEUE); + + const builder = new TaskGraphBuilder(); + builder.DownloadModel({ + model: "ONNX Xenova/LaMini-Flan-T5-783M q8", + }); + builder.run(); + await sleep(1); + expect(await queue?.size()).toEqual(1); + await queue?.clear(); + }); + }); + + describe("SqliteJobQueue", () => { + it("Should have an item queued", async () => { + registerHuggingfaceLocalTasks(); + setGlobalModelRepository(new InMemoryModelRepository()); + await getGlobalModelRepository().addModel({ + name: "ONNX Xenova/LaMini-Flan-T5-783M q8", + url: "Xenova/LaMini-Flan-T5-783M", + availableOnBrowser: true, + availableOnServer: true, + provider: LOCAL_ONNX_TRANSFORMERJS, + pipeline: "text2text-generation", + }); + await getGlobalModelRepository().connectTaskToModel( + "TextGenerationTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + await getGlobalModelRepository().connectTaskToModel( + "TextRewritingTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + const providerRegistry = getProviderRegistry(); + const jobQueue = new SqliteJobQueue( + getDatabase(":memory:"), + HFQUEUE, + new ConcurrencyLimiter(1, 10), + 10 + ); + jobQueue.ensureTableExists(); + providerRegistry.registerQueue(LOCAL_ONNX_TRANSFORMERJS, jobQueue); + const queue = providerRegistry.getQueue(LOCAL_ONNX_TRANSFORMERJS); + expect(queue).toBeDefined(); + expect(queue?.queue).toEqual(HFQUEUE); + + const builder = new TaskGraphBuilder(); + builder.DownloadModel({ + model: "ONNX Xenova/LaMini-Flan-T5-783M q8", + }); + builder.run(); + await sleep(1); + expect(await queue?.size()).toEqual(1); + builder.reset(); + await queue?.clear(); + }); + }); +}); diff --git a/packages/ai-provider/src/tf-mediapipe/bindings/all_inmemory.ts b/packages/ai-provider/src/tf-mediapipe/bindings/all_inmemory.ts deleted file mode 100644 index c1929fe..0000000 --- a/packages/ai-provider/src/tf-mediapipe/bindings/all_inmemory.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { getProviderRegistry } from "ellmers-ai"; -import { InMemoryJobQueue } from "ellmers-storage/inmemory"; -import { ModelProcessorEnum } from "ellmers-ai"; -import { ConcurrencyLimiter } from "ellmers-core"; -import { TaskInput, TaskOutput } from "ellmers-core"; -import { registerMediaPipeTfJsLocalTasks } from "./local_mp"; -import "../model/MediaPipeModelSamples"; - -export async function registerMediaPipeTfJsLocalInMemory() { - registerMediaPipeTfJsLocalTasks(); - const ProviderRegistry = getProviderRegistry(); - const jobQueue = new InMemoryJobQueue( - "local_media_pipe", - new ConcurrencyLimiter(1, 10), - 10 - ); - ProviderRegistry.registerQueue(ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, jobQueue); - jobQueue.start(); -} diff --git a/packages/ai-provider/src/tf-mediapipe/bindings/all_sqlite.ts b/packages/ai-provider/src/tf-mediapipe/bindings/all_sqlite.ts deleted file mode 100644 index 3535d37..0000000 --- a/packages/ai-provider/src/tf-mediapipe/bindings/all_sqlite.ts +++ /dev/null @@ -1,36 +0,0 @@ -// import { registerHuggingfaceLocalTasks } from "./local_hf"; -// import { registerMediaPipeTfJsLocalTasks } from "./local_mp"; -// import { getProviderRegistry } from "../provider/ProviderRegistry"; -// import { ModelProcessorEnum } from "../model/Model"; -// import { ConcurrencyLimiter } from "../job/ConcurrencyLimiter"; -// import { SqliteJobQueue } from "../job/SqliteJobQueue"; -// import { getDatabase } from "../util/db_sqlite"; -// import { TaskInput, TaskOutput } from "../task/base/Task"; -// import { mkdirSync } from "node:fs"; - -// mkdirSync("./.cache", { recursive: true }); -// const db = getDatabase("./.cache/local.db"); - -// export async function registerHuggingfaceLocalTasksSqlite() { -// registerHuggingfaceLocalTasks(); -// const ProviderRegistry = getProviderRegistry(); -// const jobQueue = new SqliteJobQueue( -// db, -// "local_hf", -// new ConcurrencyLimiter(1, 10) -// ); -// ProviderRegistry.registerQueue(ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, jobQueue); -// jobQueue.start(); -// } - -// export async function registerMediaPipeTfJsLocalSqlite() { -// registerMediaPipeTfJsLocalTasks(); -// const ProviderRegistry = getProviderRegistry(); -// const jobQueue = new SqliteJobQueue( -// db, -// "local_media_pipe", -// new ConcurrencyLimiter(1, 10) -// ); -// ProviderRegistry.registerQueue(ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, jobQueue); -// jobQueue.start(); -// } diff --git a/packages/ai-provider/src/tf-mediapipe/bindings/local_mp.ts b/packages/ai-provider/src/tf-mediapipe/bindings/registerTasks.ts similarity index 75% rename from packages/ai-provider/src/tf-mediapipe/bindings/local_mp.ts rename to packages/ai-provider/src/tf-mediapipe/bindings/registerTasks.ts index 3c2ff3e..06cdf81 100644 --- a/packages/ai-provider/src/tf-mediapipe/bindings/local_mp.ts +++ b/packages/ai-provider/src/tf-mediapipe/bindings/registerTasks.ts @@ -1,22 +1,23 @@ -import { ModelProcessorEnum, getProviderRegistry } from "ellmers-ai"; +import { getProviderRegistry } from "ellmers-ai"; import { DownloadModelTask, TextEmbeddingTask } from "ellmers-ai"; import { MediaPipeTfJsLocal_Download, MediaPipeTfJsLocal_Embedding, } from "../provider/MediaPipeLocalTaskRun"; +import { MEDIA_PIPE_TFJS_MODEL } from "../browser"; export const registerMediaPipeTfJsLocalTasks = () => { const ProviderRegistry = getProviderRegistry(); ProviderRegistry.registerRunFn( DownloadModelTask.type, - ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, + MEDIA_PIPE_TFJS_MODEL, MediaPipeTfJsLocal_Download ); ProviderRegistry.registerRunFn( TextEmbeddingTask.type, - ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, + MEDIA_PIPE_TFJS_MODEL, MediaPipeTfJsLocal_Embedding ); }; diff --git a/packages/ai-provider/src/tf-mediapipe/browser.ts b/packages/ai-provider/src/tf-mediapipe/browser.ts index ad31311..6fc38a1 100644 --- a/packages/ai-provider/src/tf-mediapipe/browser.ts +++ b/packages/ai-provider/src/tf-mediapipe/browser.ts @@ -7,5 +7,4 @@ export * from "./provider/MediaPipeLocalTaskRun"; export * from "./model/MediaPipeModel"; -export * from "./bindings/local_mp"; -export * from "./bindings/all_inmemory"; +export * from "./bindings/registerTasks"; diff --git a/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModel.ts b/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModel.ts index 5cb3c70..5280a7e 100644 --- a/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModel.ts +++ b/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModel.ts @@ -5,16 +5,4 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { Model, ModelOptions, ModelProcessorEnum, ModelUseCaseEnum } from "ellmers-ai"; - -export class MediaPipeTfJsModel extends Model { - constructor( - name: string, - useCase: ModelUseCaseEnum[], - public url: string, - options?: Pick - ) { - super(name, useCase, options); - } - readonly type = ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL; -} +export const MEDIA_PIPE_TFJS_MODEL = "MEDIA_PIPE_TFJS_MODEL"; diff --git a/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModelSamples.ts b/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModelSamples.ts deleted file mode 100644 index 3e6e30a..0000000 --- a/packages/ai-provider/src/tf-mediapipe/model/MediaPipeModelSamples.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { ModelUseCaseEnum } from "ellmers-ai"; -import { MediaPipeTfJsModel } from "./MediaPipeModel"; - -export const universal_sentence_encoder = new MediaPipeTfJsModel( - "Universal Sentence Encoder", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "https://storage.googleapis.com/mediapipe-tasks/text_embedder/universal_sentence_encoder.tflite", - { dimensions: 100, browserOnly: true } -); - -export const kerasSdTextEncoder = new MediaPipeTfJsModel( - "keras-sd/text-encoder-tflite", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "https://huggingface.co/keras-sd/text-encoder-tflite/resolve/main/text_encoder.tflite?download=true", - { dimensions: 100, browserOnly: true } -); diff --git a/packages/ai-provider/src/tf-mediapipe/provider/MediaPipeLocalTaskRun.ts b/packages/ai-provider/src/tf-mediapipe/provider/MediaPipeLocalTaskRun.ts index c23bbf9..b32a307 100644 --- a/packages/ai-provider/src/tf-mediapipe/provider/MediaPipeLocalTaskRun.ts +++ b/packages/ai-provider/src/tf-mediapipe/provider/MediaPipeLocalTaskRun.ts @@ -8,13 +8,12 @@ import { FilesetResolver, TextEmbedder } from "@mediapipe/tasks-text"; import { ElVector } from "ellmers-core"; import { - findModelByName, DownloadModelTask, DownloadModelTaskInput, TextEmbeddingTask, TextEmbeddingTaskInput, + getGlobalModelRepository, } from "ellmers-ai"; -import { MediaPipeTfJsModel } from "../model/MediaPipeModel"; /** * This is a task that downloads and caches a MediaPipe TFJS model. @@ -26,10 +25,13 @@ export async function MediaPipeTfJsLocal_Download( const textFiles = await FilesetResolver.forTextTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" ); - const model = findModelByName(runInputData.model) as MediaPipeTfJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; + if (!model) { + throw `MediaPipeTfJsLocal_Download: Model ${runInputData.model} not found`; + } const results = await TextEmbedder.createFromOptions(textFiles, { baseOptions: { - modelAssetPath: model.url, + modelAssetPath: model.url!, }, quantize: true, }); @@ -48,10 +50,13 @@ export async function MediaPipeTfJsLocal_Embedding( const textFiles = await FilesetResolver.forTextTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" ); - const model = findModelByName(runInputData.model) as MediaPipeTfJsModel; + const model = (await getGlobalModelRepository().findByName(runInputData.model))!; + if (!model) { + throw `MediaPipeTfJsLocal_Embedding: Model ${runInputData.model} not found`; + } const textEmbedder = await TextEmbedder.createFromOptions(textFiles, { baseOptions: { - modelAssetPath: model.url, + modelAssetPath: model.url!, }, quantize: true, }); @@ -59,8 +64,8 @@ export async function MediaPipeTfJsLocal_Embedding( const output = textEmbedder.embed(runInputData.text); const vector = output.embeddings[0].floatEmbedding; - if (vector?.length !== model.dimensions) { - throw `MediaPipeTfJsLocal Embedding vector length does not match model dimensions v${vector?.length} != m${model.dimensions}`; + if (vector?.length !== model.nativeDimensions) { + throw `MediaPipeTfJsLocal Embedding vector length does not match model dimensions v${vector?.length} != m${model.nativeDimensions}`; } return { vector: vector ? new ElVector(vector, true) : null }; } diff --git a/packages/ai-provider/src/tf-mediapipe/test/TfMediaPipeBinding.test.ts b/packages/ai-provider/src/tf-mediapipe/test/TfMediaPipeBinding.test.ts new file mode 100644 index 0000000..3d08631 --- /dev/null +++ b/packages/ai-provider/src/tf-mediapipe/test/TfMediaPipeBinding.test.ts @@ -0,0 +1,102 @@ +// // ******************************************************************************* +// // * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// // * * +// // * Copyright Steven Roussey * +// // * Licensed under the Apache License, Version 2.0 (the "License"); * +// // ******************************************************************************* + +// import { describe, expect, it } from "bun:test"; +// import { ConcurrencyLimiter, TaskGraphBuilder, TaskInput, TaskOutput } from "ellmers-core"; +// import { getGlobalModelRepository, getProviderRegistry, Model } from "ellmers-ai"; +// import { InMemoryJobQueue } from "ellmers-storage/inmemory"; +// import { SqliteJobQueue } from "../../../../storage/dist/bun/sqlite"; +// import { registerMediaPipeTfJsLocalTasks } from "../bindings/registerTasks"; +// import { sleep } from "ellmers-core"; +// import { MEDIA_PIPE_TFJS_MODEL } from "../model/MediaPipeModel"; +// import { getDatabase } from "../../../../storage/src/util/db_sqlite"; + +// const TFQUEUE = "local_tf-mediapipe"; + +// describe("TfMediaPipeBinding", () => { +// describe("InMemoryJobQueue", () => { +// it("should not fail", async () => { +// // register on creation +// const universal_sentence_encoder: Model = { +// name: "Universal Sentence Encoder", +// url: "https://storage.googleapis.com/mediapipe-tasks/text_embedder/universal_sentence_encoder.tflite", +// nativeDimensions: 100, +// availableOnBrowser: true, +// availableOnServer: false, +// provider: MEDIA_PIPE_TFJS_MODEL, +// }; +// getGlobalModelRepository().addModel(universal_sentence_encoder); +// getGlobalModelRepository().connectTaskToModel( +// "TextEmbeddingTask", +// universal_sentence_encoder.name +// ); +// registerMediaPipeTfJsLocalTasks(); +// const ProviderRegistry = getProviderRegistry(); +// const jobQueue = new InMemoryJobQueue( +// TFQUEUE, +// new ConcurrencyLimiter(1, 10), +// 10 +// ); +// ProviderRegistry.registerQueue(MEDIA_PIPE_TFJS_MODEL, jobQueue); +// const queue = ProviderRegistry.getQueue(MEDIA_PIPE_TFJS_MODEL); +// expect(queue).toBeDefined(); +// expect(queue?.queue).toEqual(TFQUEUE); + +// const builder = new TaskGraphBuilder(); +// builder.DownloadModel({ +// model: "Universal Sentence Encoder", +// }); +// builder.run(); +// await sleep(1); +// // we are not in a browser context, so the model should not be registered +// expect(await queue?.size()).toEqual(0); +// builder.reset(); +// await queue?.clear(); +// }); +// }); +// describe("SqliteJobQueue", () => { +// it("should not fail", async () => { +// const universal_sentence_encoder: Model = { +// name: "Universal Sentence Encoder", +// url: "https://storage.googleapis.com/mediapipe-tasks/text_embedder/universal_sentence_encoder.tflite", +// nativeDimensions: 100, +// availableOnBrowser: true, +// availableOnServer: false, +// provider: MEDIA_PIPE_TFJS_MODEL, +// }; +// getGlobalModelRepository().addModel(universal_sentence_encoder); +// getGlobalModelRepository().connectTaskToModel( +// "TextEmbeddingTask", +// universal_sentence_encoder.name +// ); +// registerMediaPipeTfJsLocalTasks(); +// const ProviderRegistry = getProviderRegistry(); +// const jobQueue = new SqliteJobQueue( +// getDatabase(":memory:"), +// TFQUEUE, +// new ConcurrencyLimiter(1, 10), +// 10 +// ); +// jobQueue.ensureTableExists(); +// ProviderRegistry.registerQueue(MEDIA_PIPE_TFJS_MODEL, jobQueue); +// const queue = ProviderRegistry.getQueue(MEDIA_PIPE_TFJS_MODEL); +// expect(queue).toBeDefined(); +// expect(queue?.queue).toEqual(TFQUEUE); + +// const builder = new TaskGraphBuilder(); +// builder.DownloadModel({ +// model: "Universal Sentence Encoder", +// }); +// builder.run(); +// await sleep(1); +// // we are not in a browser context, so the model should not be registered +// expect(await queue?.size()).toEqual(0); +// builder.reset(); +// await queue?.clear(); +// }); +// }); +// }); diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index 85dbe56..f9a3441 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -1,4 +1,5 @@ export * from "./task"; export * from "./model/Model"; -export * from "./model/InMemoryStorage"; +export * from "./model/ModelRegistry"; +export * from "./model/ModelRepository"; export * from "./provider/ProviderRegistry"; diff --git a/packages/ai/src/model/InMemoryStorage.ts b/packages/ai/src/model/InMemoryStorage.ts deleted file mode 100644 index d0f87ed..0000000 --- a/packages/ai/src/model/InMemoryStorage.ts +++ /dev/null @@ -1,21 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Model, ModelUseCaseEnum } from "./Model"; - -export function findModelByName(name: string) { - if (typeof name != "string") return undefined; - return Model.all.find((m) => m.name.toLowerCase() == name.toLowerCase()); -} - -export function findModelByUseCase(usecase: ModelUseCaseEnum) { - return Model.all.filter((m) => m.useCase.includes(usecase)); -} - -export function findAllModels() { - return Model.all.slice(); -} diff --git a/packages/ai/src/model/Model.ts b/packages/ai/src/model/Model.ts index 3baa9be..348d220 100644 --- a/packages/ai/src/model/Model.ts +++ b/packages/ai/src/model/Model.ts @@ -5,58 +5,27 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -export enum ModelProcessorEnum { - LOCAL_ONNX_TRANSFORMERJS = "LOCAL_ONNX_TRANSFORMERJS", - MEDIA_PIPE_TFJS_MODEL = "MEDIA_PIPE_TFJS_MODEL", - LOCAL_MLC = "LOCAL_MLC", - LOCAL_LLAMACPP = "LOCAL_LLAMACPP", - ONLINE_HUGGINGFACE = "ONLINE_HUGGINGFACE", - ONLINE_OPENAI = "ONLINE_OPENAI", - ONLINE_REPLICATE = "ONLINE_REPLICATE", -} +export type ModelPrimaryKey = { + name: string; +}; -export enum ModelUseCaseEnum { - TEXT_EMBEDDING = "TEXT_EMBEDDING", - TEXT_REWRITING = "TEXT_REWRITING", - TEXT_GENERATION = "TEXT_GENERATION", - TEXT_SUMMARIZATION = "TEXT_SUMMARIZATION", - TEXT_QUESTION_ANSWERING = "TEXT_QUESTION_ANSWERING", - TEXT_CLASSIFICATION = "TEXT_CLASSIFICATION", - TEXT_TRANSLATION = "TEXT_TRANSLATION", -} +export const ModelPrimaryKeySchema = { + name: "string", +} as const; -const runningOnServer = typeof (globalThis as any).window === "undefined"; - -export interface ModelOptions { - nativeDimensions?: number; // Matryoshka Representation Learning (MRL) -- can truncate embedding dimensions from native number - dimensions?: number; +export type ModelDetail = { + url: string; + provider: string; + availableOnBrowser: boolean; + availableOnServer: boolean; + quantization?: string; + pipeline?: string; + normalize?: boolean; + nativeDimensions?: number; + usingDimensions?: number; contextWindow?: number; - extras?: Record; - browserOnly?: boolean; - parameters?: number; + numParameters?: number; languageStyle?: string; -} - -export abstract class Model implements ModelOptions { - public static readonly all: ModelList = []; - public dimensions?: number; - public nativeDimensions?: number; - public contextWindow?: number; - public normalize: boolean = true; - public browserOnly: boolean = false; - public extras: Record = {}; - public parameters?: number; - constructor( - public name: string, - public useCase: ModelUseCaseEnum[] = [], - options?: ModelOptions - ) { - Object.assign(this, options); - if (!(runningOnServer && this.browserOnly)) { - Model.all.push(this); - } - } - abstract readonly type: ModelProcessorEnum; -} +}; -export type ModelList = Model[]; +export type Model = ModelPrimaryKey & ModelDetail; diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts new file mode 100644 index 0000000..1200da2 --- /dev/null +++ b/packages/ai/src/model/ModelRegistry.ts @@ -0,0 +1,48 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { Model } from "./Model"; +import { ModelRepository, Task2ModelPrimaryKey } from "./ModelRepository"; + +// temporary model registry that is synchronous until we have a proper model repository + +class FallbackModelRegistry { + models: Model[] = []; + task2models: Task2ModelPrimaryKey[] = []; + + public async addModel(model: Model) { + if (this.models.some((m) => m.name === model.name)) { + this.models = this.models.filter((m) => m.name !== model.name); + } + + this.models.push(model); + } + public async findModelsByTask(task: string) { + return this.task2models + .filter((t2m) => t2m.task === task) + .map((t2m) => this.models.find((m) => m.name === t2m.model)) + .filter((m) => m !== undefined); + } + public async findTasksByModel(name: string) { + return this.task2models.filter((t2m) => t2m.model === name).map((t2m) => t2m.task); + } + public async findByName(name: string) { + return this.models.find((m) => m.name === name); + } + public async connectTaskToModel(task: string, model: string) { + this.task2models.push({ task, model }); + } +} + +let modelRegistry: FallbackModelRegistry | ModelRepository; +export function getGlobalModelRepository() { + if (!modelRegistry) modelRegistry = new FallbackModelRegistry(); + return modelRegistry; +} +export function setGlobalModelRepository(pr: FallbackModelRegistry | ModelRepository) { + modelRegistry = pr; +} diff --git a/packages/ai/src/model/ModelRepository.ts b/packages/ai/src/model/ModelRepository.ts new file mode 100644 index 0000000..531238f --- /dev/null +++ b/packages/ai/src/model/ModelRepository.ts @@ -0,0 +1,176 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import EventEmitter from "eventemitter3"; +import { DefaultValueType, KVRepository } from "ellmers-core"; +import { Model, ModelPrimaryKey, ModelPrimaryKeySchema } from "./Model"; + +/** + * Events that can be emitted by the ModelRepository + * @typedef {string} ModelEvents + */ +export type ModelEvents = + | "model_added" + | "model_removed" + | "task_model_connected" + | "task_model_disconnected" + | "model_updated"; + +/** + * Represents the primary key structure for mapping tasks to models + * @interface Task2ModelPrimaryKey + */ +export type Task2ModelPrimaryKey = { + /** The task identifier */ + task: string; + /** The model name identifier */ + model: string; +}; + +export const Task2ModelPrimaryKeySchema = { + task: "string", + model: "string", +} as const; + +/** + * Schema definition for Task2ModelDetail + */ +export type Task2ModelDetail = { + /** Optional details about the task-model relationship */ + details: string | null; +}; + +export const Task2ModelDetailSchema = { + details: "string", +} as const; + +/** + * Abstract base class for managing AI models and their relationships with tasks. + * Provides functionality for storing, retrieving, and managing the lifecycle of models + * and their associations with specific tasks. + */ +export abstract class ModelRepository { + /** Repository type identifier */ + public type = "ModelRepository"; + + /** + * Repository for storing and managing Model instances + */ + abstract modelKvRepository: KVRepository< + ModelPrimaryKey, + DefaultValueType, + typeof ModelPrimaryKeySchema + >; + + /** + * Repository for managing relationships between tasks and models + */ + abstract task2ModelKvRepository: KVRepository< + Task2ModelPrimaryKey, + Task2ModelDetail, + typeof Task2ModelPrimaryKeySchema, + typeof Task2ModelDetailSchema + >; + + /** Event emitter for repository events */ + private events = new EventEmitter(); + + /** + * Registers an event listener for the specified event + * @param name - The event name to listen for + * @param fn - The callback function to execute when the event occurs + */ + on(name: ModelEvents, fn: (...args: any[]) => void) { + this.events.on.call(this.events, name, fn); + } + + /** + * Removes an event listener for the specified event + * @param name - The event name to stop listening for + * @param fn - The callback function to remove + */ + off(name: ModelEvents, fn: (...args: any[]) => void) { + this.events.off.call(this.events, name, fn); + } + + /** + * Emits an event with the specified name and arguments + * @param name - The event name to emit + * @param args - Arguments to pass to the event listeners + */ + emit(name: ModelEvents, ...args: any[]) { + this.events.emit.call(this.events, name, ...args); + } + + /** + * Adds a new model to the repository + * @param model - The model instance to add + */ + async addModel(model: Model) { + await this.modelKvRepository.put({ name: model.name }, { "kv-value": JSON.stringify(model) }); + this.emit("model_added", model); + } + + /** + * Finds all models associated with a specific task + * @param task - The task identifier to search for + * @returns Promise resolving to an array of associated models, or undefined if none found + */ + async findModelsByTask(task: string) { + if (typeof task != "string") return undefined; + const junctions = await this.task2ModelKvRepository.search({ task }); + if (!junctions || junctions.length === 0) return undefined; + const models = []; + for (const junction of junctions) { + const model = await this.modelKvRepository.getKeyValue({ name: junction.model }); + if (model) models.push(JSON.parse(model["kv-value"])); + } + return models; + } + + /** + * Finds all tasks associated with a specific model + * @param model - The model identifier to search for + * @returns Promise resolving to an array of associated tasks, or undefined if none found + */ + async findTasksByModel(model: string) { + if (typeof model != "string") return undefined; + const junctions = await this.task2ModelKvRepository.search({ model }); + if (!junctions || junctions.length === 0) return undefined; + return junctions.map((junction) => junction.task); + } + + /** + * Creates an association between a task and a model + * @param task - The task identifier + * @param model - The model to associate with the task + */ + async connectTaskToModel(task: string, model: string) { + await this.task2ModelKvRepository.putKeyValue({ task, model }, { details: null }); + this.emit("task_model_connected", task, model); + } + + /** + * Retrieves a model by its name + * @param name - The name of the model to find + * @returns Promise resolving to the found model or undefined if not found + */ + async findByName(name: string) { + if (typeof name != "string") return undefined; + const modelstr = await this.modelKvRepository.getKeyValue({ name }); + if (!modelstr) return undefined; + return JSON.parse(modelstr["kv-value"]); + } + + /** + * Gets the total number of models in the repository + * @returns Promise resolving to the number of stored models + */ + async size(): Promise { + return await this.modelKvRepository.size(); + } +} diff --git a/packages/ai/src/provider/ProviderRegistry.ts b/packages/ai/src/provider/ProviderRegistry.ts index 3347c84..feb7586 100644 --- a/packages/ai/src/provider/ProviderRegistry.ts +++ b/packages/ai/src/provider/ProviderRegistry.ts @@ -5,7 +5,6 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import type { ModelProcessorEnum } from "../model/Model"; import { Job, type JobQueue, @@ -15,11 +14,18 @@ import { JobQueueTask, } from "ellmers-core"; +/** + * Enum to define the types of job queue execution + */ export enum JobQueueRunType { local = "local", api = "api", } +/** + * Extends the base Job class to provide custom execution functionality + * through a provided function. + */ class ProviderJob extends Job { constructor( details: JobConstructorDetails & { @@ -36,19 +42,36 @@ class ProviderJob extends Job { } } +/** + * Registry that manages provider-specific task execution functions and job queues. + * Handles the registration, retrieval, and execution of task processing functions + * for different model providers and task types. + */ export class ProviderRegistry { + // Registry of task execution functions organized by task type and model provider runFnRegistry: Record Promise>> = {}; + + /** + * Registers a task execution function for a specific task type and model provider + * @param taskType - The type of task (e.g., 'text-generation', 'embedding') + * @param modelProvider - The provider of the model (e.g., 'hf-transformers', 'tf-mediapipe', 'openai', etc) + * @param runFn - The function that executes the task + */ registerRunFn( taskType: string, - modelType: ModelProcessorEnum, + modelProvider: string, runFn: (task: any, runInputData: any) => Promise ) { if (!this.runFnRegistry[taskType]) this.runFnRegistry[taskType] = {}; - this.runFnRegistry[taskType][modelType] = runFn; + this.runFnRegistry[taskType][modelProvider] = runFn; } - jobAsRunFn(runtype: string, modelType: ModelProcessorEnum) { + /** + * Creates a job wrapper around a task execution function + * This allows the task to be queued and executed asynchronously + */ + jobAsRunFn(runtype: string, modelType: string) { const fn = this.runFnRegistry[runtype]?.[modelType]; return async (task: JobQueueTask, input: Input) => { const queue = this.queues.get(modelType)!; @@ -69,16 +92,26 @@ export class ProviderRegistry { }; } - getDirectRunFn(taskType: string, modelType: ModelProcessorEnum) { + /** + * Retrieves the direct execution function for a task type and model + * Bypasses the job queue system for immediate execution + */ + getDirectRunFn(taskType: string, modelType: string) { return this.runFnRegistry[taskType]?.[modelType]; } - queues: Map> = new Map(); - registerQueue(modelType: ModelProcessorEnum, jobQueue: JobQueue) { + // Map of model types to their corresponding job queues + queues: Map> = new Map(); + + /** + * Queue management methods for starting, stopping, and clearing job queues + * These methods help control the execution flow of tasks across all providers + */ + registerQueue(modelType: string, jobQueue: JobQueue) { this.queues.set(modelType, jobQueue); } - getQueue(modelType: ModelProcessorEnum) { + getQueue(modelType: string) { return this.queues.get(modelType); } @@ -101,6 +134,7 @@ export class ProviderRegistry { } } +// Singleton instance management for the ProviderRegistry let providerRegistry: ProviderRegistry; export function getProviderRegistry() { if (!providerRegistry) providerRegistry = new ProviderRegistry(); diff --git a/packages/ai/src/task/DocumentSplitterTask.ts b/packages/ai/src/task/DocumentSplitterTask.ts index 91378a4..cd51b66 100644 --- a/packages/ai/src/task/DocumentSplitterTask.ts +++ b/packages/ai/src/task/DocumentSplitterTask.ts @@ -63,7 +63,7 @@ export class DocumentSplitterTask extends SingleTask { } } - runSyncOnly(): DocumentSplitterTaskOutput { + async runReactive(): Promise { return { texts: this.flattenFragmentsToTexts(this.runInputData.file) }; } } diff --git a/packages/ai/src/task/DownloadModelTask.ts b/packages/ai/src/task/DownloadModelTask.ts index a21be84..5a716e2 100644 --- a/packages/ai/src/task/DownloadModelTask.ts +++ b/packages/ai/src/task/DownloadModelTask.ts @@ -16,8 +16,7 @@ import { TaskOutput, JobQueueTaskConfig, } from "ellmers-core"; -import { ModelUseCaseEnum } from "../model/Model"; -import { findModelByName } from "../model/InMemoryStorage"; +import { getGlobalModelRepository } from "../model/ModelRegistry"; import { JobQueueLlmTask } from "./base/JobQueueLlmTask"; export type DownloadModelTaskInput = CreateMappedType; @@ -80,29 +79,29 @@ export class DownloadModelTask extends JobQueueLlmTask { constructor(config: JobQueueTaskConfig & { input?: DownloadModelTaskInput } = {}) { super(config); } - runSyncOnly(): TaskOutput { - const model = findModelByName(this.runInputData.model); + async runReactive(): Promise { + const model = await getGlobalModelRepository().findByName(this.runInputData.model); if (model) { - model.useCase.forEach((useCase) => { - // @ts-expect-error -- we really can use this an index - this.runOutputData[String(useCase).toLowerCase()] = model.name; + const tasks = (await getGlobalModelRepository().findTasksByModel(model.name)) || []; + tasks.forEach((task) => { + // this.runOutputData[String(task).toLowerCase()] = model.name; }); this.runOutputData.model = model.name; - this.runOutputData.dimensions = model.dimensions!; - this.runOutputData.normalize = model.normalize; - if (model.useCase.includes(ModelUseCaseEnum.TEXT_EMBEDDING)) { + this.runOutputData.dimensions = model.usingDimensions!; + this.runOutputData.normalize = model.normalize!; + if (tasks.includes("TextEmbeddingTask")) { this.runOutputData.text_embedding_model = model.name; } - if (model.useCase.includes(ModelUseCaseEnum.TEXT_GENERATION)) { + if (tasks.includes("TextGenerationTask")) { this.runOutputData.text_generation_model = model.name; } - if (model.useCase.includes(ModelUseCaseEnum.TEXT_SUMMARIZATION)) { + if (tasks.includes("TextSummaryTask")) { this.runOutputData.text_summarization_model = model.name; } - if (model.useCase.includes(ModelUseCaseEnum.TEXT_QUESTION_ANSWERING)) { + if (tasks.includes("TextQuestionAnswerTask")) { this.runOutputData.text_question_answering_model = model.name; } - if (model.useCase.includes(ModelUseCaseEnum.TEXT_TRANSLATION)) { + if (tasks.includes("TextTranslationTask")) { this.runOutputData.text_translation_model = model.name; } } diff --git a/packages/ai/src/task/SimilarityTask.ts b/packages/ai/src/task/SimilarityTask.ts index 747556f..23bcb13 100644 --- a/packages/ai/src/task/SimilarityTask.ts +++ b/packages/ai/src/task/SimilarityTask.ts @@ -115,7 +115,7 @@ export class SimilarityTask extends SingleTask { } } - runSyncOnly() { + async runReactive() { const query = this.runInputData.query as ElVector; let similarities = []; const fns = { cosine_similarity }; diff --git a/packages/ai/src/task/TextEmbeddingTask.ts b/packages/ai/src/task/TextEmbeddingTask.ts index c89a96d..c71fb5d 100644 --- a/packages/ai/src/task/TextEmbeddingTask.ts +++ b/packages/ai/src/task/TextEmbeddingTask.ts @@ -22,7 +22,11 @@ export type TextEmbeddingTaskInput = CreateMappedType; /** - * This is a task that generates an embedding for a single piece of text + * A task that generates vector embeddings for text using a specified embedding model. + * Embeddings are numerical representations of text that capture semantic meaning, + * useful for similarity comparisons and semantic search. + * + * @extends JobQueueLlmTask */ export class TextEmbeddingTask extends JobQueueLlmTask { public static inputs = [ @@ -41,6 +45,7 @@ export class TextEmbeddingTask extends JobQueueLlmTask { constructor(config: JobQueueTaskConfig & { input?: TextEmbeddingTaskInput } = {}) { super(config); } + declare runInputData: TextEmbeddingTaskInput; declare runOutputData: TextEmbeddingTaskOutput; declare defaults: Partial; @@ -52,11 +57,20 @@ TaskRegistry.registerTask(TextEmbeddingTask); type TextEmbeddingCompoundTaskOutput = ConvertAllToArrays; type TextEmbeddingCompoundTaskInput = ConvertSomeToOptionalArray; +/** + * A compound task factory that creates a task capable of processing multiple texts + * and generating embeddings in parallel + */ export const TextEmbeddingCompoundTask = arrayTaskFactory< TextEmbeddingCompoundTaskInput, TextEmbeddingCompoundTaskOutput >(TextEmbeddingTask, ["model", "text"]); +/** + * Convenience function to create and run a TextEmbeddingCompoundTask + * @param {TextEmbeddingCompoundTaskInput} input - Input containing text(s) and model(s) for embedding + * @returns {Promise} Promise resolving to the generated embeddings + */ export const TextEmbedding = (input: TextEmbeddingCompoundTaskInput) => { return new TextEmbeddingCompoundTask({ input }).run(); }; diff --git a/packages/ai/src/task/TextGenerationTask.ts b/packages/ai/src/task/TextGenerationTask.ts index aac4f47..6e801dd 100644 --- a/packages/ai/src/task/TextGenerationTask.ts +++ b/packages/ai/src/task/TextGenerationTask.ts @@ -55,11 +55,22 @@ type TextGenerationCompoundTaskInput = ConvertSomeToOptionalArray< TextGenerationTaskInput, "model" | "prompt" >; + +/** + * Factory-generated task class for handling batch text generation operations. + * Created using arrayTaskFactory to support processing multiple prompts/models simultaneously. + */ export const TextGenerationCompoundTask = arrayTaskFactory< TextGenerationCompoundTaskInput, TextGenerationCompoundOutput >(TextGenerationTask, ["model", "prompt"]); +/** + * Convenience function to run text generation tasks. + * Creates and executes a TextGenerationCompoundTask with the provided input. + * @param input The input parameters for text generation (prompts and models) + * @returns Promise resolving to the generated text output(s) + */ export const TextGeneration = (input: TextGenerationCompoundTaskInput) => { return new TextGenerationCompoundTask({ input }).run(); }; diff --git a/packages/ai/src/task/base/JobQueueLlmTask.ts b/packages/ai/src/task/base/JobQueueLlmTask.ts index 3e52088..f2e4070 100644 --- a/packages/ai/src/task/base/JobQueueLlmTask.ts +++ b/packages/ai/src/task/base/JobQueueLlmTask.ts @@ -10,8 +10,8 @@ */ import { JobQueueTask, JobQueueTaskConfig, type TaskOutput } from "ellmers-core"; -import { findModelByName } from "../../model/InMemoryStorage"; import { getProviderRegistry } from "../../provider/ProviderRegistry"; +import { getGlobalModelRepository } from "../../model/ModelRegistry"; export class JobQueueLlmTask extends JobQueueTask { static readonly type: string = "JobQueueLlmTask"; @@ -35,9 +35,12 @@ export class JobQueueLlmTask extends JobQueueTask { const ProviderRegistry = getProviderRegistry(); const modelname = this.runInputData["model"]; if (!modelname) throw new Error("JobQueueTaskTask: No model name found"); - const model = findModelByName(modelname); - if (!model) throw new Error("JobQueueTaskTask: No model found"); - const runFn = ProviderRegistry.jobAsRunFn(runtype, model.type); + const model = await getGlobalModelRepository().findByName(modelname); + + if (!model) { + throw new Error(`JobQueueTaskTask: No model ${modelname} found ${modelname}`); + } + const runFn = ProviderRegistry.jobAsRunFn(runtype, model.provider); if (!runFn) throw new Error("JobQueueTaskTask: No run function found for " + runtype); results = await runFn(this, this.runInputData); } catch (err) { @@ -47,10 +50,10 @@ export class JobQueueLlmTask extends JobQueueTask { } this.emit("complete"); this.runOutputData = results ?? {}; - this.runOutputData = this.runSyncOnly(); + this.runOutputData = await this.runReactive(); return this.runOutputData; } - runSyncOnly(): TaskOutput { + async runReactive(): Promise { return this.runOutputData ?? {}; } } diff --git a/packages/core/src/job/base/JobQueue.ts b/packages/core/src/job/base/JobQueue.ts index 930b78c..62949b2 100644 --- a/packages/core/src/job/base/JobQueue.ts +++ b/packages/core/src/job/base/JobQueue.ts @@ -124,11 +124,13 @@ export abstract class JobQueue { this.running = true; this.events.emit("queue_start", this.queue); this.processJobs(); + return this; } async stop() { this.running = false; this.events.emit("queue_stop", this.queue); + return this; } async restart() { @@ -138,5 +140,6 @@ export abstract class JobQueue { this.waits.forEach(({ reject }) => reject("Queue Restarted")); this.waits.clear(); await this.start(); + return this; } } diff --git a/packages/core/src/source/MasterDocument.ts b/packages/core/src/source/MasterDocument.ts index 00bc2c0..ef98ab5 100644 --- a/packages/core/src/source/MasterDocument.ts +++ b/packages/core/src/source/MasterDocument.ts @@ -8,6 +8,19 @@ import { Document, DocumentMetadata, TextFragment } from "./Document"; import { DocumentConverter } from "./DocumentConverter"; +/** + * MasterDocument represents a container for managing multiple versions/variants of a document. + * It maintains the original document and its transformed variants for different use cases. + * + * Key features: + * - Stores original document and metadata + * - Maintains a master version and variants + * - Automatically creates paragraph-split variant + * + * The paragraph variant splits text fragments by newlines while preserving other fragment types, + * which is useful for more granular text processing. + */ + export class MasterDocument { public metadata: DocumentMetadata; public original: DocumentConverter; diff --git a/packages/core/src/storage/base/KVRepository.ts b/packages/core/src/storage/base/KVRepository.ts index 3bc1ea6..b59fef1 100644 --- a/packages/core/src/storage/base/KVRepository.ts +++ b/packages/core/src/storage/base/KVRepository.ts @@ -6,15 +6,48 @@ // ******************************************************************************* import EventEmitter from "eventemitter3"; +import { makeFingerprint } from "../../util/Misc"; -export type KVEvents = "put" | "get" | "clear"; +/** + * Type definitions for key-value repository events + */ +export type KVEvents = "put" | "get" | "search" | "delete" | "clearall"; -export type DiscriminatorSchema = Record; +/** + * Schema definitions for primary keys and values + */ +export type BasicKeyType = string | number | bigint; +export type BasicValueType = string | number | bigint | boolean | null; +export type BasePrimaryKeySchema = Record; +export type BaseValueSchema = Record; +/** + * Default schema types for simple string key-value pairs + */ +export type DefaultPrimaryKeyType = { "kv-key": string }; +export const DefaultPrimaryKeySchema: BasePrimaryKeySchema = { "kv-key": "string" } as const; + +export type DefaultValueType = { "kv-value": string }; +export const DefaultValueSchema: BaseValueSchema = { "kv-value": "string" } as const; + +/** + * Abstract base class for key-value storage repositories. + * Provides a flexible interface for storing and retrieving data with typed + * keys and values, and supports comound keys and partial key lookup. + * Has a basic event emitter for listening to repository events. + * + * @typeParam Key - Type for the primary key structure + * @typeParam Value - Type for the value structure + * @typeParam PrimaryKeySchema - Schema definition for the primary key + * @typeParam ValueSchema - Schema definition for the value + * @typeParam Combined - Combined type of Key & Value + */ export abstract class KVRepository< - Key, - Value, - Discriminators extends DiscriminatorSchema = DiscriminatorSchema, + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Record = Key & Value > { // KV repository event emitter private events = new EventEmitter(); @@ -36,46 +69,180 @@ export abstract class KVRepository< ) { this.events.emit.call(this.events, name, ...args); } + /** + * Indexes for primary key and value columns which are _only_ populated if the + * key or value schema has a single field. + */ + protected primaryKeyIndex: string | undefined = undefined; + protected valueIndex: string | undefined = undefined; + /** + * Creates a new KVRepository instance + * @param primaryKeySchema - Schema defining the structure of primary keys + * @param valueSchema - Schema defining the structure of values + * @param searchable - Array of columns to make searchable + */ + constructor( + protected primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + protected valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + protected searchable: Array = [] + ) { + this.primaryKeySchema = primaryKeySchema; + this.valueSchema = valueSchema; + if (this.primaryKeyColumns().length === 1) { + this.primaryKeyIndex = this.primaryKeyColumns()[0] as string; + } + if (this.valueColumns().length === 1) { + this.valueIndex = this.valueColumns()[0] as string; + } + const firstKeyPart = this.primaryKeyColumns()[0] as keyof Combined; + if (!searchable.includes(firstKeyPart)) { + searchable.push(firstKeyPart); + } + this.searchable = searchable; - // discriminators for KV repository store - protected discriminatorsSchema: Discriminators = {} as Discriminators; + // make sure all the searchable columns are in the primary key schema or value schema + for (const column of this.searchable) { + if (!(column in this.primaryKeySchema) && !(column in this.valueSchema)) { + throw new Error( + `Searchable column ${column as string} is not in the primary key schema or value schema` + ); + } + } + } - // Abstract methods for KV repository store - abstract put(key: Key, value: Value): Promise; - abstract get(key: Key): Promise; - abstract clear(): Promise; + /** + * Core abstract methods that must be implemented by concrete repositories + */ + abstract putKeyValue(key: Key, value: Value): Promise; + abstract getKeyValue(key: Key): Promise; + abstract deleteKeyValue(key: Key | Combined): Promise; + abstract deleteAll(): Promise; abstract size(): Promise; - // Discriminator helper methods - protected primaryKeyColumnList(): string { - return this.primaryKeyColumns().join(", "); + /** + * Stores a key-value pair in the repository. + * Automatically converts simple types to structured format if using default schema. + * + * @param key - Primary key (can be simple type if using a single property key like default schema) + * @param value - Value to store (can be simple type if using a single property value like default schema) + */ + public put(key: BasicKeyType | Key, value: Value | BasicValueType): Promise { + if (typeof key !== "object" && this.primaryKeyIndex) { + key = { [this.primaryKeyIndex]: key } as Key; + if (typeof value !== "object" && this.valueIndex) { + value = { [this.valueIndex]: value } as Value; + } + } + return this.putKeyValue(key as Key, value as Value); + } + + /** + * Retrieves a value by its key. + * For default schema, returns the simple value type directly. + * + * @param key - Primary key to look up (can be simple type if using a single property key like default schema) + * @returns The stored value or undefined if not found + */ + public async get(key: BasicKeyType | Key): Promise { + /* if the key is not an object, and there is a primary key index, then we need to convert the key to an object + * this allows us to do simple "key" / "value" situations without having to use objects like a compound key + * would require */ + const isKeySimple = !!(typeof key !== "object" && this.primaryKeyIndex); + if (isKeySimple) { + key = { [this.primaryKeyIndex!]: key } as Key; + } + const value = await this.getKeyValue(key as Key); + if (typeof value !== "object") return value; + if (isKeySimple && this.valueIndex) { + /* if it looks like we are doing a simple "key" / "value" situation, then we need to return + the value as a simple type as well. */ + return value[this.valueIndex] as BasicValueType; + } + return value as Value; + } + + /** + * Abstract method to be implemented by concrete repositories to search for key-value pairs + * based on a partial key or value. + * + * @param key - Partial key or value to search for + * @returns Promise resolving to an array of combined key-value objects or undefined if not found + */ + public abstract search(key: Partial): Promise; + + /** + * Retrieves both key and value as a combined object. + * + * @param key - Primary key to look up (can be simple type if using a single property key like default schema) + * @returns Combined key-value object or undefined if not found + */ + public async getCombined(key: Key): Promise { + const value = await this.getKeyValue(key); + if (typeof value !== "object") return undefined; + return Object.assign({}, key, value) as Combined; } - protected primaryKeyColumns(): string[] { - return Object.keys(this.discriminatorsSchema).concat("key"); + /** + * Deletes a key-value pair from the repository. + * Automatically converts simple types to structured format if using default schema. + * + * @param key - Primary key to delete (can be simple type if using a single property key like default schema) + */ + public delete(key: Key | BasicKeyType): Promise { + if (typeof key !== "object" && this.primaryKeyIndex) { + key = { [this.primaryKeyIndex]: key } as Key; + } + return this.deleteKeyValue(key as Key); } - protected extractDiscriminators(keySimpleOrObject: any): { - discriminators: Record; - key: any; - } { - const discriminatorKeys = Object.keys(this.discriminatorsSchema); - const discriminators: DiscriminatorSchema = {}; - if (typeof keySimpleOrObject !== "object") { - return { discriminators, key: keySimpleOrObject }; + protected primaryKeyColumns(): Array { + return Object.keys(this.primaryKeySchema); + } + + protected valueColumns(): Array { + return Object.keys(this.valueSchema); + } + + /** + * Utility method to separate a combined object into its key and value components + * based on the defined schemas. + * + * @param obj - Combined key-value object + * @returns Separated key and value objects + */ + protected separateKeyValueFromCombined(obj: Combined): { value: Value; key: Key } { + if (obj === null) { + console.warn("Key is null"); + return { value: {} as Value, key: {} as Key }; + } + if (typeof obj !== "object") { + console.warn("Object is not an object"); + return { value: {} as Value, key: {} as Key }; + } + const primaryKeyNames = this.primaryKeyColumns(); + const valueNames = this.valueColumns(); + const value: Partial = {}; + const key: Partial = {}; + for (const k of primaryKeyNames) { + key[k] = obj[k as keyof Combined]; } - let keyClone: any = { ...keySimpleOrObject }; - if (discriminatorKeys.length > 0) { - discriminatorKeys.forEach((k) => { - if (Object.prototype.hasOwnProperty.call(keyClone, k)) { - discriminators[k] = keyClone[k]; - delete keyClone[k]; - } - }); + for (const k of valueNames) { + value[k] = obj[k as keyof Combined]; } - if (Object.keys(keyClone).length === 1) { - keyClone = keyClone[Object.keys(keyClone)[0]]; + + return { value: value as Value, key: key as Key }; + } + + /** + * Generates a consistent string identifier for a given key. + * + * @param key - Primary key to convert + * @returns Promise resolving to a string fingerprint of the key + */ + protected async getKeyAsIdString(key: Key | BasicKeyType): Promise { + if (this.primaryKeyIndex && typeof key === "object") { + key = key[this.primaryKeyIndex]; } - return { discriminators, key: keyClone }; + return await makeFingerprint(key); } } diff --git a/packages/core/src/storage/taskgraph/TaskGraphRepository.ts b/packages/core/src/storage/taskgraph/TaskGraphRepository.ts index bcf7b6c..6a901bb 100644 --- a/packages/core/src/storage/taskgraph/TaskGraphRepository.ts +++ b/packages/core/src/storage/taskgraph/TaskGraphRepository.ts @@ -11,22 +11,53 @@ import { KVRepository } from "../base/KVRepository"; import { CompoundTask } from "../../task/base/Task"; import { TaskRegistry } from "../../task/base/TaskRegistry"; +/** + * Events that can be emitted by the TaskGraphRepository + */ export type TaskGraphEvents = "graph_saved" | "graph_retrieved" | "graph_cleared"; +/** + * Abstract repository class for managing task graphs persistence and retrieval. + * Provides functionality to save, load, and manipulate task graphs with their associated tasks and data flows. + */ export abstract class TaskGraphRepository { public type = "TaskGraphRepository"; - abstract kvRepository: KVRepository; + abstract kvRepository: KVRepository; private events = new EventEmitter(); + + /** + * Registers an event listener for the specified event + * @param name The event name to listen for + * @param fn The callback function to execute when the event occurs + */ on(name: TaskGraphEvents, fn: (...args: any[]) => void) { this.events.on.call(this.events, name, fn); } + + /** + * Removes an event listener for the specified event + * @param name The event name to stop listening for + * @param fn The callback function to remove + */ off(name: TaskGraphEvents, fn: (...args: any[]) => void) { this.events.off.call(this.events, name, fn); } + + /** + * Emits an event with the given arguments + * @param name The event name to emit + * @param args Additional arguments to pass to the event listeners + */ emit(name: TaskGraphEvents, ...args: any[]) { this.events.emit.call(this.events, name, ...args); } + /** + * Creates a task instance from a task graph item JSON representation + * @param item The JSON representation of the task + * @returns A new task instance + * @throws Error if required fields are missing or invalid + */ private createTask(item: TaskGraphItemJson) { if (!item.id) throw new Error("Task id required"); if (!item.type) throw new Error("Task type required"); @@ -51,6 +82,11 @@ export abstract class TaskGraphRepository { return task; } + /** + * Creates a TaskGraph instance from its JSON representation + * @param graphJsonObj The JSON representation of the task graph + * @returns A new TaskGraph instance with all tasks and data flows + */ public createSubGraph(graphJsonObj: TaskGraphJson) { const subGraph = new TaskGraph(); for (const subitem of graphJsonObj.nodes) { @@ -69,29 +105,50 @@ export abstract class TaskGraphRepository { return subGraph; } - async saveTaskGraph(id: unknown, output: TaskGraph): Promise { - const jsonObj = output.toJSON(); - await this.kvRepository.put(id, jsonObj); - this.emit("graph_saved", id); + /** + * Saves a task graph to persistent storage + * @param key The unique identifier for the task graph + * @param output The task graph to save + * @emits graph_saved when the operation completes + */ + async saveTaskGraph(key: string, output: TaskGraph): Promise { + const value = JSON.stringify(output.toJSON()); + await this.kvRepository.put(key, value); + this.emit("graph_saved", key); } - async getTaskGraph(id: unknown): Promise { - const jsonObj = await this.kvRepository.get(id); - if (!jsonObj) { + /** + * Retrieves a task graph from persistent storage + * @param key The unique identifier of the task graph to retrieve + * @returns The retrieved task graph, or undefined if not found + * @emits graph_retrieved when the operation completes successfully + */ + async getTaskGraph(key: string): Promise { + const jsonStr = (await this.kvRepository.get(key)) as string; + if (!jsonStr) { return undefined; } + const jsonObj = JSON.parse(jsonStr); const graph = this.createSubGraph(jsonObj); - this.emit("graph_retrieved", id); + this.emit("graph_retrieved", key); return graph; } + /** + * Clears all task graphs from the repository + * @emits graph_cleared when the operation completes + */ async clear(): Promise { - await this.kvRepository.clear(); + await this.kvRepository.deleteAll(); this.emit("graph_cleared"); } + /** + * Returns the number of task graphs stored in the repository + * @returns The count of stored task graphs + */ async size(): Promise { return await this.kvRepository.size(); } diff --git a/packages/core/src/storage/taskoutput/TaskOutputRepository.ts b/packages/core/src/storage/taskoutput/TaskOutputRepository.ts index a8a8686..b80a25c 100644 --- a/packages/core/src/storage/taskoutput/TaskOutputRepository.ts +++ b/packages/core/src/storage/taskoutput/TaskOutputRepository.ts @@ -7,44 +7,99 @@ import EventEmitter from "eventemitter3"; import { TaskInput, TaskOutput } from "../../task/base/Task"; -import { KVRepository } from "../base/KVRepository"; +import { DefaultValueType, KVRepository } from "../base/KVRepository"; +import { makeFingerprint } from "../../util/Misc"; export type TaskOutputEvents = "output_saved" | "output_retrieved" | "output_cleared"; -export const TaskOutputDiscriminator = { +export type TaskOutputPrimaryKey = { + key: string; + taskType: string; +}; +export const TaskOutputPrimaryKeySchema = { + key: "string", taskType: "string", } as const; +/** + * Abstract class for managing task outputs in a repository + * Provides methods for saving, retrieving, and clearing task outputs + */ export abstract class TaskOutputRepository { public type = "TaskOutputRepository"; - abstract kvRepository: KVRepository; + abstract kvRepository: KVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >; private events = new EventEmitter(); + + /** + * Registers an event listener for a specific event + * @param name The event name to listen for + * @param fn The callback function to execute when the event occurs + */ on(name: TaskOutputEvents, fn: (...args: any[]) => void) { this.events.on.call(this.events, name, fn); } + + /** + * Removes an event listener for a specific event + * @param name The event name to stop listening for + * @param fn The callback function to remove + */ off(name: TaskOutputEvents, fn: (...args: any[]) => void) { this.events.off.call(this.events, name, fn); } + + /** + * Emits an event with the given arguments + * @param name The event name to emit + * @param args Additional arguments to pass to the event listeners + */ emit(name: TaskOutputEvents, ...args: any[]) { this.events.emit.call(this.events, name, ...args); } + /** + * Saves a task output to the repository + * @param taskType The type of task to save the output for + * @param inputs The input parameters for the task + * @param output The task output to save + */ async saveOutput(taskType: string, inputs: TaskInput, output: TaskOutput): Promise { - await this.kvRepository.put({ taskType, inputs }, output); + const key = await makeFingerprint(inputs); + const value = JSON.stringify(output); + await this.kvRepository.putKeyValue({ key, taskType }, { "kv-value": value }); this.emit("output_saved", taskType); } + /** + * Retrieves a task output from the repository + * @param taskType The type of task to retrieve the output for + * @param inputs The input parameters for the task + * @returns The retrieved task output, or undefined if not found + */ async getOutput(taskType: string, inputs: TaskInput): Promise { - const output = await this.kvRepository.get({ taskType, inputs }); + const key = await makeFingerprint(inputs); + const output = await this.kvRepository.getKeyValue({ key, taskType }); this.emit("output_retrieved", taskType); - return output as TaskOutput; + return output ? (JSON.parse(output["kv-value"]) as TaskOutput) : undefined; } + /** + * Clears all task outputs from the repository + * @emits output_cleared when the operation completes + */ async clear(): Promise { - await this.kvRepository.clear(); + await this.kvRepository.deleteAll(); this.emit("output_cleared"); } + /** + * Returns the number of task outputs stored in the repository + * @returns The count of stored task outputs + */ async size(): Promise { return await this.kvRepository.size(); } diff --git a/packages/core/src/task/DebugLogTask.ts b/packages/core/src/task/DebugLogTask.ts index 5c3e296..78da3b1 100644 --- a/packages/core/src/task/DebugLogTask.ts +++ b/packages/core/src/task/DebugLogTask.ts @@ -13,6 +13,17 @@ import { TaskRegistry } from "./base/TaskRegistry"; export type DebugLogTaskInput = CreateMappedType; export type DebugLogTaskOutput = CreateMappedType; +/** + * DebugLogTask provides console logging functionality as a task within the system. + * + * Features: + * - Supports multiple log levels (info, warn, error, dir) + * - Passes through the logged message as output + * - Configurable logging format and depth + * + * This task is particularly useful for debugging task graphs and monitoring + * data flow between tasks during development and testing. + */ export class DebugLogTask extends OutputTask { static readonly type: string = "DebugLogTask"; static readonly category = "Output"; @@ -32,7 +43,7 @@ export class DebugLogTask extends OutputTask { }, ] as const; public static outputs = [{ id: "output", name: "Output", valueType: "any" }] as const; - runSyncOnly() { + async runReactive() { const level = this.runInputData.level || "log"; if (level == "dir") { console.dir(this.runInputData.message, { depth: null }); diff --git a/packages/core/src/task/JsonTask.ts b/packages/core/src/task/JsonTask.ts index 88ca020..917c04b 100644 --- a/packages/core/src/task/JsonTask.ts +++ b/packages/core/src/task/JsonTask.ts @@ -11,42 +11,52 @@ import { TaskGraphBuilder, TaskGraphBuilderHelper } from "./base/TaskGraphBuilde import { CreateMappedType } from "./base/TaskIOTypes"; import { TaskRegistry } from "./base/TaskRegistry"; +/** + * Represents a single task item in the JSON configuration. + * This structure defines how tasks should be configured in JSON format. + */ export type JsonTaskItem = { - id: unknown; - type: string; - name?: string; - input?: TaskInput; + id: unknown; // Unique identifier for the task + type: string; // Type of task to create + name?: string; // Optional display name for the task + input?: TaskInput; // Input configuration for the task dependencies?: { - [x: string]: - | { - id: unknown; - output: string; + // Defines data flow between tasks + [x: string]: // Input parameter name + | { + id: unknown; // ID of the source task + output: string; // Output parameter name from source task } | { id: unknown; output: string; }[]; }; - provenance?: TaskInput; - subtasks?: JsonTaskItem[]; + provenance?: TaskInput; // Optional metadata about task origin + subtasks?: JsonTaskItem[]; // Nested tasks for compound operations }; type JsonTaskInput = CreateMappedType; type JsonTaskOutput = CreateMappedType; +/** + * JsonTask is a specialized task that creates and manages task graphs from JSON configurations. + * It allows dynamic creation of task networks by parsing JSON definitions of tasks and their relationships. + */ export class JsonTask extends RegenerativeCompoundTask { public static inputs = [ { id: "json", name: "JSON", - valueType: "text", + valueType: "text", // Expects JSON string input }, ] as const; + public static outputs = [ { id: "output", name: "Output", - valueType: "any", + valueType: "any", // Output type depends on the generated task graph }, ] as const; @@ -61,6 +71,9 @@ export class JsonTask extends RegenerativeCompoundTask { } } + /** + * Updates the task's input data and regenerates the graph if JSON input changes + */ public addInputData(overrides: Partial | undefined) { let changed = false; if (overrides?.json != this.runInputData.json) changed = true; @@ -69,6 +82,10 @@ export class JsonTask extends RegenerativeCompoundTask { return this; } + /** + * Creates a task instance from a JSON task item configuration + * Validates required fields and resolves task type from registry + */ private createTask(item: JsonTaskItem) { if (!item.id) throw new Error("Task id required"); if (!item.type) throw new Error("Task type required"); @@ -93,6 +110,10 @@ export class JsonTask extends RegenerativeCompoundTask { return task; } + /** + * Creates a task graph from an array of JSON task items + * Recursively processes subtasks for compound tasks + */ private createSubGraph(jsonItems: JsonTaskItem[]) { const subGraph = new TaskGraph(); for (const subitem of jsonItems) { @@ -101,14 +122,20 @@ export class JsonTask extends RegenerativeCompoundTask { return subGraph; } + /** + * Regenerates the entire task graph based on the current JSON input + * Creates task nodes and establishes data flow connections between them + */ public regenerateGraph() { if (!this.runInputData.json) return; let data = JSON.parse(this.runInputData.json) as JsonTaskItem[] | JsonTaskItem; if (!Array.isArray(data)) data = [data]; const jsonItems: JsonTaskItem[] = data as JsonTaskItem[]; - // create the task nodes + + // Create task nodes this.subGraph = this.createSubGraph(jsonItems); - // create the data flow edges + + // Establish data flow connections for (const item of jsonItems) { if (!item.dependencies) continue; for (const [input, dependency] of Object.entries(item.dependencies)) { @@ -130,16 +157,24 @@ export class JsonTask extends RegenerativeCompoundTask { static readonly category = "Utility"; } +// Register JsonTask with the task registry TaskRegistry.registerTask(JsonTask); +/** + * Helper function to create and configure a JsonTask instance + */ const JsonBuilder = (input: JsonTaskInput) => { return new JsonTask({ input }); }; +/** + * Convenience function to create and run a JsonTask + */ export const Json = (input: JsonTaskInput) => { return JsonBuilder(input).run(); }; +// Add Json task builder to TaskGraphBuilder interface declare module "./base/TaskGraphBuilder" { interface TaskGraphBuilder { Json: TaskGraphBuilderHelper; diff --git a/packages/core/src/task/LambdaTask.ts b/packages/core/src/task/LambdaTask.ts index 57d0883..36473e7 100644 --- a/packages/core/src/task/LambdaTask.ts +++ b/packages/core/src/task/LambdaTask.ts @@ -10,40 +10,64 @@ import { TaskGraphBuilder, TaskGraphBuilderHelper } from "./base/TaskGraphBuilde import { CreateMappedType } from "./base/TaskIOTypes"; import { TaskRegistry } from "./base/TaskRegistry"; -// =============================================================================== - +/** + * Type definitions for LambdaTask input and output + * These types are generated from the static input/output definitions + */ export type LambdaTaskInput = CreateMappedType; export type LambdaTaskOutput = CreateMappedType; +/** + * LambdaTask provides a way to execute arbitrary functions within the task framework + * It wraps a provided function and its input into a task that can be integrated + * into task graphs and workflows + */ export class LambdaTask extends SingleTask { static readonly type = "LambdaTask"; declare runInputData: LambdaTaskInput; declare defaults: Partial; declare runOutputData: TaskOutput; + + /** + * Input definition for LambdaTask + * - fn: The function to execute + * - input: Optional input data to pass to the function + */ public static inputs = [ { id: "fn", name: "Function", - valueType: "function", + valueType: "function", // Expects a callable function }, { id: "input", name: "Input", - valueType: "any", + valueType: "any", // Can accept any type of input defaultValue: null, }, ] as const; + + /** + * Output definition for LambdaTask + * The output will be whatever the provided function returns + */ public static outputs = [ { id: "output", name: "Output", - valueType: "any", + valueType: "any", // Can return any type of value }, ] as const; + constructor(config: TaskConfig & { input?: LambdaTaskInput } = {}) { super(config); } - runSyncOnly() { + + /** + * Executes the provided function with the given input + * Throws an error if no function is provided or if the provided value is not callable + */ + async runReactive() { if (!this.runInputData.fn) { throw new Error("No runner provided"); } @@ -55,16 +79,25 @@ export class LambdaTask extends SingleTask { return this.runOutputData; } } + +// Register LambdaTask with the task registry TaskRegistry.registerTask(LambdaTask); +/** + * Helper function to create and configure a LambdaTask instance + */ const LambdaBuilder = (input: LambdaTaskInput) => { return new LambdaTask({ input }); }; +/** + * Convenience function to create and run a LambdaTask + */ export const Lambda = (input: LambdaTaskInput) => { return LambdaBuilder(input).run(); }; +// Add Lambda task builder to TaskGraphBuilder interface declare module "./base/TaskGraphBuilder" { interface TaskGraphBuilder { Lambda: TaskGraphBuilderHelper; diff --git a/packages/core/src/task/base/ArrayTask.ts b/packages/core/src/task/base/ArrayTask.ts index 8286be6..7622dfe 100644 --- a/packages/core/src/task/base/ArrayTask.ts +++ b/packages/core/src/task/base/ArrayTask.ts @@ -19,24 +19,35 @@ import { TaskGraph } from "./TaskGraph"; import { CreateMappedType, TaskInputDefinition, TaskOutputDefinition } from "./TaskIOTypes"; import { TaskRegistry } from "./TaskRegistry"; +// Type utilities for array transformations +// Makes specified properties optional arrays export type ConvertSomeToOptionalArray = { [P in keyof T]: P extends K ? Array | T[P] : T[P]; }; +// Makes all properties optional arrays export type ConvertAllToOptionalArray = { [P in keyof T]: Array | T[P]; }; +// Makes specified properties required arrays export type ConvertSomeToArray = { [P in keyof T]: P extends K ? Array : T[P]; }; +// Makes all properties required arrays export type ConvertAllToArrays = { [P in keyof T]: Array; }; +// Removes readonly modifiers from object properties type Writeable = { -readonly [P in keyof T]: T[P] }; +/** + * Takes an array of objects and collects values for each property into arrays + * @param input Array of objects to process + * @returns Object with arrays of values for each property + */ function collectPropertyValues(input: T[]): { [K in keyof T]?: T[K][] } { const output: { [K in keyof T]?: T[K][] } = {}; @@ -54,6 +65,12 @@ function collectPropertyValues(input: T[]): { [K in keyof T]?: return output; } +/** + * Converts specified IO definitions to array type + * @param io Array of input/output definitions + * @param id Optional ID to target specific definition + * @returns Modified array of definitions with isArray set to true + */ function convertToArray( io: D[], id?: string | number | symbol @@ -69,6 +86,12 @@ function convertToArray( return results as D[]; } +/** + * Converts multiple IO definitions to array type based on provided IDs + * @param io Array of input/output definitions + * @param ids Array of IDs to target specific definitions + * @returns Modified array of definitions with isArray set to true for matching IDs + */ function convertMultipleToArray( io: D[], ids: Array @@ -84,6 +107,12 @@ function convertMultipleToArray(input: T, inputMakeArray: (keyof T)[]): T[] { // Helper function to check if a property is an array const isArray = (value: any): value is Array => Array.isArray(value); @@ -104,8 +133,7 @@ function generateCombinations(input: T, inputMakeArray: (ke // Move to the next combination of indices for (let i = indices.length - 1; i >= 0; i--) { if (++indices[i] < arraysToCombine[i].length) break; // Increment current index if possible - if (i === 0) - done = true; // All combinations have been generated + if (i === 0) done = true; // All combinations have been generated else indices[i] = 0; // Reset current index and move to the next position } } @@ -124,9 +152,17 @@ function generateCombinations(input: T, inputMakeArray: (ke }); } +/** + * Factory function to create array-based task classes + * Creates a task that can process arrays of inputs in parallel + * @param taskClass Base task class to wrap + * @param inputMakeArray Array of input keys to process as arrays + * @param name Optional name for the generated task class + * @returns New task class that handles array inputs + */ export function arrayTaskFactory< PluralInputType extends TaskInput = TaskInput, - PluralOutputType extends TaskOutput = TaskOutput, + PluralOutputType extends TaskOutput = TaskOutput >( taskClass: typeof SingleTask | typeof CompoundTask, inputMakeArray: Array, @@ -142,6 +178,10 @@ export function arrayTaskFactory< const nameWithoutTask = taskClass.type.slice(0, -4); name ??= nameWithoutTask + "CompoundTask"; + /** + * A task class that handles array-based processing by creating subtasks for each combination of inputs + * Extends RegenerativeCompoundTask to manage a collection of child tasks running in parallel + */ class ArrayTask extends RegenerativeCompoundTask { static readonly type: TaskTypeName = name!; static readonly runtype = taskClass.type; @@ -155,10 +195,16 @@ export function arrayTaskFactory< static inputs = inputs; static override outputs = outputs; + constructor(config: TaskConfig & { input?: Partial } = {}) { super(config); this.regenerateGraph(); } + + /** + * Regenerates the task graph by creating child tasks for each input combination + * Each child task processes a single combination of the array inputs + */ regenerateGraph() { this.subGraph = new TaskGraph(); const combinations = generateCombinations(this.runInputData, inputMakeArray); @@ -169,19 +215,32 @@ export function arrayTaskFactory< super.regenerateGraph(); } + /** + * Adds new input data and regenerates the task graph to handle the updated inputs + * @param overrides Partial input data to merge with existing inputs + */ addInputData(overrides: Partial) { super.addInputData(overrides); this.regenerateGraph(); return this; } - runSyncOnly(): PluralOutputType { - const runDataOut = super.runSyncOnly(); + /** + * Runs the task reactively, collecting outputs from all child tasks into arrays + * @returns Combined output with arrays of values from all child tasks + */ + async runReactive(): Promise { + const runDataOut = await super.runReactive(); this.runOutputData = collectPropertyValues( runDataOut.outputs ) as PluralOutputType; return this.runOutputData; } + + /** + * Runs the task synchronously, collecting outputs from all child tasks into arrays + * @returns Combined output with arrays of values from all child tasks + */ async run(...args: any[]): Promise { const runDataOut = await super.run(...args); this.runOutputData = collectPropertyValues( @@ -189,10 +248,12 @@ export function arrayTaskFactory< ) as PluralOutputType; return this.runOutputData; } + toJSON(): JsonTaskItem { const { subgraph, ...result } = super.toJSON(); return result; } + toDependencyJSON(): JsonTaskItem { const { subtasks, ...result } = super.toDependencyJSON(); return result; diff --git a/packages/core/src/task/base/JobQueueTask.ts b/packages/core/src/task/base/JobQueueTask.ts index 9e72d16..b30cac4 100644 --- a/packages/core/src/task/base/JobQueueTask.ts +++ b/packages/core/src/task/base/JobQueueTask.ts @@ -7,11 +7,17 @@ import { SingleTask, TaskConfig } from "./Task"; +/** + * Configuration interface for job queue tasks + */ export interface JobQueueTaskConfig extends TaskConfig { queue?: string; currentJobId?: unknown; } +/** + * Base class for job queue tasks + */ export abstract class JobQueueTask extends SingleTask { static readonly type: string = "JobQueueTask"; declare config: JobQueueTaskConfig & { id: unknown }; diff --git a/packages/core/src/task/base/OutputTask.ts b/packages/core/src/task/base/OutputTask.ts index 318875c..5c26526 100644 --- a/packages/core/src/task/base/OutputTask.ts +++ b/packages/core/src/task/base/OutputTask.ts @@ -7,6 +7,10 @@ import { SingleTask, TaskInput } from "./Task"; +/** + * A task class that handles array-based processing by creating subtasks for each combination of inputs + * Extends RegenerativeCompoundTask to manage a collection of child tasks running in parallel + */ export class OutputTask extends SingleTask { static readonly category = "Output"; provenance: TaskInput = {}; diff --git a/packages/core/src/task/base/Task.ts b/packages/core/src/task/base/Task.ts index 1b9bf88..dbb927b 100644 --- a/packages/core/src/task/base/Task.ts +++ b/packages/core/src/task/base/Task.ts @@ -62,6 +62,9 @@ export interface IConfig { provenance?: TaskInput; } +/** + * Base class for all tasks + */ export abstract class TaskBase { // information about the task that should be overriden by the subclasses static readonly type: TaskTypeName = "TaskBase"; @@ -211,9 +214,22 @@ export abstract class TaskBase { return this; } + /** + * Validates an item against the task's input definition + * @param valueType The type of the item + * @param item The item to validate + * @returns True if the item is valid, false otherwise + */ validateItem(valueType: string, item: any) { return validateItem(valueType as ValueTypesIndex, item); } + + /** + * Validates an input item against the task's input definition + * @param input The input to validate + * @param inputId The id of the input to validate + * @returns True if the input is valid, false otherwise + */ validateInputItem(input: Partial, inputId: keyof TaskInput) { const classRef = this.constructor as typeof TaskBase; const inputdef = this.inputs.find((def) => def.id === inputId); @@ -258,6 +274,11 @@ export abstract class TaskBase { return true; } + /** + * Validates an input data object against the task's input definition + * @param input The input to validate + * @returns True if the input is valid, false otherwise + */ validateInputData(input: Partial) { for (const inputdef of this.inputs) { if (this.validateInputItem(input, inputdef.id) === false) { @@ -267,20 +288,32 @@ export abstract class TaskBase { return true; } + /** + * Runs the task + * @returns The output of the task + */ async run(): Promise { if (!this.validateInputData(this.runInputData)) { throw new Error("Invalid input data"); } this.emit("start"); - const result = this.runSyncOnly(); + const result = await this.runReactive(); this.emit("complete"); this.runOutputData = result; return result; } - runSyncOnly(): TaskOutput { + /** + * Runs the task reactively + * @returns The output of the task + */ + async runReactive(): Promise { return this.runOutputData; } + /** + * Converts the task to a JSON format suitable for dependency tracking + * @returns The task in JSON format + */ toJSON(): JsonTaskItem { const p = this.getProvenance(); return { @@ -290,6 +323,10 @@ export abstract class TaskBase { ...(Object.keys(p).length ? { provenance: p } : {}), }; } + /** + * Converts the task to a JSON format suitable for dependency tracking + * @returns The task in JSON format + */ toDependencyJSON(): JsonTaskItem { return this.toJSON(); } @@ -297,32 +334,55 @@ export abstract class TaskBase { export type TaskIdType = TaskBase["config"]["id"]; +/** + * Represents a single task, which is a basic unit of work in the task graph + */ export class SingleTask extends TaskBase implements ITaskSimple { static readonly type: TaskTypeName = "SingleTask"; readonly isCompound = false; } +/** + * Represents a compound task, which is a task that contains other tasks + */ export class CompoundTask extends TaskBase implements ITaskCompound { static readonly type: TaskTypeName = "CompoundTask"; declare runOutputData: TaskOutput; readonly isCompound = true; _subGraph: TaskGraph | null = null; + /** + * Sets the subtask graph for the compound task + * @param subGraph The subtask graph to set + */ set subGraph(subGraph: TaskGraph) { this._subGraph = subGraph; } + /** + * Gets the subtask graph for the compound task + * @returns The subtask graph + */ get subGraph() { if (!this._subGraph) { this._subGraph = new TaskGraph(); } return this._subGraph; } + /** + * Resets the input data for the compound task and its subtasks + */ resetInputData() { super.resetInputData(); this.subGraph.getNodes().forEach((node) => { node.resetInputData(); }); } + /** + * Runs the compound task + * @param nodeProvenance The provenance for the subtasks + * @param repository The repository to use for caching task outputs + * @returns The output of the compound task + */ async run( nodeProvenance: TaskInput = {}, repository?: TaskOutputRepository @@ -334,9 +394,9 @@ export class CompoundTask extends TaskBase implements ITaskCompound { this.emit("complete"); return this.runOutputData; } - runSyncOnly(): TaskOutput { + async runReactive(): Promise { const runner = new TaskGraphRunner(this.subGraph); - this.runOutputData.outputs = runner.runGraphSyncOnly(); + this.runOutputData.outputs = await runner.runGraphReactive(); return this.runOutputData; } @@ -348,15 +408,24 @@ export class CompoundTask extends TaskBase implements ITaskCompound { this.resetInputData(); return { ...super.toJSON(), subgraph: this.subGraph.toJSON() }; } - + /** + * Converts the task to a JSON format suitable for dependency tracking + * @returns The task in JSON format + */ toDependencyJSON(): JsonTaskItem { this.resetInputData(); return { ...super.toDependencyJSON(), subtasks: this.subGraph.toDependencyJSON() }; } } +/** + * Represents a regenerative compound task, which is a task that contains other tasks and can regenerate its subtasks + */ export class RegenerativeCompoundTask extends CompoundTask { static readonly type: TaskTypeName = "CompoundTask"; + /** + * Emits a "regenerate" event when the subtask graph is regenerated + */ public regenerateGraph() { this.emit("regenerate", this.subGraph); } diff --git a/packages/core/src/task/base/TaskGraph.ts b/packages/core/src/task/base/TaskGraph.ts index 8de938e..fb1c86c 100644 --- a/packages/core/src/task/base/TaskGraph.ts +++ b/packages/core/src/task/base/TaskGraph.ts @@ -12,6 +12,9 @@ import type { JsonTaskItem } from "../JsonTask"; export type DataFlowIdType = string; +/** + * Represents a data flow between two tasks, indicating how one task's output is used as input for another task + */ export class DataFlow { constructor( public sourceTaskId: TaskIdType, @@ -35,6 +38,9 @@ export class DataFlow { } } +/** + * Represents a task graph item, which can be a task or a subgraph + */ export type TaskGraphItemJson = { id: unknown; type: string; @@ -56,6 +62,9 @@ export type DataFlowJson = { targetTaskInputId: string; }; +/** + * Represents a task graph, a directed acyclic graph of tasks and data flows + */ export class TaskGraph extends DirectedAcyclicGraph { constructor() { super( @@ -63,24 +72,60 @@ export class TaskGraph extends DirectedAcyclicGraph dataFlow.id ); } + + /** + * Retrieves a task from the task graph by its id + * @param id The id of the task to retrieve + * @returns The task with the given id, or undefined if not found + */ public getTask(id: TaskIdType): Task | undefined { return super.getNode(id); } + + /** + * Adds a task to the task graph + * @param task The task to add + * @returns The current task graph + */ public addTask(task: Task) { return super.addNode(task); } + + /** + * Adds multiple tasks to the task graph + * @param tasks The tasks to add + * @returns The current task graph + */ public addTasks(tasks: Task[]) { return super.addNodes(tasks); } + + /** + * Adds a data flow to the task graph + * @param dataflow The data flow to add + * @returns The current task graph + */ public addDataFlow(dataflow: DataFlow) { return super.addEdge(dataflow.sourceTaskId, dataflow.targetTaskId, dataflow); } + + /** + * Adds multiple data flows to the task graph + * @param dataflows The data flows to add + * @returns The current task graph + */ public addDataFlows(dataflows: DataFlow[]) { const addedEdges = dataflows.map<[s: unknown, t: unknown, e: DataFlow]>((edge) => { return [edge.sourceTaskId, edge.targetTaskId, edge]; }); return super.addEdges(addedEdges); } + + /** + * Retrieves a data flow from the task graph by its id + * @param id The id of the data flow to retrieve + * @returns The data flow with the given id, or undefined if not found + */ public getDataFlow(id: DataFlowIdType): DataFlow | undefined { for (const i in this.adjacency) { for (const j in this.adjacency[i]) { @@ -99,22 +144,46 @@ export class TaskGraph extends DirectedAcyclicGraph edge[2]); } + /** + * Retrieves the data flows that are sources of a given task + * @param taskId The id of the task to retrieve sources for + * @returns An array of data flows that are sources of the given task + */ public getSourceDataFlows(taskId: unknown): DataFlow[] { return this.inEdges(taskId).map(([, , dataFlow]) => dataFlow); } + /** + * Retrieves the data flows that are targets of a given task + * @param taskId The id of the task to retrieve targets for + * @returns An array of data flows that are targets of the given task + */ public getTargetDataFlows(taskId: unknown): DataFlow[] { return this.outEdges(taskId).map(([, , dataFlow]) => dataFlow); } + /** + * Retrieves the tasks that are sources of a given task + * @param taskId The id of the task to retrieve sources for + * @returns An array of tasks that are sources of the given task + */ public getSourceTasks(taskId: unknown): Task[] { return this.getSourceDataFlows(taskId).map((dataFlow) => this.getNode(dataFlow.sourceTaskId)!); } + /** + * Retrieves the tasks that are targets of a given task + * @param taskId The id of the task to retrieve targets for + * @returns An array of tasks that are targets of the given task + */ public getTargetTasks(taskId: unknown): Task[] { return this.getTargetDataFlows(taskId).map((dataFlow) => this.getNode(dataFlow.targetTaskId)!); } + /** + * Converts the task graph to a JSON format suitable for dependency tracking + * @returns An array of JsonTaskItem objects, each representing a task and its dependencies + */ public toJSON(): TaskGraphJson { const nodes = this.getNodes().map((node) => node.toJSON()); const edges = this.getDataFlows().map((df) => df.toJSON()); @@ -124,6 +193,10 @@ export class TaskGraph extends DirectedAcyclicGraph node.toDependencyJSON()); this.getDataFlows().forEach((edge) => { diff --git a/packages/core/src/task/base/TaskGraphBuilder.ts b/packages/core/src/task/base/TaskGraphBuilder.ts index 670d3e0..71f94a5 100644 --- a/packages/core/src/task/base/TaskGraphBuilder.ts +++ b/packages/core/src/task/base/TaskGraphBuilder.ts @@ -73,7 +73,9 @@ export function TaskGraphBuilderHelper( if (matches.size === 0) { this._error = `Could not find a match between the outputs of ${ (parent.constructor as any).type - } and the inputs of ${(task.constructor as any).type}. You now need to connect the outputs to the inputs via connect() manually before adding this task. Task not added.`; + } and the inputs of ${ + (task.constructor as any).type + }. You now need to connect the outputs to the inputs via connect() manually before adding this task. Task not added.`; console.error(this._error); this.graph.removeNode(task.config.id); } @@ -89,6 +91,10 @@ export function TaskGraphBuilderHelper( type BuilderEvents = GraphEvents | "changed" | "reset" | "error" | "start" | "complete"; +/** + * Class for building and managing a task graph + * Provides methods for adding tasks, connecting outputs to inputs, and running the task graph + */ export class TaskGraphBuilder { private _graph: TaskGraph = new TaskGraph(); private _runner: TaskGraphRunner; @@ -148,6 +154,10 @@ export class TaskGraphBuilder { this._graph.events.off("edge-removed", this._onChanged); } + /** + * Runs the task graph + * @returns The output of the task graph + */ async run() { this.emit("start"); const out = await this._runner.runGraph(); @@ -155,6 +165,10 @@ export class TaskGraphBuilder { return out; } + /** + * Removes the last task from the task graph + * @returns The current task graph builder + */ pop() { this._error = ""; const nodes = this._graph.getNodes(); @@ -175,6 +189,11 @@ export class TaskGraphBuilder { return this._graph.toDependencyJSON(); } + /** + * Creates a new task graph builder that runs multiple task graph builders in parallel + * @param args The task graph builders to run in parallel + * @returns The current task graph builder + */ parallel(...args: Array<(b: TaskGraphBuilder) => void>) { this._error = ""; const group = new TaskGraphBuilder(); @@ -188,6 +207,13 @@ export class TaskGraphBuilder { } _dataFlows: DataFlow[] = []; + /** + * Renames an output of a task to a new target input + * @param source The id of the output to rename + * @param target The id of the input to rename to + * @param index The index of the task to rename the output of, defaults to the last task + * @returns The current task graph builder + */ rename(source: string, target: string, index: number = -1) { this._error = ""; const nodes = this._graph.getNodes(); @@ -205,6 +231,10 @@ export class TaskGraphBuilder { return this; } + /** + * Resets the task graph builder to its initial state + * @returns The current task graph builder + */ reset() { id = 0; this.clearEvents(); diff --git a/packages/core/src/task/base/TaskGraphRunner.ts b/packages/core/src/task/base/TaskGraphRunner.ts index 512bd56..2f5dc73 100644 --- a/packages/core/src/task/base/TaskGraphRunner.ts +++ b/packages/core/src/task/base/TaskGraphRunner.ts @@ -9,18 +9,38 @@ import { TaskOutputRepository } from "../../storage/taskoutput/TaskOutputReposit import { TaskInput, Task, TaskOutput } from "./Task"; import { TaskGraph } from "./TaskGraph"; +/** + * Class for running a task graph + * Manages the execution of tasks in a task graph, including provenance tracking and caching + */ export class TaskGraphRunner { + /** + * Map of layers, where each layer contains an array of tasks + * @type {Map} + */ public layers: Map; + + /** + * Map of provenance input for each task + * @type {Map} + */ public provenanceInput: Map; - constructor( - public dag: TaskGraph, - public repository?: TaskOutputRepository - ) { + /** + * Constructor for TaskGraphRunner + * @param dag The task graph to run + * @param repository The task output repository to use for caching task outputs + */ + constructor(public dag: TaskGraph, public repository?: TaskOutputRepository) { this.layers = new Map(); this.provenanceInput = new Map(); } + /** + * Assigns layers to tasks based on their dependencies. Each layer is a set of tasks + * that can be run in parallel as a set, the next layer is run after the previous layer has completed. + * @param sortedNodes The topologically sorted list of tasks + */ public assignLayers(sortedNodes: Task[]) { this.layers = new Map(); const nodeToLayer = new Map(); @@ -59,6 +79,11 @@ export class TaskGraphRunner { }); } + /** + * Retrieves the provenance input for a task + * @param node The task to retrieve provenance input for + * @returns The provenance input for the task + */ private getInputProvenance(node: Task): TaskInput { const nodeProvenance: TaskInput = {}; this.dag.getSourceDataFlows(node.config.id).forEach((dataFlow) => { @@ -67,6 +92,12 @@ export class TaskGraphRunner { return nodeProvenance; } + /** + * Pushes the output of a task to its target tasks + * @param node The task that produced the output + * @param results The output of the task + * @param nodeProvenance The provenance input for the task + */ private pushOutputFromNodeToEdges(node: Task, results: TaskOutput, nodeProvenance?: TaskInput) { this.dag.getTargetDataFlows(node.config.id).forEach((dataFlow) => { if (results[dataFlow.sourceTaskOutputId] !== undefined) { @@ -76,6 +107,12 @@ export class TaskGraphRunner { }); } + /** + * Runs a task with provenance input + * @param task The task to run + * @param parentProvenance The provenance input for the task + * @returns The output of the task + */ private async runTaskWithProvenance( task: Task, parentProvenance: TaskInput @@ -99,7 +136,7 @@ export class TaskGraphRunner { task.emit("start"); task.emit("progress", 100, Object.values(results)[0]); task.runOutputData = results; - task.runSyncOnly(); + await task.runReactive(); task.emit("complete"); } } @@ -118,6 +155,11 @@ export class TaskGraphRunner { return results; } + /** + * Runs the task graph + * @param parentProvenance The provenance input for the task graph + * @returns The output of the task graph + */ public async runGraph(parentProvenance: TaskInput = {}) { this.provenanceInput = new Map(); this.dag.getNodes().forEach((node) => node.resetInputData()); @@ -135,23 +177,34 @@ export class TaskGraphRunner { return results; } - private runTasksSync() { + /** + * Runs the task graph in a reactive manner + * @returns The output of the task graph + */ + private async runTasksReactive() { let results: TaskOutput[] = []; for (const [_layerNumber, nodes] of this.layers.entries()) { - results = nodes.map((node) => { - this.copyInputFromEdgesToNode(node); - const results = node.runSyncOnly(); - this.pushOutputFromNodeToEdges(node, results); - return results; - }); + const settledResults = await Promise.allSettled( + nodes.map(async (node) => { + this.copyInputFromEdgesToNode(node); + const results = await node.runReactive(); + this.pushOutputFromNodeToEdges(node, results); + return results; + }) + ); + results = settledResults.map((r) => (r.status === "fulfilled" ? r.value : {})); } return results; } - public runGraphSyncOnly() { + /** + * Runs the task graph in a reactive manner + * @returns The output of the task graph + */ + public async runGraphReactive() { this.dag.getNodes().forEach((node) => node.resetInputData()); const sortedNodes = this.dag.topologicallySortedNodes(); this.assignLayers(sortedNodes); - return this.runTasksSync(); + return await this.runTasksReactive(); } } diff --git a/packages/core/src/task/test/Task.test.ts b/packages/core/src/task/test/Task.test.ts index 683f920..a493885 100644 --- a/packages/core/src/task/test/Task.test.ts +++ b/packages/core/src/task/test/Task.test.ts @@ -26,7 +26,7 @@ class TestTask extends SingleTask { ] as const; static readonly outputs = [ { - id: "syncOnly", + id: "reactiveOnly", name: "Output", valueType: "boolean", }, @@ -41,11 +41,11 @@ class TestTask extends SingleTask { valueType: "text", }, ] as const; - runSyncOnly(): TestTaskOutput { - return { all: false, key: this.runInputData.key, syncOnly: true }; + async runReactive(): Promise { + return { all: false, key: this.runInputData.key, reactiveOnly: true }; } async run(): Promise { - return { all: true, key: this.runInputData.key, syncOnly: false }; + return { all: true, key: this.runInputData.key, reactiveOnly: false }; } } @@ -62,7 +62,7 @@ class TestCompoundTask extends CompoundTask { ] as const; static readonly outputs = [ { - id: "syncOnly", + id: "reactiveOnly", name: "Output", valueType: "boolean", }, @@ -78,12 +78,12 @@ class TestCompoundTask extends CompoundTask { }, ] as const; static readonly type = "TestCompoundTask"; - runSyncOnly(): TestTaskOutput { - this.runOutputData = { key: this.runInputData.key, all: false, syncOnly: true }; + async runReactive(): Promise { + this.runOutputData = { key: this.runInputData.key, all: false, reactiveOnly: true }; return this.runOutputData; } async run(): Promise { - this.runOutputData = { key: this.runInputData.key, all: true, syncOnly: false }; + this.runOutputData = { key: this.runInputData.key, all: true, reactiveOnly: false }; return this.runOutputData; } } @@ -95,14 +95,14 @@ describe("Task", () => { const input = { key: "value" }; node.addInputData(input); const output = await node.run(); - expect(output).toEqual({ ...input, syncOnly: false, all: true }); + expect(output).toEqual({ ...input, reactiveOnly: false, all: true }); expect(node.runInputData).toEqual(input); }); - it("should run the task synchronously", () => { + it("should run the task reactively", async () => { const node = new TestTask(); - const output = node.runSyncOnly(); - expect(output).toEqual({ key: "", syncOnly: true, all: false }); + const output = await node.runReactive(); + expect(output).toEqual({ key: "", reactiveOnly: true, all: false }); }); }); @@ -123,14 +123,14 @@ describe("Task", () => { const input = { key: "value" }; node.addInputData(input); const output = await node.run(); - expect(output).toEqual({ key: "value", all: true, syncOnly: false }); + expect(output).toEqual({ key: "value", all: true, reactiveOnly: false }); expect(node.runInputData).toEqual(input); }); - it("should run the task synchronously", () => { + it("should run the task synchronously", async () => { const node = new TestCompoundTask({ input: { key: "value2" } }); - const output = node.runSyncOnly(); - expect(output).toEqual({ key: "value2", syncOnly: true, all: false }); + const output = await node.runReactive(); + expect(output).toEqual({ key: "value2", reactiveOnly: true, all: false }); }); }); }); diff --git a/packages/core/src/task/test/TaskGraph.test.ts b/packages/core/src/task/test/TaskGraph.test.ts index 1371e03..42462ce 100644 --- a/packages/core/src/task/test/TaskGraph.test.ts +++ b/packages/core/src/task/test/TaskGraph.test.ts @@ -11,7 +11,7 @@ import { TaskGraph, DataFlow, serialGraph } from "../base/TaskGraph"; class TestTask extends SingleTask { static readonly type = "TestTask"; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { return {}; } } diff --git a/packages/core/src/task/test/TaskGraphRunner.test.ts b/packages/core/src/task/test/TaskGraphRunner.test.ts index 42cbffe..b4f0495 100644 --- a/packages/core/src/task/test/TaskGraphRunner.test.ts +++ b/packages/core/src/task/test/TaskGraphRunner.test.ts @@ -13,7 +13,7 @@ import { CreateMappedType } from "../base/TaskIOTypes"; class TestTask extends SingleTask { static readonly type = "TestTask"; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { return {}; } } @@ -40,7 +40,7 @@ class TestSquareTask extends SingleTask { valueType: "number", }, ] as const; - runSyncOnly(): TestSquareTaskOutput { + async runReactive(): Promise { return { output: this.runInputData.input * this.runInputData.input }; } } @@ -66,7 +66,7 @@ class TestDoubleTask extends SingleTask { valueType: "number", }, ] as const; - runSyncOnly(): TestDoubleTaskOutput { + async runReactive(): Promise { return { output: this.runInputData.input * 2 }; } } @@ -98,7 +98,7 @@ class TestAddTask extends SingleTask { valueType: "number", }, ] as const; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { const input = this.runInputData; return { output: input.a + input.b }; } @@ -146,13 +146,13 @@ describe("TaskGraphRunner", () => { }); }); - describe("runGraphSyncOnly", () => { - it("should run nodes in each layer synchronously", () => { - const runSyncOnlySpy = spyOn(nodes[0], "runSyncOnly"); + describe("runGraphReactive", () => { + it("should run nodes in each layer synchronously", async () => { + const runReactiveSpy = spyOn(nodes[0], "runReactive"); - runner.runGraphSyncOnly(); + await runner.runGraphReactive(); - expect(runSyncOnlySpy).toHaveBeenCalledTimes(1); + expect(runReactiveSpy).toHaveBeenCalledTimes(1); }); }); diff --git a/packages/core/src/task/test/TaskSubGraphRunner.test.ts b/packages/core/src/task/test/TaskSubGraphRunner.test.ts index 3157d3a..67654d9 100644 --- a/packages/core/src/task/test/TaskSubGraphRunner.test.ts +++ b/packages/core/src/task/test/TaskSubGraphRunner.test.ts @@ -37,7 +37,7 @@ class TestSquareTask extends SingleTask { valueType: "number", }, ] as const; - runSyncOnly(): TestSquareTaskOutput { + async runReactive(): Promise { return { output: this.runInputData.input * this.runInputData.input }; } } @@ -63,7 +63,7 @@ class TestDoubleTask extends SingleTask { valueType: "number", }, ] as const; - runSyncOnly(): TestDoubleTaskOutput { + async runReactive(): Promise { return { output: this.runInputData.input * 2 }; } } @@ -90,7 +90,7 @@ class TestAddTask extends SingleTask { valueType: "number", }, ] as const; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { const inputs = Array.isArray(this.runInputData.input) ? this.runInputData.input : [this.runInputData.input ?? 0]; diff --git a/packages/storage/package.json b/packages/storage/package.json index cc3da7e..6a1a52a 100644 --- a/packages/storage/package.json +++ b/packages/storage/package.json @@ -5,15 +5,15 @@ "description": "Ellmers is a tool for building and running DAG pipelines of AI tasks.", "scripts": { "watch": "concurrently -c 'auto' 'bun:watch-*'", - "watch-browser": "bun build --watch --no-clear-screen--target=browser --sourcemap=external --external ellmers-core --outdir ./dist/browser ./src/browser/*/index.ts", - "watch-node": "bun build --watch --no-clear-screen--target=node --sourcemap=external --external ellmers-core --outdir ./dist ./src/node/*/index.ts", - "watch-bun": "bun build --watch --no-clear-screen--target=bun --sourcemap=external --external ellmers-core --outdir ./dist ./src/bun/*/index.ts", + "watch-browser": "bun build --watch --no-clear-screen --target=browser --sourcemap=external --external ellmers-core --external ellmers-ai --outdir ./dist --entry-naming \"browser/[dir]/[name].[ext]\" ./src/browser/*/index.ts", + "watch-node": "bun build --watch --no-clear-screen --target=node --sourcemap=external --external ellmers-core --external ellmers-ai --outdir ./dist --entry-naming \"node/[dir]/[name].[ext]\" ./src/node/*/index.ts", + "watch-bun": "bun build --watch --no-clear-screen --target=bun --sourcemap=external --external ellmers-core --external ellmers-ai --outdir ./dist --entry-naming \"bun/[dir]/[name].[ext]\" ./src/bun/*/index.ts", "watch-types": "tsc --watch --preserveWatchOutput", "build": "bun run build-clean && bun run build-types && bun run build-browser && bun run build-node && bun run build-bun", "build-clean": "rm -fr dist/* tsconfig.tsbuildinfo", - "build-browser": "bun build --target=browser --sourcemap=external --external ellmers-core --outdir ./dist --entry-naming \"browser/[dir]/[name].[ext]\" ./src/browser/*/index.ts", - "build-node": "bun build --target=node --sourcemap=external --external ellmers-core --outdir ./dist --entry-naming \"node/[dir]/[name].[ext]\" ./src/node/*/index.ts", - "build-bun": "bun build --target=bun --sourcemap=external --external ellmers-core --outdir ./dist --entry-naming \"bun/[dir]/[name].[ext]\" ./src/bun/*/index.ts", + "build-browser": "bun build --target=browser --sourcemap=external --external ellmers-core --external ellmers-ai --outdir ./dist --entry-naming \"browser/[dir]/[name].[ext]\" ./src/browser/*/index.ts", + "build-node": "bun build --target=node --sourcemap=external --external ellmers-core --external ellmers-ai --outdir ./dist --entry-naming \"node/[dir]/[name].[ext]\" ./src/node/*/index.ts", + "build-bun": "bun build --target=bun --sourcemap=external --external ellmers-core --external ellmers-ai --outdir ./dist --entry-naming \"bun/[dir]/[name].[ext]\" ./src/bun/*/index.ts", "build-types": "tsc", "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", "test": "bun test" @@ -48,6 +48,7 @@ "dist" ], "dependencies": { - "ellmers-core": "workspace:packages/core" + "ellmers-core": "workspace:packages/core", + "ellmers-ai": "workspace:packages/ai" } } diff --git a/packages/storage/src/browser/indexeddb/IndexedDbTaskGraphRepository.ts b/packages/storage/src/browser/indexeddb/IndexedDbTaskGraphRepository.ts index a0d7b43..a0932f5 100644 --- a/packages/storage/src/browser/indexeddb/IndexedDbTaskGraphRepository.ts +++ b/packages/storage/src/browser/indexeddb/IndexedDbTaskGraphRepository.ts @@ -5,14 +5,14 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskGraphJson, TaskGraphRepository } from "ellmers-core"; +import { TaskGraphRepository } from "ellmers-core"; import { IndexedDbKVRepository } from "./base/IndexedDbKVRepository"; export class IndexedDbTaskGraphRepository extends TaskGraphRepository { - kvRepository: IndexedDbKVRepository; + kvRepository: IndexedDbKVRepository; public type = "IndexedDbTaskGraphRepository" as const; constructor() { super(); - this.kvRepository = new IndexedDbKVRepository("task_graphs"); + this.kvRepository = new IndexedDbKVRepository("task_graphs"); } } diff --git a/packages/storage/src/browser/indexeddb/IndexedDbTaskOutputRepository.ts b/packages/storage/src/browser/indexeddb/IndexedDbTaskOutputRepository.ts index 0f83e0d..809a625 100644 --- a/packages/storage/src/browser/indexeddb/IndexedDbTaskOutputRepository.ts +++ b/packages/storage/src/browser/indexeddb/IndexedDbTaskOutputRepository.ts @@ -5,18 +5,27 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskInput, TaskOutput, TaskOutputDiscriminator, TaskOutputRepository } from "ellmers-core"; +import { + DefaultValueType, + TaskOutputPrimaryKey, + TaskOutputPrimaryKeySchema, + TaskOutputRepository, +} from "ellmers-core"; import { IndexedDbKVRepository } from "./base/IndexedDbKVRepository"; export class IndexedDbTaskOutputRepository extends TaskOutputRepository { - kvRepository: IndexedDbKVRepository; + kvRepository: IndexedDbKVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >; public type = "IndexedDbTaskOutputRepository" as const; constructor() { super(); this.kvRepository = new IndexedDbKVRepository< - TaskInput, - TaskOutput, - typeof TaskOutputDiscriminator - >("task_outputs"); + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >("task_outputs", TaskOutputPrimaryKeySchema); } } diff --git a/packages/storage/src/browser/indexeddb/base/IndexedDbKVRepository.ts b/packages/storage/src/browser/indexeddb/base/IndexedDbKVRepository.ts index 0895590..5c5bb75 100644 --- a/packages/storage/src/browser/indexeddb/base/IndexedDbKVRepository.ts +++ b/packages/storage/src/browser/indexeddb/base/IndexedDbKVRepository.ts @@ -5,29 +5,70 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { DiscriminatorSchema, KVRepository } from "ellmers-core"; +import { + BaseValueSchema, + BasePrimaryKeySchema, + BasicKeyType, + DefaultValueType, + DefaultValueSchema, + DefaultPrimaryKeyType, + DefaultPrimaryKeySchema, + KVRepository, +} from "ellmers-core"; import { ensureIndexedDbTable } from "./IndexedDbTable"; import { makeFingerprint } from "../../../util/Misc"; // IndexedDbKVRepository is a key-value store that uses IndexedDB as the backend for // simple browser-based examples with no server-side component. It does not support di +/** + * A key-value repository implementation using IndexedDB for browser-based storage. + * This class provides a simple persistent storage solution for web applications + * without requiring a server component. + * + * @template Key - The type of the primary key object + * @template Value - The type of the value object to be stored + * @template PrimaryKeySchema - Schema definition for the primary key + * @template ValueSchema - Schema definition for the value + * @template Combined - Combined type of Key & Value + */ export class IndexedDbKVRepository< - Key = string, - Value = string, - Discriminator extends DiscriminatorSchema = DiscriminatorSchema -> extends KVRepository { + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Record = Key & Value +> extends KVRepository { + /** Promise that resolves to the IndexedDB database instance */ private dbPromise: Promise; - constructor(public table: string = "kv_store") { - super(); + /** + * Creates a new IndexedDB-based key-value repository + * @param table - Name of the IndexedDB store to use + * @param primaryKeySchema - Schema defining the structure of primary keys + * @param valueSchema - Schema defining the structure of values + * @param searchable - Array of properties that can be searched (Note: search not implemented) + */ + constructor( + public table: string = "kv_store", + primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + protected searchable: Array = [] + ) { + super(primaryKeySchema, valueSchema, searchable); this.dbPromise = ensureIndexedDbTable(this.table, (db) => { db.createObjectStore(table, { keyPath: "id" }); }); } - async put(key: Key, value: Value): Promise { - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); + /** + * Stores a key-value pair in the repository + * @param key - The key object + * @param value - The value object to store + * @emits put - Emitted when the value is successfully stored + */ + async putKeyValue(key: Key, value: Value): Promise { + const id = await makeFingerprint(key); const db = await this.dbPromise; return new Promise((resolve, reject) => { @@ -43,8 +84,14 @@ export class IndexedDbKVRepository< }); } - async get(key: Key): Promise { - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); + /** + * Retrieves a value by its key + * @param key - The key object to look up + * @returns The stored value or undefined if not found + * @emits get - Emitted when a value is retrieved + */ + async getKeyValue(key: Key): Promise { + const id = await makeFingerprint(key); const db = await this.dbPromise; return new Promise((resolve, reject) => { @@ -64,7 +111,41 @@ export class IndexedDbKVRepository< }); } - async clear(): Promise { + /** + * Search functionality is not supported in this implementation + * @throws Error indicating search is not supported + */ + async search(key: Partial): Promise { + throw new Error("Search not supported for IndexedDbKVRepository"); + } + + /** + * Deletes a key-value pair from the repository + * @param key - The key object to delete + * @emits delete - Emitted when a value is deleted + */ + async deleteKeyValue(key: Key): Promise { + const id = await makeFingerprint(key); + const db = await this.dbPromise; + + return new Promise((resolve, reject) => { + const transaction = db.transaction(this.table, "readwrite"); + const store = transaction.objectStore(this.table); + const request = store.delete(id); + + request.onerror = () => reject(request.error); + request.onsuccess = () => { + this.emit("delete", id); + resolve(); + }; + }); + } + + /** + * Deletes all key-value pairs from the repository + * @emits clearall - Emitted when all values are deleted + */ + async deleteAll(): Promise { const db = await this.dbPromise; return new Promise((resolve, reject) => { @@ -74,12 +155,16 @@ export class IndexedDbKVRepository< request.onerror = () => reject(request.error); request.onsuccess = () => { - this.emit("clear"); + this.emit("clearall"); resolve(); }; }); } + /** + * Returns the total number of key-value pairs in the repository + * @returns The count of stored items + */ async size(): Promise { const db = await this.dbPromise; diff --git a/packages/storage/src/browser/inmemory/InMemoryJobQueue.ts b/packages/storage/src/browser/inmemory/InMemoryJobQueue.ts index 9aef9e2..bc5e350 100644 --- a/packages/storage/src/browser/inmemory/InMemoryJobQueue.ts +++ b/packages/storage/src/browser/inmemory/InMemoryJobQueue.ts @@ -10,7 +10,12 @@ import { Job, JobStatus, JobQueue, ILimiter } from "ellmers-core"; import { makeFingerprint } from "../../util/Misc"; export class InMemoryJobQueue extends JobQueue { - constructor(queue: string, limiter: ILimiter, waitDurationInMilliseconds = 100) { + constructor( + queue: string, + limiter: ILimiter, + waitDurationInMilliseconds = 100, + protected jobClass: typeof Job = Job + ) { super(queue, limiter, waitDurationInMilliseconds); this.jobQueue = []; } diff --git a/packages/storage/src/browser/inmemory/InMemoryKVRepository.ts b/packages/storage/src/browser/inmemory/InMemoryKVRepository.ts deleted file mode 100644 index 53078e6..0000000 --- a/packages/storage/src/browser/inmemory/InMemoryKVRepository.ts +++ /dev/null @@ -1,42 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { DiscriminatorSchema, KVRepository } from "ellmers-core"; -import { makeFingerprint } from "../../util/Misc"; - -// InMemoryKVRepository is a simple in-memory key-value store that can be used for testing or as a cache -// It does not support discriminators - -export class InMemoryKVRepository< - Key = string, - Value = string, - Discriminator extends DiscriminatorSchema = DiscriminatorSchema -> extends KVRepository { - values = new Map(); - - async put(key: Key, value: Value): Promise { - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - this.values.set(id, value); - this.emit("put", id); - } - - async get(key: Key): Promise { - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - const out = this.values.get(id); - this.emit("get", id); - return out; - } - - async clear(): Promise { - this.values.clear(); - this.emit("clear"); - } - - async size(): Promise { - return this.values.size; - } -} diff --git a/packages/storage/src/browser/inmemory/InMemoryModelRepository.ts b/packages/storage/src/browser/inmemory/InMemoryModelRepository.ts new file mode 100644 index 0000000..a4fd0d0 --- /dev/null +++ b/packages/storage/src/browser/inmemory/InMemoryModelRepository.ts @@ -0,0 +1,49 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { + ModelPrimaryKey, + ModelRepository, + Task2ModelDetail, + Task2ModelPrimaryKey, +} from "ellmers-ai"; +import { InMemoryKVRepository } from "./base/InMemoryKVRepository"; +import { + ModelPrimaryKeySchema, + Task2ModelPrimaryKeySchema, + Task2ModelDetailSchema, +} from "ellmers-ai"; +import { DefaultValueType } from "ellmers-core"; + +export class InMemoryModelRepository extends ModelRepository { + modelKvRepository: InMemoryKVRepository< + ModelPrimaryKey, + DefaultValueType, + typeof ModelPrimaryKeySchema + >; + task2ModelKvRepository: InMemoryKVRepository< + Task2ModelPrimaryKey, + Task2ModelDetail, + typeof Task2ModelPrimaryKeySchema, + typeof Task2ModelDetailSchema + >; + public type = "InMemoryModelRepository" as const; + constructor() { + super(); + this.modelKvRepository = new InMemoryKVRepository< + ModelPrimaryKey, + DefaultValueType, + typeof ModelPrimaryKeySchema + >(ModelPrimaryKeySchema); + this.task2ModelKvRepository = new InMemoryKVRepository< + Task2ModelPrimaryKey, + Task2ModelDetail, + typeof Task2ModelPrimaryKeySchema, + typeof Task2ModelDetailSchema + >(Task2ModelPrimaryKeySchema, Task2ModelDetailSchema, ["model"]); + } +} diff --git a/packages/storage/src/browser/inmemory/InMemoryTaskGraphRepository.ts b/packages/storage/src/browser/inmemory/InMemoryTaskGraphRepository.ts index 92b7da9..84eb836 100644 --- a/packages/storage/src/browser/inmemory/InMemoryTaskGraphRepository.ts +++ b/packages/storage/src/browser/inmemory/InMemoryTaskGraphRepository.ts @@ -5,14 +5,14 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskGraphJson, TaskGraphRepository } from "ellmers-core"; -import { InMemoryKVRepository } from "./InMemoryKVRepository"; +import { TaskGraphRepository } from "ellmers-core"; +import { InMemoryKVRepository } from "./base/InMemoryKVRepository"; export class InMemoryTaskGraphRepository extends TaskGraphRepository { - kvRepository: InMemoryKVRepository; + kvRepository: InMemoryKVRepository; public type = "InMemoryTaskGraphRepository" as const; constructor() { super(); - this.kvRepository = new InMemoryKVRepository(); + this.kvRepository = new InMemoryKVRepository(); } } diff --git a/packages/storage/src/browser/inmemory/InMemoryTaskOutputRepository.ts b/packages/storage/src/browser/inmemory/InMemoryTaskOutputRepository.ts index 9f030f7..66b294a 100644 --- a/packages/storage/src/browser/inmemory/InMemoryTaskOutputRepository.ts +++ b/packages/storage/src/browser/inmemory/InMemoryTaskOutputRepository.ts @@ -5,18 +5,29 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskInput, TaskOutput, TaskOutputDiscriminator, TaskOutputRepository } from "ellmers-core"; -import { InMemoryKVRepository } from "./InMemoryKVRepository"; +import { + DefaultValueType, + TaskInput, + TaskOutput, + TaskOutputPrimaryKey, + TaskOutputPrimaryKeySchema, + TaskOutputRepository, +} from "ellmers-core"; +import { InMemoryKVRepository } from "./base/InMemoryKVRepository"; export class InMemoryTaskOutputRepository extends TaskOutputRepository { - kvRepository: InMemoryKVRepository; + kvRepository: InMemoryKVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >; public type = "InMemoryTaskOutputRepository" as const; constructor() { super(); this.kvRepository = new InMemoryKVRepository< - TaskInput, - TaskOutput, - typeof TaskOutputDiscriminator - >(); + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >(TaskOutputPrimaryKeySchema); } } diff --git a/packages/storage/src/browser/inmemory/base/InMemoryKVRepository.ts b/packages/storage/src/browser/inmemory/base/InMemoryKVRepository.ts new file mode 100644 index 0000000..5defaa0 --- /dev/null +++ b/packages/storage/src/browser/inmemory/base/InMemoryKVRepository.ts @@ -0,0 +1,129 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { + BaseValueSchema, + BasePrimaryKeySchema, + BasicKeyType, + DefaultValueType, + DefaultValueSchema, + DefaultPrimaryKeyType, + DefaultPrimaryKeySchema, + KVRepository, +} from "ellmers-core"; +import { makeFingerprint } from "../../../util/Misc"; + +// InMemoryKVRepository is a simple in-memory key-value store that can be used for testing or as a cache + +/** + * A generic in-memory key-value repository implementation. + * Provides a simple, non-persistent storage solution suitable for testing and caching scenarios. + * + * @template Key - The type of the primary key object, must be a record of basic types + * @template Value - The type of the value object being stored + * @template PrimaryKeySchema - Schema definition for the primary key + * @template ValueSchema - Schema definition for the value + * @template Combined - The combined type of Key & Value + */ +export class InMemoryKVRepository< + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Record = Key & Value +> extends KVRepository { + /** Internal storage using a Map with fingerprint strings as keys */ + values = new Map(); + + /** + * Creates a new InMemoryKVRepository instance + * @param primaryKeySchema - Schema defining the structure of primary keys + * @param valueSchema - Schema defining the structure of values + * @param searchable - Array of field names that can be searched + */ + constructor( + primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + searchable: Array = [] + ) { + super(primaryKeySchema, valueSchema, searchable); + } + + /** + * Stores a key-value pair in the repository + * @param key - The primary key object + * @param value - The value object to store + * @emits 'put' event with the fingerprint ID when successful + */ + async putKeyValue(key: Key, value: Value): Promise { + const id = await makeFingerprint(key); + this.values.set(id, Object.assign({}, key, value) as Combined); + this.emit("put", id); + } + + /** + * Retrieves a value by its key + * @param key - The primary key object to look up + * @returns The value object if found, undefined otherwise + * @emits 'get' event with the fingerprint ID and value when found + */ + async getKeyValue(key: Key): Promise { + const id = await makeFingerprint(key); + const out = this.values.get(id); + if (out === undefined) { + return undefined; + } + this.emit("get", id, out); + const { value } = this.separateKeyValueFromCombined(out); + return value; + } + + /** + * Searches for entries matching a partial key + * @param key - Partial key object to search for + * @returns Array of matching combined objects + * @throws Error if search criteria contains more than one key + */ + async search(key: Partial): Promise { + const search = Object.keys(key); + if (search.length !== 1) { + throw new Error("Search must be a single key"); + } + this.emit("search", key); + return Array.from(this.values.entries()) + .filter(([_fingerprint, value]) => value[search[0]] === key[search[0]]) + .map(([_id, value]) => value); + } + + /** + * Deletes an entry by its key + * @param key - The primary key object of the entry to delete + * @emits 'delete' event with the fingerprint ID when successful + */ + async deleteKeyValue(key: Key): Promise { + const id = await makeFingerprint(key); + this.values.delete(id); + this.emit("delete", id); + } + + /** + * Removes all entries from the repository + * @emits 'clearall' event when successful + */ + async deleteAll(): Promise { + this.values.clear(); + this.emit("clearall"); + } + + /** + * Returns the number of entries in the repository + * @returns The total count of stored key-value pairs + */ + async size(): Promise { + return this.values.size; + } +} diff --git a/packages/storage/src/browser/inmemory/index.ts b/packages/storage/src/browser/inmemory/index.ts index 5e9a4fe..c892320 100644 --- a/packages/storage/src/browser/inmemory/index.ts +++ b/packages/storage/src/browser/inmemory/index.ts @@ -1,5 +1,6 @@ -export * from "./InMemoryKVRepository"; +export * from "./base/InMemoryKVRepository"; export * from "./InMemoryTaskOutputRepository"; export * from "./InMemoryTaskGraphRepository"; export * from "./InMemoryJobQueue"; export * from "./InMemoryRateLimiter"; +export * from "./InMemoryModelRepository"; diff --git a/packages/core/src/job/test/JobQueue-InMemory.test.ts b/packages/storage/src/browser/inmemory/test/InMemoryJobQueue.test.ts similarity index 96% rename from packages/core/src/job/test/JobQueue-InMemory.test.ts rename to packages/storage/src/browser/inmemory/test/InMemoryJobQueue.test.ts index 26eb4a4..857c745 100644 --- a/packages/core/src/job/test/JobQueue-InMemory.test.ts +++ b/packages/storage/src/browser/inmemory/test/InMemoryJobQueue.test.ts @@ -6,10 +6,9 @@ // ******************************************************************************* import { describe, it, expect, beforeEach, afterEach, spyOn } from "bun:test"; -import { Job, JobStatus } from "../base/Job"; +import { Job, JobStatus, TaskInput, TaskOutput } from "ellmers-core"; import { InMemoryJobQueue, InMemoryRateLimiter } from "ellmers-storage/inmemory"; -import { sleep } from "../../util/Misc"; -import { TaskInput, TaskOutput } from "../../task/base/Task"; +import { sleep } from "ellmers-core"; class TestJob extends Job { public async execute() { @@ -17,7 +16,7 @@ class TestJob extends Job { } } -describe("LocalJobQueue", () => { +describe("InMemoryJobQueue", () => { let jobQueue: InMemoryJobQueue; beforeEach(() => { diff --git a/packages/storage/src/browser/inmemory/test/InMemoryKVRepository.test.ts b/packages/storage/src/browser/inmemory/test/InMemoryKVRepository.test.ts new file mode 100644 index 0000000..debfe92 --- /dev/null +++ b/packages/storage/src/browser/inmemory/test/InMemoryKVRepository.test.ts @@ -0,0 +1,73 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach } from "bun:test"; +import { InMemoryKVRepository } from "../base/InMemoryKVRepository"; +import { BaseValueSchema, BasePrimaryKeySchema } from "ellmers-core"; + +type PrimaryKey = { + name: string; + type: string; +}; +type Value = { + option: string; + success: boolean; +}; + +export const PrimaryKeySchema: BasePrimaryKeySchema = { name: "string", type: "string" } as const; +export const ValueSchema: BaseValueSchema = { option: "string", success: "boolean" } as const; + +describe("InMemoryKVRepository", () => { + describe("with default schemas (key and value)", () => { + let repository: InMemoryKVRepository; + + beforeEach(() => { + repository = new InMemoryKVRepository(); + }); + + it("should store and retrieve values for a key", async () => { + const key = "key1"; + const value = "value1"; + await repository.put(key, value); + const output = await repository.get(key); + + expect(output).toEqual(value); + }); + it("should get undefined for a key that doesn't exist", async () => { + const key = "key"; + const value = "value"; + await repository.put(key, value); + const output = await repository.get("not-a-key"); + + expect(output == undefined).toEqual(true); + }); + }); + + describe("with complex schemas", () => { + let repository: InMemoryKVRepository; + + beforeEach(() => { + repository = new InMemoryKVRepository(PrimaryKeySchema, ValueSchema); + }); + + it("should store and retrieve values for a key", async () => { + const key = { name: "key", type: "string" }; + const value = { option: "value", success: true }; + await repository.put(key, value); + const output = await repository.getKeyValue(key); + + expect(output?.option).toEqual("value"); + expect(!!output?.success).toEqual(true); // TODO need some conversion to boolean from 1 + }); + it("should get undefined for a key that doesn't exist", async () => { + const key = { name: "key", type: "string" }; + const output = await repository.get(key); + + expect(output == undefined).toEqual(true); + }); + }); +}); diff --git a/packages/storage/src/browser/inmemory/test/InMemoryModelRepository.test.ts b/packages/storage/src/browser/inmemory/test/InMemoryModelRepository.test.ts new file mode 100644 index 0000000..496bc95 --- /dev/null +++ b/packages/storage/src/browser/inmemory/test/InMemoryModelRepository.test.ts @@ -0,0 +1,56 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach } from "bun:test"; +import { setGlobalModelRepository, getGlobalModelRepository } from "ellmers-ai"; +import { InMemoryModelRepository } from "../InMemoryModelRepository"; +import { LOCAL_ONNX_TRANSFORMERJS } from "ellmers-ai-provider/hf-transformers/server"; + +describe("InMemoryModelRepository", () => { + it("store and find model by task", async () => { + setGlobalModelRepository(new InMemoryModelRepository()); + await getGlobalModelRepository().addModel({ + name: "ONNX Xenova/LaMini-Flan-T5-783M q8", + url: "Xenova/LaMini-Flan-T5-783M", + availableOnBrowser: true, + availableOnServer: true, + provider: LOCAL_ONNX_TRANSFORMERJS, + pipeline: "text2text-generation", + }); + await getGlobalModelRepository().connectTaskToModel( + "TextGenerationTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + await getGlobalModelRepository().connectTaskToModel( + "TextRewritingTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + const tasks = await getGlobalModelRepository().findTasksByModel( + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + expect(tasks).toBeDefined(); + expect(tasks?.length).toEqual(2); + const models = await getGlobalModelRepository().findModelsByTask("TextGenerationTask"); + expect(models).toBeDefined(); + expect(models?.length).toEqual(1); + }); + it("store and find model by name", async () => { + setGlobalModelRepository(new InMemoryModelRepository()); + await getGlobalModelRepository().addModel({ + name: "ONNX Xenova/LaMini-Flan-T5-783M q8", + url: "Xenova/LaMini-Flan-T5-783M", + availableOnBrowser: true, + availableOnServer: true, + provider: LOCAL_ONNX_TRANSFORMERJS, + pipeline: "text2text-generation", + }); + + const model = await getGlobalModelRepository().findByName("ONNX Xenova/LaMini-Flan-T5-783M q8"); + expect(model).toBeDefined(); + expect(model?.name).toEqual("ONNX Xenova/LaMini-Flan-T5-783M q8"); + }); +}); diff --git a/packages/storage/src/browser/inmemory/test/InMemoryTaskGraphRepository.test.ts b/packages/storage/src/browser/inmemory/test/InMemoryTaskGraphRepository.test.ts index 92fb8ec..70f8360 100644 --- a/packages/storage/src/browser/inmemory/test/InMemoryTaskGraphRepository.test.ts +++ b/packages/storage/src/browser/inmemory/test/InMemoryTaskGraphRepository.test.ts @@ -6,13 +6,12 @@ // ******************************************************************************* import { describe, expect, it, beforeEach } from "bun:test"; -import { rmdirSync } from "fs"; import { SingleTask, TaskOutput, DataFlow, TaskGraph, TaskRegistry } from "ellmers-core"; -import { InMemoryTaskGraphRepository } from "../InMemoryTaskGraphRepository"; +import { InMemoryTaskGraphRepository } from "ellmers-storage/inmemory"; class TestTask extends SingleTask { static readonly type = "TestTask"; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { return {}; } } @@ -22,7 +21,6 @@ describe("FileTaskGraphRepository", () => { let repository: InMemoryTaskGraphRepository; beforeEach(() => { - rmdirSync(".cache/test/file-task-graph", { recursive: true }); repository = new InMemoryTaskGraphRepository(); }); diff --git a/packages/storage/src/browser/inmemory/test/InMemoryTaskOutputRepository.test.ts b/packages/storage/src/browser/inmemory/test/InMemoryTaskOutputRepository.test.ts index 8fcfb29..3144c23 100644 --- a/packages/storage/src/browser/inmemory/test/InMemoryTaskOutputRepository.test.ts +++ b/packages/storage/src/browser/inmemory/test/InMemoryTaskOutputRepository.test.ts @@ -6,8 +6,8 @@ // ******************************************************************************* import { describe, expect, it, beforeEach } from "bun:test"; -import { InMemoryTaskOutputRepository } from "../InMemoryTaskOutputRepository"; import { TaskInput, TaskOutput } from "ellmers-core"; +import { InMemoryTaskOutputRepository } from "ellmers-storage/inmemory"; describe("InMemoryTaskOutputRepository", () => { let repository: InMemoryTaskOutputRepository; diff --git a/packages/storage/src/bun/sqlite/SqliteJobQueue.ts b/packages/storage/src/bun/sqlite/SqliteJobQueue.ts index dec6b6f..6df5ed5 100644 --- a/packages/storage/src/bun/sqlite/SqliteJobQueue.ts +++ b/packages/storage/src/bun/sqlite/SqliteJobQueue.ts @@ -16,14 +16,14 @@ export class SqliteJobQueue extends JobQueue { protected db: Database, queue: string, limiter: ILimiter, - protected jobClass: typeof Job = Job, - waitDurationInMilliseconds = 100 + waitDurationInMilliseconds = 100, + protected jobClass: typeof Job = Job ) { super(queue, limiter, waitDurationInMilliseconds); } public ensureTableExists() { - this.db.exec(` + const a = this.db.exec(` CREATE TABLE IF NOT EXISTS job_queue ( id INTEGER PRIMARY KEY, diff --git a/packages/storage/src/bun/sqlite/SqliteModelRepository.ts b/packages/storage/src/bun/sqlite/SqliteModelRepository.ts new file mode 100644 index 0000000..673f08d --- /dev/null +++ b/packages/storage/src/bun/sqlite/SqliteModelRepository.ts @@ -0,0 +1,47 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { + ModelRepository, + ModelPrimaryKeySchema, + ModelPrimaryKey, + Task2ModelDetailSchema, + Task2ModelPrimaryKey, + Task2ModelDetail, + Task2ModelPrimaryKeySchema, +} from "ellmers-ai"; +import { SqliteKVRepository } from "./base/SqliteKVRepository"; +import { DefaultValueType } from "ellmers-core"; + +export class SqliteModelRepository extends ModelRepository { + public type = "SqliteModelRepository" as const; + modelKvRepository: SqliteKVRepository< + ModelPrimaryKey, + DefaultValueType, + typeof ModelPrimaryKeySchema + >; + task2ModelKvRepository: SqliteKVRepository< + Task2ModelPrimaryKey, + Task2ModelDetail, + typeof Task2ModelPrimaryKeySchema, + typeof Task2ModelDetailSchema + >; + constructor(dbOrPath: string) { + super(); + this.modelKvRepository = new SqliteKVRepository< + ModelPrimaryKey, + DefaultValueType, + typeof ModelPrimaryKeySchema + >(dbOrPath, "aimodel", ModelPrimaryKeySchema); + this.task2ModelKvRepository = new SqliteKVRepository< + Task2ModelPrimaryKey, + Task2ModelDetail, + typeof Task2ModelPrimaryKeySchema, + typeof Task2ModelDetailSchema + >(dbOrPath, "aitask2aimodel", Task2ModelPrimaryKeySchema, Task2ModelDetailSchema); + } +} diff --git a/packages/storage/src/bun/sqlite/SqliteTaskGraphRepository.ts b/packages/storage/src/bun/sqlite/SqliteTaskGraphRepository.ts index 6bf7c92..941880a 100644 --- a/packages/storage/src/bun/sqlite/SqliteTaskGraphRepository.ts +++ b/packages/storage/src/bun/sqlite/SqliteTaskGraphRepository.ts @@ -5,14 +5,14 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskGraphJson, TaskGraphRepository } from "ellmers-core"; +import { TaskGraphRepository } from "ellmers-core"; import { SqliteKVRepository } from "./base/SqliteKVRepository"; export class SqliteTaskGraphRepository extends TaskGraphRepository { - kvRepository: SqliteKVRepository; + kvRepository: SqliteKVRepository; public type = "SqliteTaskGraphRepository" as const; constructor(dbOrPath: string) { super(); - this.kvRepository = new SqliteKVRepository(dbOrPath, "task_graphs"); + this.kvRepository = new SqliteKVRepository(dbOrPath, "task_graphs"); } } diff --git a/packages/storage/src/bun/sqlite/SqliteTaskOutputRepository.ts b/packages/storage/src/bun/sqlite/SqliteTaskOutputRepository.ts index c8577ed..2ea177e 100644 --- a/packages/storage/src/bun/sqlite/SqliteTaskOutputRepository.ts +++ b/packages/storage/src/bun/sqlite/SqliteTaskOutputRepository.ts @@ -5,18 +5,27 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskOutputDiscriminator, TaskOutputRepository, TaskInput, TaskOutput } from "ellmers-core"; +import { + TaskOutputPrimaryKeySchema, + TaskOutputRepository, + TaskOutputPrimaryKey, + DefaultValueType, +} from "ellmers-core"; import { SqliteKVRepository } from "./base/SqliteKVRepository"; export class SqliteTaskOutputRepository extends TaskOutputRepository { - kvRepository: SqliteKVRepository; + kvRepository: SqliteKVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >; public type = "SqliteTaskOutputRepository" as const; constructor(dbOrPath: string) { super(); this.kvRepository = new SqliteKVRepository< - TaskInput, - TaskOutput, - typeof TaskOutputDiscriminator - >(dbOrPath, "task_outputs", TaskOutputDiscriminator); + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >(dbOrPath, "task_outputs", TaskOutputPrimaryKeySchema); } } diff --git a/packages/storage/src/bun/sqlite/base/SqliteKVRepository.ts b/packages/storage/src/bun/sqlite/base/SqliteKVRepository.ts index 50a70a7..81c2921 100644 --- a/packages/storage/src/bun/sqlite/base/SqliteKVRepository.ts +++ b/packages/storage/src/bun/sqlite/base/SqliteKVRepository.ts @@ -6,58 +6,92 @@ // ******************************************************************************* import { Database } from "bun:sqlite"; -import { DiscriminatorSchema, KVRepository } from "ellmers-core"; -import { makeFingerprint } from "../../../util/Misc"; +import { + BaseValueSchema, + BasicKeyType, + BasePrimaryKeySchema, + DefaultValueType, + DefaultValueSchema, + DefaultPrimaryKeyType, + DefaultPrimaryKeySchema, +} from "ellmers-core"; +import { BaseSqlKVRepository } from "../../../util/base/BaseSqlKVRepository"; + // SqliteKVRepository is a key-value store that uses SQLite as the backend for -// in app data. It supports discriminators. +// in app data. +/** + * A SQLite-based key-value repository implementation. + * @template Key - The type of the primary key object, must be a record of basic types + * @template Value - The type of the value object being stored + * @template PrimaryKeySchema - Schema definition for the primary key + * @template ValueSchema - Schema definition for the value + * @template Combined - Combined type of Key & Value + */ export class SqliteKVRepository< - Key = string, - Value = string, - Discriminator extends DiscriminatorSchema = DiscriminatorSchema -> extends KVRepository { + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Record = Key & Value +> extends BaseSqlKVRepository { + /** The SQLite database instance */ private db: Database; + + /** + * Creates a new SQLite key-value repository + * @param dbOrPath - Either a Database instance or a path to the SQLite database file + * @param table - The name of the table to use for storage (defaults to 'kv_store') + * @param primaryKeySchema - Schema defining the structure of the primary key + * @param valueSchema - Schema defining the structure of the values + * @param searchable - Array of columns to make searchable + */ constructor( dbOrPath: string, - public table: string = "kv_store", - discriminatorsSchema: Discriminator = {} as Discriminator + table: string = "kv_store", + primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + searchable: Array = [] ) { - super(); + super(table, primaryKeySchema, valueSchema, searchable); if (typeof dbOrPath === "string") { this.db = new Database(dbOrPath); } else { this.db = dbOrPath; } - this.discriminatorsSchema = discriminatorsSchema; this.setupDatabase(); } - private setupDatabase(): void { - this.db.exec(` - CREATE TABLE IF NOT EXISTS ${this.table} ( - ${this.constructDiscriminatorColumns()} - key TEXT NOT NULL, - value TEXT NOT NULL, + /** + * Creates the database table if it doesn't exist with the defined schema + */ + public setupDatabase(): void { + const sql = ` + CREATE TABLE IF NOT EXISTS \`${this.table}\` ( + ${this.constructPrimaryKeyColumns()}, + ${this.constructValueColumns()}, PRIMARY KEY (${this.primaryKeyColumnList()}) ) - `); - } - - private constructDiscriminatorColumns(): string { - const cols = Object.entries(this.discriminatorsSchema) - .map(([key, type]) => { - // Convert the provided type to a SQL type, assuming simple mappings; adjust as necessary - const sqlType = this.mapTypeToSQL(type); - return `${key} ${sqlType} NOT NULL`; - }) - .join(", "); - if (cols.length > 0) { - return `${cols}, `; + `; + this.db.exec(sql); + for (const column of this.searchable) { + /* Makes other columns searchable, but excludes the first column + of a primary key (which would be redundant) */ + if (column !== this.primaryKeyColumns()[0]) { + this.db.exec( + `CREATE INDEX IF NOT EXISTS \`${this.table}_${column as string}\` + ON \`${this.table}\` (\`${column as string}\`)` + ); + } } - return ""; } - private mapTypeToSQL(type: string): string { + /** + * Maps TypeScript/JavaScript types to their SQLite column type equivalents + * @param type - The TypeScript/JavaScript type to map + * @returns The corresponding SQLite column type + */ + protected mapTypeToSQL(type: string): string { // Basic type mapping; extend according to your needs switch (type) { case "string": @@ -70,46 +104,114 @@ export class SqliteKVRepository< } } - async put(keySimpleOrObject: Key, value: Value): Promise { - const { discriminators, key } = this.extractDiscriminators(keySimpleOrObject); - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - const stmt = this.db.prepare(` - INSERT OR REPLACE INTO ${this.table} (${this.primaryKeyColumnList()}, value) - VALUES (${this.primaryKeyColumns().map((i) => "?")}, ?) - `); - const values = Object.values(discriminators).concat(id, JSON.stringify(value)); - stmt.run(...values); - this.emit("put", id, discriminators); - } + /** + * Stores a key-value pair in the database + * @param key - The primary key object + * @param value - The value object to store + * @emits 'put' event when successful + */ + async putKeyValue(key: Key, value: Value): Promise { + const sql = ` + INSERT OR REPLACE INTO \`${ + this.table + }\` (${this.primaryKeyColumnList()}, ${this.valueColumnList()}) + VALUES ( + ${this.primaryKeyColumns().map((i) => "?")}, + ${this.valueColumns().map((i) => "?")} + ) + `; + const stmt = this.db.prepare(sql); + + const primaryKeyParams = this.getPrimaryKeyAsOrderedArray(key); + const valueParams = this.getValueAsOrderedArray(value); + const params = [...primaryKeyParams, ...valueParams]; + + const result = stmt.run(...params); - async get(keySimpleOrObject: Key): Promise { - const { discriminators, key } = this.extractDiscriminators(keySimpleOrObject); - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); + this.emit("put", key); + } - const whereClauses = this.primaryKeyColumns() - .map((discriminatorKey) => `${discriminatorKey} = ?`) + /** + * Retrieves a value from the database by its key + * @param key - The primary key object to look up + * @returns The stored value or undefined if not found + * @emits 'get' event when successful + */ + async getKeyValue(key: Key): Promise { + const whereClauses = (this.primaryKeyColumns() as string[]) + .map((key) => `\`${key}\` = ?`) .join(" AND "); - const stmt = this.db.prepare<{ value: string }, [key: string]>(` - SELECT value FROM ${this.table} WHERE ${whereClauses} - `); + const sql = ` + SELECT ${this.valueColumnList()} FROM \`${this.table}\` WHERE ${whereClauses} + `; + const stmt = this.db.prepare(sql); + const params = this.getPrimaryKeyAsOrderedArray(key); + const value = stmt.get(...params); + if (value) { + this.emit("get", key, value); + return value; + } else { + return undefined; + } + } - const values = Object.values(discriminators).concat(id); + /** + * Method to be implemented by concrete repositories to search for key-value pairs + * based on a partial key. + * + * @param key - Partial key to search for + * @returns Promise resolving to an array of combined key-value objects or undefined if not found + */ + public async search(key: Partial): Promise { + const search = Object.keys(key); + if (search.length !== 1) { + //TODO: make this work with any prefix of primary key + throw new Error("Search must be a single key"); + } - const row = stmt.get(...(values as [string])) as { value: string } | undefined; - if (row) { - this.emit("get", id, discriminators); - return JSON.parse(row.value) as Value; + const sql = ` + SELECT * FROM \`${this.table}\` + WHERE \`${search[0]}\` = ? + `; + const stmt = this.db.prepare(sql); + const value = stmt.all(key[search[0]]); + if (value) { + this.emit("search"); + return value; } else { return undefined; } } - async clear(): Promise { + /** + * Deletes a key-value pair from the database + * @param key - The primary key object to delete + * @emits 'delete' event when successful + */ + async deleteKeyValue(key: Key): Promise { + const whereClauses = (this.primaryKeyColumns() as string[]) + .map((key) => `${key} = ?`) + .join(" AND "); + const params = this.getPrimaryKeyAsOrderedArray(key); + const stmt = this.db.prepare(`DELETE FROM ${this.table} WHERE ${whereClauses}`); + stmt.run(...params); + this.emit("delete", key); + } + + /** + * Deletes all entries from the database table + * @emits 'clearall' event when successful + */ + async deleteAll(): Promise { this.db.exec(`DELETE FROM ${this.table}`); - this.emit("clear"); + this.emit("clearall"); } + /** + * Gets the total number of entries in the database table + * @returns The count of entries + */ async size(): Promise { const stmt = this.db.prepare<{ count: number }, []>(` SELECT COUNT(*) AS count FROM ${this.table} diff --git a/packages/storage/src/bun/sqlite/index.ts b/packages/storage/src/bun/sqlite/index.ts index b671f30..6550cbb 100644 --- a/packages/storage/src/bun/sqlite/index.ts +++ b/packages/storage/src/bun/sqlite/index.ts @@ -1,3 +1,5 @@ +export { getDatabase } from "../../util/db_sqlite"; export * from "./SqliteJobQueue"; export * from "./SqliteTaskGraphRepository"; export * from "./SqliteTaskOutputRepository"; +export * from "./SqliteModelRepository"; diff --git a/packages/storage/src/bun/sqlite/test/JobQueue-Sqlite.test.ts b/packages/storage/src/bun/sqlite/test/SqliteJobQueue.test.ts similarity index 99% rename from packages/storage/src/bun/sqlite/test/JobQueue-Sqlite.test.ts rename to packages/storage/src/bun/sqlite/test/SqliteJobQueue.test.ts index c5730e8..48169bc 100644 --- a/packages/storage/src/bun/sqlite/test/JobQueue-Sqlite.test.ts +++ b/packages/storage/src/bun/sqlite/test/SqliteJobQueue.test.ts @@ -26,8 +26,8 @@ describe("SqliteJobQueue", () => { db, queueName, new SqliteRateLimiter(db, queueName, 4, 1).ensureTableExists(), - TestJob, - 0 + 0, + TestJob ).ensureTableExists(); afterEach(() => { diff --git a/packages/storage/src/bun/sqlite/test/SqliteKVRepository.test.ts b/packages/storage/src/bun/sqlite/test/SqliteKVRepository.test.ts new file mode 100644 index 0000000..ca6d397 --- /dev/null +++ b/packages/storage/src/bun/sqlite/test/SqliteKVRepository.test.ts @@ -0,0 +1,78 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach } from "bun:test"; +import { SqliteKVRepository } from "../base/SqliteKVRepository"; +import { BaseValueSchema, BasePrimaryKeySchema } from "ellmers-core"; + +type PrimaryKey = { + name: string; + type: string; +}; +type Value = { + option: string; + success: boolean; +}; + +export const PrimaryKeySchema: BasePrimaryKeySchema = { name: "string", type: "string" } as const; +export const ValueSchema: BaseValueSchema = { option: "string", success: "boolean" } as const; + +describe("SqliteKVRepository", () => { + describe("with default schemas (key and value)", () => { + let repository: SqliteKVRepository; + + beforeEach(() => { + repository = new SqliteKVRepository(":memory:"); + }); + + it("should store and retrieve values for a key", async () => { + const key = "key"; + const value = "value"; + await repository.put(key, value); + const output = await repository.get(key); + + expect(output).toEqual(value); + }); + it("should get undefined for a key that doesn't exist", async () => { + const key = "key"; + const value = "value"; + await repository.put(key, value); + const output = await repository.get("not-a-key"); + + expect(output == undefined).toEqual(true); + }); + }); + + describe("with complex schemas", () => { + let repository: SqliteKVRepository; + + beforeEach(() => { + repository = new SqliteKVRepository( + ":memory:", + "complex_store", + PrimaryKeySchema, + ValueSchema + ); + }); + + it("should store and retrieve values for a key", async () => { + const key = { name: "key", type: "string" }; + const value = { option: "value", success: true }; + await repository.putKeyValue(key, value); + const output = await repository.getKeyValue(key); + + expect(output?.option).toEqual("value"); + expect(!!output?.success).toEqual(true); // TODO need some conversion to boolean from 1 + }); + it("should get undefined for a key that doesn't exist", async () => { + const key = { name: "key", type: "string" }; + const output = await repository.get(key); + + expect(output == undefined).toEqual(true); + }); + }); +}); diff --git a/packages/storage/src/bun/sqlite/test/SqliteModelRepository.test.ts b/packages/storage/src/bun/sqlite/test/SqliteModelRepository.test.ts new file mode 100644 index 0000000..74edc6c --- /dev/null +++ b/packages/storage/src/bun/sqlite/test/SqliteModelRepository.test.ts @@ -0,0 +1,62 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach } from "bun:test"; +import { setGlobalModelRepository, getGlobalModelRepository } from "ellmers-ai"; +import { SqliteModelRepository } from "../SqliteModelRepository"; +import { LOCAL_ONNX_TRANSFORMERJS } from "ellmers-ai-provider/hf-transformers/server"; + +describe("SqliteModelRepository", () => { + it("store and find model by task", async () => { + setGlobalModelRepository(new SqliteModelRepository(":memory:")); + await getGlobalModelRepository().addModel({ + name: "ONNX Xenova/LaMini-Flan-T5-783M q8", + url: "Xenova/LaMini-Flan-T5-783M", + availableOnBrowser: true, + availableOnServer: true, + provider: LOCAL_ONNX_TRANSFORMERJS, + pipeline: "text2text-generation", + }); + await getGlobalModelRepository().connectTaskToModel( + "TextGenerationTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + await getGlobalModelRepository().connectTaskToModel( + "TextRewritingTask", + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + const tasks = await getGlobalModelRepository().findTasksByModel( + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + expect(tasks).toBeDefined(); + expect(tasks?.length).toEqual(2); + const models = await getGlobalModelRepository().findModelsByTask("TextGenerationTask"); + expect(models).toBeDefined(); + expect(models?.length).toEqual(1); + }); + it("store and find model by name", async () => { + setGlobalModelRepository(new SqliteModelRepository(":memory:")); + await getGlobalModelRepository().addModel({ + name: "ONNX Xenova/LaMini-Flan-T5-783M q8", + url: "Xenova/LaMini-Flan-T5-783M", + availableOnBrowser: true, + availableOnServer: true, + provider: LOCAL_ONNX_TRANSFORMERJS, + pipeline: "text2text-generation", + }); + + const model = await getGlobalModelRepository().findByName("ONNX Xenova/LaMini-Flan-T5-783M q8"); + expect(model).toBeDefined(); + expect(model?.name).toEqual("ONNX Xenova/LaMini-Flan-T5-783M q8"); + const tasks = await getGlobalModelRepository().findTasksByModel( + "ONNX Xenova/LaMini-Flan-T5-783M q8" + ); + expect(tasks).toBeUndefined(); + const model2 = await getGlobalModelRepository().findByName("ONNX Xenova/no-exist"); + expect(model2).toBeUndefined(); + }); +}); diff --git a/packages/storage/src/bun/sqlite/test/SqliteTaskGraphRepository.test.ts b/packages/storage/src/bun/sqlite/test/SqliteTaskGraphRepository.test.ts index 2e6c705..fe4ee07 100644 --- a/packages/storage/src/bun/sqlite/test/SqliteTaskGraphRepository.test.ts +++ b/packages/storage/src/bun/sqlite/test/SqliteTaskGraphRepository.test.ts @@ -11,7 +11,7 @@ import { SingleTask, TaskOutput, DataFlow, TaskGraph, TaskRegistry } from "ellme class TestTask extends SingleTask { static readonly type = "TestTask"; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { return {}; } } diff --git a/packages/storage/src/node/filesystem/FileTaskGraphRepository.ts b/packages/storage/src/node/filesystem/FileTaskGraphRepository.ts index 5dac80b..378974d 100644 --- a/packages/storage/src/node/filesystem/FileTaskGraphRepository.ts +++ b/packages/storage/src/node/filesystem/FileTaskGraphRepository.ts @@ -5,14 +5,14 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskGraphJson, TaskGraphRepository } from "ellmers-core"; +import { TaskGraphRepository } from "ellmers-core"; import { FileKVRepository } from "./base/FileKVRepository"; export class FileTaskGraphRepository extends TaskGraphRepository { - kvRepository: FileKVRepository; + kvRepository: FileKVRepository; public type = "FileTaskGraphRepository" as const; constructor(folderPath: string) { super(); - this.kvRepository = new FileKVRepository(folderPath); + this.kvRepository = new FileKVRepository(folderPath); } } diff --git a/packages/storage/src/node/filesystem/FileTaskOutputRepository.ts b/packages/storage/src/node/filesystem/FileTaskOutputRepository.ts index 054fc5b..af762ae 100644 --- a/packages/storage/src/node/filesystem/FileTaskOutputRepository.ts +++ b/packages/storage/src/node/filesystem/FileTaskOutputRepository.ts @@ -5,17 +5,28 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskInput, TaskOutput, TaskOutputDiscriminator, TaskOutputRepository } from "ellmers-core"; +import { + DefaultValueType, + TaskInput, + TaskOutputPrimaryKey, + TaskOutputPrimaryKeySchema, + TaskOutputRepository, +} from "ellmers-core"; import { FileKVRepository } from "./base/FileKVRepository"; export class FileTaskOutputRepository extends TaskOutputRepository { - kvRepository: FileKVRepository; + kvRepository: FileKVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >; public type = "FileTaskOutputRepository" as const; constructor(folderPath: string) { super(); - this.kvRepository = new FileKVRepository( - folderPath, - TaskOutputDiscriminator - ); + this.kvRepository = new FileKVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >(folderPath, TaskOutputPrimaryKeySchema); } } diff --git a/packages/storage/src/node/filesystem/base/FileKVRepository.ts b/packages/storage/src/node/filesystem/base/FileKVRepository.ts index a0be385..5b7aa8b 100644 --- a/packages/storage/src/node/filesystem/base/FileKVRepository.ts +++ b/packages/storage/src/node/filesystem/base/FileKVRepository.ts @@ -6,57 +6,122 @@ // ******************************************************************************* import path from "node:path"; -import { readFile, writeFile, unlink, mkdir } from "node:fs/promises"; -import { DiscriminatorSchema, KVRepository } from "ellmers-core"; -import { makeFingerprint } from "../../../util/Misc"; +import { readFile, writeFile, rm } from "node:fs/promises"; +import { mkdirSync } from "node:fs"; import { glob } from "glob"; +import { + BaseValueSchema, + BasePrimaryKeySchema, + BasicKeyType, + DefaultValueType, + DefaultValueSchema, + DefaultPrimaryKeyType, + DefaultPrimaryKeySchema, + KVRepository, +} from "ellmers-core"; -// FileKVRepository is a key-value store that uses the file system as the backend for -// simple scenarios. It does support discriminators. - +/** + * A key-value repository implementation that uses the filesystem for storage. + * Each key-value pair is stored as a separate JSON file in the specified directory. + * + * @template Key - The type of the primary key object, defaults to DefaultPrimaryKeyType + * @template Value - The type of the value object, defaults to DefaultValueType + * @template PrimaryKeySchema - The schema for the primary key, defaults to DefaultPrimaryKeySchema + * @template ValueSchema - The schema for the value, defaults to DefaultValueSchema + * @template Combined - The combined type of Key & Value + */ export class FileKVRepository< - Key = string, - Value = string, - Discriminator extends DiscriminatorSchema = DiscriminatorSchema -> extends KVRepository { + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Key & Value = Key & Value +> extends KVRepository { private folderPath: string; - constructor(folderPath: string, discriminatorsSchema: Discriminator = {} as Discriminator) { - super(); - this.discriminatorsSchema = discriminatorsSchema; - this.folderPath = folderPath; - mkdir(this.folderPath, { recursive: true }); + /** + * Creates a new FileKVRepository instance. + * + * @param folderPath - The directory path where the JSON files will be stored + * @param primaryKeySchema - Schema defining the structure of the primary key + * @param valueSchema - Schema defining the structure of the values + * @param searchable - Array of keys that can be used for searching (Note: search is not supported in this implementation) + */ + constructor( + folderPath: string, + primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + searchable: Array = [] + ) { + super(primaryKeySchema, valueSchema, searchable); + this.folderPath = path.dirname(folderPath); + mkdirSync(this.folderPath, { recursive: true }); } - async put(keySimpleOrObject: Key, value: Value): Promise { - const { discriminators, key } = this.extractDiscriminators(keySimpleOrObject); - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - const filePath = await this.getFilePath(key, discriminators); - await writeFile(filePath, JSON.stringify(value)); + /** + * Stores a key-value pair in the repository + * @param key - The primary key object + * @param value - The value object to store + * @emits 'put' event with the fingerprint ID when successful + */ + async putKeyValue(key: Key, value: Value): Promise { + const filePath = await this.getFilePath(key); + try { + await writeFile(filePath, JSON.stringify(value)); + } catch (error) { + console.error("Error writing file", filePath, error); + } this.emit("put", key); } - async get(keySimpleOrObject: Key): Promise { - const { discriminators, key } = this.extractDiscriminators(keySimpleOrObject); - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - const filePath = await this.getFilePath(key, discriminators); + /** + * Retrieves a value by its key + * @param key - The primary key object to look up + * @returns The value object if found, undefined otherwise + * @emits 'get' event with the fingerprint ID and value when found + */ + async getKeyValue(key: Key): Promise { + const filePath = await this.getFilePath(key); try { const data = await readFile(filePath, "utf-8"); - this.emit("get", key); - return JSON.parse(data); + const value = JSON.parse(data) as Value; + this.emit("get", key, value); + return value; } catch (error) { + // console.info("Error getting file (may not exist)", filePath); return undefined; // File not found or read error } } - async clear(): Promise { + /** + * Deletes an entry by its key + * @param key - The primary key object of the entry to delete + * @emits 'delete' event with the fingerprint ID when successful + */ + async deleteKeyValue(key: Key): Promise { + const filePath = await this.getFilePath(key); + try { + await rm(filePath); + } catch (error) { + // console.error("Error deleting file", filePath, error); + } + this.emit("delete", key); + } + + /** + * Removes all entries from the repository + * @emits 'clearall' event when successful + */ + async deleteAll(): Promise { // Delete all files in the folder ending in .json - const globPattern = path.join(this.folderPath, "*.json"); - const filesToDelete = await glob(globPattern); - await Promise.all(filesToDelete.map((file) => unlink(file))); - this.emit("clear"); + await rm(this.folderPath, { recursive: true, force: true }); + this.emit("clearall"); } + /** + * Returns the number of entries in the repository + * @returns The total count of stored key-value pairs + */ async size(): Promise { // Count all files in the folder ending in .json const globPattern = path.join(this.folderPath, "*.json"); @@ -64,12 +129,21 @@ export class FileKVRepository< return files.length; } - private async getFilePath( - key: Key, - discriminators: Record - ): Promise { - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - const filename = Object.values(discriminators).concat(id).join("_"); - return path.join(this.folderPath, `${filename}.json`); + /** + * Search is not supported in the filesystem implementation. + * @throws {Error} Always throws an error indicating search is not supported + */ + async search(key: Partial): Promise { + throw new Error("Search not supported for FileKVRepository"); + } + + /** + * Generates the full filesystem path for a given key. + * @private + */ + private async getFilePath(key: Key | BasicKeyType): Promise { + const filename = await this.getKeyAsIdString(key); + const fullPath = path.join(this.folderPath, `${filename}.json`); + return fullPath; } } diff --git a/packages/storage/src/node/filesystem/test/FileKVRepository.test.ts b/packages/storage/src/node/filesystem/test/FileKVRepository.test.ts new file mode 100644 index 0000000..36c76d2 --- /dev/null +++ b/packages/storage/src/node/filesystem/test/FileKVRepository.test.ts @@ -0,0 +1,94 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach, afterEach } from "bun:test"; +import { rmdirSync } from "fs"; +import { FileKVRepository } from "../base/FileKVRepository"; +import { BaseValueSchema, BasePrimaryKeySchema } from "ellmers-core"; + +type PrimaryKey = { + name: string; + type: string; +}; +type Value = { + option: string; + success: boolean; +}; + +export const PrimaryKeySchema: BasePrimaryKeySchema = { name: "string", type: "string" } as const; +export const ValueSchema: BaseValueSchema = { option: "string", success: "boolean" } as const; + +const testDir = ".cache/test/testing"; + +describe("FileKVRepository", () => { + let repository: FileKVRepository; + rmdirSync(testDir, { recursive: true }); + + beforeEach(() => { + repository = new FileKVRepository(testDir); + }); + afterEach(() => { + repository.deleteAll(); + }); + + describe("with default schemas (key and value)", () => { + let repository: FileKVRepository; + + beforeEach(() => { + repository = new FileKVRepository(testDir); + }); + + it("should store and retrieve values for a key", async () => { + const key = "key"; + const value = "value"; + await repository.put(key, value); + const output = await repository.get(key); + + expect(output).toEqual(value); + }); + it("should get undefined for a key that doesn't exist", async () => { + const key = "key"; + const value = "value"; + await repository.put(key, value); + const output = await repository.get("not-a-key"); + + expect(output == undefined).toEqual(true); + }); + }); + + describe("with complex schemas", () => { + let repository: FileKVRepository; + + beforeEach(async () => { + repository = new FileKVRepository(testDir, PrimaryKeySchema, ValueSchema); + }); + afterEach(async () => { + // await repository.deleteAll(); + }); + + it("should store and retrieve values for a key", async () => { + const key = { name: "key", type: "string" }; + const value = { option: "value", success: true }; + await repository.put(key, value); + const output = await repository.getKeyValue(key); + + expect(output?.option).toEqual("value"); + expect(!!output?.success).toEqual(true); // TODO need some conversion to boolean from 1 + + await repository.delete(key); + + const output2 = await repository.get(key); + expect(output2 == undefined).toEqual(true); + }); + it("should get undefined for a key that doesn't exist", async () => { + const key = { name: "key-unknown", type: "string" }; + const output = await repository.get(key); + + expect(output == undefined).toEqual(true); + }); + }); +}); diff --git a/packages/storage/src/node/filesystem/test/FileTaskGraphRepository.test.ts b/packages/storage/src/node/filesystem/test/FileTaskGraphRepository.test.ts index 72e6741..9566cf0 100644 --- a/packages/storage/src/node/filesystem/test/FileTaskGraphRepository.test.ts +++ b/packages/storage/src/node/filesystem/test/FileTaskGraphRepository.test.ts @@ -12,7 +12,7 @@ import { SingleTask, TaskOutput, TaskRegistry, DataFlow, TaskGraph } from "ellme class TestTask extends SingleTask { static readonly type = "TestTask"; - runSyncOnly(): TaskOutput { + async runReactive(): Promise { return {}; } } diff --git a/packages/storage/src/node/postgres/PostgresTaskGraphRepository.ts b/packages/storage/src/node/postgres/PostgresTaskGraphRepository.ts index f0e3ce7..855e44a 100644 --- a/packages/storage/src/node/postgres/PostgresTaskGraphRepository.ts +++ b/packages/storage/src/node/postgres/PostgresTaskGraphRepository.ts @@ -5,17 +5,14 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskGraphJson, TaskGraphRepository } from "ellmers-core"; +import { TaskGraphRepository } from "ellmers-core"; import { PostgresKVRepository } from "./base/PostgresKVRepository"; export class PostgresTaskGraphRepository extends TaskGraphRepository { - kvRepository: PostgresKVRepository; + kvRepository: PostgresKVRepository; public type = "PostgresTaskGraphRepository" as const; constructor(connectionString: string) { super(); - this.kvRepository = new PostgresKVRepository( - connectionString, - "task_graphs" - ); + this.kvRepository = new PostgresKVRepository(connectionString, "task_graphs"); } } diff --git a/packages/storage/src/node/postgres/PostgresTaskOutputRepository.ts b/packages/storage/src/node/postgres/PostgresTaskOutputRepository.ts index cb75b2c..e69e24b 100644 --- a/packages/storage/src/node/postgres/PostgresTaskOutputRepository.ts +++ b/packages/storage/src/node/postgres/PostgresTaskOutputRepository.ts @@ -5,18 +5,27 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { TaskInput, TaskOutput, TaskOutputDiscriminator, TaskOutputRepository } from "ellmers-core"; +import { + DefaultValueType, + TaskOutputPrimaryKey, + TaskOutputPrimaryKeySchema, + TaskOutputRepository, +} from "ellmers-core"; import { PostgresKVRepository } from "./base/PostgresKVRepository"; export class PostgresTaskOutputRepository extends TaskOutputRepository { - kvRepository: PostgresKVRepository; + kvRepository: PostgresKVRepository< + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >; public type = "PostgresTaskOutputRepository" as const; constructor(connectionString: string) { super(); this.kvRepository = new PostgresKVRepository< - TaskInput, - TaskOutput, - typeof TaskOutputDiscriminator - >(connectionString, "task_outputs", TaskOutputDiscriminator); + TaskOutputPrimaryKey, + DefaultValueType, + typeof TaskOutputPrimaryKeySchema + >(connectionString, "task_outputs", TaskOutputPrimaryKeySchema); } } diff --git a/packages/storage/src/node/postgres/base/PostgresKVRepository.ts b/packages/storage/src/node/postgres/base/PostgresKVRepository.ts index 3973cea..f807a38 100644 --- a/packages/storage/src/node/postgres/base/PostgresKVRepository.ts +++ b/packages/storage/src/node/postgres/base/PostgresKVRepository.ts @@ -6,56 +6,99 @@ // ******************************************************************************* import { Pool } from "pg"; -import { DiscriminatorSchema, KVRepository } from "ellmers-core"; -import { makeFingerprint } from "../../../util/Misc"; +import { + BaseValueSchema, + BasePrimaryKeySchema, + BasicKeyType, + DefaultValueSchema, + DefaultPrimaryKeySchema, + DefaultPrimaryKeyType, + DefaultValueType, +} from "ellmers-core"; +import { BaseSqlKVRepository } from "../../../util/base/BaseSqlKVRepository"; -// PostgresKVRepository is a key-value store that uses PostgreSQL as the backend for -// multi-user scenarios. It supports discriminators. +/// ****************************************************************** +/// * +/// ****************************************************************** +/// ********************** NOT TESTED YET *********************** +/// ****************************************************************** +/// * +/// ****************************************************************** +/// really... i wrote it and it passes the linter only! +/** +/** + * A PostgreSQL-based key-value repository implementation that extends BaseSqlKVRepository. + * This class provides persistent storage for key-value pairs in a PostgreSQL database, + * making it suitable for multi-user scenarios. + * + * @template Key - The type of the primary key, must be a record of basic types + * @template Value - The type of the stored value, can be any record type + * @template PrimaryKeySchema - Schema definition for the primary key + * @template ValueSchema - Schema definition for the value + * @template Combined - Combined type of Key & Value + */ export class PostgresKVRepository< - Key = string, - Value = string, - Discriminator extends DiscriminatorSchema = DiscriminatorSchema -> extends KVRepository { + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Record = Key & Value +> extends BaseSqlKVRepository { private pool: Pool; + /** + * Creates a new PostgresKVRepository instance. + * + * @param connectionString - PostgreSQL connection string + * @param table - Name of the table to store key-value pairs (defaults to "kv_store") + * @param primaryKeySchema - Schema definition for primary key columns + * @param valueSchema - Schema definition for value columns + * @param searchable - Array of columns to make searchable + */ constructor( connectionString: string, public table: string = "kv_store", - discriminatorsSchema: Discriminator = {} as Discriminator + primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + searchable: Array = [] ) { - super(); - this.discriminatorsSchema = discriminatorsSchema; + super(table, primaryKeySchema, valueSchema, searchable); this.pool = new Pool({ connectionString }); - this.setupDatabase(table); + this.setupDatabase(); } - private async setupDatabase(table: string): Promise { + /** + * Initializes the database table with the required schema. + * Creates the table if it doesn't exist with primary key and value columns. + */ + private async setupDatabase(): Promise { await this.pool.query(` - CREATE TABLE IF NOT EXISTS ${this.table} ( - ${this.constructDiscriminatorColumns()} - key TEXT NOT NULL, - value JSONB NOT NULL, + CREATE TABLE IF NOT EXISTS \`${this.table}\` ( + ${this.constructPrimaryKeyColumns()}, + ${this.constructValueColumns()}, PRIMARY KEY (${this.primaryKeyColumnList()}) ) `); - } - - private constructDiscriminatorColumns(): string { - const cols = Object.entries(this.discriminatorsSchema) - .map(([key, type]) => { - // Convert the provided type to a SQL type, assuming simple mappings; adjust as necessary - const sqlType = this.mapTypeToSQL(type); - return `${key} ${sqlType} NOT NULL`; - }) - .join(", "); - if (cols.length > 0) { - return `${cols}, `; + for (const column of this.searchable) { + if (column !== this.primaryKeyColumns()[0]) { + /* Makes other columns searchable, but excludes the first column + of a primary key (which would be redundant) */ + await this.pool.query( + `CREATE INDEX IF NOT EXISTS \`${this.table}_${column as string}\` + ON \`${this.table}\` (\`${column as string}\`)` + ); + } } - return ""; } - private mapTypeToSQL(type: string): string { + /** + * Maps TypeScript/JavaScript types to corresponding PostgreSQL data types. + * + * @param type - The TypeScript/JavaScript type to map + * @returns The corresponding PostgreSQL data type + */ + protected mapTypeToSQL(type: string): string { // Basic type mapping; extend according to your needs switch (type) { case "string": @@ -68,50 +111,121 @@ export class PostgresKVRepository< } } - async put(keySimpleOrObject: Key, value: Value): Promise { - const { discriminators, key } = this.extractDiscriminators(keySimpleOrObject); - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); - const values = Object.values(discriminators).concat(id, JSON.stringify(value)); - await this.pool.query( - `INSERT INTO ${this.table} (${this.primaryKeyColumnList()}, value) - VALUES (${this.primaryKeyColumns().map((i) => "?")}, ?) - ON CONFLICT (key) DO UPDATE - SET value = EXCLUDED.value`, - values - ); - this.emit("put", id, discriminators); - } + /** + * Stores or updates a key-value pair in the database. + * Uses UPSERT (INSERT ... ON CONFLICT DO UPDATE) for atomic operations. + * + * @param key - The primary key object + * @param value - The value object to store + * @emits "put" event with the key when successful + */ + async putKeyValue(key: Key, value: Value): Promise { + const sql = ` + INSERT INTO \`${this.table}\` ( + ${this.primaryKeyColumnList()}, + ${this.valueColumnList()} + ) + VALUES ( + ${this.primaryKeyColumns().map((i) => "?")} + ) + ON CONFLICT (${this.primaryKeyColumnList()}) DO UPDATE + SET + ${(this.valueColumns() as string[]).map((col) => `\`${col}\` = EXCLUDED.\`${col}\``).join(", ")} + `; - async get(keySimpleOrObject: Key): Promise { - const { discriminators, key } = this.extractDiscriminators(keySimpleOrObject); - const id = typeof key === "object" ? await makeFingerprint(key) : String(key); + const primaryKeyParams = this.getPrimaryKeyAsOrderedArray(key); + const valueParams = this.getValueAsOrderedArray(value); + const params = [...primaryKeyParams, ...valueParams]; + await this.pool.query(sql, params); + this.emit("put", key); + } - const whereClauses = this.primaryKeyColumns() - .map((discriminatorKey, i) => `${discriminatorKey} = $${i + 1}`) + /** + * Retrieves a value from the database by its primary key. + * + * @param key - The primary key object to look up + * @returns The stored value or undefined if not found + * @emits "get" event with the key when successful + */ + async getKeyValue(key: Key): Promise { + const whereClauses = (this.primaryKeyColumns() as string[]) + .map((discriminatorKey, i) => `\`${discriminatorKey}\` = $${i + 1}`) .join(" AND "); - const values = Object.values(discriminators).concat(id); + const params = this.getPrimaryKeyAsOrderedArray(key); const result = await this.pool.query( - `SELECT value FROM ${this.table} WHERE ${whereClauses}`, - values + `SELECT ${this.valueColumnList()} FROM \`${this.table}\` WHERE ${whereClauses}`, + params ); if (result.rows.length > 0) { - this.emit("get", id, discriminators); - return result.rows[0].value as Value; + this.emit("get", key); + return result.rows[0] as Value; } else { return undefined; } } - async clear(): Promise { - await this.pool.query(`DELETE FROM ${this.table}`); - this.emit("clear"); + /** + * Method to be implemented by concrete repositories to search for key-value pairs + * based on a partial key. + * + * @param key - Partial key to search for + * @returns Promise resolving to an array of combined key-value objects or undefined if not found + */ + public async search(key: Partial): Promise { + const search = Object.keys(key); + if (search.length !== 1) { + //TODO: make this work with any prefix of primary key + throw new Error("Search must be a single key"); + } + + const sql = ` + SELECT * FROM \`${this.table}\` + WHERE \`${search[0]}\` = ? + `; + const result = await this.pool.query(sql, [key[search[0]]]); + if (result.rows.length > 0) { + this.emit("search"); + return result.rows; + } else { + return undefined; + } + } + + /** + * Deletes a key-value pair from the database. + * + * @param key - The primary key object to delete + * @emits "delete" event with the key when successful + */ + async deleteKeyValue(key: Key): Promise { + const whereClauses = (this.primaryKeyColumns() as string[]) + .map((key, i) => `\`${key}\` = $${i + 1}`) + .join(" AND "); + + const params = this.getPrimaryKeyAsOrderedArray(key); + await this.pool.query(`DELETE FROM \`${this.table}\` WHERE ${whereClauses}`, params); + this.emit("delete", key); + } + + /** + * Deletes all key-value pairs from the database table. + * @emits "clearall" event when successful + */ + async deleteAll(): Promise { + await this.pool.query(`DELETE FROM \`${this.table}\``); + this.emit("clearall"); } + /** + * Returns the total number of key-value pairs in the database. + * + * @returns Promise resolving to the count of stored items + */ async size(): Promise { - const result = await this.pool.query(`SELECT COUNT(*) FROM ${this.table}`); + const result = await this.pool.query(`SELECT COUNT(*) FROM \`${this.table}\``); return parseInt(result.rows[0].count, 10); } } diff --git a/packages/storage/src/util/base/BaseSqlKVRepository.ts b/packages/storage/src/util/base/BaseSqlKVRepository.ts new file mode 100644 index 0000000..6913617 --- /dev/null +++ b/packages/storage/src/util/base/BaseSqlKVRepository.ts @@ -0,0 +1,183 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { + BaseValueSchema, + BasicKeyType, + BasePrimaryKeySchema, + DefaultValueType, + DefaultValueSchema, + DefaultPrimaryKeyType, + DefaultPrimaryKeySchema, + KVRepository, + BasicValueType, +} from "ellmers-core"; + +// BaseKVRepository is a key-value store that uses SQLite and Postgres use as common code + +/** + * Base class for SQL-based key-value repositories that implements common functionality + * for both SQLite and PostgreSQL database implementations. + * + * @template Key - The type of the primary key object, must be a record of basic types + * @template Value - The type of the value object being stored + * @template PrimaryKeySchema - Schema definition for the primary key + * @template ValueSchema - Schema definition for the value + * @template Combined - Combined type of Key & Value in case just combining them is not enough + */ +export abstract class BaseSqlKVRepository< + Key extends Record = DefaultPrimaryKeyType, + Value extends Record = DefaultValueType, + PrimaryKeySchema extends BasePrimaryKeySchema = typeof DefaultPrimaryKeySchema, + ValueSchema extends BaseValueSchema = typeof DefaultValueSchema, + Combined extends Record = Key & Value +> extends KVRepository { + /** + * Creates a new instance of BaseSqlKVRepository + * @param table - The name of the database table to use for storage + * @param primaryKeySchema - Schema defining the structure of the primary key + * @param valueSchema - Schema defining the structure of the stored values + * @param searchable - Array of columns to make searchable + */ + constructor( + public table: string = "kv_store", + primaryKeySchema: PrimaryKeySchema = DefaultPrimaryKeySchema as PrimaryKeySchema, + valueSchema: ValueSchema = DefaultValueSchema as ValueSchema, + protected searchable: Array = [] + ) { + super(primaryKeySchema, valueSchema, searchable); + this.validateTableAndSchema(); + } + + /** + * Maps JavaScript/TypeScript types to their corresponding SQL type + * Must be implemented by derived classes for specific SQL dialects + */ + protected abstract mapTypeToSQL(type: string): string; + + /** + * Generates the SQL column definitions for primary key fields + * @returns SQL string containing primary key column definitions + */ + protected constructPrimaryKeyColumns(): string { + const cols = Object.entries(this.primaryKeySchema) + .map(([key, type]) => { + const sqlType = this.mapTypeToSQL(type); + return `\`${key}\` ${sqlType} NOT NULL`; + }) + .join(", "); + return cols; + } + + /** + * Generates the SQL column definitions for value fields + * @returns SQL string containing value column definitions + */ + protected constructValueColumns(): string { + const cols = Object.entries(this.valueSchema) + .map(([key, type]) => { + const sqlType = this.mapTypeToSQL(type); + return `\`${key}\` ${sqlType} NULL`; + }) + .join(", "); + return cols; + } + + /** + * Returns a comma-separated list of primary key column names + * @returns Formatted string of primary key column names + */ + protected primaryKeyColumnList(): string { + return "`" + this.primaryKeyColumns().join("`, `") + "`"; + } + + /** + * Returns a comma-separated list of value column names + * @returns Formatted string of value column names + */ + protected valueColumnList(): string { + return "`" + this.valueColumns().join("`, `") + "`"; + } + + /** + * Converts a value object into an ordered array based on the valueSchema + * This ensures consistent parameter ordering for SQL queries + * @param value - The value object to convert + * @returns Array of values ordered according to the schema + * @throws Error if a required field is missing + */ + protected getValueAsOrderedArray(value: Value): BasicValueType[] { + const orderedParams: BasicValueType[] = []; + for (const [key, type] of Object.entries(this.valueSchema)) { + if (key in value) { + orderedParams.push(value[key]); + } else { + throw new Error(`Missing required value field: ${key}`); + } + } + return orderedParams; + } + + /** + * Converts a primary key object into an ordered array based on the primaryKeySchema + * This ensures consistent parameter ordering for SQL queries + * @param key - The primary key object to convert + * @returns Array of key values ordered according to the schema + * @throws Error if a required primary key field is missing + */ + protected getPrimaryKeyAsOrderedArray(key: Key): BasicKeyType[] { + const orderedParams: BasicKeyType[] = []; + for (const [k, type] of Object.entries(this.primaryKeySchema)) { + if (k in key) { + orderedParams.push(key[k]); + } else { + throw new Error(`Missing required primary key field: ${k}`); + } + } + return orderedParams; + } + + /** + * Validates table name and schema configurations + * Checks for: + * 1. Valid table name format + * 2. Valid schema key names + * 3. No duplicate keys between primary key and value schemas + * This is a sanity check to make sure the table and schema are valid, + * and to prevent dumb mistakes and mischevious behavior. + * @throws Error if validation fails + */ + protected validateTableAndSchema(): void { + // Check for invalid characters in table name + if (!/^[a-zA-Z_][a-zA-Z0-9_-]*$/.test(this.table)) { + throw new Error( + `Invalid table name: ${this.table}. Must start with letter/underscore and contain only alphanumeric/underscore characters` + ); + } + + // Validate schema key naming + const validateSchemaKeys = (schema: Record) => { + Object.keys(schema).forEach((key) => { + if (!/^[a-zA-Z_][a-zA-Z0-9_-]*$/.test(key)) { + throw new Error( + `Invalid schema key: ${key}. Must start with letter/underscore and contain only alphanumeric/underscore characters` + ); + } + }); + }; + validateSchemaKeys(this.primaryKeySchema); + validateSchemaKeys(this.valueSchema); + + // Check for key name collisions between schemas + const primaryKeys = new Set(Object.keys(this.primaryKeySchema)); + const valueKeys = Object.keys(this.valueSchema); + const duplicates = valueKeys.filter((key) => primaryKeys.has(key)); + if (duplicates.length > 0) { + throw new Error(`Duplicate keys found in schemas: ${duplicates.join(", ")}`); + } + } +} diff --git a/packages/storage/tsconfig.json b/packages/storage/tsconfig.json index 69bf5c5..2e09cad 100644 --- a/packages/storage/tsconfig.json +++ b/packages/storage/tsconfig.json @@ -18,5 +18,5 @@ "ellmers-core": ["../core/src"] } }, - "references": [{ "path": "../core" }] + "references": [{ "path": "../core" }, { "path": "../ai" }] } diff --git a/packages/task/src/task/JavaScriptTask.ts b/packages/task/src/task/JavaScriptTask.ts index 0fc23ad..8264a10 100644 --- a/packages/task/src/task/JavaScriptTask.ts +++ b/packages/task/src/task/JavaScriptTask.ts @@ -46,7 +46,7 @@ export class JavaScriptTask extends SingleTask { constructor(config: TaskConfig & { input?: JavaScriptTaskInput } = {}) { super(config); } - runSyncOnly() { + async runReactive() { if (this.runInputData.code) { try { const myInterpreter = new Interpreter(this.runInputData.code); diff --git a/packages/test/package.json b/packages/test/package.json new file mode 100644 index 0000000..e20b518 --- /dev/null +++ b/packages/test/package.json @@ -0,0 +1,32 @@ +{ + "name": "ellmers-test", + "type": "module", + "version": "0.0.1", + "description": "Ellmers is a tool for building and running DAG pipelines of AI tasks.", + "scripts": { + "watch": "concurrently -c 'auto' 'bun:watch-*'", + "watch-code": "bun build --watch --no-clear-screen --target=browser --sourcemap=external --external @mediapipe/tasks-text --external @huggingface/transformers --external ellmers-core --external ellmers-ai --external ellmers-ai-provider --external ellmers-storage --outdir ./dist/ ./src/index.ts", + "watch-types": "tsc --watch --preserveWatchOutput", + "build": "bun run build-clean && bun run build-types && bun run build-code", + "build-clean": "rm -fr dist/* tsconfig.tsbuildinfo", + "build-code": "bun build --target=browser --sourcemap=external --external @mediapipe/tasks-text --external @mediapipe/tasks-text --external @huggingface/transformers --external ellmers-core --external ellmers-ai --external ellmers-ai-provider --external ellmers-storage --outdir ./dist/ ./src/index.ts", + "build-types": "tsc", + "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0", + "test": "bun test" + }, + "exports": { + ".": { + "import": "./dist/index.js", + "types": "./dist/index.d.ts" + } + }, + "files": [ + "dist" + ], + "dependencies": { + "ellmers-core": "workspace:packages/core", + "ellmers-ai": "workspace:packages/ai", + "ellmers-ai-provider": "workspace:packages/ai-provider", + "ellmers-storage": "workspace:packages/storage" + } +} diff --git a/packages/test/src/index.ts b/packages/test/src/index.ts new file mode 100644 index 0000000..27d9f56 --- /dev/null +++ b/packages/test/src/index.ts @@ -0,0 +1,38 @@ +import { getProviderRegistry } from "ellmers-ai"; +import { + LOCAL_ONNX_TRANSFORMERJS, + registerHuggingfaceLocalTasks, +} from "ellmers-ai-provider/hf-transformers/browser"; +import { + MEDIA_PIPE_TFJS_MODEL, + registerMediaPipeTfJsLocalTasks, +} from "ellmers-ai-provider/tf-mediapipe/browser"; +import { ConcurrencyLimiter, TaskInput, TaskOutput } from "ellmers-core"; +import { InMemoryJobQueue } from "ellmers-storage/inmemory"; + +export * from "./sample/MediaPipeModelSamples"; +export * from "./sample/ONNXModelSamples"; + +export async function registerHuggingfaceLocalTasksInMemory() { + registerHuggingfaceLocalTasks(); + const ProviderRegistry = getProviderRegistry(); + const jobQueue = new InMemoryJobQueue( + "local_hf", + new ConcurrencyLimiter(1, 10), + 10 + ); + ProviderRegistry.registerQueue(LOCAL_ONNX_TRANSFORMERJS, jobQueue); + jobQueue.start(); +} + +export async function registerMediaPipeTfJsLocalInMemory() { + registerMediaPipeTfJsLocalTasks(); + const ProviderRegistry = getProviderRegistry(); + const jobQueue = new InMemoryJobQueue( + "local_media_pipe", + new ConcurrencyLimiter(1, 10), + 10 + ); + ProviderRegistry.registerQueue(MEDIA_PIPE_TFJS_MODEL, jobQueue); + jobQueue.start(); +} diff --git a/packages/test/src/sample/MediaPipeModelSamples.ts b/packages/test/src/sample/MediaPipeModelSamples.ts new file mode 100644 index 0000000..f15ae35 --- /dev/null +++ b/packages/test/src/sample/MediaPipeModelSamples.ts @@ -0,0 +1,49 @@ +import { MEDIA_PIPE_TFJS_MODEL } from "ellmers-ai-provider/tf-mediapipe/browser"; +import { getGlobalModelRepository, Model } from "ellmers-ai"; + +async function addMediaPipeModel(info: Partial, tasks: string[]) { + const name = "MEDIAPIPE " + info.name; + + const model = Object.assign( + { + provider: MEDIA_PIPE_TFJS_MODEL, + quantization: null, + normalize: true, + contextWindow: 4096, + availableOnBrowser: true, + availableOnServer: false, + parameters: null, + languageStyle: null, + usingDimensions: info.nativeDimensions ?? null, + }, + info, + { name } + ) as Model; + + await getGlobalModelRepository().addModel(model); + await Promise.allSettled( + tasks.map((task) => getGlobalModelRepository().connectTaskToModel(task, name)) + ); +} + +export async function registerMediaPipeTfJsLocalModels(): Promise { + await addMediaPipeModel( + { + name: "Universal Sentence Encoder", + pipeline: "text_embedder", + nativeDimensions: 100, + url: "https://storage.googleapis.com/mediapipe-tasks/text_embedder/universal_sentence_encoder.tflite", + }, + ["TextEmbeddingTask"] + ); + + await addMediaPipeModel( + { + name: "Text Encoder", + pipeline: "text_embedder", + nativeDimensions: 100, + url: "https://huggingface.co/keras-sd/text-encoder-tflite/resolve/main/text_encoder.tflite?download=true", + }, + ["TextEmbeddingTask"] + ); +} diff --git a/packages/test/src/sample/ONNXModelSamples.ts b/packages/test/src/sample/ONNXModelSamples.ts new file mode 100644 index 0000000..cc127d0 --- /dev/null +++ b/packages/test/src/sample/ONNXModelSamples.ts @@ -0,0 +1,178 @@ +import { + LOCAL_ONNX_TRANSFORMERJS, + QUANTIZATION_DATA_TYPES, +} from "ellmers-ai-provider/hf-transformers/browser"; +import { getGlobalModelRepository, Model } from "ellmers-ai"; + +async function addONNXModel(info: Partial, tasks: string[]) { + const name = info.name + ? info.name + : "ONNX " + info.url + " " + (info.quantization ?? QUANTIZATION_DATA_TYPES.q8); + + const model = Object.assign( + { + provider: LOCAL_ONNX_TRANSFORMERJS, + quantization: QUANTIZATION_DATA_TYPES.q8, + normalize: true, + contextWindow: 4096, + availableOnBrowser: true, + availableOnServer: true, + parameters: null, + languageStyle: null, + usingDimensions: info.nativeDimensions ?? null, + }, + info, + { name } + ) as Model; + + await getGlobalModelRepository().addModel(model); + await Promise.allSettled( + tasks.map((task) => getGlobalModelRepository().connectTaskToModel(task, name)) + ); +} + +export async function registerHuggingfaceLocalModels(): Promise { + await addONNXModel( + { + pipeline: "feature-extraction", + nativeDimensions: 384, + url: "Supabase/gte-small", + }, + ["TextEmbeddingTask"] + ); + + await addONNXModel( + { + pipeline: "feature-extraction", + nativeDimensions: 768, + url: "Xenova/bge-base-en-v1.5", + }, + ["TextEmbeddingTask"] + ); + + await addONNXModel( + { + pipeline: "feature-extraction", + nativeDimensions: 384, + url: "Xenova/all-MiniLM-L6-v2", + }, + ["TextEmbeddingTask"] + ); + + await addONNXModel( + { + pipeline: "feature-extraction", + nativeDimensions: 1024, + url: "WhereIsAI/UAE-Large-V1", + }, + ["TextEmbeddingTask"] + ); + + await addONNXModel( + { + pipeline: "feature-extraction", + nativeDimensions: 384, + url: "Xenova/bge-small-en-v1.5", + }, + ["TextEmbeddingTask"] + ); + await addONNXModel( + { + pipeline: "question-answering", + url: "Xenova/distilbert-base-uncased-distilled-squad", + }, + ["TextQuestionAnsweringTask"] + ); + + await addONNXModel( + { + pipeline: "zero-shot-classification", + url: "Xenova/distilbert-base-uncased-mnli", + }, + ["TextClassificationTask"] + ); + + await addONNXModel( + { + pipeline: "fill-mask", + url: "answerdotai/ModernBERT-base", + }, + ["TextClassificationTask"] + ); + + await addONNXModel( + { + pipeline: "feature-extraction", + nativeDimensions: 768, + url: "Xenova/multi-qa-mpnet-base-dot-v1", + }, + ["TextEmbeddingTask"] + ); + + await addONNXModel( + { + pipeline: "text-generation", + url: "Xenova/gpt2", + }, + ["TextGenerationTask"] + ); + + await addONNXModel( + { + pipeline: "text-generation", + url: "Xenova/distilgpt2", + }, + ["TextGenerationTask"] + ); + + await addONNXModel( + { + pipeline: "text2text-generation", + url: "Xenova/flan-t5-small", + }, + ["TextGenerationTask"] + ); + + await addONNXModel( + { + pipeline: "text2text-generation", + url: "Xenova/LaMini-Flan-T5-783M", + }, + ["TextGenerationTask"] + ); + + await addONNXModel( + { + pipeline: "summarization", + url: "Falconsai/text_summarization", + }, + ["TextSummaryTask"] + ); + + await addONNXModel( + { + pipeline: "translation", + url: "Xenova/nllb-200-distilled-600M", + languageStyle: "FLORES-200", + }, + ["TextTranslationTask"] + ); + + await addONNXModel( + { + pipeline: "translation", + url: "Xenova/m2m100_418M", + languageStyle: "ISO-639", + }, + ["TextTranslationTask"] + ); + + await addONNXModel( + { + pipeline: "translation", + url: "Xenova/mbart-large-50-many-to-many-mmt", + languageStyle: "ISO-639_ISO-3166-1-alpha-2", + }, + ["TextTranslationTask"] + ); +} diff --git a/packages/test/src/util/db_sqlite.ts b/packages/test/src/util/db_sqlite.ts new file mode 100644 index 0000000..fa89038 --- /dev/null +++ b/packages/test/src/util/db_sqlite.ts @@ -0,0 +1,21 @@ +const wrapper = function () { + if (process["isBun"]) { + // eslint-disable-next-line @typescript-eslint/no-var-requires + return require("bun:sqlite").Database; + } + + return require("better-sqlite3"); +}; + +const module = wrapper(); + +let db: any; + +export function getDatabase(name = ":memory:"): any { + if (!db) { + db = new module(name); + } + return db; +} + +export default module; diff --git a/packages/test/tsconfig.json b/packages/test/tsconfig.json new file mode 100644 index 0000000..ad55c7f --- /dev/null +++ b/packages/test/tsconfig.json @@ -0,0 +1,24 @@ +{ + "extends": "../../tsconfig.json", + "include": ["src/**/*"], + "files": ["src/index.ts"], + "exclude": ["**/*.test.ts", "dist"], + "compilerOptions": { + "outDir": "./dist", + "baseUrl": "./src", + "rootDir": "./src", + "paths": { + "#/*": ["./src/*"], + "ellmers-core": ["../core/src"], + "ellmers-ai": ["../ai/src"], + "ellmers-ai-provider": ["../ai-provider/src"], + "ellmers-storage": ["../storage/src"] + } + }, + "references": [ + { "path": "../core" }, + { "path": "../ai" }, + { "path": "../ai-provider" }, + { "path": "../storage" } + ] +} diff --git a/tsconfig.json b/tsconfig.json index 376192d..2116cb3 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -12,6 +12,8 @@ "allowSyntheticDefaultImports": true, "forceConsistentCasingInFileNames": true, "allowJs": true, + "emitDecoratorMetadata": true, + "experimentalDecorators": true, "declaration": true, "emitDeclarationOnly": true, "declarationMap": true,