Skip to content

Commit

Permalink
Merge pull request #12 from oramasearch/feat/improve-tanstack-result-…
Browse files Browse the repository at this point in the history
…code-quality

Use Counter as Scorer for Code
  • Loading branch information
allevo authored Nov 5, 2024
2 parents 97b2e12 + b883ed1 commit 9300580
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 60 deletions.
34 changes: 32 additions & 2 deletions code_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl CodeParser {
let _module = parser.parse_module();
// We ignore errors, so we can parse as much as possible
// TODO: should we collect this information and return it to the user?
// .map_err(|e| e.into_diagnostic(&handler).emit())
// .expect("Failed to parse module.");
// .map_err(|e| e.into_diagnostic(&handler).emit())
// .expect("Failed to parse module.");

let tokens: Vec<_> = parser.input().take();

Expand Down Expand Up @@ -264,4 +264,34 @@ export default function RootLayout({ children }) {
vec![("th".to_string(), vec![]), ("key".to_string(), vec![])]
);
}

#[test]
fn test_1() {
// This code is not parsable from swc.
// The parser stops and returns only "initialState"
let code = r###"initialState?: Partial<
VisibilityTableState &
ColumnOrderTableState &
ColumnPinningTableState &
FiltersTableState &
SortingTableState &
ExpandedTableState &
GroupingTableState &
ColumnSizingTableState &
PaginationTableState &
RowSelectionTableState
>"###;

let parser = CodeParser::from_language(CodeLanguage::TSX);

let t = parser.tokenize_and_stem(code).unwrap();

assert_eq!(
t,
vec![(
"initialstate".to_string(),
vec!["initial".to_string(), "state".to_string()]
)]
);
}
}
98 changes: 95 additions & 3 deletions collection_manager/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use nlp::{locales::Locale, TextParser};
use ordered_float::NotNan;
use serde_json::Value;
use storage::Storage;
use string_index::StringIndex;
use string_index::{
scorer::{bm25::BM25Score, counter::CounterScore},
StringIndex,
};
use types::{
CollectionId, DocumentId, DocumentList, FieldId, ScalarType, SearchResult, SearchResultHit,
StringParser, TokenScore, ValueType,
Expand Down Expand Up @@ -97,7 +100,6 @@ impl Collection {

let parser = self.parsers.get(&field_id).unwrap_or(&self.default_parser);

println!("tokenizing doc {doc:?}");
let tokens = parser.tokenize_str_and_stem(&value)?;

strings
Expand All @@ -119,6 +121,24 @@ impl Collection {
}

pub fn search(&self, search_params: SearchParams) -> Result<SearchResult, anyhow::Error> {
// TODO: handle search_params.properties

let boost: HashMap<_, _> = search_params
.boost
.into_iter()
.map(|(field_name, boost)| {
let field_id = self.get_field_id(field_name);
(field_id, boost)
})
.collect();
let properties: Vec<_> = match search_params.properties {
Some(properties) => properties
.into_iter()
.map(|p| self.get_field_id(p))
.collect(),
None => self.string_fields.iter().map(|e| *e.value()).collect(),
};

let tokens: Vec<_> = self
.default_parser
.tokenize_str_and_stem(&search_params.term)?
Expand All @@ -129,7 +149,79 @@ impl Collection {
terms
})
.collect();
let token_scores = self.string_index.search(tokens, None, None)?;

let fields_on_search_with_default_parser: Vec<_> = self
.string_fields
.iter()
.filter(|field_id| !self.parsers.contains_key(field_id.value()))
.filter(|field_id| properties.contains(field_id.value()))
.map(|field_id| *field_id.value())
.collect();
println!(
"searching for tokens {:?}",
fields_on_search_with_default_parser,
);
let mut token_scores = self.string_index.search(
tokens,
Some(fields_on_search_with_default_parser),
boost.clone(),
BM25Score,
)?;
println!(
"Element found with default parser: {:?}",
token_scores.len()
);

// Depends on the self.parsers size, this loop can be optimized, parallelizing the search.
// But for now, we will keep it simple.
// TODO: think about how to parallelize this search
for (field_id, parser) in &self.parsers {
if !properties.contains(field_id) {
continue;
}
let tokens: Vec<_> = parser
.tokenize_str_and_stem(&search_params.term)?
.into_iter()
.flat_map(|(token, stemmed)| {
let mut terms = vec![token];
terms.extend(stemmed);
terms
})
.collect();

let field_token_scores = self.string_index.search(
tokens,
Some(vec![*field_id]),
boost.clone(),
CounterScore,
)?;

let field_name = self
.string_fields
.iter()
.find(|v| v.value() == field_id)
.unwrap();
let field_name = field_name.key();
println!(
"Element found with parser: {field_id:?} ({:?}) {:?}",
field_name,
field_token_scores.len()
);

// Merging scores that come from different parsers are hard.
// Because we are focused on PoC with tanstack, this case happens only with "code" field.
// We use just a simple counter to merge the scores.
// Anyway, this it not a good solution for a real-world application.
// TODO: think about how to merge scores from different parsers
for (document_id, score) in field_token_scores {
if let Some(existing_score) = token_scores.get(&document_id) {
token_scores.insert(document_id, existing_score + score);
} else {
token_scores.insert(document_id, score);
}
}
}

let count = token_scores.len();

let token_scores = top_n(token_scores, search_params.limit.0);
Expand Down
4 changes: 4 additions & 0 deletions collection_manager/src/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,8 @@ pub struct SearchParams {
pub term: String,
#[serde(default)]
pub limit: Limit,
#[serde(default)]
pub boost: HashMap<String, f32>,
#[serde(default)]
pub properties: Option<Vec<String>>,
}
91 changes: 90 additions & 1 deletion collection_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ mod tests {
use serde_json::json;
use storage::Storage;
use tempdir::TempDir;
use types::CodeLanguage;

use crate::dto::{CreateCollectionOptionDTO, Limit, SearchParams};
use crate::dto::{CreateCollectionOptionDTO, Limit, SearchParams, TypedField};

use super::CollectionManager;

Expand Down Expand Up @@ -260,6 +261,8 @@ mod tests {
let search_params = SearchParams {
term: "Tommaso".to_string(),
limit: Limit(10),
boost: Default::default(),
properties: Default::default(),
};
collection.search(search_params)
});
Expand Down Expand Up @@ -309,6 +312,8 @@ mod tests {
let search_params = SearchParams {
term: "text".to_string(),
limit: Limit(10),
boost: Default::default(),
properties: Default::default(),
};
collection.search(search_params)
});
Expand Down Expand Up @@ -357,6 +362,8 @@ mod tests {
let search_params = SearchParams {
term: "text".to_string(),
limit: Limit(10),
boost: Default::default(),
properties: Default::default(),
};
collection.search(search_params)
});
Expand All @@ -372,4 +379,86 @@ mod tests {
assert_eq!(output.hits[3].id, "96");
assert_eq!(output.hits[4].id, "95");
}

