Skip to content

Commit

Permalink
some fixes + removed lazystatic with std lib
Browse files Browse the repository at this point in the history
  • Loading branch information
itsmeadarsh2008 committed Jul 9, 2024
1 parent 25a64d0 commit 4c49edc
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::wrap_pyfunction;
use regex::{Captures, Regex, RegexBuilder};
use std::collections::HashMap;
use std::sync::Mutex;
use lazy_static::lazy_static;
use std::sync::OnceLock;

#[pyclass]
struct Pattern {
Expand All @@ -16,6 +16,7 @@ struct Match {
#[allow(dead_code)]
mat: regex::Match<'static>,
captures: Captures<'static>,
text: String,
}

#[pyclass]
Expand All @@ -35,9 +36,10 @@ struct Constants;
#[pyclass]
struct Sre;

// Global cache for compiled regex patterns
lazy_static! {
static ref REGEX_CACHE: Mutex<HashMap<(String, u32), Regex>> = Mutex::new(HashMap::new());
static REGEX_CACHE: OnceLock<Mutex<HashMap<(String, u32), Regex>>> = OnceLock::new();

fn get_regex_cache() -> &'static Mutex<HashMap<(String, u32), Regex>> {
REGEX_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}

#[pymethods]
Expand All @@ -56,31 +58,29 @@ impl Match {

fn start(&self, idx: usize) -> Option<usize> {
self.captures.get(idx).map(|m| {
self.captures.get(0).unwrap().as_str()[..m.start()].chars().count()
self.text[..m.start()].chars().count()
})
}

fn end(&self, idx: usize) -> Option<usize> {
self.captures.get(idx).map(|m| {
self.captures.get(0).unwrap().as_str()[..m.end()].chars().count()
self.text[..m.end()].chars().count()
})
}

fn span(&self, idx: usize) -> Option<(usize, usize)> {
self.captures.get(idx).map(|m| {
let full_match = self.captures.get(0).unwrap().as_str();
let start = full_match[..m.start()].chars().count();
let end = full_match[..m.end()].chars().count();
let start = self.text[..m.start()].chars().count();
let end = self.text[..m.end()].chars().count();
(start, end)
})
}
}

#[pyfunction]
#[pyo3(signature = (pattern, flags=None))]
fn compile(pattern: &str, flags: Option<u32>) -> PyResult<Pattern> {
let flags = flags.unwrap_or(0);
let mut cache = REGEX_CACHE.lock().unwrap();
let mut cache = get_regex_cache().lock().unwrap();

if let Some(regex) = cache.get(&(pattern.to_string(), flags)) {
return Ok(Pattern { regex: regex.clone() });
Expand Down Expand Up @@ -113,6 +113,7 @@ fn search(pattern: &Pattern, text: &str) -> PyResult<Option<Match>> {
Ok(Some(Match {
mat: unsafe { std::mem::transmute(mat) },
captures: unsafe { std::mem::transmute(captures) },
text: text.to_string(),
}))
}).unwrap_or(Ok(None))
}
Expand All @@ -125,6 +126,7 @@ fn fmatch(pattern: &Pattern, text: &str) -> PyResult<Option<Match>> {
Some(Ok(Some(Match {
mat: unsafe { std::mem::transmute(mat) },
captures: unsafe { std::mem::transmute(captures) },
text: text.to_string(),
})))
} else {
None
Expand All @@ -140,6 +142,7 @@ fn fullmatch(pattern: &Pattern, text: &str) -> PyResult<Option<Match>> {
Some(Ok(Some(Match {
mat: unsafe { std::mem::transmute(mat) },
captures: unsafe { std::mem::transmute(captures) },
text: text.to_string(),
})))
} else {
None
Expand Down Expand Up @@ -171,6 +174,7 @@ fn finditer(pattern: &Pattern, text: &str) -> PyResult<Vec<Match>> {
Match {
mat: unsafe { std::mem::transmute(mat) },
captures: unsafe { std::mem::transmute(captures) },
text: text.to_string(),
}
})
.collect())
Expand All @@ -195,7 +199,7 @@ fn escape(text: &str) -> PyResult<String> {

#[pyfunction]
fn purge() -> PyResult<()> {
REGEX_CACHE.lock().unwrap().clear();
get_regex_cache().lock().unwrap().clear();
Ok(())
}

Expand Down

0 comments on commit 4c49edc

Please sign in to comment.