Skip to content

Commit

Permalink
New parsing logic and model definition data types (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
wdoppenberg authored Oct 2, 2024
1 parent d9d513b commit 7672805
Show file tree
Hide file tree
Showing 25 changed files with 590 additions and 448 deletions.
11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ resolver = "2"
exclude = ["tests", "scripts"]

[workspace.dependencies]
candle-core = { version = "0.4.1" }
candle-nn = { version = "0.4.1"}
candle-transformers = { version = "0.4.1" }
candle-core = { version = "0.6.0" }
candle-nn = { version = "0.6.0" }
candle-transformers = { version = "0.6.0" }
tokenizers = { version = "0.20.0" }
clap = { version = "4.5.17"}

# Enable high optimizations for candle in dev builds
[profile.dev.package]
candle-core = { opt-level = 3 }
candle-nn = { opt-level = 3 }
candle-transformers = { opt-level = 3 }


[workspace.package]
license = "Apache-2.0"
version = "0.4.1"
version = "0.5.0"

12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn main() -> Result<(), Error> {
"Hey, how are you doing?"
];

let embeddings = encoder.encode_batch(sentences, true, PoolingStrategy::Mean)?;
let embeddings = encoder.encode_batch(sentences, true)?;

println!("{:?}", embeddings);

Expand Down Expand Up @@ -148,11 +148,13 @@ print(client.models.list())
## Disclaimer

This is still a work-in-progress. The embedding performance is decent but can probably do with some
benchmarking. Furthermore, for higher batch sizes, the program is killed due to a [bug](https://github.com/huggingface/candle/issues/1596).
benchmarking. Furthermore, this is meant to be a lightweight embedding model library + server.

Do not use this in a production environment.
Do not use this in a production environment. If you are looking for something production-ready & in Rust,
consider [`text-embeddings-inference`](https://github.com/huggingface/text-embeddings-inference).

## Credits

* [Huggingface](https://huggingface.co) for the models and the `candle` library.
* [`sentence-transformers`](https://www.sbert.net/index.html) for being the gold standard in sentence embeddings.
* [Huggingface](https://huggingface.co) for the model hosting, the `candle` library, [`text-embeddings-inference`](https://github.com/huggingface/text-embeddings-inference), and
[`text-generation-inference`](https://github.com/huggingface/text-generation-inference) which has helped me find the right patterns.
* [`sentence-transformers`](https://www.sbert.net/index.html) and its contributors for being the gold standard in sentence embeddings.
10 changes: 5 additions & 5 deletions crates/glowrs-server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "glowrs-server"
version = { workspace = true }
edition = "2021"
version = { workspace = true }
license = { workspace = true }

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -11,24 +11,24 @@ glowrs = { path = "../glowrs" }
candle-core = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true }
axum = { version = "0.7.4", features = ["macros"] }
bytes = "1.5.0"
console-subscriber = "0.2.0"
console-subscriber = "0.4.0"
futures-util = "0.3.28"
serde = { version = "1.0.183", features = ["derive"] }
tokio = { version = "1.31.0", features = ["full", "rt-multi-thread", "tracing"] }
tracing = "0.1.37"
tracing-subscriber = "0.3.18"
uuid = { version = "1.6.1", features = ["v4"] }
serde_json = "1.0.111"
tokenizers = "0.15.0"
hf-hub = { version = "0.3.2", features = ["tokio"] }
anyhow = "1.0.79"
thiserror = "1.0.56"
tracing-chrome = "0.7.1"
tower-http = { version = "0.5.1", features = ["trace", "timeout"] }
tower-http = { version = "0.6.1", features = ["trace", "timeout"] }
once_cell = "1.19.0"
clap = { version = "4.4.18", features = ["derive"] }
clap = { workspace = true, features = ["derive"] }


[features]
Expand Down
1 change: 1 addition & 0 deletions crates/glowrs-server/src/server/data_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub enum EncodingFormat {
}

#[derive(Debug, Deserialize, Clone)]
#[allow(dead_code)]
pub struct EmbeddingsRequest {
pub input: Sentences,
pub model: String,
Expand Down
12 changes: 5 additions & 7 deletions crates/glowrs-server/src/server/infer/embed.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use glowrs::{Device, PoolingStrategy, SentenceTransformer};

use crate::server::data_models::{EmbeddingsRequest, EmbeddingsResponse};
use crate::server::infer::client::Client;
use crate::server::infer::handler::RequestHandler;
use crate::server::infer::DedicatedExecutor;
use glowrs::model::embedder::EmbedOutput;
use glowrs::{Device, SentenceTransformer};

pub struct EmbeddingsHandler {
sentence_transformer: SentenceTransformer,
Expand Down Expand Up @@ -39,11 +39,9 @@ impl RequestHandler for EmbeddingsHandler {
const NORMALIZE: bool = false;

// Infer embeddings
let (embeddings, usage) = self.sentence_transformer.encode_batch_with_usage(
sentences.into(),
NORMALIZE,
PoolingStrategy::Mean,
)?;
let EmbedOutput { embeddings, usage } = self
.sentence_transformer
.encode_batch_with_usage(sentences.into(), NORMALIZE)?;

let response = EmbeddingsResponse::from_embeddings(embeddings, usage, request.model);

Expand Down
11 changes: 10 additions & 1 deletion crates/glowrs-server/src/server/routes/embeddings.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
use anyhow::Result;
use axum::extract::State;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::Json;
use serde::Deserialize;
use std::sync::Arc;
use tokio::time::Instant;

use crate::server::data_models::{EmbeddingsRequest, EmbeddingsResponse};
use crate::server::state::ServerState;
use crate::server::ServerError;

#[derive(Debug, Deserialize)]
pub struct QueryData {
api_version: Option<String>,
}

pub async fn infer_text_embeddings(
State(server_state): State<Arc<ServerState>>,
Query(query): Query<QueryData>,
Json(embeddings_request): Json<EmbeddingsRequest>,
) -> Result<(StatusCode, Json<EmbeddingsResponse>), ServerError> {
tracing::trace!("Requested API version: {:?}", query.api_version);

let start = Instant::now();
let (client, _) = server_state
.model_map
Expand Down
8 changes: 5 additions & 3 deletions crates/glowrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@ exclude = ["tests/fixtures/*"]
candle-core = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true }
serde = { version = "1.0.183", features = ["derive"] }
serde_json = "1.0.111"
tracing = "0.1.37"
uuid = { version = "1.6.1", features = ["v4"] }
tokenizers = "0.19.1"
hf-hub = { version = "0.3.2", features = ["tokio"] }
thiserror = "1.0.56"
once_cell = "1.19.0"
clap = { version = "4.5.4", features = ["derive"] }
clap = { workspace = true, features = ["derive"], optional = true }
anyhow = "1.0.86"
once_cell = "1.20.1"

[features]
default = []
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
clap = ["dep:clap"]

[dev-dependencies]
dirs = "5.0.1"
Expand Down
2 changes: 1 addition & 1 deletion crates/glowrs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() -> Result<(), Error> {
"Hey, how are you doing?"
];

let embeddings = encoder.encode_batch(sentences, true, PoolingStrategy::Mean)?;
let embeddings = encoder.encode_batch(sentences, true)?;

println!("{:?}", embeddings);

Expand Down
38 changes: 23 additions & 15 deletions crates/glowrs/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use clap::Parser;
use glowrs::{Device, Error, PoolingStrategy, SentenceTransformer};
use std::process::ExitCode;
use tracing_subscriber::prelude::*;
#[allow(dead_code, unused_imports)]
use std::error::Error;

#[derive(Debug, Parser)]
pub struct App {
#[clap(short, long, default_value = "jinaai/jina-embeddings-v2-small-en")]
pub model_repo: String,
#[cfg(feature = "clap")]
fn main() -> Result<(), Box<dyn Error>> {
use clap::Parser;
use glowrs::{Device, PoolingStrategy, SentenceTransformer};
use tracing_subscriber::prelude::*;

#[clap(short, long, default_value = "debug")]
pub log_level: String,
}
#[derive(Debug, Parser)]
pub struct App {
#[clap(short, long, default_value = "jinaai/jina-embeddings-v2-small-en")]
pub model_repo: String,

#[clap(short, long, default_value = "debug")]
pub log_level: String,
}

fn main() -> Result<ExitCode, Error> {
let app = App::parse();

let sentences = [
Expand Down Expand Up @@ -41,8 +44,8 @@ fn main() -> Result<ExitCode, Error> {
let device = Device::Cpu;
let encoder = SentenceTransformer::from_repo_string(&app.model_repo, &device)?;

let pooling_strategy = PoolingStrategy::Mean;
let embeddings = encoder.encode_batch(sentences.into(), false, pooling_strategy)?;
let pooling_strategy = Some(PoolingStrategy::Mean);
let embeddings = encoder.encode_batch(sentences.into(), false)?;
println!("Embeddings: {:?}", embeddings);

let (n_sentences, _) = embeddings.dims2()?;
Expand All @@ -65,5 +68,10 @@ fn main() -> Result<ExitCode, Error> {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}

Ok(ExitCode::SUCCESS)
Ok(())
}

#[cfg(not(feature = "clap"))]
fn main() {
eprintln!("Enable feature 'clap' to run this example.")
}
50 changes: 50 additions & 0 deletions crates/glowrs/src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
pub(crate) mod model;
pub(crate) mod parse;

#[cfg(test)]
mod tests {
const BERT_CONFIG_PATH: &str = "tests/fixtures/all-MiniLM-L6-v2";
const JINABERT_CONFIG_PATH: &str = "tests/fixtures/jina-embeddings-v2-base-en";
const DISTILBERT_CONFIG_PATH: &str = "tests/fixtures/multi-qa-distilbert-dot-v1";

use std::path::Path;

use super::*;
use crate::config::parse::parse_config;
use crate::pooling::PoolingStrategy;
use crate::Result;

fn test_parse_config_helper(config_path: &str, expected_type: model::ModelType) -> Result<()> {
let path = Path::new(config_path);

let config = parse_config(path, None)?;

assert_eq!(config.model_type, expected_type);

Ok(())
}

#[test]
fn test_parse_config_bert() -> Result<()> {
test_parse_config_helper(
BERT_CONFIG_PATH,
crate::config::model::ModelType::Embedding(PoolingStrategy::Mean),
)
}

#[test]
fn test_parse_config_jinabert() -> Result<()> {
test_parse_config_helper(
JINABERT_CONFIG_PATH,
crate::config::model::ModelType::Embedding(PoolingStrategy::Mean),
)
}

#[test]
fn test_parse_config_distilbert() -> Result<()> {
test_parse_config_helper(
DISTILBERT_CONFIG_PATH,
crate::config::model::ModelType::Embedding(PoolingStrategy::Cls),
)
}
}
65 changes: 65 additions & 0 deletions crates/glowrs/src/config/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//! Embedding model configuration data model
//!
//! An embedder is defined by its model configuration (defined in the `config.json` in the root
//! of a Hugging Face model repository) the model type, and the pooling strategy (optionally
//! defined in a `1_Pooling/config.json` file in the model repository).
use candle_transformers::models::bert::Config as _BertConfig;
use candle_transformers::models::distilbert::Config as DistilBertConfig;
use candle_transformers::models::jina_bert::Config as _JinaBertConfig;
use serde::Deserialize;
use std::collections::HashMap;

use crate::pooling::PoolingStrategy;

/// The base HF embedding model configuration.
///
/// This represents the base fields present in a `config.json` for an embedding model.
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct HFConfig {
pub architectures: Vec<String>,
pub model_type: String,
#[serde(alias = "n_positions")]
pub max_position_embeddings: usize,
#[serde(default)]
pub pad_token_id: usize,
pub id2label: Option<HashMap<String, String>>,
pub label2id: Option<HashMap<String, usize>>,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(crate) enum BertConfig {
Bert(_BertConfig),
JinaBert(_JinaBertConfig),
}

/// The given model type.
///
/// Based on the `model_type` key in the `config.json`, the given variant enables the parser
/// to know what specific model configuration data model to use when deserializing the non-base
/// keys.
#[derive(Deserialize)]
#[serde(tag = "model_type", rename_all = "kebab-case")]
pub(crate) enum ModelConfig {
Bert(BertConfig),
// XlmRoberta(BertConfig),
// Camembert(BertConfig),
// Roberta(BertConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
}

/// The embedding strategy used by a given model.
#[derive(Debug, PartialEq, Clone)]
pub enum ModelType {
Classifier,
Embedding(PoolingStrategy),
}

/// The model definition
pub(crate) struct ModelDefinition {
pub(crate) model_config: ModelConfig,
pub(crate) model_type: ModelType,
}
Loading

0 comments on commit 7672805

Please sign in to comment.