From 5bee71256d990d0246adbf35d77917270635c3e4 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Tue, 5 Nov 2024 15:56:53 +0100 Subject: [PATCH 1/3] feat: adds custom models --- .gitignore | 1 + Cargo.lock | 1 + collection_manager/src/collection.rs | 2 +- collection_manager/src/dto.rs | 2 +- embeddings/.gitignore | 3 +- embeddings/Cargo.toml | 5 + embeddings/src/bin/embeddings.rs | 17 +++ embeddings/src/bin/pq.rs | 54 ++++++--- embeddings/src/custom_models.rs | 167 +++++++++++++++++++++++++++ embeddings/src/lib.rs | 30 ++++- nlp/src/lib.rs | 2 +- types/src/lib.rs | 2 +- 12 files changed, 262 insertions(+), 24 deletions(-) create mode 100644 embeddings/src/bin/embeddings.rs create mode 100644 embeddings/src/custom_models.rs diff --git a/.gitignore b/.gitignore index e6bee7e..1b54bd9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ posting_storage .idea +.custom_models \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 4d61c0b..275277b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1501,6 +1501,7 @@ dependencies = [ "rand_chacha", "rayon", "reductive", + "reqwest", "serde", "serde_json", "strum", diff --git a/collection_manager/src/collection.rs b/collection_manager/src/collection.rs index 5e79b31..51ad9d4 100644 --- a/collection_manager/src/collection.rs +++ b/collection_manager/src/collection.rs @@ -6,8 +6,8 @@ use std::{ use anyhow::anyhow; use dashmap::DashMap; use document_storage::DocumentStorage; -use nlp::Parser; use nlp::locales::Locale; +use nlp::Parser; use serde_json::Value; use storage::Storage; use string_index::StringIndex; diff --git a/collection_manager/src/dto.rs b/collection_manager/src/dto.rs index 14f1965..8faab82 100644 --- a/collection_manager/src/dto.rs +++ b/collection_manager/src/dto.rs @@ -20,7 +20,7 @@ impl From for LanguageDTO { fn from(language: Locale) -> Self { match language { Locale::EN => LanguageDTO::English, - _ => LanguageDTO::English + _ => LanguageDTO::English, } } } diff --git a/embeddings/.gitignore b/embeddings/.gitignore index d3f5752..b15deac 100644 --- a/embeddings/.gitignore +++ b/embeddings/.gitignore @@ -1 +1,2 @@ -.fastembed_cache \ No newline at end of file +.fastembed_cache +.custom_models \ No newline at end of file diff --git a/embeddings/Cargo.toml b/embeddings/Cargo.toml index 3da1959..c5451e8 100644 --- a/embeddings/Cargo.toml +++ b/embeddings/Cargo.toml @@ -7,6 +7,10 @@ edition = "2021" name = "pq" path = "./src/bin/pq.rs" +[[bin]] +name = "embeddings" +path = "./src/bin/embeddings.rs" + [dependencies] anyhow = "1.0.92" fastembed = { version = "4.1.0", features = ["ort-download-binaries"] } @@ -20,3 +24,4 @@ rand = "0.8.5" num-traits = "0.2.19" reductive = { version = "0.9.0" } rand_chacha = "0.3.1" +reqwest = { version = "0.12.9", features = ["blocking"] } diff --git a/embeddings/src/bin/embeddings.rs b/embeddings/src/bin/embeddings.rs new file mode 100644 index 0000000..6cdc40e --- /dev/null +++ b/embeddings/src/bin/embeddings.rs @@ -0,0 +1,17 @@ +use anyhow::Result; +use embeddings::custom_models::{ConfigCustomModelFiles, CustomModelFiles}; + +fn main() -> Result<()> { + let model_files = CustomModelFiles::new( + "jinaai/jina-embeddings-v2-base-code".to_string(), + ConfigCustomModelFiles { + onnx_model: "onnx/model.onnx".to_string(), + config: "config.json".to_string(), + tokenizer_config: "tokenizer_config.json".to_string(), + tokenizer: "tokenizer.json".to_string(), + special_tokens_map: "special_tokens_map.json".to_string(), + }, + ); + + model_files.download() +} diff --git a/embeddings/src/bin/pq.rs b/embeddings/src/bin/pq.rs index 1f10f53..ca459ba 100644 --- a/embeddings/src/bin/pq.rs +++ b/embeddings/src/bin/pq.rs @@ -5,23 +5,49 @@ use embeddings::pq; fn main() -> Result<()> { let models = load_models(); - let vectors = models.embed(embeddings::OramaModels::MultilingualE5Small, vec![ - "CASUAL COMFORT, SPORTY STYLE.Slide into comfort in the lightweight and sporty Nike Benassi JDI Slide. It features the Nike logo on the foot strap, which is lined in super soft fabric. The foam midsole brings that beach feeling to your feet and adds spring to your kicked-back style.Benefits1-piece, synthetic leather strap is lined with super soft, towel-like fabric.The foam midsole doubles as an outsole, adding lightweight cushioning.Flex grooves let you move comfortably.Shown: Black/WhiteStyle: 343880-090".to_string(), - "CLASSIC SUPPORT AND COMFORT.The Nike Air Monarch IV gives you classic style with real leather and plenty of lightweight Nike Air cushioning to keep you moving in comfort.BenefitsLeather and synthetic leather team up for durability and classic comfort.An Air-Sole unit runs the length of your foot for cushioning, comfort and support.Rubber sole is durable and provides traction".to_string(), - "STAY TRUE TO YOUR TEAM ALL DAY, EVERY DAY, GAME DAY.\nRep your favorite team and player anytime in the NFL Baltimore Ravens Game Jersey, inspired by what they're wearing on the field and designed for total comfort.\nTAILORED FIT\nThis jersey features a tailored fit designed for movement.\n\nLIGHT, SOFT FEEL\nScreen-print numbers provide a light and soft feel".to_string(), - "STAY TRUE TO YOUR TEAM ALL DAY, EVERY DAY, GAME DAY.\nRep your favorite team and player anytime in the NFL Indianapolis Colts Game Jersey, inspired by what they're wearing on the field and designed for total comfort.\nTAILORED FIT\nThis jersey features a tailored fit designed for movement.\n\nCLEAN COMFORT\nThe no-tag neck label offers clean comfort.\n\nLIGHT, SOFT FEEL\nScreen-print numbers provide a light and soft feel.\n\nAdditional Details\n\n\nStrategic ventilation for breathability\nWoven jock tag at front lower left\nTPU shield at V-neck\n\n\n\nFabric: 100% recycled polyester\nMachine Wash\nImportedShown: Gym BlueStyle: 468955-442".to_string(), - "A GAME-DAY ESSENTIAL.Featuring comfortable, absorbent fabric, the Nike Swoosh Wristbands stretch with you and keep your hands dry, so you can play your best even when the game heat up.Product DetailsWidth: 3"Sold in pairsSwoosh design embroideryFabric: 72% cotton/12% nylon/11% polyester/4% rubber/1% spandexMachine washImportedShown: White/BlackStyle: NNN04-101".to_string(), - "MATCH-READY COMFORT FOR YOUR FEET.The Nike Academy Socks are designed to keep you comfortable during play with soft, sweat-wicking fabric with arch support.BenefitsNike Dri-FIT technology moves sweat away from your skin for quicker evaporation, helping you stay dry and comfortable.Reinforced heel and toe add durability in high-wear areas.Snug band wraps around the arch for a supportive feel.Product DetailsLeft/right specific98% nylon/2% spandexMachine washImportedShown: Varsity Royal/WhiteStyle: SX4120-402".to_string() - ], None)?; + // let vectors = models.embed(embeddings::OramaModels::MultilingualE5Small, vec![ + // "CASUAL COMFORT, SPORTY STYLE.Slide into comfort in the lightweight and sporty Nike Benassi JDI Slide. It features the Nike logo on the foot strap, which is lined in super soft fabric. The foam midsole brings that beach feeling to your feet and adds spring to your kicked-back style.Benefits1-piece, synthetic leather strap is lined with super soft, towel-like fabric.The foam midsole doubles as an outsole, adding lightweight cushioning.Flex grooves let you move comfortably.Shown: Black/WhiteStyle: 343880-090".to_string(), + // "CLASSIC SUPPORT AND COMFORT.The Nike Air Monarch IV gives you classic style with real leather and plenty of lightweight Nike Air cushioning to keep you moving in comfort.BenefitsLeather and synthetic leather team up for durability and classic comfort.An Air-Sole unit runs the length of your foot for cushioning, comfort and support.Rubber sole is durable and provides traction".to_string(), + // "STAY TRUE TO YOUR TEAM ALL DAY, EVERY DAY, GAME DAY.\nRep your favorite team and player anytime in the NFL Baltimore Ravens Game Jersey, inspired by what they're wearing on the field and designed for total comfort.\nTAILORED FIT\nThis jersey features a tailored fit designed for movement.\n\nLIGHT, SOFT FEEL\nScreen-print numbers provide a light and soft feel".to_string(), + // "STAY TRUE TO YOUR TEAM ALL DAY, EVERY DAY, GAME DAY.\nRep your favorite team and player anytime in the NFL Indianapolis Colts Game Jersey, inspired by what they're wearing on the field and designed for total comfort.\nTAILORED FIT\nThis jersey features a tailored fit designed for movement.\n\nCLEAN COMFORT\nThe no-tag neck label offers clean comfort.\n\nLIGHT, SOFT FEEL\nScreen-print numbers provide a light and soft feel.\n\nAdditional Details\n\n\nStrategic ventilation for breathability\nWoven jock tag at front lower left\nTPU shield at V-neck\n\n\n\nFabric: 100% recycled polyester\nMachine Wash\nImportedShown: Gym BlueStyle: 468955-442".to_string(), + // "A GAME-DAY ESSENTIAL.Featuring comfortable, absorbent fabric, the Nike Swoosh Wristbands stretch with you and keep your hands dry, so you can play your best even when the game heat up.Product DetailsWidth: 3"Sold in pairsSwoosh design embroideryFabric: 72% cotton/12% nylon/11% polyester/4% rubber/1% spandexMachine washImportedShown: White/BlackStyle: NNN04-101".to_string(), + // "MATCH-READY COMFORT FOR YOUR FEET.The Nike Academy Socks are designed to keep you comfortable during play with soft, sweat-wicking fabric with arch support.BenefitsNike Dri-FIT technology moves sweat away from your skin for quicker evaporation, helping you stay dry and comfortable.Reinforced heel and toe add durability in high-wear areas.Snug band wraps around the arch for a supportive feel.Product DetailsLeft/right specific98% nylon/2% spandexMachine washImportedShown: Varsity Royal/WhiteStyle: SX4120-402".to_string() + // ], None)?; + // + // let new_vector = models.embed(embeddings::OramaModels::MultilingualE5Small, vec![ + // "COMFORTABLE COVERAGE FOR YOUR SHINS.Designed to take the impacts of the game, the Nike J Shin Guards are made with a tough composite shell and perforations for ventilated comfort.BenefitsAnatomical left/right construction contours for comfort.Perforations enhance ventilation.EVA foam provides soft cushioning.Product DetailsMaterials: 80% polyethylene/20% EVAImportedShown: Black/WhiteStyle: SP0040-009".to_string() + // ], None)?; + // + // let quantizer = pq::ProductQuantizer::try_new(vectors)?; + // let quantized = quantizer.quantize(new_vector); - let new_vector = models.embed(embeddings::OramaModels::MultilingualE5Small, vec![ - "COMFORTABLE COVERAGE FOR YOUR SHINS.Designed to take the impacts of the game, the Nike J Shin Guards are made with a tough composite shell and perforations for ventilated comfort.BenefitsAnatomical left/right construction contours for comfort.Perforations enhance ventilation.EVA foam provides soft cushioning.Product DetailsMaterials: 80% polyethylene/20% EVAImportedShown: Black/WhiteStyle: SP0040-009".to_string() - ], None)?; + let vector = models.embed( + embeddings::OramaModels::JinaV2BaseCode, + vec![r" + import { create, insert, search } from '@orama/orama' + + const db = create({ + schema: { + title: 'string', + description: 'string' + } + }) - let quantizer = pq::ProductQuantizer::try_new(vectors)?; - let quantized = quantizer.quantize(new_vector); + insert(db, { + title: 'foo', + description: 'bar' + }) + + search(db, { + term: 'foo' + }) - dbg!(quantized); + " + .to_string()], + Some(1), + )?; + + dbg!(vector); Ok(()) } diff --git a/embeddings/src/custom_models.rs b/embeddings/src/custom_models.rs new file mode 100644 index 0000000..8a48637 --- /dev/null +++ b/embeddings/src/custom_models.rs @@ -0,0 +1,167 @@ +use anyhow::{Context, Result}; +use reqwest::blocking::Client; +use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT}; +use std::fs::File; +use std::io::{Read, Write}; +use std::path::Path; + +pub struct CustomModelFiles { + pub model_name: String, + pub config: String, + pub tokenizer: String, + pub tokenizer_config: String, + pub special_tokens_map: String, + pub onnx_model: String, +} + +pub struct ConfigCustomModelFiles { + pub config: String, + pub tokenizer: String, + pub tokenizer_config: String, + pub special_tokens_map: String, + pub onnx_model: String, +} + +impl CustomModelFiles { + pub fn new(model_name: String, config: ConfigCustomModelFiles) -> Self { + Self { + model_name, + onnx_model: config.onnx_model, + special_tokens_map: config.special_tokens_map, + tokenizer: config.tokenizer, + tokenizer_config: config.tokenizer_config, + config: config.config, + } + } + + pub fn download(&self) -> Result<()> { + let onnx_model = self.get_huggingface_absolute_path(&self.onnx_model); + let special_tokens_map = self.get_huggingface_absolute_path(&self.special_tokens_map); + let tokenizer = self.get_huggingface_absolute_path(&self.tokenizer); + let tokenizer_config = self.get_huggingface_absolute_path(&self.tokenizer_config); + let config = self.get_huggingface_absolute_path(&self.config); + + let client = Self::create_client()?; + + self.download_file(&client, onnx_model, "onnx/model.onnx".to_string())?; + self.download_file( + &client, + special_tokens_map, + "special_tokens_map.json".to_string(), + )?; + self.download_file(&client, tokenizer, "tokenizer.json".to_string())?; + self.download_file( + &client, + tokenizer_config, + "tokenizer_config.json".to_string(), + )?; + self.download_file(&client, config, "config.json".to_string())?; + + Ok(()) + } + + fn get_huggingface_absolute_path(&self, file: &String) -> String { + format!( + "https://huggingface.co/{}/resolve/main/{}", + self.model_name, file + ) + } + + fn create_client() -> Result { + let mut headers = HeaderMap::new(); + headers.insert( + USER_AGENT, + HeaderValue::from_static("Mozilla/5.0 (compatible; RustBot/1.0)"), + ); + + Client::builder() + .user_agent("Mozilla/5.0 (compatible; RustBot/1.0)") + .default_headers(headers) + .connect_timeout(std::time::Duration::from_secs(10)) + .timeout(std::time::Duration::from_secs(3600)) + .build() + .context("Failed to create HTTP client") + } + + fn download_file(&self, client: &Client, url: String, file_name: String) -> Result<()> { + println!("Downloading {} from {}", file_name, url); + + let base_path = Path::new(".custom_models").join(&self.model_name); + let destination_path = base_path.join(&file_name); + + if let Some(parent) = destination_path.parent() { + std::fs::create_dir_all(parent).context("Failed to create directories")?; + } + + let mut response = client + .get(&url) + .send() + .context("Failed to send HTTP request")?; + + let mut final_response = response; + while final_response.status().is_redirection() { + let new_url = final_response + .headers() + .get("location") + .and_then(|h| h.to_str().ok()) + .context("Missing or invalid Location header in redirect")?; + + final_response = client + .get(new_url) + .send() + .context("Failed to follow redirect")?; + } + + if final_response.status().is_success() { + let mut dest_file = + File::create(&destination_path).context("Failed to create destination file")?; + + let total_size = final_response + .headers() + .get("content-length") + .and_then(|ct_len| ct_len.to_str().ok()) + .and_then(|ct_len| ct_len.parse::().ok()); + + let mut downloaded = 0u64; + let mut buffer = vec![0u8; 8192]; + let mut last_print = std::time::Instant::now(); + + loop { + let bytes_read = match final_response.read(&mut buffer) { + Ok(0) => break, + Ok(n) => n, + Err(e) => return Err(e.into()), + }; + + dest_file + .write_all(&buffer[..bytes_read]) + .context("Failed to write to file")?; + + downloaded += bytes_read as u64; + + if last_print.elapsed() >= std::time::Duration::from_secs(1) { + if let Some(total) = total_size { + println!( + "Downloaded: {:.1}MB / {:.1}MB ({:.1}%)", + downloaded as f64 / 1_000_000.0, + total as f64 / 1_000_000.0, + (downloaded as f64 / total as f64) * 100.0 + ); + } else { + println!("Downloaded: {:.1}MB", downloaded as f64 / 1_000_000.0); + } + last_print = std::time::Instant::now(); + } + } + + println!("Downloaded {} to {:?}", file_name, destination_path); + } else { + return Err(anyhow::anyhow!( + "Failed to download file: {:?}", + final_response.status() + )); + } + + Ok(()) + } +} diff --git a/embeddings/src/lib.rs b/embeddings/src/lib.rs index ba28159..4acc95f 100644 --- a/embeddings/src/lib.rs +++ b/embeddings/src/lib.rs @@ -1,3 +1,4 @@ +pub mod custom_models; pub mod pq; use anyhow::{anyhow, Result}; @@ -43,6 +44,8 @@ pub enum OramaModels { MultilingualE5Base, #[serde(rename = "multilingual-e5-large")] MultilingualE5Large, + #[serde(rename = "jinaai/jina-embeddings-v2-base-code")] + JinaV2BaseCode, } pub struct LoadedModels(HashMap); @@ -66,6 +69,7 @@ impl LoadedModels { impl Into for OramaModels { fn into(self) -> EmbeddingModel { match self { + OramaModels::JinaV2BaseCode => EmbeddingModel::BGESmallENV15, // @todo: understand how to use other models OramaModels::GTESmall => EmbeddingModel::BGESmallENV15, OramaModels::GTEBase => EmbeddingModel::BGEBaseENV15, OramaModels::GTELarge => EmbeddingModel::BGELargeENV15, @@ -89,8 +93,16 @@ impl OramaModels { } } + pub fn is_custom_model(self) -> bool { + match self { + OramaModels::JinaV2BaseCode => true, + _ => false, + } + } + pub fn max_input_tokens(self) -> usize { match self { + OramaModels::JinaV2BaseCode => 512, OramaModels::GTESmall => 512, OramaModels::GTEBase => 512, OramaModels::GTELarge => 512, @@ -102,6 +114,7 @@ impl OramaModels { pub fn dimensions(self) -> usize { match self { + OramaModels::JinaV2BaseCode => 768, OramaModels::GTESmall => 384, OramaModels::GTEBase => 768, OramaModels::GTELarge => 1024, @@ -128,16 +141,23 @@ pub fn load_models() -> LoadedModels { // OramaModels::MultilingualE5Large, OramaModels::GTESmall, // OramaModels::GTEBase, - // OramaModels::GTELarge + // OramaModels::GTELarge, + OramaModels::JinaV2BaseCode, ]; let model_map: HashMap = models .into_par_iter() .map(|model| { - let initialized_model = TextEmbedding::try_new( - InitOptions::new(model.into()).with_show_download_progress(true), - ) - .unwrap(); + if !model.is_custom_model() { + let initialized_model = TextEmbedding::try_new( + InitOptions::new(model.into()).with_show_download_progress(true), + ) + .unwrap(); + + return (model, initialized_model); + } + + let initialized_model = unimplemented!(); return (model, initialized_model); }) diff --git a/nlp/src/lib.rs b/nlp/src/lib.rs index c56b6f6..53df92a 100644 --- a/nlp/src/lib.rs +++ b/nlp/src/lib.rs @@ -3,10 +3,10 @@ pub mod locales; pub mod stop_words; pub mod tokenizer; +use crate::locales::Locale; use rust_stemmers::Algorithm; pub use rust_stemmers::Stemmer; use tokenizer::Tokenizer; -use crate::locales::Locale; pub struct Parser { pub tokenizer: Tokenizer, diff --git a/types/src/lib.rs b/types/src/lib.rs index d6c8036..54aed2e 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; +use std::collections::HashMap; #[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct CollectionId(pub String); From ff78101b4591758fba9e61013fd0d75e80a280f6 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Tue, 5 Nov 2024 17:28:11 +0100 Subject: [PATCH 2/3] wip --- Cargo.lock | 1 + embeddings/Cargo.toml | 1 + embeddings/src/custom_models.rs | 356 ++++++++++++++++++++------------ embeddings/src/lib.rs | 58 +++++- 4 files changed, 278 insertions(+), 138 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 275277b..a3b59e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1505,6 +1505,7 @@ dependencies = [ "serde", "serde_json", "strum", + "strum_macros", "tokio", ] diff --git a/embeddings/Cargo.toml b/embeddings/Cargo.toml index c5451e8..6ee3e3e 100644 --- a/embeddings/Cargo.toml +++ b/embeddings/Cargo.toml @@ -25,3 +25,4 @@ num-traits = "0.2.19" reductive = { version = "0.9.0" } rand_chacha = "0.3.1" reqwest = { version = "0.12.9", features = ["blocking"] } +strum_macros = "0.26.4" diff --git a/embeddings/src/custom_models.rs b/embeddings/src/custom_models.rs index 8a48637..bc2127c 100644 --- a/embeddings/src/custom_models.rs +++ b/embeddings/src/custom_models.rs @@ -1,20 +1,29 @@ use anyhow::{Context, Result}; -use reqwest::blocking::Client; -use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT}; -use std::fs::File; -use std::io::{Read, Write}; -use std::path::Path; - -pub struct CustomModelFiles { - pub model_name: String, - pub config: String, - pub tokenizer: String, - pub tokenizer_config: String, - pub special_tokens_map: String, - pub onnx_model: String, +use fastembed::{QuantizationMode, TokenizerFiles, UserDefinedEmbeddingModel}; +use reqwest::{ + blocking::Client, + header::{HeaderMap, HeaderValue, USER_AGENT}, +}; +use std::{ + fs::{self, File}, + io::{Read, Write}, + path::{Path, PathBuf}, + time::{Duration, Instant}, +}; + +const BUFFER_SIZE: usize = 8192; +const USER_AGENT_STRING: &str = "Mozilla/5.0 (compatible; RustBot/1.0)"; +const BASE_URL: &str = "https://huggingface.co"; +const MODEL_BASE_DIR: &str = ".custom_models"; + +#[derive(Debug, Clone)] +pub struct CustomModel { + model_name: String, + files: ModelFileConfig, } -pub struct ConfigCustomModelFiles { +#[derive(Debug, Clone)] +pub struct ModelFileConfig { pub config: String, pub tokenizer: String, pub tokenizer_config: String, @@ -22,146 +31,225 @@ pub struct ConfigCustomModelFiles { pub onnx_model: String, } -impl CustomModelFiles { - pub fn new(model_name: String, config: ConfigCustomModelFiles) -> Self { - Self { - model_name, - onnx_model: config.onnx_model, - special_tokens_map: config.special_tokens_map, - tokenizer: config.tokenizer, - tokenizer_config: config.tokenizer_config, - config: config.config, - } +impl CustomModel { + pub fn try_new(model_name: String, files: ModelFileConfig) -> Result { + let model = Self { + model_name: model_name.clone(), + files, + }; + + if !model.exists() { + println!( + "Cannot find model {} locally. Starting download", + model_name + ); + model.download()?; + }; + + Ok(model) + } + + pub fn exists(&self) -> bool { + self.get_file_mappings().iter().all(|(_, destination)| { + self.get_destination_path(destination) + .unwrap_or_default() + .exists() + }) + } + + pub fn load(&self) -> Result { + let full_path = |file: &str| format!("{}/{}", MODEL_BASE_DIR, file); + let onnx_file = fs::read(full_path(&self.files.onnx_model))?; + let tokenizer_files = TokenizerFiles { + tokenizer_file: fs::read(full_path(&self.files.tokenizer))?, + config_file: fs::read(full_path(&self.files.config))?, + special_tokens_map_file: fs::read(full_path(&self.files.special_tokens_map))?, + tokenizer_config_file: fs::read(full_path(&self.files.tokenizer_config))?, + }; + + Ok(UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files)) + } + + pub fn missing_files(&self) -> Vec { + self.get_file_mappings() + .iter() + .filter_map(|(_, destination)| { + let path = self.get_destination_path(destination).ok()?; + if !path.exists() { + Some(destination.to_string()) + } else { + None + } + }) + .collect() + } + + pub fn get_model_dir(&self) -> PathBuf { + Path::new(MODEL_BASE_DIR).join(&self.model_name) } pub fn download(&self) -> Result<()> { - let onnx_model = self.get_huggingface_absolute_path(&self.onnx_model); - let special_tokens_map = self.get_huggingface_absolute_path(&self.special_tokens_map); - let tokenizer = self.get_huggingface_absolute_path(&self.tokenizer); - let tokenizer_config = self.get_huggingface_absolute_path(&self.tokenizer_config); - let config = self.get_huggingface_absolute_path(&self.config); - - let client = Self::create_client()?; - - self.download_file(&client, onnx_model, "onnx/model.onnx".to_string())?; - self.download_file( - &client, - special_tokens_map, - "special_tokens_map.json".to_string(), - )?; - self.download_file(&client, tokenizer, "tokenizer.json".to_string())?; - self.download_file( - &client, - tokenizer_config, - "tokenizer_config.json".to_string(), - )?; - self.download_file(&client, config, "config.json".to_string())?; + if self.exists() { + println!("Model {} already exists locally", self.model_name); + return Ok(()); + } + + let missing = self.missing_files(); + if !missing.is_empty() { + println!( + "Model {} is partially downloaded. Missing files: {:?}", + self.model_name, missing + ); + } + + let client = create_client()?; + let files_to_download = self.get_file_mappings(); + + for (url, destination) in files_to_download { + // Skip if file already exists + if self + .get_destination_path(&destination) + .map(|p| p.exists()) + .unwrap_or(false) + { + println!("Skipping existing file: {}", destination); + continue; + } + + self.download_file(&client, &url, &destination) + .with_context(|| format!("Failed to download {}", url))?; + } Ok(()) } - fn get_huggingface_absolute_path(&self, file: &String) -> String { - format!( - "https://huggingface.co/{}/resolve/main/{}", - self.model_name, file - ) + fn get_file_mappings(&self) -> Vec<(String, String)> { + vec![ + (self.files.onnx_model.clone(), "onnx/model.onnx"), + ( + self.files.special_tokens_map.clone(), + "special_tokens_map.json", + ), + (self.files.tokenizer.clone(), "tokenizer.json"), + (self.files.tokenizer_config.clone(), "tokenizer_config.json"), + (self.files.config.clone(), "config.json"), + ] + .into_iter() + .map(|(file, dest)| (self.get_huggingface_url(&file), dest.to_string())) + .collect() } - fn create_client() -> Result { - let mut headers = HeaderMap::new(); - headers.insert( - USER_AGENT, - HeaderValue::from_static("Mozilla/5.0 (compatible; RustBot/1.0)"), - ); - - Client::builder() - .user_agent("Mozilla/5.0 (compatible; RustBot/1.0)") - .default_headers(headers) - .connect_timeout(std::time::Duration::from_secs(10)) - .timeout(std::time::Duration::from_secs(3600)) - .build() - .context("Failed to create HTTP client") + fn get_huggingface_url(&self, file: &str) -> String { + format!("{}/{}/resolve/main/{}", BASE_URL, self.model_name, file) } - fn download_file(&self, client: &Client, url: String, file_name: String) -> Result<()> { - println!("Downloading {} from {}", file_name, url); + fn download_file(&self, client: &Client, url: &str, filename: &str) -> Result<()> { + println!("Downloading {} from {}", filename, url); - let base_path = Path::new(".custom_models").join(&self.model_name); - let destination_path = base_path.join(&file_name); + let dest_path = self.get_destination_path(filename)?; + let response = follow_redirects(client, url)?; - if let Some(parent) = destination_path.parent() { - std::fs::create_dir_all(parent).context("Failed to create directories")?; + if response.status().is_success() { + self.save_file(response, &dest_path)?; + println!("Downloaded {} to {:?}", filename, dest_path); + Ok(()) + } else { + Err(anyhow::anyhow!( + "Download failed with status: {}", + response.status() + )) } + } - let mut response = client - .get(&url) - .send() - .context("Failed to send HTTP request")?; - - let mut final_response = response; - while final_response.status().is_redirection() { - let new_url = final_response - .headers() - .get("location") - .and_then(|h| h.to_str().ok()) - .context("Missing or invalid Location header in redirect")?; - - final_response = client - .get(new_url) - .send() - .context("Failed to follow redirect")?; + fn get_destination_path(&self, filename: &str) -> Result { + let dest_path = self.get_model_dir().join(filename); + + if let Some(parent) = dest_path.parent() { + fs::create_dir_all(parent).context("Failed to create directories")?; } - if final_response.status().is_success() { - let mut dest_file = - File::create(&destination_path).context("Failed to create destination file")?; - - let total_size = final_response - .headers() - .get("content-length") - .and_then(|ct_len| ct_len.to_str().ok()) - .and_then(|ct_len| ct_len.parse::().ok()); - - let mut downloaded = 0u64; - let mut buffer = vec![0u8; 8192]; - let mut last_print = std::time::Instant::now(); - - loop { - let bytes_read = match final_response.read(&mut buffer) { - Ok(0) => break, - Ok(n) => n, - Err(e) => return Err(e.into()), - }; - - dest_file - .write_all(&buffer[..bytes_read]) - .context("Failed to write to file")?; - - downloaded += bytes_read as u64; - - if last_print.elapsed() >= std::time::Duration::from_secs(1) { - if let Some(total) = total_size { - println!( - "Downloaded: {:.1}MB / {:.1}MB ({:.1}%)", - downloaded as f64 / 1_000_000.0, - total as f64 / 1_000_000.0, - (downloaded as f64 / total as f64) * 100.0 - ); - } else { - println!("Downloaded: {:.1}MB", downloaded as f64 / 1_000_000.0); - } - last_print = std::time::Instant::now(); - } - } + Ok(dest_path) + } - println!("Downloaded {} to {:?}", file_name, destination_path); - } else { - return Err(anyhow::anyhow!( - "Failed to download file: {:?}", - final_response.status() - )); + fn save_file(&self, mut response: reqwest::blocking::Response, dest_path: &Path) -> Result<()> { + let mut file = File::create(dest_path).context("Failed to create destination file")?; + let total_size = get_content_length(&response); + let mut downloaded = 0u64; + let mut buffer = vec![0u8; BUFFER_SIZE]; + let mut last_print = Instant::now(); + + loop { + let bytes_read = match response.read(&mut buffer) { + Ok(0) => break, + Ok(n) => n, + Err(e) => return Err(e).context("Failed to read from response"), + }; + + file.write_all(&buffer[..bytes_read]) + .context("Failed to write to file")?; + + downloaded += bytes_read as u64; + print_progress(downloaded, total_size, &mut last_print); } Ok(()) } } + +fn create_client() -> Result { + let mut headers = HeaderMap::new(); + headers.insert(USER_AGENT, HeaderValue::from_static(USER_AGENT_STRING)); + + Client::builder() + .user_agent(USER_AGENT_STRING) + .default_headers(headers) + .connect_timeout(Duration::from_secs(10)) + .timeout(Duration::from_secs(3600)) + .build() + .context("Failed to create HTTP client") +} + +fn follow_redirects(client: &Client, initial_url: &str) -> Result { + let mut response = client + .get(initial_url) + .send() + .context("Failed to send HTTP request")?; + + while response.status().is_redirection() { + let new_url = response + .headers() + .get("location") + .and_then(|h| h.to_str().ok()) + .context("Missing or invalid Location header in redirect")?; + + response = client + .get(new_url) + .send() + .context("Failed to follow redirect")?; + } + + Ok(response) +} + +fn get_content_length(response: &reqwest::blocking::Response) -> Option { + response + .headers() + .get("content-length") + .and_then(|ct_len| ct_len.to_str().ok()) + .and_then(|ct_len| ct_len.parse::().ok()) +} + +fn print_progress(downloaded: u64, total_size: Option, last_print: &mut Instant) { + if last_print.elapsed() >= Duration::from_secs(1) { + match total_size { + Some(total) => println!( + "Downloaded: {:.1}MB / {:.1}MB ({:.1}%)", + downloaded as f64 / 1_000_000.0, + total as f64 / 1_000_000.0, + (downloaded as f64 / total as f64) * 100.0 + ), + None => println!("Downloaded: {:.1}MB", downloaded as f64 / 1_000_000.0), + } + *last_print = Instant::now(); + } +} diff --git a/embeddings/src/lib.rs b/embeddings/src/lib.rs index 4acc95f..c3379e6 100644 --- a/embeddings/src/lib.rs +++ b/embeddings/src/lib.rs @@ -1,14 +1,18 @@ pub mod custom_models; pub mod pq; -use anyhow::{anyhow, Result}; -use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; +use crate::custom_models::{CustomModel, ModelFileConfig}; +use anyhow::{anyhow, Context, Result}; +use fastembed::{ + EmbeddingModel, InitOptions, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel, +}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; use strum::EnumIter; use strum::IntoEnumIterator; +use strum_macros::{AsRefStr, Display}; #[derive(Deserialize, Debug)] pub struct EmbeddingsParams { @@ -30,7 +34,8 @@ pub struct EmbeddingsResponse { embeddings: Vec>, } -#[derive(Deserialize, Debug, Hash, PartialEq, Eq, Copy, Clone, EnumIter)] +#[derive(Deserialize, Debug, Hash, PartialEq, Eq, Copy, Clone, EnumIter, Display, AsRefStr)] +#[strum(serialize_all = "kebab-case")] pub enum OramaModels { #[serde(rename = "gte-small")] GTESmall, @@ -123,6 +128,19 @@ impl OramaModels { OramaModels::MultilingualE5Large => 1024, } } + + pub fn files(self) -> Option { + match self { + OramaModels::JinaV2BaseCode => Some(ModelFileConfig { + onnx_model: "onnx.model.onnx".to_string(), + special_tokens_map: "special_tokens_map.json".to_string(), + tokenizer: "tokenizer.json".to_string(), + tokenizer_config: "tokenizer_config.json".to_string(), + config: "config.json".to_string(), + }), + _ => None, + } + } } impl fmt::Display for EncodingIntent { @@ -157,7 +175,39 @@ pub fn load_models() -> LoadedModels { return (model, initialized_model); } - let initialized_model = unimplemented!(); + let custom_model = CustomModel::try_new(model.to_string(), model.files().unwrap()) + .with_context(|| format!("Unable to initialize custom model {}", model.to_string())) + .unwrap(); + + if custom_model.exists() { + custom_model + .download() + .with_context(|| { + format!("Unable to download custom model {}", model.to_string()) + }) + .unwrap(); + }; + + let init_model = custom_model + .load() + .with_context(|| { + format!( + "Unable to load local files for custom model {}", + model.to_string() + ) + }) + .unwrap(); + let initialized_model = TextEmbedding::try_new_from_user_defined( + init_model, + InitOptionsUserDefined::default(), + ) + .with_context(|| { + format!( + "Unable to initialize a new TextEmbedding instance from custom model {}", + model.to_string() + ) + }) + .unwrap(); return (model, initialized_model); }) From 604c96fa0d13ca42c17769a69498da177329e684 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Tue, 5 Nov 2024 18:47:32 +0100 Subject: [PATCH 3/3] fixes everything yo --- embeddings/src/bin/embeddings.rs | 44 +++++++++++++++++++++++--------- embeddings/src/custom_models.rs | 6 ++--- embeddings/src/lib.rs | 32 ++++++++++++++--------- 3 files changed, 55 insertions(+), 27 deletions(-) diff --git a/embeddings/src/bin/embeddings.rs b/embeddings/src/bin/embeddings.rs index 6cdc40e..0a360d8 100644 --- a/embeddings/src/bin/embeddings.rs +++ b/embeddings/src/bin/embeddings.rs @@ -1,17 +1,37 @@ use anyhow::Result; -use embeddings::custom_models::{ConfigCustomModelFiles, CustomModelFiles}; +use embeddings::custom_models::{CustomModel, ModelFileConfig}; +use embeddings::{load_models, OramaModels}; fn main() -> Result<()> { - let model_files = CustomModelFiles::new( - "jinaai/jina-embeddings-v2-base-code".to_string(), - ConfigCustomModelFiles { - onnx_model: "onnx/model.onnx".to_string(), - config: "config.json".to_string(), - tokenizer_config: "tokenizer_config.json".to_string(), - tokenizer: "tokenizer.json".to_string(), - special_tokens_map: "special_tokens_map.json".to_string(), - }, - ); + let models = load_models(); - model_files.download() + let embedding = models.embed( + OramaModels::JinaV2BaseCode, + vec![r" + /** + * This method is needed to used because of issues like: https://github.com/askorama/orama/issues/301 + * that issue is caused because the array that is pushed is huge (>100k) + * + * @example + * ```ts + * safeArrayPush(myArray, [1, 2]) + * ``` + */ + export function safeArrayPush(arr: T[], newArr: T[]): void { + if (newArr.length < MAX_ARGUMENT_FOR_STACK) { + Array.prototype.push.apply(arr, newArr) + } else { + const newArrLength = newArr.length + for (let i = 0; i < newArrLength; i += MAX_ARGUMENT_FOR_STACK) { + Array.prototype.push.apply(arr, newArr.slice(i, i + MAX_ARGUMENT_FOR_STACK)) + } + } + } + ".to_string()], + Some(1), + )?; + + dbg!(embedding.first().unwrap()); + + Ok(()) } diff --git a/embeddings/src/custom_models.rs b/embeddings/src/custom_models.rs index bc2127c..3956e46 100644 --- a/embeddings/src/custom_models.rs +++ b/embeddings/src/custom_models.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use fastembed::{QuantizationMode, TokenizerFiles, UserDefinedEmbeddingModel}; +use fastembed::{TokenizerFiles, UserDefinedEmbeddingModel}; use reqwest::{ blocking::Client, header::{HeaderMap, HeaderValue, USER_AGENT}, @@ -58,7 +58,8 @@ impl CustomModel { } pub fn load(&self) -> Result { - let full_path = |file: &str| format!("{}/{}", MODEL_BASE_DIR, file); + let full_path = |file: &str| format!("{}/{}/{}", MODEL_BASE_DIR, self.model_name, file); + let onnx_file = fs::read(full_path(&self.files.onnx_model))?; let tokenizer_files = TokenizerFiles { tokenizer_file: fs::read(full_path(&self.files.tokenizer))?, @@ -106,7 +107,6 @@ impl CustomModel { let files_to_download = self.get_file_mappings(); for (url, destination) in files_to_download { - // Skip if file already exists if self .get_destination_path(&destination) .map(|p| p.exists()) diff --git a/embeddings/src/lib.rs b/embeddings/src/lib.rs index c3379e6..1768fb0 100644 --- a/embeddings/src/lib.rs +++ b/embeddings/src/lib.rs @@ -35,21 +35,27 @@ pub struct EmbeddingsResponse { } #[derive(Deserialize, Debug, Hash, PartialEq, Eq, Copy, Clone, EnumIter, Display, AsRefStr)] -#[strum(serialize_all = "kebab-case")] pub enum OramaModels { #[serde(rename = "gte-small")] + #[strum(serialize = "gte-small")] GTESmall, #[serde(rename = "gte-base")] + #[strum(serialize = "gte-base")] GTEBase, #[serde(rename = "gte-large")] + #[strum(serialize = "gte-large")] GTELarge, #[serde(rename = "multilingual-e5-small")] + #[strum(serialize = "multilingual-e5-small")] MultilingualE5Small, #[serde(rename = "multilingual-e5-base")] + #[strum(serialize = "multilingual-e5-base")] MultilingualE5Base, #[serde(rename = "multilingual-e5-large")] + #[strum(serialize = "multilingual-e5-large")] MultilingualE5Large, #[serde(rename = "jinaai/jina-embeddings-v2-base-code")] + #[strum(serialize = "jinaai/jina-embeddings-v2-base-code")] JinaV2BaseCode, } @@ -71,16 +77,18 @@ impl LoadedModels { } } -impl Into for OramaModels { - fn into(self) -> EmbeddingModel { +impl TryInto for OramaModels { + type Error = anyhow::Error; + + fn try_into(self) -> std::result::Result { match self { - OramaModels::JinaV2BaseCode => EmbeddingModel::BGESmallENV15, // @todo: understand how to use other models - OramaModels::GTESmall => EmbeddingModel::BGESmallENV15, - OramaModels::GTEBase => EmbeddingModel::BGEBaseENV15, - OramaModels::GTELarge => EmbeddingModel::BGELargeENV15, - OramaModels::MultilingualE5Small => EmbeddingModel::MultilingualE5Small, - OramaModels::MultilingualE5Base => EmbeddingModel::MultilingualE5Base, - OramaModels::MultilingualE5Large => EmbeddingModel::MultilingualE5Large, + OramaModels::GTESmall => Ok(EmbeddingModel::BGESmallENV15), + OramaModels::GTEBase => Ok(EmbeddingModel::BGEBaseENV15), + OramaModels::GTELarge => Ok(EmbeddingModel::BGELargeENV15), + OramaModels::MultilingualE5Small => Ok(EmbeddingModel::MultilingualE5Small), + OramaModels::MultilingualE5Base => Ok(EmbeddingModel::MultilingualE5Base), + OramaModels::MultilingualE5Large => Ok(EmbeddingModel::MultilingualE5Large), + OramaModels::JinaV2BaseCode => Err(anyhow!("JinaV2BaseCode is a custom model")), } } } @@ -132,7 +140,7 @@ impl OramaModels { pub fn files(self) -> Option { match self { OramaModels::JinaV2BaseCode => Some(ModelFileConfig { - onnx_model: "onnx.model.onnx".to_string(), + onnx_model: "onnx/model.onnx".to_string(), special_tokens_map: "special_tokens_map.json".to_string(), tokenizer: "tokenizer.json".to_string(), tokenizer_config: "tokenizer_config.json".to_string(), @@ -168,7 +176,7 @@ pub fn load_models() -> LoadedModels { .map(|model| { if !model.is_custom_model() { let initialized_model = TextEmbedding::try_new( - InitOptions::new(model.into()).with_show_download_progress(true), + InitOptions::new(model.try_into().unwrap()).with_show_download_progress(true), ) .unwrap();