diff --git a/gai-frontend/lib/chat/chat.dart b/gai-frontend/lib/chat/chat.dart index 1110f6927..2166d31e6 100644 --- a/gai-frontend/lib/chat/chat.dart +++ b/gai-frontend/lib/chat/chat.dart @@ -3,6 +3,7 @@ import 'package:orchid/api/orchid_eth/orchid_account.dart'; import 'package:orchid/api/orchid_eth/orchid_account_detail.dart'; import 'package:orchid/api/orchid_keys.dart'; import 'package:orchid/chat/model.dart'; +import 'package:orchid/chat/provider_connection.dart'; import 'package:orchid/chat/scripting/chat_scripting.dart'; import 'package:orchid/common/app_sizes.dart'; import 'package:orchid/chat/chat_settings_button.dart'; @@ -86,11 +87,12 @@ class _ChatViewState extends State { log('Error initializing from params: $e, $stack'); } - /* // Initialize scripting extension + /* ChatScripting.init( // url: 'lib/extensions/test.js', - url: 'lib/extensions/party_mode.js', + // url: 'lib/extensions/party_mode.js', + url: 'lib/extensions/filter_example.js', debugMode: true, providerManager: _providerManager, chatHistory: _chatHistory, @@ -201,8 +203,8 @@ class _ChatViewState extends State { String? modelName, }) { final message = ChatMessage( - source, - msg, + source: source, + message: msg, metadata: metadata, sourceName: sourceName, modelId: modelId, @@ -212,7 +214,7 @@ class _ChatViewState extends State { } void _addChatMessage(ChatMessage message) { - log('Adding message: ${message.msg.truncate(64)}'); + log('Adding message: ${message.message.truncate(64)}'); setState(() { _chatHistory.addMessage(message); }); @@ -302,6 +304,15 @@ class _ChatViewState extends State { return; } + // Debug hack + if (_selectedModelIds.isEmpty && + ChatScripting.enabled && + ChatScripting.instance.debugMode) { + setState(() { + _selectedModelIds = ['gpt-4o']; + }); + } + // Validate the selected models if (_selectedModelIds.isEmpty) { _addMessage( @@ -344,8 +355,14 @@ class _ChatViewState extends State { ? _chatHistory.getConversation() : _chatHistory.getConversation(withModelId: modelId); - await _providerManager.sendMessagesToModel( + final chatResponse = await _providerManager.sendMessagesToModel( selectedMessages, modelId, _maxTokens); + + if (chatResponse != null) { + _handleChatResponseDefaultBehavior(chatResponse); + } else { + // The provider connection should have logged the issue. Do nothing. + } } catch (e) { _addMessage( ChatMessageSource.system, 'Error querying model $modelId: $e'); @@ -353,6 +370,20 @@ class _ChatViewState extends State { } } + // The default handler for chat responses from the models (simply adds response to the chat history). + void _handleChatResponseDefaultBehavior(ChatInferenceResponse chatResponse) { + final metadata = chatResponse.metadata; + final modelId = metadata['model_id']; // or request.modelId? + log('Handle response: ${chatResponse.message}, $metadata'); + _addMessage( + ChatMessageSource.provider, + chatResponse.message, + metadata: metadata, + modelId: modelId, + modelName: _modelsState.getModelOrDefaultNullable(modelId)?.name, + ); + } + void scrollMessagesDown() { // Dispatching it to the next frame seems to mitigate overlapping scrolls. Future.delayed(millis(50), () { @@ -600,4 +631,3 @@ Future _launchURL(String urlString) async { enum AuthTokenMethod { manual, walletConnect } enum OrchataMenuItem { debug } - diff --git a/gai-frontend/lib/chat/chat_message.dart b/gai-frontend/lib/chat/chat_message.dart index eea794c2d..7fadcd296 100644 --- a/gai-frontend/lib/chat/chat_message.dart +++ b/gai-frontend/lib/chat/chat_message.dart @@ -5,7 +5,7 @@ enum ChatMessageSource { client, provider, system, internal } class ChatMessage { final ChatMessageSource source; final String sourceName; - final String msg; + final String message; final Map? metadata; // The modelId of the model that generated this message @@ -14,17 +14,15 @@ class ChatMessage { // The name of the model that generated this message final String? modelName; - ChatMessage( - this.source, - this.msg, { + ChatMessage({ + required this.source, + required this.message, this.metadata, this.sourceName = '', this.modelId, this.modelName, }); - String get message => msg; - String? get displayName { if (source == ChatMessageSource.provider && modelName != null) { return modelName; @@ -58,6 +56,6 @@ class ChatMessage { @override String toString() { - return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${msg.substring(0, msg.length.clamp(0, 50))}...)'; + return 'ChatMessage(source: $source, modelId: $modelId, model: $modelName, msg: ${message.substring(0, message.length.clamp(0, 50))}...)'; } } diff --git a/gai-frontend/lib/chat/provider_connection.dart b/gai-frontend/lib/chat/provider_connection.dart index d2bade9e0..a310b6e61 100644 --- a/gai-frontend/lib/chat/provider_connection.dart +++ b/gai-frontend/lib/chat/provider_connection.dart @@ -5,32 +5,52 @@ import 'package:web_socket_channel/web_socket_channel.dart'; import 'package:orchid/api/orchid_crypto.dart'; import 'package:orchid/api/orchid_eth/orchid_ticket.dart'; import 'package:orchid/api/orchid_eth/orchid_account_detail.dart'; -import 'inference_client.dart'; import 'chat_message.dart'; -import 'package:orchid/api/orchid_log.dart'; +import 'inference_client.dart'; typedef MessageCallback = void Function(String message); -typedef ChatCallback = void Function( - String message, Map metadata); typedef VoidCallback = void Function(); typedef ErrorCallback = void Function(String error); typedef AuthTokenCallback = void Function(String token, String inferenceUrl); -class _PendingRequest { - final String requestId; +class ChatInferenceRequest { final String modelId; - final List messages; - final Map? params; + final List> preparedMessages; + final Map? requestParams; final DateTime timestamp; - _PendingRequest({ - required this.requestId, + ChatInferenceRequest({ required this.modelId, - required this.messages, - required this.params, + required this.preparedMessages, + required this.requestParams, }) : timestamp = DateTime.now(); } +class ChatInferenceResponse { + // Request + final ChatInferenceRequest request; + + // Result + final String message; + final Map metadata; + + ChatInferenceResponse({ + required this.request, + required this.message, + required this.metadata, + }); + + ChatMessage toChatMessage() { + return ChatMessage( + source: ChatMessageSource.provider, + message: message, + // sourceName: request.modelId, + metadata: metadata, + modelId: request.modelId, + ); + } +} + class ProviderConnection { final maxuint256 = BigInt.two.pow(256) - BigInt.one; final maxuint64 = BigInt.two.pow(64) - BigInt.one; @@ -40,7 +60,7 @@ class ProviderConnection { InferenceClient? get inferenceClient => _inferenceClient; InferenceClient? _inferenceClient; final MessageCallback onMessage; - final ChatCallback onChat; + final VoidCallback onConnect; final ErrorCallback onError; final VoidCallback onDisconnect; @@ -51,8 +71,7 @@ class ProviderConnection { final String? authToken; final AccountDetail? accountDetail; final AuthTokenCallback? onAuthToken; - final Map _requestModels = {}; - final Map _pendingRequests = {}; + bool _usingDirectAuth = false; String _generateRequestId() { @@ -62,7 +81,7 @@ class ProviderConnection { ProviderConnection({ required this.onMessage, required this.onConnect, - required this.onChat, + // required this.onChat, required this.onDisconnect, required this.onError, required this.onSystemMessage, @@ -104,7 +123,7 @@ class ProviderConnection { AccountDetail? accountDetail, String? authToken, required MessageCallback onMessage, - required ChatCallback onChat, + // required ChatCallback onChat, required VoidCallback onConnect, required ErrorCallback onError, required VoidCallback onDisconnect, @@ -119,7 +138,7 @@ class ProviderConnection { final connection = ProviderConnection( onMessage: onMessage, onConnect: onConnect, - onChat: onChat, + // onChat: onChat, onDisconnect: onDisconnect, onError: onError, onSystemMessage: onSystemMessage, @@ -222,16 +241,6 @@ class ProviderConnection { onMessage('Provider: $message'); switch (data['type']) { - case 'job_complete': - final requestId = data['request_id']; - final pendingRequest = - requestId != null ? _pendingRequests.remove(requestId) : null; - - onChat(data['output'], { - ...data, - 'model_id': pendingRequest?.modelId, - }); - break; case 'invoice': payInvoice(data); break; @@ -255,11 +264,16 @@ class ProviderConnection { _sendProviderMessage(message); } - Future requestInference( + Future requestInference( String modelId, List> preparedMessages, { Map? params, }) async { + var request = ChatInferenceRequest( + modelId: modelId, + preparedMessages: preparedMessages, + requestParams: params, + ); /* Requesting inference for model gpt-4o-mini Prepared messages: [{role: user, content: Hello!}, {role: assistant, content: Hello! How can I assist you today?}, {role: user, content: How are you?}] @@ -271,49 +285,46 @@ class ProviderConnection { if (_inferenceClient == null) { onError('No inference connection available'); - return; + return null; } } try { final requestId = _generateRequestId(); - _pendingRequests[requestId] = _PendingRequest( - requestId: requestId, - modelId: modelId, - messages: [], // Empty since we're using preparedMessages now - params: params, - ); - final allParams = { ...?params, 'request_id': requestId, }; onInternalMessage('Sending inference request:\n' - 'Model: $modelId\n' - 'Messages: ${preparedMessages}\n' - 'Params: $allParams' - ); + 'Model: $modelId\n' + 'Messages: ${preparedMessages}\n' + 'Params: $allParams'); final Map result = await _inferenceClient!.inference( messages: preparedMessages, model: modelId, params: allParams, ); - - _pendingRequests.remove(requestId); - - onChat(result['response'], { - 'type': 'job_complete', - 'output': result['response'], - 'usage': result['usage'], - 'model_id': modelId, - 'request_id': requestId, - 'estimated_prompt_tokens': result['estimated_prompt_tokens'], - }); + + final chatResult = ChatInferenceResponse( + request: request, + message: result['response'], + metadata: { + 'type': 'job_complete', + 'output': result['response'], + 'usage': result['usage'], + 'model_id': modelId, + 'request_id': requestId, + 'estimated_prompt_tokens': result['estimated_prompt_tokens'], + }); + + return chatResult; + } catch (e, stack) { onError('Failed to send inference request: $e\n$stack'); + return null; } } @@ -328,7 +339,6 @@ class ProviderConnection { void dispose() { _providerChannel?.sink.close(); - _pendingRequests.clear(); onDisconnect(); } diff --git a/gai-frontend/lib/chat/provider_manager.dart b/gai-frontend/lib/chat/provider_manager.dart index 41ccd4c2d..b5e81a224 100644 --- a/gai-frontend/lib/chat/provider_manager.dart +++ b/gai-frontend/lib/chat/provider_manager.dart @@ -11,6 +11,8 @@ class ProviderManager { late final Map> _providers; final VoidCallback onProviderConnected; final VoidCallback onProviderDisconnected; + + // Callback for to the UI to add to the chat history final void Function(ChatMessage) onChatMessage; final ModelManager modelsState; @@ -78,21 +80,21 @@ class ProviderManager { void _addMessage( ChatMessageSource source, - String msg, { + String message, { Map? metadata, String sourceName = '', String? modelId, String? modelName, }) { - final message = ChatMessage( - source, - msg, + final chatMessage = ChatMessage( + source: source, + message: message, metadata: metadata, sourceName: sourceName, modelId: modelId, modelName: modelName, ); - onChatMessage(message); + onChatMessage(chatMessage); } // TODO: review duplication between auth modes @@ -116,17 +118,6 @@ class ProviderManager { onConnect: () { _providerConnected('Direct Auth'); }, - onChat: (String msg, Map metadata) { - _addMessage( - ChatMessageSource.provider, - msg, - metadata: metadata, - modelId: metadata['model_id'], - modelName: modelsState - .getModelOrDefaultNullable(metadata['model_id']) - ?.name, - ); - }, onDisconnect: _providerDisconnected, onError: (msg) { _addMessage(ChatMessageSource.system, 'Provider error: $msg'); @@ -202,19 +193,6 @@ class ProviderManager { onConnect: () { _providerConnected(name); }, - onChat: (String msg, Map metadata) { - log('onChat received metadata: $metadata'); - final modelId = metadata['model_id']; - log('Found model_id: $modelId'); - - _addMessage( - ChatMessageSource.provider, - msg, - metadata: metadata, - modelId: modelId, - modelName: modelsState.getModelOrDefaultNullable(modelId)?.name, - ); - }, onDisconnect: _providerDisconnected, onError: (msg) { _addMessage(ChatMessageSource.system, 'Provider error: $msg'); @@ -257,7 +235,7 @@ class ProviderManager { } // Note: This method is exposed to the scripting environment. - Future sendMessagesToModel( + Future sendMessagesToModel( List messages, String modelId, int? maxTokens, @@ -273,7 +251,7 @@ class ProviderManager { } // Note: This method is exposed to the scripting environment. - Future sendFormattedMessagesToModel( + Future sendFormattedMessagesToModel( List> formattedMessages, String modelId, int? maxTokens, @@ -292,7 +270,7 @@ class ProviderManager { modelName: modelInfo.name, ); - await providerConnection?.requestInference( + return providerConnection?.requestInference( modelInfo.id, formattedMessages, params: params, diff --git a/gai-frontend/lib/chat/scripting/chat_message_js.dart b/gai-frontend/lib/chat/scripting/chat_message_js.dart index b5d2eb3c1..2d1897a21 100644 --- a/gai-frontend/lib/chat/scripting/chat_message_js.dart +++ b/gai-frontend/lib/chat/scripting/chat_message_js.dart @@ -21,7 +21,7 @@ class ChatMessageJS { return ChatMessageJS( source: chatMessage.source.name, // enum name not toString() sourceName: chatMessage.sourceName, - msg: chatMessage.msg, + msg: chatMessage.message, metadata: jsonEncode(chatMessage.metadata).toJS, modelId: chatMessage.modelId, modelName: chatMessage.modelName, @@ -35,8 +35,8 @@ class ChatMessageJS { static ChatMessage toChatMessage(ChatMessageJS chatMessageJS) { return ChatMessage( - ChatMessageSource.values.byName(chatMessageJS.source), - chatMessageJS.msg, + source: ChatMessageSource.values.byName(chatMessageJS.source), + message: chatMessageJS.msg, // TODO: // metadata: jsonDecode((chatMessageJS.metadata ?? "").toString()), diff --git a/gai-frontend/lib/chat/scripting/chat_scripting.dart b/gai-frontend/lib/chat/scripting/chat_scripting.dart index 4459355e9..4bbc4c481 100644 --- a/gai-frontend/lib/chat/scripting/chat_scripting.dart +++ b/gai-frontend/lib/chat/scripting/chat_scripting.dart @@ -1,6 +1,7 @@ import 'package:orchid/chat/chat_history.dart'; import 'package:orchid/chat/chat_message.dart'; import 'package:orchid/chat/model.dart'; +import 'package:orchid/chat/provider_connection.dart'; import 'package:orchid/chat/provider_manager.dart'; import 'package:orchid/gui-orchid/lib/orchid/orchid.dart'; import 'dart:js_interop'; @@ -96,9 +97,11 @@ class ChatScripting { // Items that need to be copied before each invocation of the JS scripting extension void updatePerCallBindings({List? userSelectedModels}) { - chatHistoryJS = ChatMessageJS.fromChatMessages(chatHistory.messages).jsify() as JSArray; + chatHistoryJS = + ChatMessageJS.fromChatMessages(chatHistory.messages).jsify() as JSArray; if (userSelectedModels != null) { - userSelectedModelsJS = ModelInfoJS.fromModelInfos(userSelectedModels).jsify() as JSArray; + userSelectedModelsJS = + ModelInfoJS.fromModelInfos(userSelectedModels).jsify() as JSArray; } if (debugMode) { evalExtensionScript(); @@ -121,38 +124,47 @@ class ChatScripting { void addChatMessageFromJS(ChatMessageJS message) { log("Add chat message: ${message.source}, ${message.msg}"); addChatMessageToUI(ChatMessageJS.toChatMessage(message)); - updatePerCallBindings(); // History has changed + // TODO: This can cause looping, let's invert the relevant calls (e.g. history) so that this is necessary. + // updatePerCallBindings(); // History has changed } // Implementation of sendMessagesToModel callback function invoked from JS - // Send a list of ChatMessage to a model for inference + // Send a list of ChatMessage to a model for inference and return a promise of ChatMessageJS JSPromise sendMessagesToModelFromJS( JSArray messagesJS, String modelId, int? maxTokens) { log("dart: Send messages to model called."); - // We must capture the Future and return convert it to a JSPromise + return (() async { try { final listJS = (messagesJS.toDart).cast(); - final messages = ChatMessageJS.toChatMessages(listJS); - log("messages = ${messages}"); + final List messages = ChatMessageJS.toChatMessages(listJS); + log("messages = $messages"); + if (messages.isEmpty) { - return []; + log("No messages to send."); + // Wow: If you forget the .jsify() here the promise will crash, even if this + // code path is not executed. + return null.jsify(); } - // Simulate delay // log("dart: simulate delay"); - // await Future.delayed(const Duration(seconds: 3)); + // await Future.delayed(const Duration(seconds: 2)); // log("dart: after delay response from sendMessagesToModel sent."); + // return ["message 1", "message 2"].jsify(); // Don't forget value to JS // Send the messages to the model - await providerManager.sendMessagesToModel(messages, modelId, maxTokens); + final ChatInferenceResponse? response = await providerManager + .sendMessagesToModel(messages, modelId, maxTokens); + if (response == null) { + log("No response from model."); + return null.jsify(); + } - // TODO: Fake return - return ["message 1", "message 2"].jsify(); // Don't forget value to JS + return ChatMessageJS.fromChatMessage(response.toChatMessage()).jsify(); } catch (e, stack) { log("Failed to send messages to model: $e"); log(stack.toString()); - return ["error: $e"].jsify(); + return null.jsify(); } })() .toJS; diff --git a/gai-frontend/lib/chat/scripting/chat_scripting_api.ts b/gai-frontend/lib/chat/scripting/chat_scripting_api.ts index 165df6acb..43cdf421e 100644 --- a/gai-frontend/lib/chat/scripting/chat_scripting_api.ts +++ b/gai-frontend/lib/chat/scripting/chat_scripting_api.ts @@ -45,15 +45,15 @@ declare let userSelectedModels: ReadonlyArray; declare function sendMessagesToModel( messages: Array, modelId: string, - maxTokens?: number | null, -): Promise> + maxTokens: number | null, +): Promise; // Send a list of formatted messages to a model for inference declare function sendFormattedMessagesToModel( formattedMessages: Array, modelId: string, maxTokens?: number, -): void +): Promise; // Add a chat message to the history declare function addChatMessage(chatMessage: ChatMessage): void @@ -61,7 +61,13 @@ declare function addChatMessage(chatMessage: ChatMessage): void // Extension entry point: The user has hit enter on a new prompt. declare function onUserPrompt(userPrompt: string): void -// Extension entry point: A response came back from inference -declare function onChatResponse(chatResponse: ChatMessage): void +function getConversation(): Array { + // Gather messages of source type 'client' or 'provider', irrespective of the provider model + return chatHistory.filter( + (message) => + message.source === ChatMessageSource.CLIENT || + message.source === ChatMessageSource.PROVIDER + ); +} console.log('Chat Scripting API loaded'); \ No newline at end of file diff --git a/gai-frontend/lib/chat/scripting/extensions/filter_example.ts b/gai-frontend/lib/chat/scripting/extensions/filter_example.ts new file mode 100644 index 000000000..0f8137ab5 --- /dev/null +++ b/gai-frontend/lib/chat/scripting/extensions/filter_example.ts @@ -0,0 +1,31 @@ +/// An example of sending the user prompt to another model for validation before the user-selected model. + +/// Let the IDE see the types from the chat_scripting_api during development. +/// + +function onUserPrompt(userPrompt: string): void { + (async () => { + // Log a system message to the chat + addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, 'Extension: Filter Example', {})); + + // Send the message to 'gpt-4o' and ask it to characterize the user prompt. + let message = new ChatMessage(ChatMessageSource.CLIENT, + `The following is a user-generated prompt. Please characterize it as either friendly or ` + + `unfriendly and respond with just your one word decision: ` + + `{BEGIN_PROMPT}${userPrompt}{END_PROMPT}`, + {}); + let response = await sendMessagesToModel([message], 'gpt-4o', null); + let decision = response.msg.trim().toLowerCase(); + + // Log the decision to the chat + addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, + `Extension: User prompt evaluated as: ${decision}`, {})); + + // Now send the prompt to the first user-selected model + const modelId = userSelectedModels[0].id; + message = new ChatMessage(ChatMessageSource.CLIENT, userPrompt, {}); + addChatMessage(await sendMessagesToModel([message], modelId, null)); + + })(); +} + diff --git a/gai-frontend/lib/chat/scripting/extensions/party_mode.ts b/gai-frontend/lib/chat/scripting/extensions/party_mode.ts index fddf8bfa8..f7b49d608 100644 --- a/gai-frontend/lib/chat/scripting/extensions/party_mode.ts +++ b/gai-frontend/lib/chat/scripting/extensions/party_mode.ts @@ -9,14 +9,17 @@ function onUserPrompt(userPrompt: string): void { addChatMessage(new ChatMessage(ChatMessageSource.SYSTEM, 'Extension: Party mode invoked', {})); addChatMessage(new ChatMessage(ChatMessageSource.CLIENT, userPrompt, {})); + throw new Error('History is not currently updated, fix this...'); + // Gather messages of source type 'client' or 'provider', irrespective of the model + // [See getConversation()] const filteredMessages = chatHistory.filter( (message) => message.source === ChatMessageSource.CLIENT || message.source === ChatMessageSource.PROVIDER ); - // Send them to all user-selected models + // Send to each user-selected model for (const model of userSelectedModels) { console.log(`party_mode: Sending messages to model: ${model.name}`); await sendMessagesToModel(filteredMessages, model.id, null); diff --git a/gai-frontend/lib/chat/scripting/extensions/test.ts b/gai-frontend/lib/chat/scripting/extensions/test.ts index 89ef881f7..95f4b7bce 100644 --- a/gai-frontend/lib/chat/scripting/extensions/test.ts +++ b/gai-frontend/lib/chat/scripting/extensions/test.ts @@ -5,11 +5,7 @@ console.log('test_script: Evaluating JavaScript code from Dart...'); console.log('test_script: Chat History:', chatHistory); - const chatMessage = new ChatMessage( - ChatMessageSource.SYSTEM, - 'Extension: Test Script', - {'foo': 'bar'} - ); + const chatMessage = new ChatMessage(ChatMessageSource.SYSTEM, 'Extension: Test Script', {'foo': 'bar'}); addChatMessage(chatMessage); const promise = sendMessagesToModel([chatMessage], 'test-model', null); diff --git a/gai-frontend/lib/chat/scripting/extensions/test2.ts b/gai-frontend/lib/chat/scripting/extensions/test_load_lib.ts similarity index 100% rename from gai-frontend/lib/chat/scripting/extensions/test2.ts rename to gai-frontend/lib/chat/scripting/extensions/test_load_lib.ts