Skip to content

Commit

Permalink
Flip sgkit mask polarity
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Jeffery committed Feb 6, 2024
1 parent e529ac0 commit 444e4fa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
36 changes: 18 additions & 18 deletions tests/test_sgkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,25 +435,25 @@ class TestSgkitMask:
def test_sgkit_variant_mask(self, tmp_path, sites):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.zeros_like(ds["variant_position"], dtype=bool)
sites_mask = np.ones_like(ds["variant_position"], dtype=bool)
for i in sites:
sites_mask[i] = True
sites_mask[i] = False
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
samples = tsinfer.SgkitSampleData(zarr_path)
assert samples.num_sites == len(sites)
assert np.array_equal(samples.sites_mask, sites_mask)
assert np.array_equal(samples.sites_mask, ~sites_mask)
assert np.array_equal(
samples.sites_position, ts.tables.sites.position[sites_mask]
samples.sites_position, ts.tables.sites.position[~sites_mask]
)
inf_ts = tsinfer.infer(samples)
assert np.array_equal(
ts.genotype_matrix()[sites_mask], inf_ts.genotype_matrix()
ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix()
)
assert np.array_equal(
ts.tables.sites.position[sites_mask], inf_ts.tables.sites.position
ts.tables.sites.position[~sites_mask], inf_ts.tables.sites.position
)
assert np.array_equal(
ts.tables.sites.ancestral_state[sites_mask],
ts.tables.sites.ancestral_state[~sites_mask],
inf_ts.tables.sites.ancestral_state,
)
# TODO - site metadata needs merging not replacing
Expand All @@ -464,7 +464,7 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
def test_sgkit_variant_bad_mask_length(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int)
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
add_array_to_dataset("variant_mask", sites_mask, zarr_path)
with pytest.raises(
ValueError,
Expand All @@ -475,7 +475,7 @@ def test_sgkit_variant_bad_mask_length(self, tmp_path):
def test_bad_mask_length_at_iterator(self, tmp_path):
ts, zarr_path = make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
sites_mask = np.ones(ds.sizes["variants"] + 1, dtype=int)
sites_mask = np.zeros(ds.sizes["variants"] + 1, dtype=int)
from tsinfer.formats import chunk_iterator

with pytest.raises(
Expand All @@ -488,33 +488,33 @@ def test_bad_mask_length_at_iterator(self, tmp_path):
def test_sgkit_sample_mask(self, tmp_path, sample_list):
ts, zarr_path = make_ts_and_zarr(tmp_path, add_optional=True)
ds = sgkit.load_dataset(zarr_path)
samples_mask = np.zeros_like(ds["sample_id"], dtype=bool)
samples_mask = np.ones_like(ds["sample_id"], dtype=bool)
for i in sample_list:
samples_mask[i] = True
samples_mask[i] = False
add_array_to_dataset("samples_mask", samples_mask, zarr_path)
samples = tsinfer.SgkitSampleData(zarr_path)
assert samples.ploidy == 3
assert samples.num_individuals == len(sample_list)
assert samples.num_samples == len(sample_list) * samples.ploidy
assert np.array_equal(samples.individuals_mask, samples_mask)
assert np.array_equal(samples.samples_mask, np.repeat(samples_mask, 3))
assert np.array_equal(samples.individuals_mask, ~samples_mask)
assert np.array_equal(samples.samples_mask, np.repeat(~samples_mask, 3))
assert np.array_equal(
samples.individuals_time, ds.individuals_time.values[samples_mask]
samples.individuals_time, ds.individuals_time.values[~samples_mask]
)
assert np.array_equal(
samples.individuals_location, ds.individuals_location.values[samples_mask]
samples.individuals_location, ds.individuals_location.values[~samples_mask]
)
assert np.array_equal(
samples.individuals_population,
ds.individuals_population.values[samples_mask],
ds.individuals_population.values[~samples_mask],
)
assert np.array_equal(
samples.individuals_flags, ds.individuals_flags.values[samples_mask]
samples.individuals_flags, ds.individuals_flags.values[~samples_mask]
)
assert np.array_equal(
samples.samples_individual, np.repeat(np.arange(len(sample_list)), 3)
)
expected_gt = ds.call_genotype.values[:, samples_mask, :].reshape(
expected_gt = ds.call_genotype.values[:, ~samples_mask, :].reshape(
samples.num_sites, len(sample_list) * 3
)
assert np.array_equal(samples.sites_genotypes, expected_gt)
Expand Down
7 changes: 4 additions & 3 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,8 @@ def num_sites(self):
@functools.cached_property
def individuals_mask(self):
try:
return self.data["samples_mask"][:].astype(bool)
# We negate the mask as it is much easier in numpy to have True=keep
return ~(self.data["samples_mask"][:].astype(bool))
except KeyError:
return np.full(self._num_unmasked_individuals, True, dtype=bool)

Expand Down Expand Up @@ -2393,8 +2394,8 @@ def sites_mask(self):
raise ValueError(
"Mask must be the same length as the number of unmasked sites"
)

return self.data["variant_mask"].astype(bool)
# We negate the mask as it is much easier in numpy to have True=keep
return ~(self.data["variant_mask"].astype(bool)[:])
except KeyError:
return np.full(self.data["variant_position"].shape, True, dtype=bool)

Expand Down

0 comments on commit 444e4fa

Please sign in to comment.