Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GigaChat LLM support #917

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ The `Langchain::LLM` module provides a unified interface for interacting with va
- AWS Bedrock
- Azure OpenAI
- Cohere
- GigaChat
- Google Gemini
- Google Vertex AI
- HuggingFace
Expand Down
2 changes: 2 additions & 0 deletions langchain.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,6 @@ Gem::Specification.new do |spec|
spec.add_development_dependency "weaviate-ruby", "~> 0.9.2"
spec.add_development_dependency "wikipedia-client", "~> 1.17.0"
spec.add_development_dependency "power_point_pptx", "~> 0.1.0"
spec.add_development_dependency "gigachat", "~> 0.1.0"
spec.add_development_dependency "event_stream_parser", "~> 1.0.0"
end
2 changes: 2 additions & 0 deletions lib/langchain/assistant/llm/adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def self.build(llm)
LLM::Adapters::Ollama.new
elsif llm.is_a?(Langchain::LLM::OpenAI)
LLM::Adapters::OpenAI.new
elsif llm.is_a?(Langchain::LLM::Gigachat)
LLM::Adapters::Gigachat.new
else
raise ArgumentError, "Unsupported LLM type: #{llm.class}"
end
Expand Down
101 changes: 101 additions & 0 deletions lib/langchain/assistant/llm/adapters/gigachat.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# frozen_string_literal: true

module Langchain
class Assistant
module LLM
module Adapters
class Gigachat < Base
# Build the chat parameters for the GigaChat LLM
#
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(
messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
params[:tool_choice] = build_tool_choice(tool_choice)
params[:parallel_tool_calls] = parallel_tool_calls
end
params
end

# Build a GigaChat message
#
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The image URL
# @param tool_calls [Array] The tool calls
# @param tool_call_id [String] The tool call ID
# @return [Messages::Gigachat] The GigaChat message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
Messages::GigachatMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

# Extract the tool call information from the OpenAI tool call hash
#
# @param tool_call [Hash] The tool call hash
# @return [Array] The tool call information
def extract_tool_call_args(tool_call:)
tool_call_id = tool_call.dig("functions_state_id")

function_name = tool_call.dig("function_call", "name")
tool_name, method_name = function_name.split("__")

tool_arguments = tool_call.dig("function_call", "arguments")
tool_arguments = if tool_arguments.is_a?(Hash)
Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
else
JSON.parse(tool_arguments, symbolize_names: true)
end

[tool_call_id, tool_name, method_name, tool_arguments]
end

# Build the tools for the OpenAI LLM
def build_tools(tools)
tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
end

# Get the allowed assistant.tool_choice values for OpenAI
def allowed_tool_choices
["auto", "none"]
end

# Get the available tool names for OpenAI
def available_tool_names(tools)
build_tools(tools).map { |tool| tool.dig(:function, :name) }
end

def tool_role
Messages::GigachatMessage::TOOL_ROLE
end

def support_system_message?
Messages::GigachatMessage::ROLES.include?("system")
end

private

def build_tool_choice(choice)
case choice
when "auto"
choice
else
{"type" => "function", "function" => {"name" => choice}}
end
end
end
end
end
end
end
141 changes: 141 additions & 0 deletions lib/langchain/assistant/messages/gigachat_message.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# frozen_string_literal: true

module Langchain
class Assistant
module Messages
class GigachatMessage < Base
# GigaChat uses the following roles:
ROLES = [
"system",
"assistant",
"user",
"function"
].freeze

TOOL_ROLE = "function"

# Initialize a new GigaChat message
#
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The URL of the image
# @param tool_calls [Array<Hash>] The tool calls made in the message
# @param tool_call_id [String] The ID of the tool call
def initialize(
role:,
content: nil,
image_url: nil,
tool_calls: [],
tool_call_id: nil
)
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }

@role = role
# Some Tools return content as a JSON hence `.to_s`
@content = content.to_s
@image_url = image_url
@tool_calls = tool_calls
@tool_call_id = tool_call_id
end

# Check if the message came from an LLM
#
# @return [Boolean] true/false whether this message was produced by an LLM
def llm?
assistant?
end

# Convert the message to an GigaChat API-compatible hash
#
# @return [Hash] The message as an GigaChat API-compatible hash
def to_hash
if assistant?
assistant_hash
elsif system?
system_hash
elsif tool?
tool_hash
elsif user?
user_hash
end
end

# Check if the message came from an LLM
#
# @return [Boolean] true/false whether this message was produced by an LLM
def assistant?
role == "assistant"
end

# Check if the message are system instructions
#
# @return [Boolean] true/false whether this message are system instructions
def system?
role == "system"
end

# Check if the message is a tool call
#
# @return [Boolean] true/false whether this message is a tool call
def tool?
role == "function"
end

def user?
role == "user"
end

# Convert the message to an GigaChat API-compatible hash
# @return [Hash] The message as an GigaChat API-compatible hash, with the role as "assistant"
def assistant_hash
if tool_calls.any?
{
role: "assistant",
functions: tool_calls
}
else
{
role: "assistant",
content: build_content_array
}
end
end

# Convert the message to an GigaChat API-compatible hash
# @return [Hash] The message as an GigaChat API-compatible hash, with the role as "system"
def system_hash
{
role: "system",
content: build_content_array
}
end

# Convert the message to an GigaChat API-compatible hash
# @return [Hash] The message as an GigaChat API-compatible hash, with the role as "tool"
def tool_hash
{
role: "function",
function_call_id: tool_call_id,
content: build_content_array
}
end

# Convert the message to an GigaChat API-compatible hash
# @return [Hash] The message as an GigaChat API-compatible hash, with the role as "user"
def user_hash
{
role: "user",
content: build_content_array
}
end

# Builds the content value for the message. GigaChat does not support hash in content
# @return [String] A string of content.
def build_content_array
content
end
end
end
end
end
Loading