Skip to content

Commit

Permalink
Test stored numpy arrays in batch match
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Feb 14, 2025
1 parent 629419b commit e5fa97f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,46 @@ def test_force_sample_times(self, tmp_path, tmpdir):
ts = tsinfer.match_samples(samples, anc_ts, force_sample_times=True)
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)

def test_array_args(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
sample_mask = np.zeros(ts.num_individuals, dtype=bool)
sample_mask[42] = True
site_mask = np.zeros(ts.num_sites, dtype=bool)
site_mask[42] = True
rng = np.random.RandomState(42)
sites_time = rng.uniform(0, 1, ts.num_sites - 1)
samples = tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
sample_mask=sample_mask,
site_mask=site_mask,
sites_time=sites_time,
)
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
anc_ts = tsinfer.match_ancestors(samples, anc)
anc_ts.dump(tmpdir / "anc.trees")

wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working",
sample_data_path=samples.path,
sample_mask=sample_mask,
site_mask=site_mask,
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "anc.trees",
min_work_per_job=1e6,
)
for i in range(wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working",
partition_index=i,
)
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
ts = tsinfer.match_samples(
samples,
anc_ts,
)
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)


class TestAncestorGeneratorsEquivalant:
"""
Expand Down
1 change: 1 addition & 0 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,7 @@ def __init__(
if sites_time is None:
self._sites_time = np.full(self.num_sites, tskit.UNKNOWN_TIME)
elif isinstance(sites_time, np.ndarray):
print(sites_time.shape, self.num_sites)
if sites_time.shape[0] != self.num_sites:
raise ValueError(
"Sites time array must be the same length as the number of selected"
Expand Down

0 comments on commit e5fa97f

Please sign in to comment.