Skip to content

Commit

Permalink
Add custom model support (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
mme authored Nov 22, 2023
1 parent 28b31e8 commit fb0f7fa
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 10 deletions.
4 changes: 2 additions & 2 deletions packages/core/src/beak.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EventEmitter } from "eventemitter3";
import { OpenAI, OpenAIModel } from "@beakjs/openai";
import { OpenAI, OpenAIModel, CustomModel } from "@beakjs/openai";
import {
LLMAdapter,
Message,
Expand All @@ -22,7 +22,7 @@ const FORMATTING_INSTRUCTIONS =
export interface BeakConfiguration {
openAIApiKey?: string;
baseUrl?: string;
openAIModel?: OpenAIModel;
openAIModel?: OpenAIModel | CustomModel;
maxFeedback?: number;
instructions?: string;
temperature?: number;
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export { BeakCore } from "./beak";
export type { FunctionDefinition } from "./types";
export { Message, DebugLogger } from "./types";
export type { OpenAIModel } from "@beakjs/openai";
export type { OpenAIModel, CustomModel } from "@beakjs/openai";
7 changes: 6 additions & 1 deletion packages/openai/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
export { OpenAI } from "./openai";
export { ChatCompletion } from "./chat";
export type { OpenAIModel, OpenAIMessage, OpenAIFunction } from "./types";
export type {
OpenAIModel,
OpenAIMessage,
OpenAIFunction,
CustomModel,
} from "./types";
export type { FetchChatCompletionParams } from "./chat";
14 changes: 10 additions & 4 deletions packages/openai/src/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import {
DebugLogger,
NoopDebugLogger,
DEFAULT_MODEL,
CustomModel,
} from "./types";
import { ChatCompletion, FetchChatCompletionParams } from "./chat";

interface OpenAIConfiguration {
apiKey?: string;
baseUrl?: string;
model?: OpenAIModel;
model?: OpenAIModel | CustomModel;
debugLogger?: DebugLogger;
}

Expand All @@ -32,7 +33,7 @@ interface OpenAIEvents {
export class OpenAI extends EventEmitter<OpenAIEvents> {
private apiKey?: string;
private baseUrl?: string;
private model: OpenAIModel;
private model: OpenAIModel | CustomModel;
private debug: DebugLogger;

private completionClient: ChatCompletion | null = null;
Expand All @@ -50,9 +51,14 @@ export class OpenAI extends EventEmitter<OpenAIEvents> {

public async queryChatCompletion(params: FetchChatCompletionParams) {
params = { ...params };
params.maxTokens ||= maxTokensForModel(this.model);
if (!(this.model instanceof CustomModel)) {
params.maxTokens ||= maxTokensForModel(this.model);
} else if (!params.maxTokens) {
throw new Error("maxTokens must be specified for custom models.");
}
params.functions ||= [];
params.model = this.model;
params.model =
this.model instanceof CustomModel ? this.model.name : this.model;
params.messages = this.buildPrompt(params);
return await this.runPrompt(params);
}
Expand Down
4 changes: 4 additions & 0 deletions packages/openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ export type OpenAIModel =
| "gpt-4-32k-0613"
| "gpt-3.5-turbo-16k-0613";

export class CustomModel {
constructor(public name: string) {}
}

export interface OpenAIChatCompletionChunk {
choices: {
delta: {
Expand Down
4 changes: 2 additions & 2 deletions packages/react/src/Beak.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, { useMemo } from "react";
import { BeakCore, OpenAIModel, DebugLogger } from "@beakjs/core";
import { BeakCore, OpenAIModel, CustomModel, DebugLogger } from "@beakjs/core";
import * as DefaultIcons from "./Icons";

const DEFAULT_DEBUG_LOGGER = new DebugLogger([]);
Expand Down Expand Up @@ -49,7 +49,7 @@ export function useBeakContext(): BeakContext {
interface BeakProps {
__unsafeOpenAIApiKey__?: string;
baseUrl?: string;
openAIModel?: OpenAIModel;
openAIModel?: OpenAIModel | CustomModel;
temperature?: number;
instructions?: string;
maxFeedback?: number;
Expand Down

0 comments on commit fb0f7fa

Please sign in to comment.