Skip to content

Commit

Permalink
replaced multiprocessing with multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruviyer committed Dec 26, 2024
1 parent 6c8a647 commit 69955bf
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 48 deletions.
18 changes: 3 additions & 15 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Any

import pandas as pd
Expand All @@ -8,11 +7,6 @@
from lotus.types import LMOutput, SemanticAggOutput


def initializer(settings, log_level):
lotus.logger.setLevel(log_level)
lotus.settings.clone(settings)


def sem_agg(
docs: list[str],
model: lotus.models.LM,
Expand Down Expand Up @@ -151,7 +145,6 @@ def _validate(obj: Any) -> None:

@staticmethod
def process_group(args):
lotus.logger.debug(f"Processing in PID: {os.getpid()}")
group, user_instruction, all_cols, suffix, progress_bar_desc = args
return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)

Expand Down Expand Up @@ -196,15 +189,10 @@ def __call__(
if group_by:
grouped = self._obj.groupby(group_by)
group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped]
if lotus.settings.enable_multithreading:
lotus.logger.debug("Using multithreading")
from multiprocessing import Pool
from concurrent.futures import ThreadPoolExecutor

with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool:
return pd.concat(pool.map(SemAggDataframe.process_group, group_args))
else:
lotus.logger.debug("Not using multithreading")
return pd.concat([SemAggDataframe.process_group(group_arg) for group_arg in group_args])
with ThreadPoolExecutor() as executor:
return pd.concat(list(executor.map(SemAggDataframe.process_group, group_args)))

# Sort df by partition_id if it exists
if "_lotus_partition_id" in self._obj.columns:
Expand Down
14 changes: 3 additions & 11 deletions lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
from lotus.utils import show_safe_mode


def initializer(settings, log_level):
lotus.logger.setLevel(log_level)
lotus.settings.clone(settings)


def get_match_prompt_binary(
doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -438,13 +433,10 @@ def __call__(
(group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats)
for _, group in grouped
]
if lotus.settings.enable_multithreading:
from multiprocessing import Pool
from concurrent.futures import ThreadPoolExecutor

with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool:
results = pool.map(SemTopKDataframe.process_group, group_args)
else:
results = [SemTopKDataframe.process_group(group_arg) for group_arg in group_args]
with ThreadPoolExecutor() as executor:
results = list(executor.map(SemTopKDataframe.process_group, group_args))

new_df = pd.concat([res[0] for res in results])
stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)}
Expand Down
7 changes: 0 additions & 7 deletions lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,12 @@ class Settings:
# Serialization setting
serialization_format: SerializationFormat = SerializationFormat.DEFAULT

# Multithreading settings
enable_multithreading: bool = False

def configure(self, **kwargs):
for key, value in kwargs.items():
if not hasattr(self, key):
raise ValueError(f"Invalid setting: {key}")
setattr(self, key, value)

def clone(self, other_settings):
for key in vars(other_settings):
setattr(self, key, getattr(other_settings, key))

def __str__(self):
return str(vars(self))

Expand Down
17 changes: 2 additions & 15 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from lotus.models import LM
from lotus.settings import SerializationFormat, Settings


Expand All @@ -16,23 +15,11 @@ def test_initial_values(self, settings):
assert settings.reranker is None
assert settings.enable_cache is False
assert settings.serialization_format == SerializationFormat.DEFAULT
assert settings.enable_multithreading is False

def test_configure_method(self, settings):
settings.configure(enable_multithreading=True)
assert settings.enable_multithreading is True
settings.configure(enable_cache=True)
assert settings.enable_cache is True

def test_invalid_setting(self, settings):
with pytest.raises(ValueError, match="Invalid setting: invalid_setting"):
settings.configure(invalid_setting=True)

def test_clone_method(self, settings):
other_settings = Settings()
lm = LM(model="test-model")
other_settings.lm = lm
other_settings.enable_cache = True

settings.clone(other_settings)

assert settings.lm == lm
assert settings.enable_cache is True

0 comments on commit 69955bf

Please sign in to comment.