Skip to content

Commit

Permalink
OpenAI like 'response_format' (#279)
Browse files Browse the repository at this point in the history
* add response format and conver functions

* format

* disable lopdf erro when pdf-extract is usde
  • Loading branch information
linusbierhoff authored Jan 23, 2025
1 parent a244a77 commit 6c925d6
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 30 deletions.
1 change: 1 addition & 0 deletions src/document_loaders/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub enum LoaderError {
CSVError(#[from] csv::Error),

#[cfg(any(feature = "lopdf"))]
#[cfg(not(feature = "pdf-extract"))]
#[error(transparent)]
LoPdfError(#[from] lopdf::Error),

Expand Down
12 changes: 11 additions & 1 deletion src/language_models/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use futures::Future;
use std::{pin::Pin, sync::Arc};
use tokio::sync::Mutex;

use crate::schemas::{FunctionCallBehavior, FunctionDefinition};
use crate::schemas::{FunctionCallBehavior, FunctionDefinition, ResponseFormat};

#[derive(Clone)]
pub struct CallOptions {
Expand All @@ -26,6 +26,7 @@ pub struct CallOptions {
pub presence_penalty: Option<f32>,
pub functions: Option<Vec<FunctionDefinition>>,
pub function_call_behavior: Option<FunctionCallBehavior>,
pub response_format: Option<ResponseFormat>,
pub stream_usage: Option<bool>,
}

Expand Down Expand Up @@ -53,6 +54,7 @@ impl CallOptions {
presence_penalty: None,
functions: None,
function_call_behavior: None,
response_format: None,
stream_usage: None,
}
}
Expand Down Expand Up @@ -149,6 +151,11 @@ impl CallOptions {
self
}

pub fn with_response_format(mut self, response_format: ResponseFormat) -> Self {
self.response_format = Some(response_format);
self
}

pub fn with_stream_usage(mut self, stream_usage: bool) -> Self {
self.stream_usage = Some(stream_usage);
self
Expand All @@ -175,6 +182,9 @@ impl CallOptions {
self.function_call_behavior = incoming_options
.function_call_behavior
.or(self.function_call_behavior.clone());
self.response_format = incoming_options
.response_format
.or(self.response_format.clone());
self.stream_usage = incoming_options.stream_usage.or(self.stream_usage);

// For `Vec<String>`, merge if both are Some; prefer incoming if only incoming is Some
Expand Down
45 changes: 20 additions & 25 deletions src/llm/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::pin::Pin;

pub use async_openai::config::{AzureConfig, Config, OpenAIConfig};

use async_openai::types::{ChatCompletionToolChoiceOption, ResponseFormat};
use async_openai::{
error::OpenAIError,
types::{
Expand All @@ -9,19 +11,19 @@ use async_openai::{
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionStreamOptions,
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
CreateChatCompletionRequestArgs, FunctionObjectArgs,
CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
},
Client,
};
use async_trait::async_trait;
use futures::{Stream, StreamExt};

use crate::schemas::convert::{LangchainIntoOpenAI, TryLangchainIntoOpenAI};
use crate::{
language_models::{llm::LLM, options::CallOptions, GenerateResult, LLMError, TokenUsage},
schemas::{
messages::{Message, MessageType},
FunctionCallBehavior, StreamData,
StreamData,
},
};

Expand Down Expand Up @@ -299,38 +301,31 @@ impl<C: Config> OpenAI<C> {
request_builder.stop(stop_words);
}

if let Some(behavior) = &self.options.functions {
let mut functions = Vec::new();
for f in behavior.iter() {
let tool = FunctionObjectArgs::default()
.name(f.name.clone())
.description(f.description.clone())
.parameters(f.parameters.clone())
.build()?;
functions.push(
ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(tool)
.build()?,
)
}
request_builder.tools(functions);
if let Some(functions) = &self.options.functions {
let functions: Result<Vec<_>, OpenAIError> = functions
.clone()
.into_iter()
.map(|f| f.try_into_openai())
.collect();
request_builder.tools(functions?);
}

if let Some(behavior) = &self.options.function_call_behavior {
match behavior {
FunctionCallBehavior::Auto => request_builder.tool_choice("auto"),
FunctionCallBehavior::None => request_builder.tool_choice("none"),
FunctionCallBehavior::Named(name) => request_builder.tool_choice(name.as_str()),
};
request_builder
.tool_choice::<ChatCompletionToolChoiceOption>(behavior.clone().into_openai());
}

if let Some(response_format) = &self.options.response_format {
request_builder
.response_format::<ResponseFormat>(response_format.clone().into_openai());
}

request_builder.messages(messages);
Ok(request_builder.build()?)
}
}
#[cfg(test)]
mod tests {

use crate::schemas::FunctionDefinition;

use super::*;
Expand Down
81 changes: 81 additions & 0 deletions src/schemas/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
pub trait LangchainIntoOpenAI<T>: Sized {
fn into_openai(self) -> T;
}

pub trait LangchainFromOpenAI<T>: Sized {
fn from_openai(openai: T) -> Self;
}

pub trait OpenAiIntoLangchain<T>: Sized {
fn into_langchain(self) -> T;
}

pub trait OpenAIFromLangchain<T>: Sized {
fn from_langchain(langchain: T) -> Self;
}

impl<T, U> LangchainIntoOpenAI<U> for T
where
U: OpenAIFromLangchain<T>,
{
fn into_openai(self) -> U {
U::from_langchain(self)
}
}

impl<T, U> OpenAiIntoLangchain<U> for T
where
U: LangchainFromOpenAI<T>,
{
fn into_langchain(self) -> U {
U::from_openai(self)
}
}

// Try into and from OpenAI

pub trait TryLangchainIntoOpenAI<T>: Sized {
type Error;

fn try_into_openai(self) -> Result<T, Self::Error>;
}

pub trait TryLangchainFromOpenAI<T>: Sized {
type Error;

fn try_from_openai(openai: T) -> Result<Self, Self::Error>;
}

pub trait TryOpenAiIntoLangchain<T>: Sized {
type Error;

fn try_into_langchain(self) -> Result<T, Self::Error>;
}

pub trait TryOpenAiFromLangchain<T>: Sized {
type Error;

fn try_from_langchain(langchain: T) -> Result<Self, Self::Error>;
}

impl<T, U> TryLangchainIntoOpenAI<U> for T
where
U: TryOpenAiFromLangchain<T>,
{
type Error = U::Error;

fn try_into_openai(self) -> Result<U, U::Error> {
U::try_from_langchain(self)
}
}

impl<T, U> TryOpenAiIntoLangchain<U> for T
where
U: TryLangchainFromOpenAI<T>,
{
type Error = U::Error;

fn try_into_langchain(self) -> Result<U, U::Error> {
U::try_from_openai(self)
}
}
5 changes: 5 additions & 0 deletions src/schemas/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,10 @@ pub use retrievers::*;
mod tools_openai_like;
pub use tools_openai_like::*;

pub mod response_format_openai_like;
pub use response_format_openai_like::*;

pub mod convert;
mod stream;

pub use stream::*;
35 changes: 35 additions & 0 deletions src/schemas/response_format_openai_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::schemas::convert::OpenAIFromLangchain;

#[derive(Clone, Debug)]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema {
description: Option<String>,
name: String,
schema: Option<serde_json::Value>,
strict: Option<bool>,
},
}

impl OpenAIFromLangchain<ResponseFormat> for async_openai::types::ResponseFormat {
fn from_langchain(langchain: ResponseFormat) -> Self {
match langchain {
ResponseFormat::Text => async_openai::types::ResponseFormat::Text,
ResponseFormat::JsonObject => async_openai::types::ResponseFormat::JsonObject,
ResponseFormat::JsonSchema {
name,
description,
schema,
strict,
} => async_openai::types::ResponseFormat::JsonSchema {
json_schema: async_openai::types::ResponseFormatJsonSchema {
name,
description,
schema,
strict,
},
},
}
}
}
44 changes: 40 additions & 4 deletions src/schemas/tools_openai_like.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::ops::Deref;

use crate::schemas::convert::{OpenAIFromLangchain, TryOpenAiFromLangchain};
use crate::tools::Tool;
use async_openai::types::{
ChatCompletionNamedToolChoice, ChatCompletionTool, ChatCompletionToolArgs,
ChatCompletionToolChoiceOption, ChatCompletionToolType, FunctionName, FunctionObjectArgs,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::tools::Tool;
use std::ops::Deref;

#[derive(Clone, Debug)]
pub enum FunctionCallBehavior {
Expand All @@ -12,6 +15,23 @@ pub enum FunctionCallBehavior {
Named(String),
}

impl OpenAIFromLangchain<FunctionCallBehavior> for ChatCompletionToolChoiceOption {
fn from_langchain(langchain: FunctionCallBehavior) -> Self {
match langchain {
FunctionCallBehavior::Auto => ChatCompletionToolChoiceOption::Auto,
FunctionCallBehavior::None => ChatCompletionToolChoiceOption::None,
FunctionCallBehavior::Named(name) => {
ChatCompletionToolChoiceOption::Named(ChatCompletionNamedToolChoice {
r#type: ChatCompletionToolType::Function,
function: FunctionName {
name: name.to_owned(),
},
})
}
}
}
}

#[derive(Clone, Debug)]
pub struct FunctionDefinition {
pub name: String,
Expand Down Expand Up @@ -41,6 +61,22 @@ impl FunctionDefinition {
}
}

impl TryOpenAiFromLangchain<FunctionDefinition> for ChatCompletionTool {
type Error = async_openai::error::OpenAIError;
fn try_from_langchain(langchain: FunctionDefinition) -> Result<Self, Self::Error> {
let tool = FunctionObjectArgs::default()
.name(langchain.name)
.description(langchain.description)
.parameters(langchain.parameters)
.build()?;

ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(tool)
.build()
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct FunctionCallResponse {
pub id: String,
Expand Down

0 comments on commit 6c925d6

Please sign in to comment.