From 09b625f102e12049d8c7c86b6df6b3741acdff3e Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 09:37:08 -0700 Subject: [PATCH 1/9] parallelized sem_agg --- examples/op_examples/agg.py | 92 ++++++++++++++++++++++++++++++------- lotus/sem_ops/sem_agg.py | 32 ++++++++----- lotus/settings.py | 12 ++++- tests/test_settings.py | 38 +++++++++++++++ 4 files changed, 146 insertions(+), 28 deletions(-) create mode 100644 tests/test_settings.py diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index f04f84a..164fdcd 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,21 +1,81 @@ import pandas as pd +import time import lotus from lotus.models import LM -lm = LM(model="gpt-4o-mini") - -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - "Cooking", - "Food Sciences", - ] -} -df = pd.DataFrame(data) -df = df.sem_agg("Summarize all {Course Name}") -print(df._output[0]) +def main(): + lm = LM(model="gpt-4o-mini") + + lotus.settings.configure(lm=lm) + lotus.settings.configure(enable_multithreading=True) + + start_time = time.time() + + # turn on lotus debug logging + lotus.logger.setLevel("DEBUG") + + data = { + "Course Name": [ + "Probability and Random Processes", "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", "Computer Security", "Cooking", + "Food Sciences", "Machine Learning", "Data Structures and Algorithms", + "Quantum Mechanics", "Organic Chemistry", "Artificial Intelligence", "Robotics", + "Thermodynamics", "Fluid Mechanics", "Molecular Biology", "Genetics", + "Astrophysics", "Neuroscience", "Microeconomics", "Macroeconomics", + "Linear Algebra", "Calculus", "Statistics", "Differential Equations", + "Discrete Mathematics", "Number Theory", "Graph Theory", "Topology", + "Complex Analysis", "Real Analysis", "Abstract Algebra", "Numerical Methods", + "Cryptography", "Network Security", "Operating Systems", "Databases", + "Computer Networks", "Software Engineering", "Compilers", "Computer Architecture", + "Parallel Computing", "Distributed Systems", "Cloud Computing", "Big Data Analytics", + "Natural Language Processing", "Computer Vision", "Reinforcement Learning", + "Deep Learning", "Bioinformatics", "Computational Biology", "Systems Biology", + "Biochemistry", "Physical Chemistry", "Inorganic Chemistry", "Analytical Chemistry", + "Environmental Chemistry", "Materials Science", "Nanotechnology", "Optics", + "Electromagnetism", "Nuclear Physics", "Particle Physics", "Cosmology", + "Planetary Science", "Geophysics", "Atmospheric Science", "Oceanography", + "Ecology", "Evolutionary Biology", "Botany", "Zoology", "Microbiology", + "Immunology", "Virology", "Pharmacology", "Physiology", "Anatomy", + "Neurobiology", "Cognitive Science", "Psychology", "Sociology", "Anthropology", + "Archaeology", "Linguistics", "Philosophy", "Ethics", "Logic", + "Political Science", "International Relations", "Public Policy", "Economics", + "Finance", "Accounting", "Marketing", "Management", "Entrepreneurship", + "Law", "Criminal Justice", "Human Rights", "Environmental Studies", + "Sustainability", "Urban Planning", "Architecture", "Civil Engineering", + "Mechanical Engineering", "Electrical Engineering", "Chemical Engineering", + "Aerospace Engineering", "Biomedical Engineering", "Environmental Engineering" + ], + "Grade Level": [ + "High School", "Graduate", "Graduate", "High School", "Undergraduate", + "Undergraduate", "High School", "Undergraduate", "High School", "Undergraduate", + "High School", "Graduate", "Undergraduate", "Undergraduate", "Graduate", + "Undergraduate", "Graduate", "Graduate", "Undergraduate", "Undergraduate", + "Undergraduate", "Undergraduate", "High School", "High School", "Undergraduate", + "Graduate", "Graduate", "Graduate", "High School", "Graduate", "Graduate", "Graduate", + "Graduate", "High School", "Undergraduate", "High School", "Undergraduate", + "Undergraduate", "Graduate", "Undergraduate", "Undergraduate", "Graduate", "Graduate", + "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", + "Undergraduate", "Graduate", "Undergraduate", "High School", "Graduate", "Graduate", + "Graduate", "High School", "Graduate", "High School", "Graduate", "Graduate", + "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", + "High School", "High School", "High School", "Undergraduate", "Graduate", + "Graduate", "Graduate", "High School", "Undergraduate", "Undergraduate", + "Graduate", "Graduate", "Undergraduate", "Undergraduate", "Undergraduate", + "High School", "High School", "Graduate", "Graduate", "High School", "Graduate", + "Graduate", "Graduate", "Undergraduate", "Undergraduate", "Undergraduate", "Undergraduate", + "High School", "High School", "Graduate", "Undergraduate", "Undergraduate", "Undergraduate", + "Undergraduate", "Undergraduate", "Undergraduate", "Graduate", "Graduate", + "Graduate", "Graduate", "Graduate", "Graduate" + ], + } + + df = pd.DataFrame(data) + df = df.sem_agg("Summarize all {Course Name}", group_by=["Grade Level"]) + print(df._output[0]) + + end_time = time.time() + print(f"Total execution time: {end_time - start_time:.2f} seconds") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index dfb934b..da507ae 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -1,4 +1,5 @@ from typing import Any +import os import pandas as pd @@ -6,6 +7,9 @@ from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticAggOutput +def initializer(settings, log_level): + lotus.logger.setLevel(log_level) + lotus.settings.clone(settings) def sem_agg( docs: list[str], @@ -142,6 +146,12 @@ def __init__(self, pandas_obj: Any): @staticmethod def _validate(obj: Any) -> None: pass + + @staticmethod + def process_group(args): + lotus.logger.debug(f"Processing in PID: {os.getpid()}") + group, user_instruction, all_cols, suffix, progress_bar_desc = args + return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) def __call__( self, @@ -163,7 +173,7 @@ def __call__( Returns: pd.DataFrame: The dataframe with the aggregated answer. """ - + # print all the settings values if lotus.settings.lm is None: raise ValueError( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" @@ -181,18 +191,18 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"column {column} not found in DataFrame. Given usr instruction: {user_instruction}") - - - if group_by: grouped = self._obj.groupby(group_by) - new_df = pd.DataFrame() - for name, group in grouped: - res = group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) - new_df = pd.concat([new_df, res]) - return new_df - - + group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] + if lotus.settings.enable_multithreading: + lotus.logger.debug("Using multithreading") + from multiprocessing import Pool + with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: + return pd.concat(pool.map(SemAggDataframe.process_group, group_args)) + else: + lotus.logger.debug("Not using multithreading") + return pd.concat([SemAggDataframe.process_group(group_arg) for group_arg in group_args]) + # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: diff --git a/lotus/settings.py b/lotus/settings.py index a39be43..04d6c66 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,6 +1,7 @@ import lotus.models from lotus.types import SerializationFormat +# NOTE: Settings class is not thread-safe class Settings: # Models @@ -15,11 +16,20 @@ class Settings: # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT + # Multithreading settings + enable_multithreading: bool = False + def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") setattr(self, key, value) + def clone(self, other_settings): + for key in vars(other_settings): + setattr(self, key, getattr(other_settings, key)) + + def __str__(self): + return str(vars(self)) -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..f028e04 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,38 @@ +import pytest +from lotus.settings import Settings, SerializationFormat +from lotus.models import LM, RM, Reranker, SentenceTransformersRM + +class TestSettings: + @pytest.fixture + def settings(self): + return Settings() + + def test_initial_values(self, settings): + assert settings.lm is None + assert settings.rm is None + assert settings.helper_lm is None + assert settings.reranker is None + assert settings.enable_cache is False + assert settings.serialization_format == SerializationFormat.DEFAULT + assert settings.enable_multithreading is False + + def test_configure_method(self, settings): + settings.configure( + enable_multithreading=True + ) + assert settings.enable_multithreading is True + + def test_invalid_setting(self, settings): + with pytest.raises(ValueError, match="Invalid setting: invalid_setting"): + settings.configure(invalid_setting=True) + + def test_clone_method(self, settings): + other_settings = Settings() + lm = LM(model="test-model") + other_settings.lm = lm + other_settings.enable_cache = True + + settings.clone(other_settings) + + assert settings.lm == lm + assert settings.enable_cache is True \ No newline at end of file From 3f6c8c37724611fe42a70925fd9b9bc09a756feb Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 12:22:57 -0700 Subject: [PATCH 2/9] parallelized sem_topk --- examples/op_examples/agg.py | 6 ++-- examples/op_examples/top_k.py | 57 +++++++++++++++++++++++------------ lotus/sem_ops/sem_agg.py | 2 +- lotus/sem_ops/sem_topk.py | 52 +++++++++++++++++++------------- 4 files changed, 71 insertions(+), 46 deletions(-) diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index 164fdcd..7d805de 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -10,8 +10,6 @@ def main(): lotus.settings.configure(lm=lm) lotus.settings.configure(enable_multithreading=True) - start_time = time.time() - # turn on lotus debug logging lotus.logger.setLevel("DEBUG") @@ -71,10 +69,10 @@ def main(): } df = pd.DataFrame(data) + start_time = time.time() df = df.sem_agg("Summarize all {Course Name}", group_by=["Grade Level"]) - print(df._output[0]) - end_time = time.time() + print(df._output[0]) print(f"Total execution time: {end_time - start_time:.2f} seconds") if __name__ == '__main__': diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 8654ea1..3a8fca4 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -1,27 +1,44 @@ import pandas as pd +import time import lotus from lotus.models import LM -lm = LM(model="gpt-4o-mini") +def main(): + lm = LM(model="gpt-4o-mini") -lotus.settings.configure(lm=lm) -data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - ] -} -df = pd.DataFrame(data) + lotus.settings.configure(lm=lm) + lotus.settings.configure(enable_multithreading=True) + + data = { + "Department": ["Math", "Physics", "Computer Science", "Biology"] * 7, + "Course Name": [ + "Calculus", "Quantum Mechanics", "Data Structures", "Genetics", + "Linear Algebra", "Thermodynamics", "Algorithms", "Ecology", + "Statistics", "Optics", "Machine Learning", "Molecular Biology", + "Number Theory", "Relativity", "Computer Networks", "Evolutionary Biology", + "Differential Equations", "Particle Physics", "Operating Systems", "Biochemistry", + "Complex Analysis", "Fluid Dynamics", "Artificial Intelligence", "Microbiology", + "Topology", "Astrophysics", "Cybersecurity", "Immunology" + ] + } -for method in ["quick", "heap", "naive"]: - sorted_df, stats = df.sem_topk( - "Which {Course Name} requires the least math?", - K=2, - method=method, - return_stats=True, - ) - print(sorted_df) - print(stats) + df = pd.DataFrame(data) + + for method in ["quick", "heap", "naive"]: + + start_time = time.time() + sorted_df, stats = df.sem_topk( + "Which {Course Name} is the most challenging?", + K=2, + method=method, + return_stats=True, + group_by=["Department"] + ) + end_time = time.time() + print(sorted_df) + print(stats) + print(f"Total execution time: {end_time - start_time:.2f} seconds") + +if __name__ == '__main__': + main() diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index da507ae..a012d0a 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -173,7 +173,7 @@ def __call__( Returns: pd.DataFrame: The dataframe with the aggregated answer. """ - # print all the settings values + if lotus.settings.lm is None: raise ValueError( "The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()" diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 507d844..92f5181 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -11,6 +11,9 @@ from lotus.types import LMOutput, SemanticTopKOutput from lotus.utils import show_safe_mode +def initializer(settings, log_level): + lotus.logger.setLevel(log_level) + lotus.settings.clone(settings) def get_match_prompt_binary( doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None @@ -373,6 +376,19 @@ def __init__(self, pandas_obj: Any) -> None: def _validate(obj: Any) -> None: pass + @staticmethod + def process_group(args): + group, user_instruction, K, method, strategy, group_by, cascade_threshold, return_stats = args + return group.sem_topk( + user_instruction, + K, + method=method, + strategy=strategy, + group_by=None, + cascade_threshold=cascade_threshold, + return_stats=return_stats, + ) + def __call__( self, user_instruction: str, @@ -416,30 +432,24 @@ def __call__( # Separate code path for grouping if group_by: grouped = self._obj.groupby(group_by) - new_df = pd.DataFrame() - stats = {} - for name, group in grouped: - res = group.sem_topk( - user_instruction, - K, - method=method, - strategy=strategy, - group_by=None, - cascade_threshold=cascade_threshold, - return_stats=return_stats, - ) - - if return_stats: - sorted_group, group_stats = res - stats[name] = group_stats - else: - sorted_group = res - - new_df = pd.concat([new_df, sorted_group]) - + group_args = [ + (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) + for _, group in grouped + ] + if lotus.settings.enable_multithreading: + from multiprocessing import Pool + + with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: + results = pool.map(SemTopKDataframe.process_group, group_args) + else: + results = [SemTopKDataframe.process_group(group_arg) for group_arg in group_args] + + new_df = pd.concat([res[0] for res in results]) + stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} if return_stats: return new_df, stats return new_df + if method == "quick-sem": assert len(col_li) == 1, "Only one column can be used for embedding optimization" From 7c085b092aa254c94f30add8dda44ee1d5b03be6 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 12:26:41 -0700 Subject: [PATCH 3/9] ran ruff check and ruff format --- examples/op_examples/agg.py | 277 +++++++++++++++++++++++++++------- examples/op_examples/top_k.py | 49 ++++-- lotus/sem_ops/sem_agg.py | 12 +- lotus/sem_ops/sem_topk.py | 29 ++-- lotus/settings.py | 6 +- tests/test_settings.py | 12 +- 6 files changed, 294 insertions(+), 91 deletions(-) diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index 7d805de..dd47f19 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,9 +1,11 @@ -import pandas as pd import time +import pandas as pd + import lotus from lotus.models import LM + def main(): lm = LM(model="gpt-4o-mini") @@ -15,56 +17,228 @@ def main(): data = { "Course Name": [ - "Probability and Random Processes", "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", "Computer Security", "Cooking", - "Food Sciences", "Machine Learning", "Data Structures and Algorithms", - "Quantum Mechanics", "Organic Chemistry", "Artificial Intelligence", "Robotics", - "Thermodynamics", "Fluid Mechanics", "Molecular Biology", "Genetics", - "Astrophysics", "Neuroscience", "Microeconomics", "Macroeconomics", - "Linear Algebra", "Calculus", "Statistics", "Differential Equations", - "Discrete Mathematics", "Number Theory", "Graph Theory", "Topology", - "Complex Analysis", "Real Analysis", "Abstract Algebra", "Numerical Methods", - "Cryptography", "Network Security", "Operating Systems", "Databases", - "Computer Networks", "Software Engineering", "Compilers", "Computer Architecture", - "Parallel Computing", "Distributed Systems", "Cloud Computing", "Big Data Analytics", - "Natural Language Processing", "Computer Vision", "Reinforcement Learning", - "Deep Learning", "Bioinformatics", "Computational Biology", "Systems Biology", - "Biochemistry", "Physical Chemistry", "Inorganic Chemistry", "Analytical Chemistry", - "Environmental Chemistry", "Materials Science", "Nanotechnology", "Optics", - "Electromagnetism", "Nuclear Physics", "Particle Physics", "Cosmology", - "Planetary Science", "Geophysics", "Atmospheric Science", "Oceanography", - "Ecology", "Evolutionary Biology", "Botany", "Zoology", "Microbiology", - "Immunology", "Virology", "Pharmacology", "Physiology", "Anatomy", - "Neurobiology", "Cognitive Science", "Psychology", "Sociology", "Anthropology", - "Archaeology", "Linguistics", "Philosophy", "Ethics", "Logic", - "Political Science", "International Relations", "Public Policy", "Economics", - "Finance", "Accounting", "Marketing", "Management", "Entrepreneurship", - "Law", "Criminal Justice", "Human Rights", "Environmental Studies", - "Sustainability", "Urban Planning", "Architecture", "Civil Engineering", - "Mechanical Engineering", "Electrical Engineering", "Chemical Engineering", - "Aerospace Engineering", "Biomedical Engineering", "Environmental Engineering" + "Probability and Random Processes", + "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", + "Computer Security", + "Cooking", + "Food Sciences", + "Machine Learning", + "Data Structures and Algorithms", + "Quantum Mechanics", + "Organic Chemistry", + "Artificial Intelligence", + "Robotics", + "Thermodynamics", + "Fluid Mechanics", + "Molecular Biology", + "Genetics", + "Astrophysics", + "Neuroscience", + "Microeconomics", + "Macroeconomics", + "Linear Algebra", + "Calculus", + "Statistics", + "Differential Equations", + "Discrete Mathematics", + "Number Theory", + "Graph Theory", + "Topology", + "Complex Analysis", + "Real Analysis", + "Abstract Algebra", + "Numerical Methods", + "Cryptography", + "Network Security", + "Operating Systems", + "Databases", + "Computer Networks", + "Software Engineering", + "Compilers", + "Computer Architecture", + "Parallel Computing", + "Distributed Systems", + "Cloud Computing", + "Big Data Analytics", + "Natural Language Processing", + "Computer Vision", + "Reinforcement Learning", + "Deep Learning", + "Bioinformatics", + "Computational Biology", + "Systems Biology", + "Biochemistry", + "Physical Chemistry", + "Inorganic Chemistry", + "Analytical Chemistry", + "Environmental Chemistry", + "Materials Science", + "Nanotechnology", + "Optics", + "Electromagnetism", + "Nuclear Physics", + "Particle Physics", + "Cosmology", + "Planetary Science", + "Geophysics", + "Atmospheric Science", + "Oceanography", + "Ecology", + "Evolutionary Biology", + "Botany", + "Zoology", + "Microbiology", + "Immunology", + "Virology", + "Pharmacology", + "Physiology", + "Anatomy", + "Neurobiology", + "Cognitive Science", + "Psychology", + "Sociology", + "Anthropology", + "Archaeology", + "Linguistics", + "Philosophy", + "Ethics", + "Logic", + "Political Science", + "International Relations", + "Public Policy", + "Economics", + "Finance", + "Accounting", + "Marketing", + "Management", + "Entrepreneurship", + "Law", + "Criminal Justice", + "Human Rights", + "Environmental Studies", + "Sustainability", + "Urban Planning", + "Architecture", + "Civil Engineering", + "Mechanical Engineering", + "Electrical Engineering", + "Chemical Engineering", + "Aerospace Engineering", + "Biomedical Engineering", + "Environmental Engineering", ], "Grade Level": [ - "High School", "Graduate", "Graduate", "High School", "Undergraduate", - "Undergraduate", "High School", "Undergraduate", "High School", "Undergraduate", - "High School", "Graduate", "Undergraduate", "Undergraduate", "Graduate", - "Undergraduate", "Graduate", "Graduate", "Undergraduate", "Undergraduate", - "Undergraduate", "Undergraduate", "High School", "High School", "Undergraduate", - "Graduate", "Graduate", "Graduate", "High School", "Graduate", "Graduate", "Graduate", - "Graduate", "High School", "Undergraduate", "High School", "Undergraduate", - "Undergraduate", "Graduate", "Undergraduate", "Undergraduate", "Graduate", "Graduate", - "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", - "Undergraduate", "Graduate", "Undergraduate", "High School", "Graduate", "Graduate", - "Graduate", "High School", "Graduate", "High School", "Graduate", "Graduate", - "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", "Graduate", - "High School", "High School", "High School", "Undergraduate", "Graduate", - "Graduate", "Graduate", "High School", "Undergraduate", "Undergraduate", - "Graduate", "Graduate", "Undergraduate", "Undergraduate", "Undergraduate", - "High School", "High School", "Graduate", "Graduate", "High School", "Graduate", - "Graduate", "Graduate", "Undergraduate", "Undergraduate", "Undergraduate", "Undergraduate", - "High School", "High School", "Graduate", "Undergraduate", "Undergraduate", "Undergraduate", - "Undergraduate", "Undergraduate", "Undergraduate", "Graduate", "Graduate", - "Graduate", "Graduate", "Graduate", "Graduate" + "High School", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "Undergraduate", + "High School", + "Undergraduate", + "High School", + "Undergraduate", + "High School", + "Graduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "High School", + "Undergraduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "High School", + "High School", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Graduate", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", ], } @@ -75,5 +249,6 @@ def main(): print(df._output[0]) print(f"Total execution time: {end_time - start_time:.2f} seconds") -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 3a8fca4..1d1eabd 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -1,44 +1,67 @@ -import pandas as pd import time +import pandas as pd + import lotus from lotus.models import LM + def main(): lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) lotus.settings.configure(enable_multithreading=True) - + data = { "Department": ["Math", "Physics", "Computer Science", "Biology"] * 7, "Course Name": [ - "Calculus", "Quantum Mechanics", "Data Structures", "Genetics", - "Linear Algebra", "Thermodynamics", "Algorithms", "Ecology", - "Statistics", "Optics", "Machine Learning", "Molecular Biology", - "Number Theory", "Relativity", "Computer Networks", "Evolutionary Biology", - "Differential Equations", "Particle Physics", "Operating Systems", "Biochemistry", - "Complex Analysis", "Fluid Dynamics", "Artificial Intelligence", "Microbiology", - "Topology", "Astrophysics", "Cybersecurity", "Immunology" - ] + "Calculus", + "Quantum Mechanics", + "Data Structures", + "Genetics", + "Linear Algebra", + "Thermodynamics", + "Algorithms", + "Ecology", + "Statistics", + "Optics", + "Machine Learning", + "Molecular Biology", + "Number Theory", + "Relativity", + "Computer Networks", + "Evolutionary Biology", + "Differential Equations", + "Particle Physics", + "Operating Systems", + "Biochemistry", + "Complex Analysis", + "Fluid Dynamics", + "Artificial Intelligence", + "Microbiology", + "Topology", + "Astrophysics", + "Cybersecurity", + "Immunology", + ], } df = pd.DataFrame(data) for method in ["quick", "heap", "naive"]: - start_time = time.time() sorted_df, stats = df.sem_topk( "Which {Course Name} is the most challenging?", K=2, method=method, return_stats=True, - group_by=["Department"] + group_by=["Department"], ) end_time = time.time() print(sorted_df) print(stats) print(f"Total execution time: {end_time - start_time:.2f} seconds") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index a012d0a..d5003d0 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -1,5 +1,5 @@ -from typing import Any import os +from typing import Any import pandas as pd @@ -7,10 +7,12 @@ from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticAggOutput + def initializer(settings, log_level): lotus.logger.setLevel(log_level) lotus.settings.clone(settings) + def sem_agg( docs: list[str], model: lotus.models.LM, @@ -146,7 +148,7 @@ def __init__(self, pandas_obj: Any): @staticmethod def _validate(obj: Any) -> None: pass - + @staticmethod def process_group(args): lotus.logger.debug(f"Processing in PID: {os.getpid()}") @@ -197,13 +199,13 @@ def __call__( if lotus.settings.enable_multithreading: lotus.logger.debug("Using multithreading") from multiprocessing import Pool + with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: - return pd.concat(pool.map(SemAggDataframe.process_group, group_args)) + return pd.concat(pool.map(SemAggDataframe.process_group, group_args)) else: lotus.logger.debug("Not using multithreading") return pd.concat([SemAggDataframe.process_group(group_arg) for group_arg in group_args]) - - + # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: self._obj = self._obj.sort_values(by="_lotus_partition_id") diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 92f5181..937d3da 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -11,10 +11,12 @@ from lotus.types import LMOutput, SemanticTopKOutput from lotus.utils import show_safe_mode + def initializer(settings, log_level): lotus.logger.setLevel(log_level) lotus.settings.clone(settings) + def get_match_prompt_binary( doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None ) -> list[dict[str, Any]]: @@ -380,14 +382,14 @@ def _validate(obj: Any) -> None: def process_group(args): group, user_instruction, K, method, strategy, group_by, cascade_threshold, return_stats = args return group.sem_topk( - user_instruction, - K, - method=method, - strategy=strategy, - group_by=None, - cascade_threshold=cascade_threshold, - return_stats=return_stats, - ) + user_instruction, + K, + method=method, + strategy=strategy, + group_by=None, + cascade_threshold=cascade_threshold, + return_stats=return_stats, + ) def __call__( self, @@ -433,23 +435,22 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) group_args = [ - (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) - for _, group in grouped - ] + (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) + for _, group in grouped + ] if lotus.settings.enable_multithreading: from multiprocessing import Pool - + with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: results = pool.map(SemTopKDataframe.process_group, group_args) else: results = [SemTopKDataframe.process_group(group_arg) for group_arg in group_args] - + new_df = pd.concat([res[0] for res in results]) stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} if return_stats: return new_df, stats return new_df - if method == "quick-sem": assert len(col_li) == 1, "Only one column can be used for embedding optimization" diff --git a/lotus/settings.py b/lotus/settings.py index 04d6c66..3fe3250 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -3,6 +3,7 @@ # NOTE: Settings class is not thread-safe + class Settings: # Models lm: lotus.models.LM | None = None @@ -28,8 +29,9 @@ def configure(self, **kwargs): def clone(self, other_settings): for key in vars(other_settings): setattr(self, key, getattr(other_settings, key)) - + def __str__(self): return str(vars(self)) -settings = Settings() \ No newline at end of file + +settings = Settings() diff --git a/tests/test_settings.py b/tests/test_settings.py index f028e04..a578b3c 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,6 +1,8 @@ import pytest -from lotus.settings import Settings, SerializationFormat -from lotus.models import LM, RM, Reranker, SentenceTransformersRM + +from lotus.models import LM +from lotus.settings import SerializationFormat, Settings + class TestSettings: @pytest.fixture @@ -17,9 +19,7 @@ def test_initial_values(self, settings): assert settings.enable_multithreading is False def test_configure_method(self, settings): - settings.configure( - enable_multithreading=True - ) + settings.configure(enable_multithreading=True) assert settings.enable_multithreading is True def test_invalid_setting(self, settings): @@ -35,4 +35,4 @@ def test_clone_method(self, settings): settings.clone(other_settings) assert settings.lm == lm - assert settings.enable_cache is True \ No newline at end of file + assert settings.enable_cache is True From 6c8a647d805d38d3af6fe44c72fe208b3512fbd7 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 12:59:42 -0700 Subject: [PATCH 4/9] reverted agg and topk examples and replaced with two new files for testing multithreading --- examples/op_examples/agg.py | 265 ++------------------ examples/op_examples/agg_with_grouping.py | 248 ++++++++++++++++++ examples/op_examples/top_k.py | 84 ++----- examples/op_examples/top_k_with_grouping.py | 61 +++++ 4 files changed, 347 insertions(+), 311 deletions(-) create mode 100644 examples/op_examples/agg_with_grouping.py create mode 100644 examples/op_examples/top_k_with_grouping.py diff --git a/examples/op_examples/agg.py b/examples/op_examples/agg.py index dd47f19..f04f84a 100644 --- a/examples/op_examples/agg.py +++ b/examples/op_examples/agg.py @@ -1,254 +1,21 @@ -import time - import pandas as pd import lotus from lotus.models import LM - -def main(): - lm = LM(model="gpt-4o-mini") - - lotus.settings.configure(lm=lm) - lotus.settings.configure(enable_multithreading=True) - - # turn on lotus debug logging - lotus.logger.setLevel("DEBUG") - - data = { - "Course Name": [ - "Probability and Random Processes", - "Optimization Methods in Engineering", - "Digital Design and Integrated Circuits", - "Computer Security", - "Cooking", - "Food Sciences", - "Machine Learning", - "Data Structures and Algorithms", - "Quantum Mechanics", - "Organic Chemistry", - "Artificial Intelligence", - "Robotics", - "Thermodynamics", - "Fluid Mechanics", - "Molecular Biology", - "Genetics", - "Astrophysics", - "Neuroscience", - "Microeconomics", - "Macroeconomics", - "Linear Algebra", - "Calculus", - "Statistics", - "Differential Equations", - "Discrete Mathematics", - "Number Theory", - "Graph Theory", - "Topology", - "Complex Analysis", - "Real Analysis", - "Abstract Algebra", - "Numerical Methods", - "Cryptography", - "Network Security", - "Operating Systems", - "Databases", - "Computer Networks", - "Software Engineering", - "Compilers", - "Computer Architecture", - "Parallel Computing", - "Distributed Systems", - "Cloud Computing", - "Big Data Analytics", - "Natural Language Processing", - "Computer Vision", - "Reinforcement Learning", - "Deep Learning", - "Bioinformatics", - "Computational Biology", - "Systems Biology", - "Biochemistry", - "Physical Chemistry", - "Inorganic Chemistry", - "Analytical Chemistry", - "Environmental Chemistry", - "Materials Science", - "Nanotechnology", - "Optics", - "Electromagnetism", - "Nuclear Physics", - "Particle Physics", - "Cosmology", - "Planetary Science", - "Geophysics", - "Atmospheric Science", - "Oceanography", - "Ecology", - "Evolutionary Biology", - "Botany", - "Zoology", - "Microbiology", - "Immunology", - "Virology", - "Pharmacology", - "Physiology", - "Anatomy", - "Neurobiology", - "Cognitive Science", - "Psychology", - "Sociology", - "Anthropology", - "Archaeology", - "Linguistics", - "Philosophy", - "Ethics", - "Logic", - "Political Science", - "International Relations", - "Public Policy", - "Economics", - "Finance", - "Accounting", - "Marketing", - "Management", - "Entrepreneurship", - "Law", - "Criminal Justice", - "Human Rights", - "Environmental Studies", - "Sustainability", - "Urban Planning", - "Architecture", - "Civil Engineering", - "Mechanical Engineering", - "Electrical Engineering", - "Chemical Engineering", - "Aerospace Engineering", - "Biomedical Engineering", - "Environmental Engineering", - ], - "Grade Level": [ - "High School", - "Graduate", - "Graduate", - "High School", - "Undergraduate", - "Undergraduate", - "High School", - "Undergraduate", - "High School", - "Undergraduate", - "High School", - "Graduate", - "Undergraduate", - "Undergraduate", - "Graduate", - "Undergraduate", - "Graduate", - "Graduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "High School", - "High School", - "Undergraduate", - "Graduate", - "Graduate", - "Graduate", - "High School", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "High School", - "Undergraduate", - "High School", - "Undergraduate", - "Undergraduate", - "Graduate", - "Undergraduate", - "Undergraduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Undergraduate", - "Graduate", - "Undergraduate", - "High School", - "Graduate", - "Graduate", - "Graduate", - "High School", - "Graduate", - "High School", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "High School", - "High School", - "High School", - "Undergraduate", - "Graduate", - "Graduate", - "Graduate", - "High School", - "Undergraduate", - "Undergraduate", - "Graduate", - "Graduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "High School", - "High School", - "Graduate", - "Graduate", - "High School", - "Graduate", - "Graduate", - "Graduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "High School", - "High School", - "Graduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "Undergraduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - "Graduate", - ], - } - - df = pd.DataFrame(data) - start_time = time.time() - df = df.sem_agg("Summarize all {Course Name}", group_by=["Grade Level"]) - end_time = time.time() - print(df._output[0]) - print(f"Total execution time: {end_time - start_time:.2f} seconds") - - -if __name__ == "__main__": - main() +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) +data = { + "Course Name": [ + "Probability and Random Processes", + "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", + "Computer Security", + "Cooking", + "Food Sciences", + ] +} +df = pd.DataFrame(data) +df = df.sem_agg("Summarize all {Course Name}") +print(df._output[0]) diff --git a/examples/op_examples/agg_with_grouping.py b/examples/op_examples/agg_with_grouping.py new file mode 100644 index 0000000..bbd34cf --- /dev/null +++ b/examples/op_examples/agg_with_grouping.py @@ -0,0 +1,248 @@ +import time + +import pandas as pd + +import lotus +from lotus.models import LM + +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) +lotus.settings.configure(enable_multithreading=True) + +# turn on lotus debug logging +lotus.logger.setLevel("DEBUG") + +data = { + "Course Name": [ + "Probability and Random Processes", + "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", + "Computer Security", + "Cooking", + "Food Sciences", + "Machine Learning", + "Data Structures and Algorithms", + "Quantum Mechanics", + "Organic Chemistry", + "Artificial Intelligence", + "Robotics", + "Thermodynamics", + "Fluid Mechanics", + "Molecular Biology", + "Genetics", + "Astrophysics", + "Neuroscience", + "Microeconomics", + "Macroeconomics", + "Linear Algebra", + "Calculus", + "Statistics", + "Differential Equations", + "Discrete Mathematics", + "Number Theory", + "Graph Theory", + "Topology", + "Complex Analysis", + "Real Analysis", + "Abstract Algebra", + "Numerical Methods", + "Cryptography", + "Network Security", + "Operating Systems", + "Databases", + "Computer Networks", + "Software Engineering", + "Compilers", + "Computer Architecture", + "Parallel Computing", + "Distributed Systems", + "Cloud Computing", + "Big Data Analytics", + "Natural Language Processing", + "Computer Vision", + "Reinforcement Learning", + "Deep Learning", + "Bioinformatics", + "Computational Biology", + "Systems Biology", + "Biochemistry", + "Physical Chemistry", + "Inorganic Chemistry", + "Analytical Chemistry", + "Environmental Chemistry", + "Materials Science", + "Nanotechnology", + "Optics", + "Electromagnetism", + "Nuclear Physics", + "Particle Physics", + "Cosmology", + "Planetary Science", + "Geophysics", + "Atmospheric Science", + "Oceanography", + "Ecology", + "Evolutionary Biology", + "Botany", + "Zoology", + "Microbiology", + "Immunology", + "Virology", + "Pharmacology", + "Physiology", + "Anatomy", + "Neurobiology", + "Cognitive Science", + "Psychology", + "Sociology", + "Anthropology", + "Archaeology", + "Linguistics", + "Philosophy", + "Ethics", + "Logic", + "Political Science", + "International Relations", + "Public Policy", + "Economics", + "Finance", + "Accounting", + "Marketing", + "Management", + "Entrepreneurship", + "Law", + "Criminal Justice", + "Human Rights", + "Environmental Studies", + "Sustainability", + "Urban Planning", + "Architecture", + "Civil Engineering", + "Mechanical Engineering", + "Electrical Engineering", + "Chemical Engineering", + "Aerospace Engineering", + "Biomedical Engineering", + "Environmental Engineering", + ], + "Grade Level": [ + "High School", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "Undergraduate", + "High School", + "Undergraduate", + "High School", + "Undergraduate", + "High School", + "Graduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "High School", + "Undergraduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Undergraduate", + "Graduate", + "Undergraduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "High School", + "High School", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "High School", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Graduate", + "Graduate", + "High School", + "Graduate", + "Graduate", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "High School", + "High School", + "Graduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Undergraduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + "Graduate", + ], +} + +df = pd.DataFrame(data) +start_time = time.time() +df = df.sem_agg("Summarize all {Course Name}", group_by=["Grade Level"]) +end_time = time.time() +print(df._output[0]) +print(f"Total execution time: {end_time - start_time:.2f} seconds") diff --git a/examples/op_examples/top_k.py b/examples/op_examples/top_k.py index 1d1eabd..8654ea1 100644 --- a/examples/op_examples/top_k.py +++ b/examples/op_examples/top_k.py @@ -1,67 +1,27 @@ -import time - import pandas as pd import lotus from lotus.models import LM - -def main(): - lm = LM(model="gpt-4o-mini") - - lotus.settings.configure(lm=lm) - lotus.settings.configure(enable_multithreading=True) - - data = { - "Department": ["Math", "Physics", "Computer Science", "Biology"] * 7, - "Course Name": [ - "Calculus", - "Quantum Mechanics", - "Data Structures", - "Genetics", - "Linear Algebra", - "Thermodynamics", - "Algorithms", - "Ecology", - "Statistics", - "Optics", - "Machine Learning", - "Molecular Biology", - "Number Theory", - "Relativity", - "Computer Networks", - "Evolutionary Biology", - "Differential Equations", - "Particle Physics", - "Operating Systems", - "Biochemistry", - "Complex Analysis", - "Fluid Dynamics", - "Artificial Intelligence", - "Microbiology", - "Topology", - "Astrophysics", - "Cybersecurity", - "Immunology", - ], - } - - df = pd.DataFrame(data) - - for method in ["quick", "heap", "naive"]: - start_time = time.time() - sorted_df, stats = df.sem_topk( - "Which {Course Name} is the most challenging?", - K=2, - method=method, - return_stats=True, - group_by=["Department"], - ) - end_time = time.time() - print(sorted_df) - print(stats) - print(f"Total execution time: {end_time - start_time:.2f} seconds") - - -if __name__ == "__main__": - main() +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) +data = { + "Course Name": [ + "Probability and Random Processes", + "Optimization Methods in Engineering", + "Digital Design and Integrated Circuits", + "Computer Security", + ] +} +df = pd.DataFrame(data) + +for method in ["quick", "heap", "naive"]: + sorted_df, stats = df.sem_topk( + "Which {Course Name} requires the least math?", + K=2, + method=method, + return_stats=True, + ) + print(sorted_df) + print(stats) diff --git a/examples/op_examples/top_k_with_grouping.py b/examples/op_examples/top_k_with_grouping.py new file mode 100644 index 0000000..88c7683 --- /dev/null +++ b/examples/op_examples/top_k_with_grouping.py @@ -0,0 +1,61 @@ +import time + +import pandas as pd + +import lotus +from lotus.models import LM + +lm = LM(model="gpt-4o-mini") + +lotus.settings.configure(lm=lm) +lotus.settings.configure(enable_multithreading=True) + +data = { + "Department": ["Math", "Physics", "Computer Science", "Biology"] * 7, + "Course Name": [ + "Calculus", + "Quantum Mechanics", + "Data Structures", + "Genetics", + "Linear Algebra", + "Thermodynamics", + "Algorithms", + "Ecology", + "Statistics", + "Optics", + "Machine Learning", + "Molecular Biology", + "Number Theory", + "Relativity", + "Computer Networks", + "Evolutionary Biology", + "Differential Equations", + "Particle Physics", + "Operating Systems", + "Biochemistry", + "Complex Analysis", + "Fluid Dynamics", + "Artificial Intelligence", + "Microbiology", + "Topology", + "Astrophysics", + "Cybersecurity", + "Immunology", + ], +} + +df = pd.DataFrame(data) + +for method in ["quick", "heap", "naive"]: + start_time = time.time() + sorted_df, stats = df.sem_topk( + "Which {Course Name} is the most challenging?", + K=2, + method=method, + return_stats=True, + group_by=["Department"], + ) + end_time = time.time() + print(sorted_df) + print(stats) + print(f"Total execution time: {end_time - start_time:.2f} seconds") From 69955bf68704baecbfd745a1cdf266d4b53dbbd3 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 13:16:56 -0700 Subject: [PATCH 5/9] replaced multiprocessing with multithreading --- lotus/sem_ops/sem_agg.py | 18 +++--------------- lotus/sem_ops/sem_topk.py | 14 +++----------- lotus/settings.py | 7 ------- tests/test_settings.py | 17 ++--------------- 4 files changed, 8 insertions(+), 48 deletions(-) diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index d5003d0..3362000 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -1,4 +1,3 @@ -import os from typing import Any import pandas as pd @@ -8,11 +7,6 @@ from lotus.types import LMOutput, SemanticAggOutput -def initializer(settings, log_level): - lotus.logger.setLevel(log_level) - lotus.settings.clone(settings) - - def sem_agg( docs: list[str], model: lotus.models.LM, @@ -151,7 +145,6 @@ def _validate(obj: Any) -> None: @staticmethod def process_group(args): - lotus.logger.debug(f"Processing in PID: {os.getpid()}") group, user_instruction, all_cols, suffix, progress_bar_desc = args return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) @@ -196,15 +189,10 @@ def __call__( if group_by: grouped = self._obj.groupby(group_by) group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] - if lotus.settings.enable_multithreading: - lotus.logger.debug("Using multithreading") - from multiprocessing import Pool + from concurrent.futures import ThreadPoolExecutor - with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: - return pd.concat(pool.map(SemAggDataframe.process_group, group_args)) - else: - lotus.logger.debug("Not using multithreading") - return pd.concat([SemAggDataframe.process_group(group_arg) for group_arg in group_args]) + with ThreadPoolExecutor() as executor: + return pd.concat(list(executor.map(SemAggDataframe.process_group, group_args))) # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 937d3da..2e2e918 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -12,11 +12,6 @@ from lotus.utils import show_safe_mode -def initializer(settings, log_level): - lotus.logger.setLevel(log_level) - lotus.settings.clone(settings) - - def get_match_prompt_binary( doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None ) -> list[dict[str, Any]]: @@ -438,13 +433,10 @@ def __call__( (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) for _, group in grouped ] - if lotus.settings.enable_multithreading: - from multiprocessing import Pool + from concurrent.futures import ThreadPoolExecutor - with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool: - results = pool.map(SemTopKDataframe.process_group, group_args) - else: - results = [SemTopKDataframe.process_group(group_arg) for group_arg in group_args] + with ThreadPoolExecutor() as executor: + results = list(executor.map(SemTopKDataframe.process_group, group_args)) new_df = pd.concat([res[0] for res in results]) stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} diff --git a/lotus/settings.py b/lotus/settings.py index 3fe3250..1841df6 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -17,19 +17,12 @@ class Settings: # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT - # Multithreading settings - enable_multithreading: bool = False - def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") setattr(self, key, value) - def clone(self, other_settings): - for key in vars(other_settings): - setattr(self, key, getattr(other_settings, key)) - def __str__(self): return str(vars(self)) diff --git a/tests/test_settings.py b/tests/test_settings.py index a578b3c..dc6f871 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,6 +1,5 @@ import pytest -from lotus.models import LM from lotus.settings import SerializationFormat, Settings @@ -16,23 +15,11 @@ def test_initial_values(self, settings): assert settings.reranker is None assert settings.enable_cache is False assert settings.serialization_format == SerializationFormat.DEFAULT - assert settings.enable_multithreading is False def test_configure_method(self, settings): - settings.configure(enable_multithreading=True) - assert settings.enable_multithreading is True + settings.configure(enable_cache=True) + assert settings.enable_cache is True def test_invalid_setting(self, settings): with pytest.raises(ValueError, match="Invalid setting: invalid_setting"): settings.configure(invalid_setting=True) - - def test_clone_method(self, settings): - other_settings = Settings() - lm = LM(model="test-model") - other_settings.lm = lm - other_settings.enable_cache = True - - settings.clone(other_settings) - - assert settings.lm == lm - assert settings.enable_cache is True From 324c81a16b6703b5dbdd65d2508000f4b08f3e29 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 15:04:52 -0700 Subject: [PATCH 6/9] fix failing CI test --- lotus/sem_ops/sem_topk.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 2e2e918..c92e81d 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -433,16 +433,18 @@ def __call__( (group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats) for _, group in grouped ] + from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: results = list(executor.map(SemTopKDataframe.process_group, group_args)) - - new_df = pd.concat([res[0] for res in results]) - stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} + if return_stats: + new_df = pd.concat([res[0] for res in results]) + stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} return new_df, stats - return new_df + else: + return pd.concat(results) if method == "quick-sem": assert len(col_li) == 1, "Only one column can be used for embedding optimization" From 259b357a7210ec253c8965ea7ad1f2b326874ec1 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 15:06:07 -0700 Subject: [PATCH 7/9] added max_threads setting --- lotus/settings.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lotus/settings.py b/lotus/settings.py index 1841df6..ce12363 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -17,6 +17,9 @@ class Settings: # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT + # Parallel groupby settings + parallel_groupby_max_threads: int = 8 + def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): From 4b0ec969dd7ea1d13ac6dd0e9cad977de13dbcf4 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 15:06:40 -0700 Subject: [PATCH 8/9] fix failing example --- examples/op_examples/top_k_with_grouping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/op_examples/top_k_with_grouping.py b/examples/op_examples/top_k_with_grouping.py index 88c7683..ab1cf06 100644 --- a/examples/op_examples/top_k_with_grouping.py +++ b/examples/op_examples/top_k_with_grouping.py @@ -8,7 +8,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.settings.configure(enable_multithreading=True) data = { "Department": ["Math", "Physics", "Computer Science", "Biology"] * 7, From 84141b8846fbf7732f2486eaee4a7862b9c66621 Mon Sep 17 00:00:00 2001 From: dhruviyer Date: Thu, 26 Dec 2024 15:07:51 -0700 Subject: [PATCH 9/9] update sem_agg with new max_workers setting --- examples/op_examples/agg_with_grouping.py | 4 ---- lotus/sem_ops/sem_agg.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/op_examples/agg_with_grouping.py b/examples/op_examples/agg_with_grouping.py index bbd34cf..05f925b 100644 --- a/examples/op_examples/agg_with_grouping.py +++ b/examples/op_examples/agg_with_grouping.py @@ -8,10 +8,6 @@ lm = LM(model="gpt-4o-mini") lotus.settings.configure(lm=lm) -lotus.settings.configure(enable_multithreading=True) - -# turn on lotus debug logging -lotus.logger.setLevel("DEBUG") data = { "Course Name": [ diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 3362000..7bb67ab 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -191,7 +191,7 @@ def __call__( group_args = [(group, user_instruction, all_cols, suffix, progress_bar_desc) for _, group in grouped] from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: return pd.concat(list(executor.map(SemAggDataframe.process_group, group_args))) # Sort df by partition_id if it exists