diff --git a/tests/test_sgkit.py b/tests/test_sgkit.py index fedd4423..e88f0f2f 100644 --- a/tests/test_sgkit.py +++ b/tests/test_sgkit.py @@ -25,6 +25,7 @@ import sys import tempfile +import msprime import numcodecs import numpy as np import pytest @@ -37,6 +38,43 @@ from tsinfer import formats +def ts_to_dataset(ts, chunks=None, samples=None): + """ + # From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63 + Convert the specified tskit tree sequence into an sgkit dataset. + Note this just generates haploids for now - see the note above + in simulate_ts. + """ + if samples is None: + samples = ts.samples() + tables = ts.dump_tables() + alleles = [] + genotypes = [] + max_alleles = 0 + for var in ts.variants(samples=samples): + alleles.append(var.alleles) + max_alleles = max(max_alleles, len(var.alleles)) + genotypes.append(var.genotypes.astype(np.int8)) + padded_alleles = [ + list(site_alleles) + [""] * (max_alleles - len(site_alleles)) + for site_alleles in alleles + ] + alleles = np.array(padded_alleles).astype("S") + genotypes = np.expand_dims(genotypes, axis=2) + + ds = sgkit.create_genotype_call_dataset( + variant_contig_names=["1"], + variant_contig=np.zeros(len(tables.sites), dtype=int), + variant_position=tables.sites.position.astype(int), + variant_allele=alleles, + sample_id=np.array([f"tsk_{u}" for u in samples]).astype("U"), + call_genotype=genotypes, + ) + if chunks is not None: + ds = ds.chunk(dict(zip(["variants", "samples"], chunks))) + return ds + + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") def test_sgkit_dataset_roundtrip(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) @@ -209,6 +247,20 @@ def test_sgkit_accessors_defaults(tmp_path): ) +def test_simulate_genotype_call_dataset(tmp_path): + # Test that byte alleles are correctly converted to string + ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123) + ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123) + ds = ts_to_dataset(ts) + ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]}) + ds.to_zarr(tmp_path, mode="w") + sd = tsinfer.SgkitSampleData(tmp_path) + ts = tsinfer.infer(sd) + for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes): + assert np.all(v.genotypes == ds_v.values.flatten()) + assert np.all(v.genotypes == sd_v) + + @pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows") class TestSgkitMask: @pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []]) diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 9ba12322..102c5080 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -2440,13 +2440,15 @@ def sites_position(self): @functools.cached_property def sites_alleles(self): - return self.data["variant_allele"][:][self.sites_select] + return self.data["variant_allele"][:][self.sites_select].astype(str) @functools.cached_property def sites_ancestral_allele(self): unknown_alleles = collections.Counter() try: - string_allele = self.data["variant_ancestral_allele"][:][self.sites_select] + string_allele = ( + self.data["variant_ancestral_allele"][:][self.sites_select] + ).astype(str) except KeyError: raise ValueError( "variant_ancestral_allele was not found in the dataset."