Skip to content

Commit

Permalink
Merge pull request #15 from oramasearch/feat/adds-llm-module
Browse files Browse the repository at this point in the history
feat: adds llm module
  • Loading branch information
micheleriva authored Nov 6, 2024
2 parents 6af319a + 54e86a8 commit 71ff3b6
Show file tree
Hide file tree
Showing 15 changed files with 468 additions and 58 deletions.
302 changes: 278 additions & 24 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ resolver = "2"
members = [
"string_index",
"nlp",
"content_expander",
"web_server",
"collection_manager",
"storage",
Expand All @@ -14,5 +13,6 @@ members = [
"tanstack_example",
"code_parser",
"vector_index",
"utils"
"utils",
"llm"
]
17 changes: 15 additions & 2 deletions content_expander/Cargo.toml → llm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
[package]
name = "content_expander"
name = "llm"
version = "0.1.0"
edition = "2021"

[[bin]]
name = "test_image"
path = "src/bin/test_image.rs"
path = "../llm/src/bin/test_image.rs"

[[bin]]
name = "test_code"
path = "../llm/src/bin/test_code.rs"

[[bin]]
name = "questions"
path = "../llm/src/bin/questions.rs"

[dependencies]
linkify = "0.10.0"
Expand All @@ -23,6 +31,11 @@ markdown = "0.3.0"
html_parser = "0.7.0"
futures = "0.3.31"
html-escape = "0.2.13"
utils = { path = "../utils" }
serde_json = "1.0.132"
strum = "0.26.3"
async-std = "1.13.0"
async-once-cell = "0.5.4"

[target.'cfg(not(target_os = "macos"))'.dependencies]
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" }
Expand Down
30 changes: 30 additions & 0 deletions llm/src/bin/questions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use futures::executor::block_on;
use llm::questions_generation::generator::generate_questions;
use textwrap::dedent;

fn main() {
let context = dedent(
r"
Introduction
When we say that Orama Cloud is batteries-included, we mean that it gives you everything you need to start searching and generating answers (conversations) without any complex configuration. Out of the box, Orama Cloud also includes:
🧩 Native and Custom integrations to easily import your data.
🚀 Web Components to easily integrate a full-featured Searchbox on your website in no time.
📊 Quality checks, analytics and quality control tools to fine-tune your users experience.
🔐 Secure proxy configuration and advanced security options.
and much more…
Basic concepts
At the core of Orama Cloud, there are three simple concepts:
📖 Index: a collection of documents that you can search through.
📄 Schema: a set of rules that define how the documents are structured.
🗿 Immutability: once you’ve created an index and populated it with documents, it will remain immutable. To change the content of an index, you have to perform a re-deployment.
With your index, you can perform full-text, vector, and hybrid search queries, as well as generative conversations. Add your data, define the schema, and you’re ready to go!
",
);

let questions = block_on(generate_questions(context)).unwrap();

dbg!(questions);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use content_expander::code;
use llm::content_expander::code;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use content_expander::prompts::Prompts;
use content_expander::vision::describe_images;
use llm::content_expander::prompts::Prompts;
use llm::content_expander::vision::describe_images;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
Expand All @@ -9,9 +9,8 @@ async fn main() -> Result<(), anyhow::Error> {
Foo bar baz.
";

let results = describe_images(example_text.to_string(), Prompts::VisionECommerce)
.await
.unwrap();
let results = describe_images(example_text.to_string(), Prompts::VisionECommerce).await?;

dbg!(results);

Ok(())
Expand Down
14 changes: 4 additions & 10 deletions content_expander/src/code.rs → llm/src/content_expander/code.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::prompts::{get_prompt, Prompts};
use crate::content_expander::prompts::{get_prompt, Prompts};
use crate::LocalLLM;
use html_parser::{Dom, Node};
use mistralrs::{IsqType, TextMessageRole, TextMessages, TextModelBuilder};
use mistralrs::{TextMessageRole, TextMessages};

type CodeBlockDescriptions = Vec<String>;

Expand All @@ -11,18 +12,11 @@ pub enum TextFormat {
Plaintext,
}

const TEXT_MODEL_ID: &str = "microsoft/Phi-3.5-mini-instruct";

pub async fn describe_code_blocks(
text: String,
format: TextFormat,
) -> Option<CodeBlockDescriptions> {
let model = TextModelBuilder::new(TEXT_MODEL_ID)
.with_isq(IsqType::Q8_0)
.with_logging()
.build()
.await
.unwrap();
let model = LocalLLM::Phi3_5MiniInstruct.try_new().await.unwrap();

let code_blocks = capture_code_blocks(text, format)?;

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::prompts::{get_prompt, Prompts};
use crate::content_expander::prompts::{get_prompt, Prompts};
use crate::LocalLLM;
use anyhow::Context;
use linkify::{LinkFinder, LinkKind};
use mistralrs::{IsqType, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder};
Expand Down Expand Up @@ -116,11 +117,7 @@ pub async fn describe_images(
text: String,
domain: Prompts,
) -> Result<Vec<(String, String)>, anyhow::Error> {
let model = VisionModelBuilder::new(VISION_MODEL_ID, VisionLoaderType::Phi3V)
.with_isq(IsqType::Q4K)
.with_logging()
.build()
.await?;
let model = LocalLLM::Phi3_5VisionInstruct.try_new().await?;

let image_links = UrlParser::try_new(UrlParserConfig {
domains_allow_list: vec![],
Expand Down
50 changes: 50 additions & 0 deletions llm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use anyhow::{Context, Result};
use async_once_cell::OnceCell;
use async_std::sync::RwLock;
use mistralrs::{IsqType, Model, TextModelBuilder};
use serde::Serialize;
use std::collections::HashMap;
use std::sync::Arc;
use strum::{Display, EnumIter};

pub mod content_expander;
pub mod questions_generation;

static MODELS: OnceCell<RwLock<HashMap<LocalLLM, Arc<Model>>>> = OnceCell::new();

#[derive(Serialize, EnumIter, Eq, PartialEq, Hash, Clone, Display)]
pub enum LocalLLM {
#[serde(rename = "microsoft/Phi-3.5-mini-instruct")]
#[strum(serialize = "microsoft/Phi-3.5-mini-instruct")]
Phi3_5MiniInstruct,

#[serde(rename = "microsoft/Phi-3.5-vision-instruct")]
#[strum(serialize = "microsoft/Phi-3.5-vision-instruct")]
Phi3_5VisionInstruct,
}

impl LocalLLM {
async fn try_new(&self) -> Result<Arc<Model>> {
MODELS
.get_or_init(async {
let mut models_map = HashMap::new();
let model = TextModelBuilder::new(self)
.with_isq(IsqType::Q8_0)
.with_logging()
.build()
.await
.with_context(|| "Failed to build the text model")
.unwrap();

models_map.insert(self.clone(), Arc::new(model));
RwLock::new(models_map)
})
.await;

let models = MODELS.get().unwrap().read().await;
models
.get(self)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Model not found"))
}
}
48 changes: 48 additions & 0 deletions llm/src/questions_generation/generator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use crate::questions_generation::prompts::{
get_questions_generation_prompt, QUESTIONS_GENERATION_SYSTEM_PROMPT,
};
use crate::LocalLLM;
use anyhow::{Context, Result};
use mistralrs::{IsqType, TextMessageRole, TextMessages, TextModelBuilder};
use serde_json::Value;
use textwrap::dedent;
use utils::parse_json_safely;

pub async fn generate_questions(context: String) -> Result<Vec<String>> {
let model = LocalLLM::Phi3_5MiniInstruct.try_new().await?;

let messages = TextMessages::new()
.add_message(
TextMessageRole::System,
dedent(QUESTIONS_GENERATION_SYSTEM_PROMPT),
)
.add_message(
TextMessageRole::User,
get_questions_generation_prompt(context),
);

let response = model
.send_chat_request(messages)
.await
.context("Failed to send chat request")?;

if let Some(content) = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
{
match parse_json_safely(content) {
Ok(Value::Array(json_array)) => {
let questions: Vec<String> = json_array
.iter()
.filter_map(|val| val.as_str().map(|s| s.to_string()))
.collect();
Ok(questions)
}
Ok(_) => anyhow::bail!("Parsed content is not an array of strings"),
Err(e) => Err(e).context("Failed to parse response content as JSON"),
}
} else {
anyhow::bail!("No content in the response");
}
}
2 changes: 2 additions & 0 deletions llm/src/questions_generation/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod generator;
mod prompts;
23 changes: 23 additions & 0 deletions llm/src/questions_generation/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use textwrap::dedent;

pub const QUESTIONS_GENERATION_SYSTEM_PROMPT: &str = r#"
Pretend you're a user searching on Google, a forum, or a blog. Your task is to generate a list of questions that relates to the the context (### Context).
For example, if the context was the following:
```
At Orama, we specialize in edge-application development. This allows us to build high-performance, low-latency applications distributed via global CDNs. In other words, we prioritize performance and security when developing software.
```
Valid questions would look like the following:
```json
["What does Orama specialize on?", "Is Orama a low-latency edge application?", "Do Orama prioritize security when developing software?"]
```
Reply with a valid array of strings in a JSON format and nothing more.
"#;

pub fn get_questions_generation_prompt(context: String) -> String {
format!("### Context\n\n{}", context)
}
16 changes: 8 additions & 8 deletions utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use regex::Regex;
use serde_json::Value;

pub fn parse_json_safely(input_str: &str) -> Result<Value> {
pub fn parse_json_safely(input_str: String) -> Result<Value> {
if let Ok(parsed) = serde_json::from_str(input_str.trim()) {
return Ok(parsed);
}
Expand Down Expand Up @@ -30,7 +30,7 @@ mod tests {

#[test]
fn test_parse_valid_json() {
let input = r#"{"key": "value"}"#;
let input = r#"{"key": "value"}"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"key": "value"}));
Expand All @@ -43,7 +43,7 @@ mod tests {
"name": "John",
"age": 30
}
```"#;
```"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"name": "John", "age": 30}));
Expand All @@ -55,7 +55,7 @@ mod tests {
{
"status": "ok"
}
```"#;
```"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"status": "ok"}));
Expand All @@ -69,15 +69,15 @@ mod tests {
"foo": "bar"
}
Some text after
"#;
"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"foo": "bar"}));
}

#[test]
fn test_invalid_json_returns_error() {
let input = r#"Invalid JSON string"#;
let input = r#"Invalid JSON string"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_err());
}
Expand All @@ -87,7 +87,7 @@ mod tests {
let input = r#"
{
"incomplete": true,
"#;
"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_err());
}
Expand All @@ -100,7 +100,7 @@ mod tests {
"key": "value"
}
Followed by more text.
"#;
"#.to_string();
let result = parse_json_safely(input);
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"key": "value"}));
Expand Down

0 comments on commit 71ff3b6

Please sign in to comment.