From 84141b8846fbf7732f2486eaee4a7862b9c66621 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 15:07:51 -0700 Subject: [PATCH] update sem_agg with new max_workers setting --- examples/op_examples/agg_with_grouping.py | 4 ---- lotus/sem_ops/sem_agg.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/op_examples/agg_with_grouping.py b/examples/op_examples/agg_with_grouping.py index bbd34cf..05f925b 100644 --- a/examples/op_examples/agg_with_grouping.py +++ b/examples/op_examples/agg_with_grouping.py @@ -8,10 +8,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.settings.configure(enable_multithreading=True) - -# turn on lotus debug logging -lotus.logger.setLevel("DEBUG") data = { "Course Name": [ diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 3362000..7bb67ab 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -191,7 +191,7 @@ def __call__( group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: return pd.concat(list(executor.map(SemAggDataframe.process_group, group_args))) # Sort df by partition_id if it exists