diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index d5003d0..3362000 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -1,4 +1,3 @@ -import os from typing import Any import pandas as pd @@ -8,11 +7,6 @@ from lotus.types import LMOutput, SemanticAggOutput -def initializer(settings, log_level): - lotus.logger.setLevel(log_level) - lotus.settings.clone(settings) - - def sem_agg( docs: list[str], model: lotus.models.LM, @@ -151,7 +145,6 @@ def _validate(obj: Any) -> None: @staticmethod def process_group(args): - lotus.logger.debug(f"Processing in PID: {os.getpid()}") group, user_instruction, all_cols, suffix, progress_bar_desc = args return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) @@ -196,15 +189,10 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] - if lotus.settings.enable_multithreading: - lotus.logger.debug("Using multithreading") - from multiprocessing import Pool + from concurrent.futures import ThreadPoolExecutor - with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: - return pd.concat(pool.map(SemAggDataframe.process_group, group_args)) - else: - lotus.logger.debug("Not using multithreading") - return pd.concat([SemAggDataframe.process_group(group_arg) for group_arg in group_args]) + with ThreadPoolExecutor() as executor: + return pd.concat(list(executor.map(SemAggDataframe.process_group, group_args))) # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 937d3da..2e2e918 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -12,11 +12,6 @@ from lotus.utils import show_safe_mode -def initializer(settings, log_level): - lotus.logger.setLevel(log_level) - lotus.settings.clone(settings) - - def get_match_prompt_binary( doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None ) -> list[dict[str, Any]]: @@ -438,13 +433,10 @@ def __call__( (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) for _, group in grouped ] - if lotus.settings.enable_multithreading: - from multiprocessing import Pool + from concurrent.futures import ThreadPoolExecutor - with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: - results = pool.map(SemTopKDataframe.process_group, group_args) - else: - results = [SemTopKDataframe.process_group(group_arg) for group_arg in group_args] + with ThreadPoolExecutor() as executor: + results = list(executor.map(SemTopKDataframe.process_group, group_args)) new_df = pd.concat([res[0] for res in results]) stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} diff --git a/lotus/settings.py b/lotus/settings.py index 3fe3250..1841df6 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -17,19 +17,12 @@ class Settings: # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT - # Multithreading settings - enable_multithreading: bool = False - def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") setattr(self, key, value) - def clone(self, other_settings): - for key in vars(other_settings): - setattr(self, key, getattr(other_settings, key)) - def __str__(self): return str(vars(self)) diff --git a/tests/test_settings.py b/tests/test_settings.py index a578b3c..dc6f871 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,6 +1,5 @@ import pytest -from lotus.models import LM from lotus.settings import SerializationFormat, Settings @@ -16,23 +15,11 @@ def test_initial_values(self, settings): assert settings.reranker is None assert settings.enable_cache is False assert settings.serialization_format == SerializationFormat.DEFAULT - assert settings.enable_multithreading is False def test_configure_method(self, settings): - settings.configure(enable_multithreading=True) - assert settings.enable_multithreading is True + settings.configure(enable_cache=True) + assert settings.enable_cache is True def test_invalid_setting(self, settings): with pytest.raises(ValueError, match="Invalid setting: invalid_setting"): settings.configure(invalid_setting=True) - - def test_clone_method(self, settings): - other_settings = Settings() - lm = LM(model="test-model") - other_settings.lm = lm - other_settings.enable_cache = True - - settings.clone(other_settings) - - assert settings.lm == lm - assert settings.enable_cache is True