diff --git a/algorithms.py b/algorithms.py index e08ef337f..b2d90bdb3 100644 --- a/algorithms.py +++ b/algorithms.py @@ -708,7 +708,7 @@ def assert_stop_condition(self): elif self.stop_condition == "all_local_mrcas": return any(num_anc > 1 for num_anc in self.S.values()) elif self.stop_condition == "time": - return self.get_num_ancestors() > 1 + return True elif self.stop_condition == "pedigree": return True else: @@ -1143,7 +1143,7 @@ def dtwf_simulate(self, end_time): Simulates the algorithm until all loci have coalesced. """ ret = 0 - while self.ancestors_remain(): + while self.assert_stop_condition(): if self.t + 1 > end_time: ret = 2 # _msprime.EXIT_MAX_TIME break @@ -2307,7 +2307,7 @@ def run_simulate(args): else: from_ts = tskit.load(args.from_ts) tables = from_ts.dump_tables() - if args.stop_condition == "full_pedigree": + if args.stop_condition == "pedigree": end_time = np.max(from_ts.nodes_time) else: end_time = args.end_time diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 624c583b9..d62e1d16f 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -469,7 +469,15 @@ def test_stopping_condition_pedigree(self): tables.dump(ts_path) ts = self.run_script( f"0 --from-ts {ts_path} --model=fixed_pedigree -r 0.1 \ - --stop-condition=full_pedigree" + --stop-condition=pedigree" ) assert ts.num_trees > 1 assert ts.max_root_time == num_generations - 1 + + def test_stopping_condition_dtwf(self): + end_time = 20 + ts = self.run_script( + f"10 --model=dtwf --stop-condition=time --end-time={end_time}" + ) + assert ts.num_trees > 1 + assert ts.max_root_time == end_time