Skip to content

Commit

Permalink
feat: update TextSplitter implementation
Browse files Browse the repository at this point in the history
This commit updates the TextSplitter logic to use  crate for splitting text. It also removes the usage of  and simplifies the  struct by removing unnecessary fields. Additionally, it adds the  option to . It also simplifies the implementation of  by using the newly adopted  crate.
  • Loading branch information
Abraxas-365 committed Mar 7, 2024
1 parent ee60f11 commit fa564bf
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 299 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ tiktoken-rs = "0.5.8"
sqlx = { version = "0.7.3", features = ["postgres", "runtime-tokio-native-tls", "json", "uuid" ], optional = true }
uuid = {version = "1.7.0", features = ["v4"], optional = true }
pgvector = {version = "0.3.2", features = ["postgres", "sqlx"], optional = true }
text-splitter = { version = "0.6", features = ["tiktoken-rs"] }


[features]
default = []
Expand Down
2 changes: 0 additions & 2 deletions src/text_splitter/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
mod options;
mod recursive_character_splitter;
mod text_splitter;
mod token_splitter;

pub use options::*;
pub use recursive_character_splitter::*;
pub use text_splitter::*;
pub use token_splitter::*;
59 changes: 4 additions & 55 deletions src/text_splitter/options.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
use super::TextSplitter;

// Options is a struct that contains options for a text splitter.
pub struct SplitterOptions {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub separators: Vec<String>,
pub len_func: fn(&str) -> usize,
pub model_name: String,
pub encoding_name: String,
pub allowed_special: Vec<String>,
pub disallowed_special: Vec<String>,
pub code_blocks: bool,
pub reference_links: bool,
pub second_splitter: Option<Box<dyn TextSplitter>>,
pub trim_chunks: bool,
}

impl Default for SplitterOptions {
Expand All @@ -25,16 +16,9 @@ impl SplitterOptions {
pub fn new() -> Self {
SplitterOptions {
chunk_size: 512,
chunk_overlap: 100,
separators: vec!["\n\n".into(), "\n".into(), " ".into(), "".into()],
len_func: |s| s.chars().count(),
model_name: String::from("gpt-3.5-turbo"),
encoding_name: String::from("cl100k_base"),
allowed_special: Vec::new(),
disallowed_special: Vec::from(["all".into()]),
code_blocks: false,
second_splitter: None,
reference_links: false,
trim_chunks: false,
}
}
}
Expand All @@ -46,21 +30,6 @@ impl SplitterOptions {
self
}

pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.chunk_overlap = chunk_overlap;
self
}

pub fn with_separators(mut self, separators: Vec<&str>) -> Self {
self.separators = separators.into_iter().map(String::from).collect();
self
}

pub fn with_len_func(mut self, len_func: fn(&str) -> usize) -> Self {
self.len_func = len_func;
self
}

pub fn with_model_name(mut self, model_name: &str) -> Self {
self.model_name = String::from(model_name);
self
Expand All @@ -71,28 +40,8 @@ impl SplitterOptions {
self
}

pub fn with_allowed_special(mut self, allowed_special: Vec<&str>) -> Self {
self.allowed_special = allowed_special.into_iter().map(String::from).collect();
self
}

pub fn with_disallowed_special(mut self, disallowed_special: Vec<&str>) -> Self {
self.disallowed_special = disallowed_special.into_iter().map(String::from).collect();
self
}

pub fn with_code_blocks(mut self, code_blocks: bool) -> Self {
self.code_blocks = code_blocks;
self
}

pub fn with_reference_links(mut self, reference_links: bool) -> Self {
self.reference_links = reference_links;
self
}

pub fn with_second_splitter<TS: TextSplitter + 'static>(mut self, second_splitter: TS) -> Self {
self.second_splitter = Some(Box::new(second_splitter));
pub fn with_trim_chunks(mut self, trim_chunks: bool) -> Self {
self.trim_chunks = trim_chunks;
self
}
}
143 changes: 0 additions & 143 deletions src/text_splitter/recursive_character_splitter.rs

This file was deleted.

71 changes: 0 additions & 71 deletions src/text_splitter/text_splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,74 +44,3 @@ pub trait TextSplitter {
Ok(documents)
}
}

pub(crate) fn join_documents(docs: &[String], separator: &str) -> Option<String> {
let joined = docs.join(separator).trim().to_string();
Some(joined).filter(|s| !s.is_empty())
}

pub(crate) fn merge_splits(
splits: &[String],
separator: &str,
chunk_size: usize,
chunk_overlap: usize,
len_func: fn(&str) -> usize,
) -> Vec<String> {
let mut docs: Vec<String> = Vec::new();
let mut current_doc: Vec<String> = Vec::new();
let mut total = 0;
for split in splits {
let mut total_with_split = total + len_func(split);
if !current_doc.is_empty() {
total_with_split += len_func(separator);
}
if total_with_split > chunk_size && !current_doc.is_empty() {
let doc = join_documents(&current_doc, separator);
if let Some(doc) = doc {
docs.push(doc);
}
while should_pop(
chunk_overlap,
chunk_size,
total,
len_func(split),
len_func(separator),
current_doc.len(),
) {
total -= len_func(&current_doc[0]);
if current_doc.len() > 1 {
total -= len_func(separator);
}
current_doc.remove(0);
}
}
current_doc.push(split.to_string());
total += len_func(split);
if current_doc.len() > 1 {
total += len_func(separator);
}
}
let doc = join_documents(&current_doc, separator);
if let Some(doc) = doc {
docs.push(doc);
}
docs
}

pub(crate) fn should_pop(
chunk_overlap: usize,
chunk_size: usize,
total: usize,
split_len: usize,
separator_len: usize,
current_doc_len: usize,
) -> bool {
let docs_needed_to_add_sep = 2;
let separator_len = match current_doc_len < docs_needed_to_add_sep {
true => 0,
false => separator_len,
};

current_doc_len > 0
&& (total > chunk_overlap || (total + split_len + separator_len > chunk_size && total > 0))
}
Loading

0 comments on commit fa564bf

Please sign in to comment.