Skip to content

Commit

Permalink
parallelized sem_topk
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruviyer committed Dec 26, 2024
1 parent 09b625f commit 3f6c8c3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 46 deletions.
6 changes: 2 additions & 4 deletions examples/op_examples/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ def main():
lotus.settings.configure(lm=lm)
lotus.settings.configure(enable_multithreading=True)

start_time = time.time()

# turn on lotus debug logging
lotus.logger.setLevel("DEBUG")

Expand Down Expand Up @@ -71,10 +69,10 @@ def main():
}

df = pd.DataFrame(data)
start_time = time.time()
df = df.sem_agg("Summarize all {Course Name}", group_by=["Grade Level"])
print(df._output[0])

end_time = time.time()
print(df._output[0])
print(f"Total execution time: {end_time - start_time:.2f} seconds")

if __name__ == '__main__':
Expand Down
57 changes: 37 additions & 20 deletions examples/op_examples/top_k.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
import pandas as pd
import time

import lotus
from lotus.models import LM

lm = LM(model="gpt-4o-mini")
def main():
lm = LM(model="gpt-4o-mini")

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
lotus.settings.configure(lm=lm)
lotus.settings.configure(enable_multithreading=True)

data = {
"Department": ["Math", "Physics", "Computer Science", "Biology"] * 7,
"Course Name": [
"Calculus", "Quantum Mechanics", "Data Structures", "Genetics",
"Linear Algebra", "Thermodynamics", "Algorithms", "Ecology",
"Statistics", "Optics", "Machine Learning", "Molecular Biology",
"Number Theory", "Relativity", "Computer Networks", "Evolutionary Biology",
"Differential Equations", "Particle Physics", "Operating Systems", "Biochemistry",
"Complex Analysis", "Fluid Dynamics", "Artificial Intelligence", "Microbiology",
"Topology", "Astrophysics", "Cybersecurity", "Immunology"
]
}

for method in ["quick", "heap", "naive"]:
sorted_df, stats = df.sem_topk(
"Which {Course Name} requires the least math?",
K=2,
method=method,
return_stats=True,
)
print(sorted_df)
print(stats)
df = pd.DataFrame(data)

for method in ["quick", "heap", "naive"]:

start_time = time.time()
sorted_df, stats = df.sem_topk(
"Which {Course Name} is the most challenging?",
K=2,
method=method,
return_stats=True,
group_by=["Department"]
)
end_time = time.time()
print(sorted_df)
print(stats)
print(f"Total execution time: {end_time - start_time:.2f} seconds")

if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __call__(
Returns:
pd.DataFrame: The dataframe with the aggregated answer.
"""
# print all the settings values

if lotus.settings.lm is None:
raise ValueError(
"The language model must be an instance of LM. Please configure a valid language model using lotus.settings.configure()"
Expand Down
52 changes: 31 additions & 21 deletions lotus/sem_ops/sem_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from lotus.types import LMOutput, SemanticTopKOutput
from lotus.utils import show_safe_mode

def initializer(settings, log_level):
lotus.logger.setLevel(log_level)
lotus.settings.clone(settings)

def get_match_prompt_binary(
doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None
Expand Down Expand Up @@ -373,6 +376,19 @@ def __init__(self, pandas_obj: Any) -> None:
def _validate(obj: Any) -> None:
pass

@staticmethod
def process_group(args):
group, user_instruction, K, method, strategy, group_by, cascade_threshold, return_stats = args
return group.sem_topk(
user_instruction,
K,
method=method,
strategy=strategy,
group_by=None,
cascade_threshold=cascade_threshold,
return_stats=return_stats,
)

def __call__(
self,
user_instruction: str,
Expand Down Expand Up @@ -416,30 +432,24 @@ def __call__(
# Separate code path for grouping
if group_by:
grouped = self._obj.groupby(group_by)
new_df = pd.DataFrame()
stats = {}
for name, group in grouped:
res = group.sem_topk(
user_instruction,
K,
method=method,
strategy=strategy,
group_by=None,
cascade_threshold=cascade_threshold,
return_stats=return_stats,
)

if return_stats:
sorted_group, group_stats = res
stats[name] = group_stats
else:
sorted_group = res

new_df = pd.concat([new_df, sorted_group])

group_args = [
(group, user_instruction, K, method, strategy, None, cascade_threshold, return_stats)
for _, group in grouped
]
if lotus.settings.enable_multithreading:
from multiprocessing import Pool

with Pool(initializer=initializer, initargs=(lotus.settings, lotus.logger.getEffectiveLevel())) as pool:
results = pool.map(SemTopKDataframe.process_group, group_args)
else:
results = [SemTopKDataframe.process_group(group_arg) for group_arg in group_args]

new_df = pd.concat([res[0] for res in results])
stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)}
if return_stats:
return new_df, stats
return new_df


if method == "quick-sem":
assert len(col_li) == 1, "Only one column can be used for embedding optimization"
Expand Down

0 comments on commit 3f6c8c3

Please sign in to comment.