#[test]
fn test_foo() {
let manager = create_manager();
let collection_id_str = "my-test-collection".to_string();

let collection_id = manager
.create_collection(CreateCollectionOptionDTO {
id: collection_id_str.clone(),
description: Some("Collection of songs".to_string()),
language: None,
typed_fields: vec![("code".to_string(), TypedField::Code(CodeLanguage::TSX))]
.into_iter()
.collect(),
})
.expect("insertion should be successful");

manager.get(collection_id.clone(), |collection| {
collection.insert_batch(
vec![
json!({
"id": "1",
"code": r#"
import { TableController, type SortingState } from '@tanstack/lit-table'
//...
@state()
private _sorting: SortingState = [
{
id: 'age', //you should get autocomplete for the `id` and `desc` properties
desc: true,
}
]
"#,
}),
json!({
"id": "2",
"code": r#"export type RowSelectionState = Record<string, boolean>
export type RowSelectionTableState = {
rowSelection: RowSelectionState
}"#,
}),
json!({
"id": "3",
"code": r#"initialState?: Partial<
VisibilityTableState &
ColumnOrderTableState &
ColumnPinningTableState &
FiltersTableState &
SortingTableState &
ExpandedTableState &
GroupingTableState &
ColumnSizingTableState &
PaginationTableState &
RowSelectionTableState
>"#,
}),
json!({
"id": "4",
"code": r#"setColumnVisibility: (updater: Updater<VisibilityState>) => void"#,
})
]
.try_into()
.unwrap(),
)
});

let output = manager
.get(collection_id, |collection| {
let search_params = SearchParams {
term: "SelectionTableState".to_string(),
limit: Limit(10),
boost: Default::default(),
properties: Default::default(),
};
collection.search(search_params)
})
.unwrap()
.unwrap();

println!("{:#?}", output);
}
}
Loading

0 comments on commit 9300580

Please sign in to comment.