Skip to content

Commit

Permalink
Merge pull request #233 from YosefLab/single-thread-hybrid
Browse files Browse the repository at this point in the history
Single thread hybrid solver
  • Loading branch information
mattjones315 authored Feb 6, 2024
2 parents 5b8a452 + 42cd03f commit 97accd7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 25 deletions.
60 changes: 41 additions & 19 deletions cassiopeia/solver/HybridSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class HybridSolver(CassiopeiaSolver.CassiopeiaSolver):
"inverse": Transforms each probability p by taking 1/p
"square_root_inverse": Transforms each probability by the
the square root of 1/p
progress_bar: Indicates if a progress bar should be shown when
`solve` is called.
"""

def __init__(
Expand All @@ -69,6 +71,7 @@ def __init__(
cell_cutoff: int = None,
threads: int = 1,
prior_transformation: str = "negative_log",
progress_bar: bool = True,
):

if lca_cutoff is None and cell_cutoff is None:
Expand All @@ -89,6 +92,7 @@ def __init__(
self.cell_cutoff = cell_cutoff

self.threads = threads
self.progress_bar = progress_bar

def solve(
self,
Expand Down Expand Up @@ -144,27 +148,45 @@ def solve(
logfile_names = iter([i for i in range(1, len(subproblems) + 1)])

# multi-threaded bottom solver approach
with multiprocessing.Pool(processes=self.threads) as pool:

results = list(
tqdm(
pool.starmap(
self.apply_bottom_solver,
[
(
cassiopeia_tree,
subproblem[0],
subproblem[1],
f"{logfile.split('.log')[0]}-"
if self.threads > 1:
with multiprocessing.Pool(processes=self.threads) as pool:

results = list(
tqdm(
pool.starmap(
self.apply_bottom_solver,
[
(
cassiopeia_tree,
subproblem[0],
subproblem[1],
None if logfile is None else
f"{logfile.split('.log')[0]}-"
f"{next(logfile_names)}.log",
layer,
)
for subproblem in subproblems
],
),
total=len(subproblems),
layer,
)
for subproblem in subproblems
],
),
total=len(subproblems),
disable=not self.progress_bar
)
)
)
# single-threaded bottom solver approach
else:
results = [
self.apply_bottom_solver(
cassiopeia_tree,
subproblem[0],
subproblem[1],
None if logfile is None else
f"{logfile.split('.log')[0]}-"
f"{next(logfile_names)}.log",
layer,
)
for subproblem in tqdm(subproblems,
total=len(subproblems),disable=not self.progress_bar)
]

for result in results:

Expand Down
20 changes: 14 additions & 6 deletions cassiopeia/solver/ILPSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ def solve(
" analysis."
)

# setup logfile config
handler = logging.FileHandler(logfile)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
# configure logger
if logfile is not None:
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
logger.ch.setLevel(logging.getLogger().level)

# add to logger
logger.info("Solving tree with the following parameters.")
logger.info(f"Convergence time limit: {self.convergence_time_limit}")
logger.info(
Expand Down Expand Up @@ -272,7 +276,10 @@ def solve(
cassiopeia_tree.collapse_mutationless_edges(
infer_ancestral_characters=True
)
logger.removeHandler(handler)

if logfile is not None:
logger.removeHandler(file_handler)
file_handler.close()

def infer_potential_graph(
self,
Expand Down Expand Up @@ -534,7 +541,8 @@ def solve_steiner_instance(

# Add user-defined parameters
model.params.MIPGAP = self.mip_gap
model.params.LogFile = logfile
if logfile is not None:
model.params.LogFile = logfile

if self.seed is not None:
model.params.Seed = self.seed
Expand Down

0 comments on commit 97accd7

Please sign in to comment.