diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f5ddbed..72263c75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ - Document that the `zarr-vcf` dataset can be either a path or an in-memory zarr group. (feature introduced in {pr}`966`, documented in {pr}`974`, {user}`hyanwong`) +- Allow a contig to be selected by name (`contig_id`), and get the `sequence_length` + of the contig associated with the unmasked sites, if contig lengths are provided + ({pr}`964`, {user}`hyanwong`) + **Fixes** - Properly account for "N" as an unknown ancestral state, and ban "" from being diff --git a/docs/usage.md b/docs/usage.md index ed2c044d..2aa2b4a7 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -107,10 +107,11 @@ onto branches by {meth}`parsimony`. It is also possible to *completely* exclude sites and samples, by specifing a boolean `site_mask` and/or a `sample_mask` when creating the `VariantData` object. Sites or samples with a mask value of `True` will be completely omitted both from inference and the final tree sequence. -This can be useful, for example, if your VCF file contains multiple chromosomes (in which case -`tsinfer` will need to be run separately on each chromosome) or if you wish to select only a subset -of the chromosome for inference (e.g. to reduce computational load). If a `site_mask` is provided, -note that the ancestral alleles array only specifies alleles for the unmasked sites. +This can be useful, for example, if you wish to select only a subset of the chromosome for +inference, e.g. to reduce computational load. You can also use it to subset inference to a +particular contig, if your dataset contains multiple contigs (although this can be more easily +done using the `contig_id` parameter). Note that if a `site_mask` is provided, +the ancestral states array should only specify alleles for the unmasked sites. Below, for instance, is an example of including only sites up to position six in the contig labelled "chr1" in the `example_data.vcz` file: diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index 550ae1a5..8667aa0e 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -38,7 +38,7 @@ from tsinfer import formats -def ts_to_dataset(ts, chunks=None, samples=None): +def ts_to_dataset(ts, chunks=None, samples=None, contigs=None): """ # From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63 Convert the specified tskit tree sequence into an sgkit dataset. @@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None): genotypes = np.expand_dims(genotypes, axis=2) ds = sgkit.create_genotype_call_dataset( - variant_contig_names=["1"], + variant_contig_names=["1"] if contigs is None else contigs, variant_contig=np.zeros(len(tables.sites), dtype=int), variant_position=tables.sites.position.astype(int), variant_allele=alleles, @@ -289,9 +289,83 @@ def test_simulate_genotype_call_dataset(tmp_path): assert np.all(v.genotypes == sd_v) +class TestMultiContig: + def make_two_ts_dataset(self, path): + # split ts into 2; put them as different contigs in the same dataset + ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123) + ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123) + split_at_site = 7 + assert ts.num_sites > 10 + site_break = ts.site(split_at_site).position + ts1 = ts.keep_intervals([(0, site_break)]).rtrim() + ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim() + ds = ts_to_dataset(ts, contigs=["chr1", "chr2"]) + ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]}) + variant_contig = ds["variant_contig"][:] + variant_contig[split_at_site:] = 1 + ds.update({"variant_contig": variant_contig}) + variant_position = ds["variant_position"].values + variant_position[split_at_site:] -= int(site_break) + ds.update({"variant_position": ds["variant_position"]}) + ds.update( + {"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])} + ) + ds.to_zarr(path, mode="w") + return ts1, ts2 + + def test_unmasked(self, tmp_path): + self.make_two_ts_dataset(tmp_path) + with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'): + tsinfer.VariantData(tmp_path, "variant_ancestral_allele") + + def test_mask(self, tmp_path): + ts1, ts2 = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, + "variant_ancestral_allele", + site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]), + ) + assert np.all(ts2.sites_position == vdata.sites_position) + assert vdata.contig_id == "chr2" + assert vdata.sequence_length == ts2.sequence_length + + @pytest.mark.parametrize("contig_id", ["chr1", "chr2"]) + def test_contig_id_param(self, contig_id, tmp_path): + tree_seqs = {} + tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, "variant_ancestral_allele", contig_id=contig_id + ) + assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position) + assert vdata.contig_id == contig_id + assert vdata.sequence_length == tree_seqs[contig_id].sequence_length + + def test_contig_id_param_and_mask(self, tmp_path): + ts1, ts2 = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, + "variant_ancestral_allele", + site_mask=np.array( + (ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False] + ), + contig_id="chr2", + ) + assert np.all(ts2.sites_position[1:] == vdata.sites_position) + assert vdata.contig_id == "chr2" + + @pytest.mark.parametrize("contig_id", ["chr1", "chr2"]) + def test_contig_length(self, contig_id, tmp_path): + tree_seqs = {} + tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path) + vdata = tsinfer.VariantData( + tmp_path, "variant_ancestral_allele", contig_id=contig_id + ) + assert vdata.sequence_length == tree_seqs[contig_id].sequence_length + + @pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows") class TestSgkitMask: - @pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []]) + @pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0]]) def test_sgkit_variant_mask(self, tmp_path, sites): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) @@ -831,3 +905,51 @@ def test_unimplemented_from_tree_sequence(self): # Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924 with pytest.raises(NotImplementedError): tsinfer.VariantData.from_tree_sequence(None) + + def test_multiple_contigs(self, tmp_path): + path = tmp_path / "data.zarr" + ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ds["contig_id"] = ( + ds["contig_id"].dims, + np.array(["c10", "c11"], dtype=" 1: + raise ValueError(f'Multiple contigs named "{contig_id}"') + contig_index = contig_index[0] + site_mask = np.logical_or( + site_mask, self.data["variant_contig"][:] != contig_index + ) + # We negate the mask as it is much easier in numpy to have True=keep self.sites_select = ~site_mask.astype(bool) + if np.sum(self.sites_select) == 0: + raise ValueError("All sites have been masked out. Please unmask some") if sample_mask is None: sample_mask = np.full(self._num_individuals_before_mask, False, dtype=bool) @@ -2413,6 +2436,20 @@ def __init__( " zarr dataset, indicating that all the genotypes are" " unphased" ) + + used_contigs = self.data.variant_contig[:][self.sites_select] + self._contig_index = used_contigs[0] + self._contig_id = self.data.contig_id[self._contig_index] + + if np.any(used_contigs != self._contig_index): + contig_names = ", ".join( + f'"{self.data.contig_id[c]}"' for c in np.unique(used_contigs) + ) + raise ValueError( + f"Sites belong to multiple contigs ({contig_names}). Please restrict " + "sites to one contig e.g. via the `contig_id` argument." + ) + if np.any(np.diff(self.sites_position) <= 0): raise ValueError( "Values taken from the variant_position array are not strictly " @@ -2517,10 +2554,30 @@ def finalised(self): @functools.cached_property def sequence_length(self): + """ + The sequence length of the contig associated with sites used in the dataset. + If the dataset has a "sequence_length" attribute, this is always used, otherwise + if the dataset has recorded contig lengths, the appropriate length is taken, + otherwise the length is calculated from the maximum variant position plus one. + """ try: return self.data.attrs["sequence_length"] except KeyError: - return int(np.max(self.data["variant_position"])) + 1 + if self._contig_index is not None: + try: + if self._contig_index < len(self.data.contig_length): + return self.data.contig_length[self._contig_index] + except AttributeError: + pass # contig_length is optional, fall back to calculating length + return int(np.max(self.data["variant_position"])) + 1 + + @property + def contig_id(self): + """ + The contig ID (name) for all used sites, or None if no + contig IDs were provided + """ + return self._contig_id @property def num_sites(self):