From fef89f9d4e8a72d6aac280d754c49596baaede3a Mon Sep 17 00:00:00 2001 From: suhov Date: Sat, 8 Feb 2025 13:43:53 +0300 Subject: [PATCH 1/2] GigaChart support added --- Gemfile.lock | 501 ----------- README.md | 1 + langchain.gemspec | 4 +- lib/langchain/assistant/llm/adapter.rb | 2 + .../assistant/llm/adapters/gigachat.rb | 101 +++ .../assistant/messages/gigachat_message.rb | 141 +++ lib/langchain/llm/gigachat.rb | 211 +++++ .../llm/response/gigachat_response.rb | 63 ++ spec/fixtures/llm/gigachat/chat.json | 30 + spec/fixtures/llm/gigachat/chat_chunk.json | 15 + .../llm/gigachat/chat_with_function_call.json | 28 + spec/langchain/llm/gigachat_spec.rb | 815 ++++++++++++++++++ .../llm/response/gigachat_response_spec.rb | 63 ++ 13 files changed, 1473 insertions(+), 502 deletions(-) delete mode 100644 Gemfile.lock create mode 100644 lib/langchain/assistant/llm/adapters/gigachat.rb create mode 100644 lib/langchain/assistant/messages/gigachat_message.rb create mode 100644 lib/langchain/llm/gigachat.rb create mode 100644 lib/langchain/llm/response/gigachat_response.rb create mode 100644 spec/fixtures/llm/gigachat/chat.json create mode 100644 spec/fixtures/llm/gigachat/chat_chunk.json create mode 100644 spec/fixtures/llm/gigachat/chat_with_function_call.json create mode 100644 spec/langchain/llm/gigachat_spec.rb create mode 100644 spec/langchain/llm/response/gigachat_response_spec.rb diff --git a/Gemfile.lock b/Gemfile.lock deleted file mode 100644 index 89efb2ef0..000000000 --- a/Gemfile.lock +++ /dev/null @@ -1,501 +0,0 @@ -PATH - remote: . - specs: - langchainrb (0.19.3) - baran (~> 0.1.9) - csv - json-schema (~> 4) - matrix - pragmatic_segmenter (~> 0.3.0) - zeitwerk (~> 2.5) - -GEM - remote: https://rubygems.org/ - specs: - Ascii85 (1.1.1) - actionpack (7.2.2.1) - actionview (= 7.2.2.1) - activesupport (= 7.2.2.1) - nokogiri (>= 1.8.5) - racc - rack (>= 2.2.4, < 3.2) - rack-session (>= 1.0.1) - rack-test (>= 0.6.3) - rails-dom-testing (~> 2.2) - rails-html-sanitizer (~> 1.6) - useragent (~> 0.16) - actionview (7.2.2.1) - activesupport (= 7.2.2.1) - builder (~> 3.1) - erubi (~> 1.11) - rails-dom-testing (~> 2.2) - rails-html-sanitizer (~> 1.6) - activesupport (7.2.2.1) - base64 - benchmark (>= 0.3) - bigdecimal - concurrent-ruby (~> 1.0, >= 1.3.1) - connection_pool (>= 2.2.5) - drb - i18n (>= 1.6, < 2) - logger (>= 1.4.2) - minitest (>= 5.1) - securerandom (>= 0.3) - tzinfo (~> 2.0, >= 2.0.5) - addressable (2.8.6) - public_suffix (>= 2.0.2, < 6.0) - afm (0.2.2) - ai21 (0.2.1) - anthropic (0.3.2) - event_stream_parser (>= 0.3.0, < 2.0.0) - faraday (>= 1) - faraday-multipart (>= 1) - ast (2.4.2) - aws-eventstream (1.3.0) - aws-partitions (1.992.0) - aws-sdk-bedrockruntime (1.27.0) - aws-sdk-core (~> 3, >= 3.210.0) - aws-sigv4 (~> 1.5) - aws-sdk-core (3.210.0) - aws-eventstream (~> 1, >= 1.3.0) - aws-partitions (~> 1, >= 1.992.0) - aws-sigv4 (~> 1.9) - jmespath (~> 1, >= 1.6.1) - aws-sigv4 (1.10.0) - aws-eventstream (~> 1, >= 1.0.2) - baran (0.1.12) - base64 (0.2.0) - benchmark (0.4.0) - bigdecimal (3.1.8) - builder (3.3.0) - byebug (11.1.3) - childprocess (5.0.0) - chroma-db (0.6.0) - dry-monads (~> 1.6) - ruby-next (>= 0.15.0) - coderay (1.1.3) - cohere-ruby (0.9.10) - faraday (>= 2.0.1, < 3.0) - concurrent-ruby (1.3.4) - connection_pool (2.4.1) - crack (1.0.0) - bigdecimal - rexml - crass (1.0.6) - csv (3.3.2) - date (3.3.4) - diff-lcs (1.5.1) - docx (0.8.0) - nokogiri (~> 1.13, >= 1.13.0) - rubyzip (~> 2.0) - dotenv (3.1.6) - dotenv-rails (3.1.6) - dotenv (= 3.1.6) - railties (>= 6.1) - drb (2.2.1) - dry-configurable (1.1.0) - dry-core (~> 1.0, < 2) - zeitwerk (~> 2.6) - dry-core (1.0.1) - concurrent-ruby (~> 1.0) - zeitwerk (~> 2.6) - dry-inflector (1.0.0) - dry-initializer (3.1.1) - dry-logic (1.5.0) - concurrent-ruby (~> 1.0) - dry-core (~> 1.0, < 2) - zeitwerk (~> 2.6) - dry-monads (1.6.0) - concurrent-ruby (~> 1.0) - dry-core (~> 1.0, < 2) - zeitwerk (~> 2.6) - dry-schema (1.13.4) - concurrent-ruby (~> 1.0) - dry-configurable (~> 1.0, >= 1.0.1) - dry-core (~> 1.0, < 2) - dry-initializer (~> 3.0) - dry-logic (>= 1.4, < 2) - dry-types (>= 1.7, < 2) - zeitwerk (~> 2.6) - dry-struct (1.6.0) - dry-core (~> 1.0, < 2) - dry-types (>= 1.7, < 2) - ice_nine (~> 0.11) - zeitwerk (~> 2.6) - dry-types (1.7.2) - bigdecimal (~> 3.0) - concurrent-ruby (~> 1.0) - dry-core (~> 1.0) - dry-inflector (~> 1.0) - dry-logic (~> 1.4) - zeitwerk (~> 2.6) - dry-validation (1.10.0) - concurrent-ruby (~> 1.0) - dry-core (~> 1.0, < 2) - dry-initializer (~> 3.0) - dry-schema (>= 1.12, < 2) - zeitwerk (~> 2.6) - elastic-transport (8.3.2) - faraday (< 3) - multi_json - elasticsearch (8.2.2) - elastic-transport (~> 8.0) - elasticsearch-api (= 8.2.2) - elasticsearch-api (8.2.2) - multi_json - epsilla-ruby (0.0.4) - faraday (>= 1) - eqn (1.6.5) - treetop (>= 1.2.0) - erubi (1.13.0) - ethon (0.16.0) - ffi (>= 1.15.0) - event_stream_parser (1.0.0) - faraday (2.12.0) - faraday-net_http (>= 2.0, < 3.4) - json - logger - faraday-multipart (1.0.4) - multipart-post (~> 2) - faraday-net_http (3.3.0) - net-http - faraday-retry (2.2.1) - faraday (~> 2.0) - faraday-typhoeus (1.1.0) - faraday (~> 2.0) - typhoeus (~> 1.4) - ffi (1.16.3) - fiber-storage (1.0.0) - google-cloud-env (2.1.1) - faraday (>= 1.0, < 3.a) - google_search_results (2.0.1) - googleauth (1.11.0) - faraday (>= 1.0, < 3.a) - google-cloud-env (~> 2.1) - jwt (>= 1.4, < 3.0) - multi_json (~> 1.11) - os (>= 0.9, < 2.0) - signet (>= 0.16, < 2.a) - graphlient (0.8.0) - faraday (~> 2.0) - graphql-client - graphql (2.3.16) - base64 - fiber-storage - graphql-client (0.23.0) - activesupport (>= 3.0) - graphql (>= 1.13.0) - hashdiff (1.1.0) - hashery (2.1.2) - hnswlib (0.8.1) - httparty (0.21.0) - mini_mime (>= 1.0.0) - multi_xml (>= 0.5.2) - hugging-face (0.3.5) - faraday (>= 1.0) - i18n (1.14.6) - concurrent-ruby (~> 1.0) - ice_nine (0.11.2) - io-console (0.8.0) - irb (1.14.2) - rdoc (>= 4.0.0) - reline (>= 0.4.2) - jmespath (1.6.2) - json (2.7.2) - json-schema (4.3.0) - addressable (>= 2.8) - jwt (2.8.1) - base64 - language_server-protocol (3.17.0.3) - lint_roller (1.1.0) - llama_cpp (0.9.5) - logger (1.6.3) - loofah (2.23.1) - crass (~> 1.0.2) - nokogiri (>= 1.12.0) - mail (2.8.1) - mini_mime (>= 0.1.1) - net-imap - net-pop - net-smtp - matrix (0.4.2) - method_source (1.1.0) - milvus (0.10.3) - faraday (>= 2.0.1, < 3) - mini_mime (1.1.5) - mini_portile2 (2.8.8) - minitest (5.25.4) - mistral-ai (1.2.0) - event_stream_parser (~> 1.0) - faraday (~> 2.9) - faraday-typhoeus (~> 1.1) - multi_json (1.15.0) - multi_xml (0.7.1) - bigdecimal (~> 3.1) - multipart-post (2.4.1) - net-http (0.4.1) - uri - net-imap (0.4.11) - date - net-protocol - net-pop (0.1.2) - net-protocol - net-protocol (0.2.2) - timeout - net-smtp (0.5.0) - net-protocol - nokogiri (1.18.1) - mini_portile2 (~> 2.8.2) - racc (~> 1.4) - nokogiri (1.18.1-aarch64-linux-gnu) - racc (~> 1.4) - nokogiri (1.18.1-arm-linux-gnu) - racc (~> 1.4) - nokogiri (1.18.1-arm64-darwin) - racc (~> 1.4) - nokogiri (1.18.1-x86_64-darwin) - racc (~> 1.4) - nokogiri (1.18.1-x86_64-linux-gnu) - racc (~> 1.4) - nokogiri (1.18.1-x86_64-linux-musl) - racc (~> 1.4) - os (1.1.4) - paco (0.2.3) - parallel (1.25.1) - parser (3.3.3.0) - ast (~> 2.4.1) - racc - pdf-reader (2.12.0) - Ascii85 (~> 1.0) - afm (~> 0.2.1) - hashery (~> 2.0) - ruby-rc4 - ttfunk - pg (1.5.6) - pgvector (0.2.2) - pinecone (0.1.71) - dry-struct (~> 1.6.0) - dry-validation (~> 1.10.0) - httparty (~> 0.21.0) - polyglot (0.3.5) - power_point_pptx (0.1.0) - nokogiri (~> 1.13, >= 1.13.0) - rubyzip (~> 2.0) - pragmatic_segmenter (0.3.23) - unicode - pry (0.14.2) - coderay (~> 1.1) - method_source (~> 1.0) - pry-byebug (3.10.1) - byebug (~> 11.0) - pry (>= 0.13, < 0.15) - psych (5.2.1) - date - stringio - public_suffix (5.0.5) - qdrant-ruby (0.9.8) - faraday (>= 2.0.1, < 3) - racc (1.8.1) - rack (3.1.8) - rack-session (2.0.0) - rack (>= 3.0.0) - rack-test (2.1.0) - rack (>= 1.3) - rackup (2.2.1) - rack (>= 3) - rails-dom-testing (2.2.0) - activesupport (>= 5.0.0) - minitest - nokogiri (>= 1.6) - rails-html-sanitizer (1.6.2) - loofah (~> 2.21) - nokogiri (>= 1.15.7, != 1.16.7, != 1.16.6, != 1.16.5, != 1.16.4, != 1.16.3, != 1.16.2, != 1.16.1, != 1.16.0.rc1, != 1.16.0) - railties (7.2.2.1) - actionpack (= 7.2.2.1) - activesupport (= 7.2.2.1) - irb (~> 1.13) - rackup (>= 1.0.0) - rake (>= 12.2) - thor (~> 1.0, >= 1.2.2) - zeitwerk (~> 2.6) - rainbow (3.1.1) - rake (13.2.1) - rdiscount (2.2.7.3) - rdoc (6.9.1) - psych (>= 4.0.0) - regexp_parser (2.9.2) - reline (0.6.0) - io-console (~> 0.5) - replicate-ruby (0.2.3) - addressable - faraday (>= 1.0) - faraday-multipart - faraday-retry - require-hooks (0.2.2) - rexml (3.3.9) - roo (2.10.1) - nokogiri (~> 1) - rubyzip (>= 1.3.0, < 3.0.0) - roo-xls (1.2.0) - nokogiri - roo (>= 2.0.0, < 3) - spreadsheet (> 0.9.0) - rspec (3.13.0) - rspec-core (~> 3.13.0) - rspec-expectations (~> 3.13.0) - rspec-mocks (~> 3.13.0) - rspec-core (3.13.0) - rspec-support (~> 3.13.0) - rspec-expectations (3.13.0) - diff-lcs (>= 1.2.0, < 2.0) - rspec-support (~> 3.13.0) - rspec-mocks (3.13.1) - diff-lcs (>= 1.2.0, < 2.0) - rspec-support (~> 3.13.0) - rspec-support (3.13.1) - rubocop (1.64.1) - json (~> 2.3) - language_server-protocol (>= 3.17.0) - parallel (~> 1.10) - parser (>= 3.3.0.2) - rainbow (>= 2.2.2, < 4.0) - regexp_parser (>= 1.8, < 3.0) - rexml (>= 3.2.5, < 4.0) - rubocop-ast (>= 1.31.1, < 2.0) - ruby-progressbar (~> 1.7) - unicode-display_width (>= 2.4.0, < 3.0) - rubocop-ast (1.31.3) - parser (>= 3.3.1.0) - rubocop-performance (1.21.1) - rubocop (>= 1.48.1, < 2.0) - rubocop-ast (>= 1.31.1, < 2.0) - ruby-next (1.0.3) - paco (~> 0.2) - require-hooks (~> 0.2) - ruby-next-core (= 1.0.3) - ruby-next-parser (>= 3.2.2.0) - unparser (~> 0.6.0) - ruby-next-core (1.0.3) - ruby-next-parser (3.2.2.0) - parser (>= 3.0.3.1) - ruby-ole (1.2.13.1) - ruby-openai (7.1.0) - event_stream_parser (>= 0.3.0, < 2.0.0) - faraday (>= 1) - faraday-multipart (>= 1) - ruby-progressbar (1.13.0) - ruby-rc4 (0.1.5) - rubyzip (2.3.2) - safe_ruby (1.0.4) - childprocess (>= 0.3.9) - securerandom (0.4.1) - sequel (5.87.0) - bigdecimal - signet (0.19.0) - addressable (~> 2.8) - faraday (>= 0.17.5, < 3.a) - jwt (>= 1.5, < 3.0) - multi_json (~> 1.10) - spreadsheet (1.3.1) - bigdecimal - ruby-ole - standard (1.39.1) - language_server-protocol (~> 3.17.0.2) - lint_roller (~> 1.0) - rubocop (~> 1.64.0) - standard-custom (~> 1.0.0) - standard-performance (~> 1.4) - standard-custom (1.0.2) - lint_roller (~> 1.0) - rubocop (~> 1.50) - standard-performance (1.4.0) - lint_roller (~> 1.1) - rubocop-performance (~> 1.21.0) - stringio (3.1.2) - thor (1.3.2) - timeout (0.4.1) - treetop (1.6.12) - polyglot (~> 0.3) - ttfunk (1.8.0) - bigdecimal (~> 3.1) - typhoeus (1.4.1) - ethon (>= 0.9.0) - tzinfo (2.0.6) - concurrent-ruby (~> 1.0) - unicode (0.4.4.5) - unicode-display_width (2.5.0) - unparser (0.6.13) - diff-lcs (~> 1.3) - parser (>= 3.3.0) - uri (0.13.1) - useragent (0.16.11) - vcr (6.2.0) - weaviate-ruby (0.9.2) - faraday (>= 2.0.1, < 3.0) - graphlient (>= 0.7.0, < 0.9.0) - webmock (3.23.1) - addressable (>= 2.8.0) - crack (>= 0.3.2) - hashdiff (>= 0.4.0, < 2.0.0) - wikipedia-client (1.17.0) - addressable (~> 2.7) - yard (0.9.36) - zeitwerk (2.6.18) - -PLATFORMS - aarch64-linux - arm-linux - arm64-darwin - ruby - x86_64-darwin - x86_64-linux - x86_64-linux-musl - -DEPENDENCIES - ai21 (~> 0.2.1) - anthropic (~> 0.3) - aws-sdk-bedrockruntime (~> 1.1) - chroma-db (~> 0.6.0) - cohere-ruby (~> 0.9.10) - docx (~> 0.8.0) - dotenv-rails (~> 3.1.6) - elasticsearch (~> 8.2.0) - epsilla-ruby (~> 0.0.4) - eqn (~> 1.6.5) - faraday - google_search_results (~> 2.0.0) - googleauth - hnswlib (~> 0.8.1) - hugging-face (~> 0.3.4) - langchainrb! - llama_cpp (~> 0.9.4) - mail (~> 2.8) - milvus (~> 0.10.3) - mistral-ai - nokogiri (~> 1.13) - pdf-reader (~> 2.0) - pg (~> 1.5) - pgvector (~> 0.2.1) - pinecone (~> 0.1.6) - power_point_pptx (~> 0.1.0) - pry-byebug (~> 3.10.0) - qdrant-ruby (~> 0.9.8) - rake (~> 13.0) - rdiscount (~> 2.2.7) - replicate-ruby (~> 0.2.2) - roo (~> 2.10.0) - roo-xls (~> 1.2.0) - rspec (~> 3.0) - rubocop - ruby-openai (~> 7.1.0) - safe_ruby (~> 1.0.4) - sequel (~> 5.87.0) - standard (>= 1.35.1) - vcr - weaviate-ruby (~> 0.9.2) - webmock - wikipedia-client (~> 1.17.0) - yard (~> 0.9.34) - -BUNDLED WITH - 2.5.11 diff --git a/README.md b/README.md index f1e2b30b9..dcaaef923 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ The `Langchain::LLM` module provides a unified interface for interacting with va - AWS Bedrock - Azure OpenAI - Cohere +- GigaChat - Google Gemini - Google Vertex AI - HuggingFace diff --git a/langchain.gemspec b/langchain.gemspec index 0f4b256f2..eadf059ed 100644 --- a/langchain.gemspec +++ b/langchain.gemspec @@ -58,7 +58,7 @@ Gem::Specification.new do |spec| spec.add_development_dependency "hnswlib", "~> 0.8.1" spec.add_development_dependency "hugging-face", "~> 0.3.4" spec.add_development_dependency "milvus", "~> 0.10.3" - spec.add_development_dependency "llama_cpp", "~> 0.9.4" + #spec.add_development_dependency "llama_cpp", "~> 0.9.4" spec.add_development_dependency "nokogiri", "~> 1.13" spec.add_development_dependency "mail", "~> 2.8" spec.add_development_dependency "mistral-ai" @@ -76,4 +76,6 @@ Gem::Specification.new do |spec| spec.add_development_dependency "weaviate-ruby", "~> 0.9.2" spec.add_development_dependency "wikipedia-client", "~> 1.17.0" spec.add_development_dependency "power_point_pptx", "~> 0.1.0" + spec.add_development_dependency "gigachat", "~> 0.1.0" + spec.add_development_dependency "event_stream_parser", "~> 1.0.0" end diff --git a/lib/langchain/assistant/llm/adapter.rb b/lib/langchain/assistant/llm/adapter.rb index 6a2e969f9..cf20f4361 100644 --- a/lib/langchain/assistant/llm/adapter.rb +++ b/lib/langchain/assistant/llm/adapter.rb @@ -18,6 +18,8 @@ def self.build(llm) LLM::Adapters::Ollama.new elsif llm.is_a?(Langchain::LLM::OpenAI) LLM::Adapters::OpenAI.new + elsif llm.is_a?(Langchain::LLM::Gigachat) + LLM::Adapters::Gigachat.new else raise ArgumentError, "Unsupported LLM type: #{llm.class}" end diff --git a/lib/langchain/assistant/llm/adapters/gigachat.rb b/lib/langchain/assistant/llm/adapters/gigachat.rb new file mode 100644 index 000000000..65c3a90f6 --- /dev/null +++ b/lib/langchain/assistant/llm/adapters/gigachat.rb @@ -0,0 +1,101 @@ +# frozen_string_literal: true + +module Langchain + class Assistant + module LLM + module Adapters + class Gigachat < Base + # Build the chat parameters for the GigaChat LLM + # + # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use + # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls + # @return [Hash] The chat parameters + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) + params = {messages: messages} + if tools.any? + params[:tools] = build_tools(tools) + params[:tool_choice] = build_tool_choice(tool_choice) + params[:parallel_tool_calls] = parallel_tool_calls + end + params + end + + # Build a GigaChat message + # + # @param role [String] The role of the message + # @param content [String] The content of the message + # @param image_url [String] The image URL + # @param tool_calls [Array] The tool calls + # @param tool_call_id [String] The tool call ID + # @return [Messages::Gigachat] The GigaChat message + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + Messages::GigachatMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id) + end + + # Extract the tool call information from the OpenAI tool call hash + # + # @param tool_call [Hash] The tool call hash + # @return [Array] The tool call information + def extract_tool_call_args(tool_call:) + tool_call_id = tool_call.dig("functions_state_id") + + function_name = tool_call.dig("function_call", "name") + tool_name, method_name = function_name.split("__") + + tool_arguments = tool_call.dig("function_call", "arguments") + tool_arguments = if tool_arguments.is_a?(Hash) + Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments) + else + JSON.parse(tool_arguments, symbolize_names: true) + end + + [tool_call_id, tool_name, method_name, tool_arguments] + end + + # Build the tools for the OpenAI LLM + def build_tools(tools) + tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten + end + + # Get the allowed assistant.tool_choice values for OpenAI + def allowed_tool_choices + ["auto", "none"] + end + + # Get the available tool names for OpenAI + def available_tool_names(tools) + build_tools(tools).map { |tool| tool.dig(:function, :name) } + end + + def tool_role + Messages::GigachatMessage::TOOL_ROLE + end + + def support_system_message? + Messages::GigachatMessage::ROLES.include?("system") + end + + private + + def build_tool_choice(choice) + case choice + when "auto" + choice + else + {"type" => "function", "function" => {"name" => choice}} + end + end + end + end + end + end +end diff --git a/lib/langchain/assistant/messages/gigachat_message.rb b/lib/langchain/assistant/messages/gigachat_message.rb new file mode 100644 index 000000000..3533cfc7d --- /dev/null +++ b/lib/langchain/assistant/messages/gigachat_message.rb @@ -0,0 +1,141 @@ +# frozen_string_literal: true + +module Langchain + class Assistant + module Messages + class GigachatMessage < Base + # GigaChat uses the following roles: + ROLES = [ + "system", + "assistant", + "user", + "function" + ].freeze + + TOOL_ROLE = "function" + + # Initialize a new GigaChat message + # + # @param role [String] The role of the message + # @param content [String] The content of the message + # @param image_url [String] The URL of the image + # @param tool_calls [Array] The tool calls made in the message + # @param tool_call_id [String] The ID of the tool call + def initialize( + role:, + content: nil, + image_url: nil, + tool_calls: [], + tool_call_id: nil + ) + raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role) + raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) } + + @role = role + # Some Tools return content as a JSON hence `.to_s` + @content = content.to_s + @image_url = image_url + @tool_calls = tool_calls + @tool_call_id = tool_call_id + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def llm? + assistant? + end + + # Convert the message to an GigaChat API-compatible hash + # + # @return [Hash] The message as an GigaChat API-compatible hash + def to_hash + if assistant? + assistant_hash + elsif system? + system_hash + elsif tool? + tool_hash + elsif user? + user_hash + end + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def assistant? + role == "assistant" + end + + # Check if the message are system instructions + # + # @return [Boolean] true/false whether this message are system instructions + def system? + role == "system" + end + + # Check if the message is a tool call + # + # @return [Boolean] true/false whether this message is a tool call + def tool? + role == "function" + end + + def user? + role == "user" + end + + # Convert the message to an GigaChat API-compatible hash + # @return [Hash] The message as an GigaChat API-compatible hash, with the role as "assistant" + def assistant_hash + if tool_calls.any? + { + role: "assistant", + functions: tool_calls + } + else + { + role: "assistant", + content: build_content_array + } + end + end + + # Convert the message to an GigaChat API-compatible hash + # @return [Hash] The message as an GigaChat API-compatible hash, with the role as "system" + def system_hash + { + role: "system", + content: build_content_array + } + end + + # Convert the message to an GigaChat API-compatible hash + # @return [Hash] The message as an GigaChat API-compatible hash, with the role as "tool" + def tool_hash + { + role: "function", + function_call_id: tool_call_id, + content: build_content_array + } + end + + # Convert the message to an GigaChat API-compatible hash + # @return [Hash] The message as an GigaChat API-compatible hash, with the role as "user" + def user_hash + { + role: "user", + content: build_content_array + } + end + + # Builds the content value for the message. GigaChat does not support hash in content + # @return [String] A string of content. + def build_content_array + content + end + end + end + end +end diff --git a/lib/langchain/llm/gigachat.rb b/lib/langchain/llm/gigachat.rb new file mode 100644 index 000000000..ee49b44e8 --- /dev/null +++ b/lib/langchain/llm/gigachat.rb @@ -0,0 +1,211 @@ +# frozen_string_literal: true + +module Langchain::LLM + # LLM interface for Gigachat APIs: https://developers.sber.ru/docs/ru/gigachat/api/overview + # + # Gem requirements: + # gem "gigachat", "~> 0.1.0" + # + # Usage: + # llm = Langchain::LLM::Gigachat.new( + # api_type: "GIGACHAT_API_CORP", # or GIGACHAT_API_PERS, GIGACHAT_API_B2B + # api_key: "Yjgy...VhYw==" # your authorization data + # ) + class Gigachat < Base + DEFAULTS = { + temperature: 0.0, + chat_model: "GigaChat", + embedding_model: "GigaChat" + }.freeze + + EMBEDDING_SIZES = {}.freeze + + # Initialize an GigaChat LLM instance + # + # @param api_key [String] The API key to use + # @param llm_options [Hash] Options to pass to the GigaChat::Client constructor + def initialize(api_key:, llm_options: {}, default_options: {}) + depends_on "gigachat" + + llm_options[:log_errors] = Langchain.logger.debug? unless llm_options.key?(:log_errors) + llm_options[:extra_headers] = {"X-Session-ID" => SecureRandom.uuid} if llm_options.dig(:extra_headers, "X-Session-ID").nil? + @client = GigaChat::Client.new(client_base64: api_key, **llm_options) do |f| + f.response :logger, Langchain.logger, {headers: true, bodies: true, errors: true} + end + @defaults = DEFAULTS.merge(default_options) + chat_parameters.update( + model: {default: @defaults[:chat_model]}, + # logprobs: {}, + # top_logprobs: {}, + temperature: {default: @defaults[:temperature]}, + # user: {}, + response_format: {default: @defaults[:response_format]} + ) + chat_parameters.ignore(:top_k) + end + + # Generate an embedding for a given text + # + # @param text [String] The text to generate an embedding for + # @param model [String] ID of the model to use + # @param encoding_format [String] The format to return the embeddings in. Can be either float or base64. + # @param user [String] A unique identifier representing your end-user + # @return [Langchain::LLM::GigachatResponse] Response object + def embed( + text:, + model: defaults[:embedding_model], + encoding_format: nil, + user: nil + # dimensions: @defaults[:dimensions] + ) + raise ArgumentError.new("text argument is required") if text.empty? + raise ArgumentError.new("model argument is required") if model.empty? + raise ArgumentError.new("encoding_format must be either float or base64") if encoding_format && %w[float base64].include?(encoding_format) + + parameters = { + input: text, + model: model + } + parameters[:encoding_format] = encoding_format if encoding_format + parameters[:user] = user if user + + # if dimensions + # parameters[:dimensions] = dimensions + # elsif EMBEDDING_SIZES.key?(model) + # parameters[:dimensions] = EMBEDDING_SIZES[model] + # end + + # dimensions parameter not supported by gigachat + parameters.delete(:dimensions) + + response = with_api_error_handling do + client.embeddings(parameters: parameters) + end + + Langchain::LLM::GigachatResponse.new(response) + end + + # rubocop:disable Style/ArgumentsForwarding + # Generate a completion for a given prompt + # + # @param prompt [String] The prompt to generate a completion for + # @param params [Hash] The parameters to pass to the `chat()` method + # @return [Langchain::LLM::GigachatResponse] Response object + def complete(prompt:, **params) + Langchain.logger.warn "DEPRECATED: `Langchain::LLM::Gigachat#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::Gigachat#chat` instead." + + if params[:stop_sequences] + params[:stop] = params.delete(:stop_sequences) + end + # Should we still accept the `messages: []` parameter here? + messages = [{role: "user", content: prompt}] + chat(messages: messages, **params) + end + + # rubocop:enable Style/ArgumentsForwarding + + # Generate a chat completion for given messages. + # + # @param [Hash] params unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA] + # @option params [Array] :messages List of messages comprising the conversation so far + # @option params [String] :model ID of the model to use + def chat(params = {}, &block) + parameters = chat_parameters + parameters.remap(tools: :functions, tool_choice: :function_call) + parameters = parameters.to_params(params) + raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty? + raise ArgumentError.new("model argument is required") if parameters[:model].to_s.empty? + if parameters[:function_call] && Array(parameters[:functions]).empty? + raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.") + end + + if block + @response_chunks = [] + parameters[:stream_options] = {include_usage: true} + parameters[:stream] = proc do |chunk, _bytesize| + chunk_content = chunk.dig("choices", 0) || {} + @response_chunks << chunk + yield chunk_content + end + end + response = with_api_error_handling do + client.chat(parameters: parameters) + end + response = response_from_chunks if block + reset_response_chunks + + Langchain::LLM::GigachatResponse.new(response) + end + + # Generate a summary for a given text + # + # @param text [String] The text to generate a summary for + # @return [String] The summary + def summarize(text:) + prompt_template = Langchain::Prompt.load_from_path( + file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml") + ) + prompt = prompt_template.format(text: text) + + complete(prompt: prompt) + end + + def default_dimensions + @defaults[:dimensions] || EMBEDDING_SIZES.fetch(defaults[:embedding_model]) + end + + private + + attr_reader :response_chunks + + def reset_response_chunks + @response_chunks = [] + end + + def with_api_error_handling + response = yield + return if response.empty? + + raise Langchain::LLM::ApiError.new "GigaChat API error: #{response.dig("status")}, #{response.dig("message")}" if response&.dig("status") + + response + end + + def response_from_chunks + grouped_chunks = @response_chunks + .group_by { |chunk| chunk.dig("choices", 0, "index") } + .except(nil) # the last chunk (that contains the token usage) has no index + final_choices = grouped_chunks.map do |index, chunks| + { + "index" => index, + "message" => { + "role" => "assistant", + "content" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "content") }.join, + "tool_calls" => tool_calls_from_choice_chunks(chunks) + }.compact, + "finish_reason" => chunks.last.dig("choices", 0, "finish_reason") + } + end + @response_chunks.first&.slice("id", "object", "created", "model")&.merge({"choices" => final_choices, "usage" => @response_chunks.last["usage"]}) + end + + def tool_calls_from_choice_chunks(choice_chunks) + chunks_by_index = choice_chunks.group_by { |chunk| chunk.dig("choices", 0, "index") } + res = chunks_by_index.map do |_index, chunks| + deltas = chunks.map { |chunk| chunk.dig("choices", 0, "delta") } + next unless deltas.any? { |delta| delta["function_call"] } + + func = deltas.find { |delta| delta["function_call"] } + { + "id" => deltas.find { |delta| delta["functions_state_id"]}&.dig("functions_state_id"), + "type" => func.dig("function_call", "type"), + "function" => { + "name" => func.dig("function_call", "name"), + "arguments" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "function_call", "arguments") }.join + } + } + end.compact + res.empty? ? nil : res + end + end +end diff --git a/lib/langchain/llm/response/gigachat_response.rb b/lib/langchain/llm/response/gigachat_response.rb new file mode 100644 index 000000000..b48767641 --- /dev/null +++ b/lib/langchain/llm/response/gigachat_response.rb @@ -0,0 +1,63 @@ +# frozen_string_literal: true + +module Langchain::LLM + class GigachatResponse < BaseResponse + def model + raw_response["model"] + end + + def created_at + if raw_response.dig("created") + Time.at(raw_response.dig("created")) + end + end + + def completion + completions&.dig(0, "message", "content") + end + + def role + completions&.dig(0, "message", "role") + end + + def chat_completion + completion + end + + def tool_calls + if chat_completions.dig(0, "message").has_key?("function_call") + chat_completions.dig(0, "message", "function_call") + else + [] + end + end + + def embedding + embeddings&.first + end + + def completions + raw_response.dig("choices") + end + + def chat_completions + raw_response.dig("choices") + end + + def embeddings + raw_response.dig("data")&.map { |datum| datum.dig("embedding") } + end + + def prompt_tokens + raw_response.dig("usage", "prompt_tokens") + end + + def completion_tokens + raw_response.dig("usage", "completion_tokens") + end + + def total_tokens + raw_response.dig("usage", "total_tokens") + end + end +end diff --git a/spec/fixtures/llm/gigachat/chat.json b/spec/fixtures/llm/gigachat/chat.json new file mode 100644 index 000000000..fa091dcfe --- /dev/null +++ b/spec/fixtures/llm/gigachat/chat.json @@ -0,0 +1,30 @@ +{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Здравствуйте! К сожалению, я не могу дать точный ответ на этот вопрос, так как это зависит от многих факторов. Однако обычно релиз новых функций и обновлений в GigaChat происходит постепенно и незаметно для пользователей. Рекомендую следить за новостями и обновлениями проекта в официальном сообществе GigaChat или на сайте разработчиков.", + "created": 1625284800, + "name": "text2image", + "functions_state_id": "77d3fb14-457a-46ba-937e-8d856156d003", + "function_call": { + "name": "string", + "arguments": {} + }, + "data_for_context": [ + {} + ] + }, + "index": 0, + "finish_reason": "stop" + } + ], + "created": 1678878333, + "model": "GigaChat", + "usage": { + "prompt_tokens": 18, + "completion_tokens": 68, + "total_tokens": 86 + }, + "object": "chat.completion" +} \ No newline at end of file diff --git a/spec/fixtures/llm/gigachat/chat_chunk.json b/spec/fixtures/llm/gigachat/chat_chunk.json new file mode 100644 index 000000000..7854dbaf6 --- /dev/null +++ b/spec/fixtures/llm/gigachat/chat_chunk.json @@ -0,0 +1,15 @@ +{ + "choices": + [ + { + "delta": + { + "content": " помощь" + }, + "index": 0 + } + ], + "created": 1683034756, + "model": "GigaChat", + "object": "chat.completion" +} \ No newline at end of file diff --git a/spec/fixtures/llm/gigachat/chat_with_function_call.json b/spec/fixtures/llm/gigachat/chat_with_function_call.json new file mode 100644 index 000000000..8c3674cc6 --- /dev/null +++ b/spec/fixtures/llm/gigachat/chat_with_function_call.json @@ -0,0 +1,28 @@ +{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "", + "functions_state_id": "77d3fb14-457a-46ba-937e-8d856156d003", + "function_call": { + "name": "weather_forecast", + "arguments": { + "location": "Москва", + "format": "celsius" + } + } + }, + "index": 0, + "finish_reason": "function_call" + } + ], + "created": 1700471392, + "model": "GigaChat", + "usage": { + "prompt_tokens": 150, + "completion_tokens": 35, + "total_tokens": 185 + }, + "object": "chat.completion" +} \ No newline at end of file diff --git a/spec/langchain/llm/gigachat_spec.rb b/spec/langchain/llm/gigachat_spec.rb new file mode 100644 index 000000000..32d86cfa8 --- /dev/null +++ b/spec/langchain/llm/gigachat_spec.rb @@ -0,0 +1,815 @@ +# frozen_string_literal: true + +require "openai" + +RSpec.describe Langchain::LLM::Gigachat do + let(:subject) { described_class.new(api_key: "123", **options) } + + let(:options) { {} } + + describe "#initialize" do + it "initializes the client without any errors" do + expect { subject }.not_to raise_error + end + + it "forwards the Langchain logger to the client" do + f_mock = double("f_mock", response: nil) + + allow(::GigaChat::Client).to receive(:new) { |**, &block| block&.call(f_mock) } + + subject + + expect(f_mock).to have_received(:response).with(:logger, Langchain.logger, anything) + end + + context "when log level is DEBUG" do + before do + Langchain.logger.level = Logger::DEBUG + end + + it "configures the client to log the errors" do + allow(GigaChat::Client).to receive(:new).and_call_original + subject + expect(GigaChat::Client).to have_received(:new).with(hash_including(log_errors: true)) + end + + context "when overriding the 'log_errors' param" do + let(:options) { {llm_options: {log_errors: false}} } + + it "configures the client to NOT log the errors" do + allow(GigaChat::Client).to receive(:new).and_call_original + subject + expect(GigaChat::Client).to have_received(:new).with(hash_including(log_errors: false)) + end + end + end + + context "when log level is not DEBUG" do + before do + Langchain.logger.level = Logger::INFO + end + + it "configures the client to NOT log the errors" do + allow(GigaChat::Client).to receive(:new).and_call_original + subject + expect(GigaChat::Client).to have_received(:new).with(hash_including(log_errors: false)) + end + + context "when overriding the 'log_errors' param" do + let(:options) { {llm_options: {log_errors: true}} } + + it "configures the client to log the errors" do + allow(GigaChat::Client).to receive(:new).and_call_original + subject + expect(GigaChat::Client).to have_received(:new).with(hash_including(log_errors: true)) + end + end + end + + context "when llm_options are passed" do + let(:options) { {llm_options: {uri_base: "http://localhost:1234"}} } + + it "initializes the client without any errors" do + expect { subject }.not_to raise_error + end + + it "passes correct options to the client" do + # openai-ruby sets global configuration options here: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb + result = subject + expect(result.client.uri_base).to eq("http://localhost:1234") + end + end + + context "when default_options are passed" do + let(:default_options) { {response_format: {type: "json_object"}} } + + subject { described_class.new(api_key: "123", default_options: default_options) } + + it "sets the defaults options" do + expect(subject.defaults[:response_format]).to eq(type: "json_object") + end + + it "get passed to consecutive chat() call" do + subject + expect(subject.client).to receive(:chat).with(parameters: hash_including(default_options)).and_return({}) + subject.chat(messages: [{role: "user", content: "Hello json!"}]) + end + + it "can be overridden" do + subject + expect(subject.client).to receive(:chat).with(parameters: hash_including({response_format: {type: "text"}})).and_return({}) + subject.chat(messages: [{role: "user", content: "Hello json!"}], response_format: {type: "text"}) + end + end + end + + describe "#embed" do + let(:result) { [-0.007097351, 0.0035200312, -0.0069700438] } + let(:parameters) do + {parameters: { + input: "Hello World", + model: "GigaChat", + # dimensions: 1536 + }} + end + let(:response) do + { + "object" => "list", + "model" => parameters[:parameters][:model], + "data" => [ + { + "object" => "embedding", + "index" => 0, + "embedding" => result + } + ], + "usage" => { + "prompt_tokens" => 2, + "total_tokens" => 2 + } + } + end + + before do + allow(subject.client).to receive(:embeddings).with(parameters).and_return(response) + end + + it "returns valid llm response object" do + response = subject.embed(text: "Hello World") + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.model).to eq("GigaChat") + expect(response.embedding).to eq([-0.007097351, 0.0035200312, -0.0069700438]) + expect(response.prompt_tokens).to eq(2) + expect(response.completion_tokens).to eq(nil) + expect(response.total_tokens).to eq(2) + end + + context "with default parameters" do + it "returns an embedding" do + response = subject.embed(text: "Hello World") + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.embedding).to eq(result) + end + end + + context "with text and parameters" do + let(:parameters) do + {parameters: {input: "Hello World", model: "GigaChat", user: "id"}} + end + + it "returns an embedding" do + response = subject.embed(text: "Hello World", model: "GigaChat", user: "id") + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.embedding).to eq(result) + end + end + + describe "the model dimension" do + let(:model) { "GigaChat" } + let(:dimensions_size) { 1536 } + let(:parameters) do + {parameters: {input: "Hello World", model: model}} + end + + context "when dimensions is not provided" do + it "forwards the models default dimensions" do + subject.embed(text: "Hello World", model: model) + + expect(subject.client).to have_received(:embeddings).with(parameters) + end + end + + context "when dimensions is provided" do + let(:parameters) do + {parameters: {input: "Hello World", model: model}} + end + + let(:subject) do + described_class.new(api_key: "123", default_options: { + embedding_model: model, + }) + end + + it "forwards the model's default dimensions" do + allow(subject.client).to receive(:embeddings).with(parameters).and_return(response) + subject.embed(text: "Hello World", model: model) + + expect(subject.client).to have_received(:embeddings).with(parameters) + end + end + end + + Langchain::LLM::Gigachat::EMBEDDING_SIZES.each do |model_key, dimensions| + model = model_key.to_s + + context "when using model #{model}" do + let(:text) { "Hello World" } + let(:result) { [0.001, 0.002, 0.003] } # Ejemplo de resultado esperado + + let(:base_parameters) do + { + input: text, + model: model + } + end + + let(:expected_parameters) do + base_parameters[:dimensions] = dimensions unless model == "GigaChat" + base_parameters + end + + let(:response) do + { + "object" => "list", + "model" => model, + "data" => [{"object" => "embedding", "index" => 0, "embedding" => result}], + "usage" => {"prompt_tokens" => 2, "total_tokens" => 2} + } + end + + before do + allow(subject.client).to receive(:embeddings).with(parameters: expected_parameters).and_return(response) + end + + it "generates an embedding using #{model}" do + embedding_response = subject.embed(text: text, model: model) + + expect(embedding_response).to be_a(Langchain::LLM::GigachatResponse) + expect(embedding_response.model).to eq(model) + expect(embedding_response.embedding).to eq(result) + expect(embedding_response.prompt_tokens).to eq(2) + expect(embedding_response.total_tokens).to eq(2) + end + end + end + + context "when dimensions are explicitly provided" do + let(:parameters) do + {parameters: {input: "Hello World", model: "GigaChat"}} + end + + it "they are passed to the API" do + allow(subject.client).to receive(:embeddings).with(parameters).and_return(response) + subject.embed(text: "Hello World", model: "GigaChat") + + expect(subject.client).to have_received(:embeddings).with(parameters) + end + end + + context "when dimensions are explicitly provided to the initialize default options" do + let(:subject) { described_class.new(api_key: "123", default_options: {}) } + let(:model) { "GigaChat" } + let(:text) { "Hello World" } + let(:parameters) do + {parameters: {input: text, model: model}} + end + + it "they are passed to the API" do + allow(subject.client).to receive(:embeddings).with(parameters).and_return(response) + subject.embed(text: text, model: model) + + expect(subject.client).to have_received(:embeddings).with(parameters) + end + end + end + + describe "#complete" do + let(:response) do + { + "id" => "chatcmpl-9orgr5hNUdCsQeNWGnmNnbXQVIcPN", + "object" => "chat.completion", + "created" => 1721906887, + "model" => "GigaChat", + "choices" => [ + { + "message" => { + "role" => "assistant", + "content" => "The meaning of life is subjective and can vary from person to person." + }, + "finish_reason" => "stop", + "index" => 0 + } + ], + "usage" => { + "prompt_tokens" => 7, + "completion_tokens" => 16, + "total_tokens" => 23 + } + } + end + + before do + allow(subject.client).to receive(:chat).with(parameters).and_return(response) + allow(subject.client).to receive(:chat).with(parameters).and_return(response) + end + + context "with default parameters" do + let(:parameters) do + { + parameters: { + model: "GigaChat", + messages: [{content: "Hello World", role: "user"}], + temperature: 0.0 + # max_tokens: 4087 + } + } + end + + it "returns valid llm response object" do + response = subject.complete(prompt: "Hello World") + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.model).to eq("GigaChat") + expect(response.completion).to eq("The meaning of life is subjective and can vary from person to person.") + expect(response.prompt_tokens).to eq(7) + expect(response.completion_tokens).to eq(16) + expect(response.total_tokens).to eq(23) + end + + it "returns a completion" do + response = subject.complete(prompt: "Hello World") + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.model).to eq("GigaChat") + expect(response.completions).to eq([{"message" => {"role" => "assistant", "content" => "The meaning of life is subjective and can vary from person to person."}, "finish_reason" => "stop", "index" => 0}]) + expect(response.completion).to eq("The meaning of life is subjective and can vary from person to person.") + end + end + + context "with custom default_options" do + context "with legacy model" do + let(:logger) { double("logger") } + let(:subject) { + described_class.new( + api_key: "123", + default_options: {completion_model: "GigaChat-Pro"} + ) + } + let(:parameters) do + { + parameters: + { + model: "GigaChat-Pro", + prompt: "Hello World", + temperature: 0.0 + # max_tokens: 4095 + } + } + end + + before do + allow(Langchain).to receive(:logger).and_return(logger) + allow(logger).to receive(:warn) + end + + it "passes correct options to the completions method" do + expect(subject.client).to receive(:chat).with({ + parameters: { + # n: 1, + # max_tokens: 4087, + model: "GigaChat", + messages: [{content: "Hello World", role: "user"}], + temperature: 0.0 + } + }).and_return(response) + subject.complete(prompt: "Hello World") + end + end + + context "with new model" do + let(:subject) { + described_class.new( + api_key: "123", + default_options: {completion_model: "GigaChat"} + ) + } + + let(:parameters) do + { + parameters: { + model: "GigaChat", + messages: [{content: "Hello World", role: "user"}], + temperature: 0.0 # , + # max_tokens: 4086 + } + } + end + + it "passes correct options to the chat method" do + expect(subject.client).to receive(:chat).with({ + parameters: { + # n: 1, + # max_tokens: 4087 , + model: "GigaChat", + messages: [{content: "Hello World", role: "user"}], + temperature: 0.0 + } + }).and_return(response) + subject.complete(prompt: "Hello World") + end + end + end + + context "with prompt and parameters" do + let(:parameters) do + {parameters: {model: "GigaChat", messages: [{content: "Hello World", role: "user"}], temperature: 1.0}} # , max_tokens: 4087}} + end + + it "returns a completion" do + response = subject.complete(prompt: "Hello World", model: "GigaChat", temperature: 1.0) + + expect(response.completion).to eq("The meaning of life is subjective and can vary from person to person.") + end + end + + context "with failed API call" do + let(:parameters) do + {parameters: {model: "GigaChat", messages: [{content: "Hello World", role: "user"}], temperature: 0.0}} # , max_tokens: 4087}} + end + let(:response) do + {"status" => 400, "message" => "User location is not supported for the API use.", "type" => "invalid_request_error"} + end + + it "raises an error" do + expect { + subject.complete(prompt: "Hello World") + }.to raise_error(Langchain::LLM::ApiError, "GigaChat API error: 400, User location is not supported for the API use.") + end + end + end + + describe "#chat" do + let(:prompt) { "What is the meaning of life?" } + let(:model) { "GigaChat" } + let(:temperature) { 0.0 } + let(:history) { [content: prompt, role: "user"] } + let(:parameters) { {parameters: {messages: history, model: model, temperature: temperature}} } # max_tokens: be_between(4014, 4096)}} } + let(:answer) { "As an AI language model, I don't have feelings, but I'm functioning well. How can I assist you today?" } + let(:answer_2) { "Alternative answer" } + let(:choices) do + [ + { + "message" => { + "role" => "assistant", + "content" => answer + }, + "finish_reason" => "stop", + "index" => 0 + } + ] + end + let(:response) do + { + "id" => "chatcmpl-9otuxUHnW84Zqu97VE1eKPmXVLAv0", + "object" => "chat.completion", + "created" => 1721918375, + "model" => "GigaChat", + "usage" => { + "prompt_tokens" => 14, + "completion_tokens" => 25, + "total_tokens" => 39 + }, + "choices" => choices + } + end + + before do + allow(subject.client).to receive(:chat).with(parameters).and_return(response) + end + + it "ignoresq any invalid parameters provided" do + response = subject.chat( + messages: [{role: "user", content: "What is the meaning of life?"}], + top_k: 5, + beep: :boop + ) + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + end + + it "returns valid llm response object" do + response = subject.chat(messages: [{role: "user", content: "What is the meaning of life?"}]) + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.model).to eq("GigaChat") + expect(response.chat_completion).to eq("As an AI language model, I don't have feelings, but I'm functioning well. How can I assist you today?") + expect(response.prompt_tokens).to eq(14) + expect(response.completion_tokens).to eq(25) + expect(response.total_tokens).to eq(39) + end + + context "with prompt" do + it "sends prompt within messages" do + response = subject.chat(messages: [{role: "user", content: prompt}]) + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.model).to eq("GigaChat") + expect(response.completions).to eq(choices) + expect(response.chat_completion).to eq(answer) + end + end + + context "with messages" do + it "sends messages" do + response = subject.chat(messages: [{role: "user", content: prompt}]) + + expect(response.chat_completion).to eq(answer) + end + end + + context "with context and examples" do + let(:context) { "You are a chatbot" } + let(:examples) do + [ + {role: "user", content: "Hello"}, + {role: "assistant", content: "Hi. How can I assist you today?"} + ] + end + let(:history) do + [ + {role: "system", content: context}, + {role: "user", content: "Hello"}, + {role: "assistant", content: "Hi. How can I assist you today?"}, + {role: "user", content: prompt} + ] + end + + context "when last message is from user and prompt is present" do + let(:messages) do + [ + {role: "system", content: context}, + {role: "user", content: "Hello"}, + {role: "assistant", content: "Hi. How can I assist you today?"}, + {role: "user", content: "I want to ask a question"} + ] + end + let(:history) do + [ + {role: "system", content: context}, + {role: "user", content: "Hello"}, + {role: "assistant", content: "Hi. How can I assist you today?"}, + {role: "user", content: "I want to ask a question\n#{prompt}"} + ] + end + end + end + + context "with options" do + let(:temperature) { 0.75 } + let(:model) { "gpt-3.5-turbo-0301" } + + it "sends prompt as message and additional params and returns a response message" do + response = subject.complete(prompt: prompt, model: model, temperature: temperature) + + expect(response.chat_completion).to eq(answer) + end + + context "with multiple choices" do + let(:n) { 2 } + let(:choices) do + [ + { + "message" => {"role" => "assistant", "content" => answer}, + "finish_reason" => "stop", + "index" => 0 + }, + { + "message" => {"role" => "assistant", "content" => answer_2}, + "finish_reason" => "stop", + "index" => 1 + } + ] + end + + it "returns multiple response messages" do + response = subject.chat(messages: [content: prompt, role: "user"], model: model, temperature: temperature, n: 2) + + expect(response.completions).to eq(choices) + end + end + end + + context "with streaming" do + let(:streamed_response_chunk) do + { + "id" => "chatcmpl-7Hcl1sXOtsaUBKJGGhNujEIwhauaD", + "choices" => [{"index" => 0, "delta" => {"content" => answer}, "finish_reason" => nil}] + } + end + let(:token_usage) do + { + "usage" => {"prompt_tokens" => 10, "completion_tokens" => 11, "total_tokens" => 12} + } + end + + it "handles streaming responses correctly" do + allow(subject.client).to receive(:chat) do |parameters| + parameters[:parameters][:stream].call(streamed_response_chunk) + parameters[:parameters][:stream].call(token_usage) + end + response = subject.chat(messages: [content: prompt, role: "user"]) do |chunk| + chunk + end + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.prompt_tokens).to eq(10) + expect(response.completion_tokens).to eq(11) + expect(response.total_tokens).to eq(12) + end + end + + context "with streaming and multiple choices n=2" do + let(:answer) { "Hello how are you?" } + let(:answer_2) { "Alternative answer" } + let(:streamed_response_chunk) do + { + "id" => "chatcmpl-7Hcl1sXOtsaUBKJGGhNujEIwhauaD", + "choices" => [{"index" => 0, "delta" => {"content" => answer}, "finish_reason" => "stop"}] + } + end + let(:streamed_response_chunk_2) do + { + "id" => "chatcmpl-7Hcl1sXOtsaUBKJGGhNujEIwhauaD", + "choices" => [{"index" => 1, "delta" => {"content" => answer_2}, "finish_reason" => "stop"}] + } + end + let(:token_usage) do + { + "usage" => {"prompt_tokens" => 10, "completion_tokens" => 11, "total_tokens" => 12} + } + end + + it "handles streaming responses correctly" do + allow(subject.client).to receive(:chat) do |parameters| + parameters[:parameters][:stream].call(streamed_response_chunk) + parameters[:parameters][:stream].call(streamed_response_chunk_2) + parameters[:parameters][:stream].call(token_usage) + end + response = subject.chat(messages: [content: prompt, role: "user"], n: 2) do |chunk| + chunk + end + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.completions).to eq( + [ + {"index" => 0, "message" => {"role" => "assistant", "content" => answer}, "finish_reason" => "stop"}, + {"index" => 1, "message" => {"role" => "assistant", "content" => answer_2}, "finish_reason" => "stop"} + ] + ) + expect(response.prompt_tokens).to eq(10) + expect(response.completion_tokens).to eq(11) + expect(response.total_tokens).to eq(12) + end + end + + context "with streaming and tool_calls" do + let(:tools) do + [{ + "type" => "function", + "function" => { + "name" => "foo", + "parameters" => { + "type" => "object", + "properties" => { + "value" => { + "type" => "string" + } + } + }, + "required" => ["value"] + } + }] + end + let(:chunk_deltas) do + [ + {"role" => "assistant", "content" => nil}, + {"function_call" => {"index" => 0, "id" => "call_123456", "type" => "function", "name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}}, + {"content" => nil, "functions_state_id" => "call_123456"}, + #{"function_call" => {"index" => 0, "function" => {"arguments" => "{\"va"}}}, + #{"function_call" => {"index" => 0, "function" => {"arguments" => "lue\":"}}}, + #{"function_call" => {"index" => 0, "function" => {"arguments" => " \"my_s"}}}, + #{"function_call" => {"index" => 0, "function" => {"arguments" => "trin"}}}, + #{"function_call" => {"index" => 0, "function" => {"arguments" => "g\"}"}}} + ] + end + let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } } + let(:expected_tool_calls) do + [ + {"id" => "call_123456", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}} + ] + end + + it "handles streaming responses correctly" do + allow(subject.client).to receive(:chat) do |parameters| + chunks.each do |chunk| + parameters[:parameters][:stream].call(chunk) + end + chunks.last + end + + response = subject.chat(messages: [content: prompt, role: "user"], tools:) do |chunk| + chunk + end + + expect(response).to be_a(Langchain::LLM::GigachatResponse) + expect(response.raw_response.dig("choices", 0, "message", "tool_calls")).to eq(expected_tool_calls) + end + end + + context "with failed API call" do + let(:response) do + {"status" => 400, "message" => "User location is not supported for the API use.", "type" => "invalid_request_error"} + end + + it "raises an error" do + expect { + subject.chat(messages: [content: prompt, role: "user"]) + }.to raise_error(Langchain::LLM::ApiError, "GigaChat API error: 400, User location is not supported for the API use.") + end + end + + context "with tool_choice" do + it "raises an error" do + expect { + subject.chat(messages: [content: prompt, role: "user"], tool_choice: "auto") + }.to raise_error(ArgumentError, "'tool_choice' is only allowed when 'tools' are specified.") + end + end +end + + describe "#summarize" do + let(:text) { "Text to summarize" } + + before do + allow(subject).to receive(:complete).and_return("Summary") + end + + it "returns a summary" do + expect(subject.summarize(text: text)).to eq("Summary") + end + end + + describe "tool_calls_from_choice_chunks" do + context "without tool_calls" do + let(:chunks) do + [ + {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => nil}}]}, + {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => {"role" => "assistant", "content" => "Hello"}}]} + ] + end + + it "returns nil" do + expect(subject.send(:tool_calls_from_choice_chunks, chunks)).to eq(nil) + end + end + + context "with tool_calls" do + let(:chunk_deltas) do + [ + {"role" => "assistant", "content" => "Мне нужно посмотреть погоду в Москве"}, + {"content" => " на"}, + {"content" => " завтра"}, + {"function_call" => {"name" => "weather_forecast", "arguments" => {"location" => "Moscow", "num_days" => 1}}}, + {"functions_state_id" => "77d3fb14-457a-46ba-937e-8d856156d003", "content" => ""} + ] + end + let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } } + let(:expected_tool_calls) do + [{"function" => {"arguments" => "{\"location\"=>\"Moscow\", \"num_days\"=>1}", "name" => "weather_forecast"}, "id" => "77d3fb14-457a-46ba-937e-8d856156d003", "type" => nil}] + end + + it "returns the tool_calls" do + expect(subject.send(:tool_calls_from_choice_chunks, chunks)).to eq(expected_tool_calls) + end + end + + context "with multiple tool_calls" do + [ + {"role" => "assistant", "content" => nil}, + {"function_call" => {"index" => 0, "id" => "call_123456", "type" => "function", "function" => {"name" => "foo", "arguments" => "{\"value\": \"my_string\"}"}, + "role" => "assistant", "content" => nil} + }, + {"content" => nil, "functions_state_id" => "call_123456"}, + ] + let(:chunk_deltas) do + [ + {"role" => "assistant", "content" => nil}, + {"function_call" => {"name" => "weather_forecast", "arguments" => {"location" => "Moscow", "num_days" => 1}}}, + {"content" => "", "functions_state_id" => "77d3fb14-457a-46ba-937e-8d856156d003"}, + ] + end + let(:chunks) { chunk_deltas.map { |delta| {"id" => "chatcmpl-abcdefg", "choices" => [{"index" => 0, "delta" => delta}]} } } + let(:expected_tool_calls) do + [ + {"id" => "77d3fb14-457a-46ba-937e-8d856156d003", "type" => nil, "function" => {"name" => "weather_forecast", "arguments" => "{\"location\"=>\"Moscow\", \"num_days\"=>1}"}}, + # {"id" => "call_456", "type" => "function", "function" => {"name" => "bar", "arguments" => "{\"value\": \"other_string\"}"}} + ] + end + + it "returns the tool_calls" do + expect(subject.send(:tool_calls_from_choice_chunks, chunks)).to eq(expected_tool_calls) + end + end + end +end diff --git a/spec/langchain/llm/response/gigachat_response_spec.rb b/spec/langchain/llm/response/gigachat_response_spec.rb new file mode 100644 index 000000000..c788f49cd --- /dev/null +++ b/spec/langchain/llm/response/gigachat_response_spec.rb @@ -0,0 +1,63 @@ +# frozen_string_literal: true + +RSpec.describe Langchain::LLM::GigachatResponse do + subject { described_class.new(raw_response) } + + describe "chat completions" do + let(:raw_response) { JSON.parse File.read("spec/fixtures/llm/gigachat/chat.json") } + + it "created_at returns correct value" do + expect(subject.created_at).to eq(Time.at(raw_response.dig("created"))) + end + + it "returns chat_completion" do + expect(subject.chat_completion).to eq(raw_response.dig("choices", 0, "message", "content")) + end + + it "prompt_tokens returns correct value" do + expect(subject.prompt_tokens).to eq(18) + end + + it "completion_tokens returns correct value" do + expect(subject.completion_tokens).to eq(68) + end + + it "total_tokens return correct value" do + expect(subject.total_tokens).to eq(86) + end + + describe "streamed response chunk" do + let(:raw_response) { JSON.parse File.read("spec/fixtures/llm/gigachat/chat_chunk.json") } + + it "created_at returns correct value" do + expect(subject.created_at).to eq(Time.at(raw_response.dig("created"))) + end + + it "returns chat_completion" do + expect(subject.chat_completion).to eq(raw_response.dig("message", "content")) + end + + it "does not return prompt_tokens" do + expect(subject.prompt_tokens).to be_nil + end + + it "does not return completion_tokens" do + expect(subject.completion_tokens).to be_nil + end + + it "does not return total_tokens" do + expect(subject.total_tokens).to be_nil + end + end + + describe "#tool_calls" do + let(:raw_response) { JSON.parse File.read("spec/fixtures/llm/gigachat/chat_with_function_call.json") } + + it "returns tool_calls" do + expect(subject.tool_calls).to eq({ + "arguments" => {"format" => "celsius", "location" => "Москва"}, "name" => "weather_forecast" + }) + end + end + end +end From 5d21ca3650346588f6659eeb8e57ff09ea862166 Mon Sep 17 00:00:00 2001 From: suhov Date: Sat, 8 Feb 2025 14:02:28 +0300 Subject: [PATCH 2/2] Restore from dev env --- langchain.gemspec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain.gemspec b/langchain.gemspec index eadf059ed..bde4a05e1 100644 --- a/langchain.gemspec +++ b/langchain.gemspec @@ -58,7 +58,7 @@ Gem::Specification.new do |spec| spec.add_development_dependency "hnswlib", "~> 0.8.1" spec.add_development_dependency "hugging-face", "~> 0.3.4" spec.add_development_dependency "milvus", "~> 0.10.3" - #spec.add_development_dependency "llama_cpp", "~> 0.9.4" + spec.add_development_dependency "llama_cpp", "~> 0.9.4" spec.add_development_dependency "nokogiri", "~> 1.13" spec.add_development_dependency "mail", "~> 2.8" spec.add_development_dependency "mistral-ai"