Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New parsing logic and model definition data types #35

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ resolver = "2"
exclude = ["tests", "scripts"]

[workspace.dependencies]
candle-core = { version = "0.4.1" }
candle-nn = { version = "0.4.1"}
candle-transformers = { version = "0.4.1" }
candle-core = { version = "0.6.0" }
candle-nn = { version = "0.6.0" }
candle-transformers = { version = "0.6.0" }
tokenizers = { version = "0.20.0" }
clap = { version = "4.5.17"}

# Enable high optimizations for candle in dev builds
[profile.dev.package]
candle-core = { opt-level = 3 }
candle-nn = { opt-level = 3 }
candle-transformers = { opt-level = 3 }


[workspace.package]
license = "Apache-2.0"
version = "0.4.1"
version = "0.5.0"

12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn main() -> Result<(), Error> {
"Hey, how are you doing?"
];

let embeddings = encoder.encode_batch(sentences, true, PoolingStrategy::Mean)?;
let embeddings = encoder.encode_batch(sentences, true)?;

println!("{:?}", embeddings);

Expand Down Expand Up @@ -148,11 +148,13 @@ print(client.models.list())
## Disclaimer

This is still a work-in-progress. The embedding performance is decent but can probably do with some
benchmarking. Furthermore, for higher batch sizes, the program is killed due to a [bug](https://github.com/huggingface/candle/issues/1596).
benchmarking. Furthermore, this is meant to be a lightweight embedding model library + server.

Do not use this in a production environment.
Do not use this in a production environment. If you are looking for something production-ready & in Rust,
consider [`text-embeddings-inference`](https://github.com/huggingface/text-embeddings-inference).

## Credits

* [Huggingface](https://huggingface.co) for the models and the `candle` library.
* [`sentence-transformers`](https://www.sbert.net/index.html) for being the gold standard in sentence embeddings.
* [Huggingface](https://huggingface.co) for the model hosting, the `candle` library, [`text-embeddings-inference`](https://github.com/huggingface/text-embeddings-inference), and
[`text-generation-inference`](https://github.com/huggingface/text-generation-inference) which has helped me find the right patterns.
* [`sentence-transformers`](https://www.sbert.net/index.html) and its contributors for being the gold standard in sentence embeddings.
10 changes: 5 additions & 5 deletions crates/glowrs-server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "glowrs-server"
version = { workspace = true }
edition = "2021"
version = { workspace = true }
license = { workspace = true }

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -11,24 +11,24 @@ glowrs = { path = "../glowrs" }
candle-core = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true }
axum = { version = "0.7.4", features = ["macros"] }
bytes = "1.5.0"
console-subscriber = "0.2.0"
console-subscriber = "0.4.0"
futures-util = "0.3.28"
serde = { version = "1.0.183", features = ["derive"] }
tokio = { version = "1.31.0", features = ["full", "rt-multi-thread", "tracing"] }
tracing = "0.1.37"
tracing-subscriber = "0.3.18"
uuid = { version = "1.6.1", features = ["v4"] }
serde_json = "1.0.111"
tokenizers = "0.15.0"
hf-hub = { version = "0.3.2", features = ["tokio"] }
anyhow = "1.0.79"
thiserror = "1.0.56"
tracing-chrome = "0.7.1"
tower-http = { version = "0.5.1", features = ["trace", "timeout"] }
tower-http = { version = "0.6.1", features = ["trace", "timeout"] }
once_cell = "1.19.0"
clap = { version = "4.4.18", features = ["derive"] }
clap = { workspace = true, features = ["derive"] }


[features]
Expand Down
1 change: 1 addition & 0 deletions crates/glowrs-server/src/server/data_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub enum EncodingFormat {
}

#[derive(Debug, Deserialize, Clone)]
#[allow(dead_code)]
pub struct EmbeddingsRequest {
pub input: Sentences,
pub model: String,
Expand Down
12 changes: 5 additions & 7 deletions crates/glowrs-server/src/server/infer/embed.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use glowrs::{Device, PoolingStrategy, SentenceTransformer};

use crate::server::data_models::{EmbeddingsRequest, EmbeddingsResponse};
use crate::server::infer::client::Client;
use crate::server::infer::handler::RequestHandler;
use crate::server::infer::DedicatedExecutor;
use glowrs::model::embedder::EmbedOutput;
use glowrs::{Device, SentenceTransformer};

pub struct EmbeddingsHandler {
sentence_transformer: SentenceTransformer,
Expand Down Expand Up @@ -39,11 +39,9 @@ impl RequestHandler for EmbeddingsHandler {
const NORMALIZE: bool = false;

// Infer embeddings
let (embeddings, usage) = self.sentence_transformer.encode_batch_with_usage(
sentences.into(),
NORMALIZE,
PoolingStrategy::Mean,
)?;
let EmbedOutput { embeddings, usage } = self
.sentence_transformer
.encode_batch_with_usage(sentences.into(), NORMALIZE)?;

let response = EmbeddingsResponse::from_embeddings(embeddings, usage, request.model);

Expand Down
11 changes: 10 additions & 1 deletion crates/glowrs-server/src/server/routes/embeddings.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
use anyhow::Result;
use axum::extract::State;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::Json;
use serde::Deserialize;
use std::sync::Arc;
use tokio::time::Instant;

use crate::server::data_models::{EmbeddingsRequest, EmbeddingsResponse};
use crate::server::state::ServerState;
use crate::server::ServerError;

#[derive(Debug, Deserialize)]
pub struct QueryData {
api_version: Option<String>,
}

pub async fn infer_text_embeddings(
State(server_state): State<Arc<ServerState>>,
Query(query): Query<QueryData>,
Json(embeddings_request): Json<EmbeddingsRequest>,
) -> Result<(StatusCode, Json<EmbeddingsResponse>), ServerError> {
tracing::trace!("Requested API version: {:?}", query.api_version);

let start = Instant::now();
let (client, _) = server_state
.model_map
Expand Down
8 changes: 5 additions & 3 deletions crates/glowrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@ exclude = ["tests/fixtures/*"]
candle-core = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true }
serde = { version = "1.0.183", features = ["derive"] }
serde_json = "1.0.111"
tracing = "0.1.37"
uuid = { version = "1.6.1", features = ["v4"] }
tokenizers = "0.19.1"
hf-hub = { version = "0.3.2", features = ["tokio"] }
thiserror = "1.0.56"
once_cell = "1.19.0"
clap = { version = "4.5.4", features = ["derive"] }
clap = { workspace = true, features = ["derive"], optional = true }
anyhow = "1.0.86"
once_cell = "1.20.1"

[features]
default = []
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
clap = ["dep:clap"]

[dev-dependencies]
dirs = "5.0.1"
Expand Down
2 changes: 1 addition & 1 deletion crates/glowrs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() -> Result<(), Error> {
"Hey, how are you doing?"
];

let embeddings = encoder.encode_batch(sentences, true, PoolingStrategy::Mean)?;
let embeddings = encoder.encode_batch(sentences, true)?;

println!("{:?}", embeddings);

Expand Down
38 changes: 23 additions & 15 deletions crates/glowrs/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use clap::Parser;
use glowrs::{Device, Error, PoolingStrategy, SentenceTransformer};
use std::process::ExitCode;
use tracing_subscriber::prelude::*;
#[allow(dead_code, unused_imports)]
use std::error::Error;

#[derive(Debug, Parser)]
pub struct App {
#[clap(short, long, default_value = "jinaai/jina-embeddings-v2-small-en")]
pub model_repo: String,
#[cfg(feature = "clap")]
fn main() -> Result<(), Box<dyn Error>> {
use clap::Parser;
use glowrs::{Device, PoolingStrategy, SentenceTransformer};
use tracing_subscriber::prelude::*;

#[clap(short, long, default_value = "debug")]
pub log_level: String,
}
#[derive(Debug, Parser)]
pub struct App {
#[clap(short, long, default_value = "jinaai/jina-embeddings-v2-small-en")]
pub model_repo: String,

#[clap(short, long, default_value = "debug")]
pub log_level: String,
}

fn main() -> Result<ExitCode, Error> {
let app = App::parse();

let sentences = [
Expand Down Expand Up @@ -41,8 +44,8 @@ fn main() -> Result<ExitCode, Error> {
let device = Device::Cpu;
let encoder = SentenceTransformer::from_repo_string(&app.model_repo, &device)?;

let pooling_strategy = PoolingStrategy::Mean;
let embeddings = encoder.encode_batch(sentences.into(), false, pooling_strategy)?;
let pooling_strategy = Some(PoolingStrategy::Mean);
let embeddings = encoder.encode_batch(sentences.into(), false)?;
println!("Embeddings: {:?}", embeddings);

let (n_sentences, _) = embeddings.dims2()?;
Expand All @@ -65,5 +68,10 @@ fn main() -> Result<ExitCode, Error> {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}

Ok(ExitCode::SUCCESS)
Ok(())
}

#[cfg(not(feature = "clap"))]
fn main() {
eprintln!("Enable feature 'clap' to run this example.")
}
50 changes: 50 additions & 0 deletions crates/glowrs/src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
pub(crate) mod model;
pub(crate) mod parse;

#[cfg(test)]
mod tests {
const BERT_CONFIG_PATH: &str = "tests/fixtures/all-MiniLM-L6-v2";
const JINABERT_CONFIG_PATH: &str = "tests/fixtures/jina-embeddings-v2-base-en";
const DISTILBERT_CONFIG_PATH: &str = "tests/fixtures/multi-qa-distilbert-dot-v1";

use std::path::Path;

use super::*;
use crate::config::parse::parse_config;
use crate::pooling::PoolingStrategy;
use crate::Result;

fn test_parse_config_helper(config_path: &str, expected_type: model::ModelType) -> Result<()> {
let path = Path::new(config_path);

let config = parse_config(path, None)?;

assert_eq!(config.model_type, expected_type);

Ok(())
}

#[test]
fn test_parse_config_bert() -> Result<()> {
test_parse_config_helper(
BERT_CONFIG_PATH,
crate::config::model::ModelType::Embedding(PoolingStrategy::Mean),
)
}

#[test]
fn test_parse_config_jinabert() -> Result<()> {
test_parse_config_helper(
JINABERT_CONFIG_PATH,
crate::config::model::ModelType::Embedding(PoolingStrategy::Mean),
)
}

#[test]
fn test_parse_config_distilbert() -> Result<()> {
test_parse_config_helper(
DISTILBERT_CONFIG_PATH,
crate::config::model::ModelType::Embedding(PoolingStrategy::Cls),
)
}
}
65 changes: 65 additions & 0 deletions crates/glowrs/src/config/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//! Embedding model configuration data model
//!
//! An embedder is defined by its model configuration (defined in the `config.json` in the root
//! of a Hugging Face model repository) the model type, and the pooling strategy (optionally
//! defined in a `1_Pooling/config.json` file in the model repository).

use candle_transformers::models::bert::Config as _BertConfig;
use candle_transformers::models::distilbert::Config as DistilBertConfig;
use candle_transformers::models::jina_bert::Config as _JinaBertConfig;
use serde::Deserialize;
use std::collections::HashMap;

use crate::pooling::PoolingStrategy;

/// The base HF embedding model configuration.
///
/// This represents the base fields present in a `config.json` for an embedding model.
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct HFConfig {
pub architectures: Vec<String>,
pub model_type: String,
#[serde(alias = "n_positions")]
pub max_position_embeddings: usize,
#[serde(default)]
pub pad_token_id: usize,
pub id2label: Option<HashMap<String, String>>,
pub label2id: Option<HashMap<String, usize>>,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(crate) enum BertConfig {
Bert(_BertConfig),
JinaBert(_JinaBertConfig),
}

/// The given model type.
///
/// Based on the `model_type` key in the `config.json`, the given variant enables the parser
/// to know what specific model configuration data model to use when deserializing the non-base
/// keys.
#[derive(Deserialize)]
#[serde(tag = "model_type", rename_all = "kebab-case")]
pub(crate) enum ModelConfig {
Bert(BertConfig),
// XlmRoberta(BertConfig),
// Camembert(BertConfig),
// Roberta(BertConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
}

/// The embedding strategy used by a given model.
#[derive(Debug, PartialEq, Clone)]
pub enum ModelType {
Classifier,
Embedding(PoolingStrategy),
}

/// The model definition
pub(crate) struct ModelDefinition {
pub(crate) model_config: ModelConfig,
pub(crate) model_type: ModelType,
}
Loading