diff --git a/src/document_loaders/error.rs b/src/document_loaders/error.rs index a5ca4e05..85917d91 100644 --- a/src/document_loaders/error.rs +++ b/src/document_loaders/error.rs @@ -22,6 +22,7 @@ pub enum LoaderError { CSVError(#[from] csv::Error), #[cfg(any(feature = "lopdf"))] + #[cfg(not(feature = "pdf-extract"))] #[error(transparent)] LoPdfError(#[from] lopdf::Error), diff --git a/src/language_models/options.rs b/src/language_models/options.rs index 60642116..1313ea47 100644 --- a/src/language_models/options.rs +++ b/src/language_models/options.rs @@ -2,7 +2,7 @@ use futures::Future; use std::{pin::Pin, sync::Arc}; use tokio::sync::Mutex; -use crate::schemas::{FunctionCallBehavior, FunctionDefinition}; +use crate::schemas::{FunctionCallBehavior, FunctionDefinition, ResponseFormat}; #[derive(Clone)] pub struct CallOptions { @@ -26,6 +26,7 @@ pub struct CallOptions { pub presence_penalty: Option, pub functions: Option>, pub function_call_behavior: Option, + pub response_format: Option, pub stream_usage: Option, } @@ -53,6 +54,7 @@ impl CallOptions { presence_penalty: None, functions: None, function_call_behavior: None, + response_format: None, stream_usage: None, } } @@ -149,6 +151,11 @@ impl CallOptions { self } + pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self { + self.response_format = Some(response_format); + self + } + pub fn with_stream_usage(mut self, stream_usage: bool) -> Self { self.stream_usage = Some(stream_usage); self @@ -175,6 +182,9 @@ impl CallOptions { self.function_call_behavior = incoming_options .function_call_behavior .or(self.function_call_behavior.clone()); + self.response_format = incoming_options + .response_format + .or(self.response_format.clone()); self.stream_usage = incoming_options.stream_usage.or(self.stream_usage); // For `Vec`, merge if both are Some; prefer incoming if only incoming is Some diff --git a/src/llm/openai/mod.rs b/src/llm/openai/mod.rs index 4ab391f4..2d4ec9e2 100644 --- a/src/llm/openai/mod.rs +++ b/src/llm/openai/mod.rs @@ -1,6 +1,8 @@ use std::pin::Pin; pub use async_openai::config::{AzureConfig, Config, OpenAIConfig}; + +use async_openai::types::{ChatCompletionToolChoiceOption, ResponseFormat}; use async_openai::{ error::OpenAIError, types::{ @@ -9,19 +11,19 @@ use async_openai::{ ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart, ChatCompletionStreamOptions, - ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, - CreateChatCompletionRequestArgs, FunctionObjectArgs, + CreateChatCompletionRequest, CreateChatCompletionRequestArgs, }, Client, }; use async_trait::async_trait; use futures::{Stream, StreamExt}; +use crate::schemas::convert::{LangchainIntoOpenAI, TryLangchainIntoOpenAI}; use crate::{ language_models::{llm::LLM, options::CallOptions, GenerateResult, LLMError, TokenUsage}, schemas::{ messages::{Message, MessageType}, - FunctionCallBehavior, StreamData, + StreamData, }, }; @@ -299,38 +301,31 @@ impl OpenAI { request_builder.stop(stop_words); } - if let Some(behavior) = &self.options.functions { - let mut functions = Vec::new(); - for f in behavior.iter() { - let tool = FunctionObjectArgs::default() - .name(f.name.clone()) - .description(f.description.clone()) - .parameters(f.parameters.clone()) - .build()?; - functions.push( - ChatCompletionToolArgs::default() - .r#type(ChatCompletionToolType::Function) - .function(tool) - .build()?, - ) - } - request_builder.tools(functions); + if let Some(functions) = &self.options.functions { + let functions: Result, OpenAIError> = functions + .clone() + .into_iter() + .map(|f| f.try_into_openai()) + .collect(); + request_builder.tools(functions?); } if let Some(behavior) = &self.options.function_call_behavior { - match behavior { - FunctionCallBehavior::Auto => request_builder.tool_choice("auto"), - FunctionCallBehavior::None => request_builder.tool_choice("none"), - FunctionCallBehavior::Named(name) => request_builder.tool_choice(name.as_str()), - }; + request_builder + .tool_choice::(behavior.clone().into_openai()); + } + + if let Some(response_format) = &self.options.response_format { + request_builder + .response_format::(response_format.clone().into_openai()); } + request_builder.messages(messages); Ok(request_builder.build()?) } } #[cfg(test)] mod tests { - use crate::schemas::FunctionDefinition; use super::*; diff --git a/src/schemas/convert.rs b/src/schemas/convert.rs new file mode 100644 index 00000000..a9414b47 --- /dev/null +++ b/src/schemas/convert.rs @@ -0,0 +1,81 @@ +pub trait LangchainIntoOpenAI: Sized { + fn into_openai(self) -> T; +} + +pub trait LangchainFromOpenAI: Sized { + fn from_openai(openai: T) -> Self; +} + +pub trait OpenAiIntoLangchain: Sized { + fn into_langchain(self) -> T; +} + +pub trait OpenAIFromLangchain: Sized { + fn from_langchain(langchain: T) -> Self; +} + +impl LangchainIntoOpenAI for T +where + U: OpenAIFromLangchain, +{ + fn into_openai(self) -> U { + U::from_langchain(self) + } +} + +impl OpenAiIntoLangchain for T +where + U: LangchainFromOpenAI, +{ + fn into_langchain(self) -> U { + U::from_openai(self) + } +} + +// Try into and from OpenAI + +pub trait TryLangchainIntoOpenAI: Sized { + type Error; + + fn try_into_openai(self) -> Result; +} + +pub trait TryLangchainFromOpenAI: Sized { + type Error; + + fn try_from_openai(openai: T) -> Result; +} + +pub trait TryOpenAiIntoLangchain: Sized { + type Error; + + fn try_into_langchain(self) -> Result; +} + +pub trait TryOpenAiFromLangchain: Sized { + type Error; + + fn try_from_langchain(langchain: T) -> Result; +} + +impl TryLangchainIntoOpenAI for T +where + U: TryOpenAiFromLangchain, +{ + type Error = U::Error; + + fn try_into_openai(self) -> Result { + U::try_from_langchain(self) + } +} + +impl TryOpenAiIntoLangchain for T +where + U: TryLangchainFromOpenAI, +{ + type Error = U::Error; + + fn try_into_langchain(self) -> Result { + U::try_from_openai(self) + } +} diff --git a/src/schemas/mod.rs b/src/schemas/mod.rs index 6939fc09..750e4e51 100644 --- a/src/schemas/mod.rs +++ b/src/schemas/mod.rs @@ -19,5 +19,10 @@ pub use retrievers::*; mod tools_openai_like; pub use tools_openai_like::*; +pub mod response_format_openai_like; +pub use response_format_openai_like::*; + +pub mod convert; mod stream; + pub use stream::*; diff --git a/src/schemas/response_format_openai_like.rs b/src/schemas/response_format_openai_like.rs new file mode 100644 index 00000000..d32811ee --- /dev/null +++ b/src/schemas/response_format_openai_like.rs @@ -0,0 +1,35 @@ +use crate::schemas::convert::OpenAIFromLangchain; + +#[derive(Clone, Debug)] +pub enum ResponseFormat { + Text, + JsonObject, + JsonSchema { + description: Option, + name: String, + schema: Option, + strict: Option, + }, +} + +impl OpenAIFromLangchain for async_openai::types::ResponseFormat { + fn from_langchain(langchain: ResponseFormat) -> Self { + match langchain { + ResponseFormat::Text => async_openai::types::ResponseFormat::Text, + ResponseFormat::JsonObject => async_openai::types::ResponseFormat::JsonObject, + ResponseFormat::JsonSchema { + name, + description, + schema, + strict, + } => async_openai::types::ResponseFormat::JsonSchema { + json_schema: async_openai::types::ResponseFormatJsonSchema { + name, + description, + schema, + strict, + }, + }, + } + } +} diff --git a/src/schemas/tools_openai_like.rs b/src/schemas/tools_openai_like.rs index 616b6639..306f220f 100644 --- a/src/schemas/tools_openai_like.rs +++ b/src/schemas/tools_openai_like.rs @@ -1,9 +1,12 @@ -use std::ops::Deref; - +use crate::schemas::convert::{OpenAIFromLangchain, TryOpenAiFromLangchain}; +use crate::tools::Tool; +use async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionTool, ChatCompletionToolArgs, + ChatCompletionToolChoiceOption, ChatCompletionToolType, FunctionName, FunctionObjectArgs, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; - -use crate::tools::Tool; +use std::ops::Deref; #[derive(Clone, Debug)] pub enum FunctionCallBehavior { @@ -12,6 +15,23 @@ pub enum FunctionCallBehavior { Named(String), } +impl OpenAIFromLangchain for ChatCompletionToolChoiceOption { + fn from_langchain(langchain: FunctionCallBehavior) -> Self { + match langchain { + FunctionCallBehavior::Auto => ChatCompletionToolChoiceOption::Auto, + FunctionCallBehavior::None => ChatCompletionToolChoiceOption::None, + FunctionCallBehavior::Named(name) => { + ChatCompletionToolChoiceOption::Named(ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: name.to_owned(), + }, + }) + } + } + } +} + #[derive(Clone, Debug)] pub struct FunctionDefinition { pub name: String, @@ -41,6 +61,22 @@ impl FunctionDefinition { } } +impl TryOpenAiFromLangchain for ChatCompletionTool { + type Error = async_openai::error::OpenAIError; + fn try_from_langchain(langchain: FunctionDefinition) -> Result { + let tool = FunctionObjectArgs::default() + .name(langchain.name) + .description(langchain.description) + .parameters(langchain.parameters) + .build()?; + + ChatCompletionToolArgs::default() + .r#type(ChatCompletionToolType::Function) + .function(tool) + .build() + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct FunctionCallResponse { pub id: String,