diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 0e6b62f7..25ae80dd 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -20,6 +20,8 @@ static PyObject *TsinfLibraryError; static PyObject *TsinfMatchImpossible; +#include "tskit_lwt_interface.h" + typedef struct { PyObject_HEAD ancestor_builder_t *builder; @@ -36,6 +38,17 @@ typedef struct { TreeSequenceBuilder *tree_sequence_builder; } AncestorMatcher; +typedef struct { + PyObject_HEAD + matcher_indexes_t *matcher_indexes; +} MatcherIndexes; + +typedef struct { + PyObject_HEAD + ancestor_matcher2_t *ancestor_matcher; + MatcherIndexes *matcher_indexes; +} AncestorMatcher2; + static void handle_library_error(int err) { @@ -49,6 +62,33 @@ handle_library_error(int err) } } +static FILE * +make_file(PyObject *fileobj, const char *mode) +{ + FILE *ret = NULL; + FILE *file = NULL; + int fileobj_fd, new_fd; + + fileobj_fd = PyObject_AsFileDescriptor(fileobj); + if (fileobj_fd == -1) { + goto out; + } + new_fd = dup(fileobj_fd); + if (new_fd == -1) { + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + file = fdopen(new_fd, mode); + if (file == NULL) { + (void) close(new_fd); + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + ret = file; +out: + return ret; +} + static int uint64_PyArray_converter(PyObject *in, PyObject **out) { @@ -1579,6 +1619,407 @@ static PyTypeObject AncestorMatcherType = { (initproc)AncestorMatcher_init, /* tp_init */ }; + +/*=================================================================== + * MatcherIndexes + *=================================================================== + */ + +static void +MatcherIndexes_dealloc(MatcherIndexes* self) +{ + if (self->matcher_indexes != NULL) { + /* matcher_indexes_free(self->matcher_indexes); */ + PyMem_Free(self->matcher_indexes); + self->matcher_indexes = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + LightweightTableCollection *tables; + static char *kwlist[] = {"tables", NULL}; + + self->matcher_indexes = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &LightweightTableCollectionType, &tables)) { + goto out; + } + if (LightweightTableCollection_check_state(tables) != 0) { + goto out; + } + + self->matcher_indexes = PyMem_Calloc(1, sizeof(*self->matcher_indexes)); + if (self->matcher_indexes == NULL) { + PyErr_NoMemory(); + goto out; + } + err = matcher_indexes_alloc(self->matcher_indexes, tables->tables, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static int +MatcherIndexes_check_state(MatcherIndexes *self) +{ + int ret = 0; + if (self->matcher_indexes == NULL) { + PyErr_SetString(PyExc_SystemError, "MatcherIndexes not initialised"); + ret = -1; + } + return ret; +} + + +static PyObject * +MatcherIndexes_print_state(MatcherIndexes *self, PyObject *args) +{ + PyObject *ret = NULL; + PyObject *fileobj; + FILE *file = NULL; + + if (MatcherIndexes_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O", &fileobj)) { + goto out; + } + file = make_file(fileobj, "w"); + if (file == NULL) { + goto out; + } + matcher_indexes_print_state(self->matcher_indexes, file); + ret = Py_BuildValue(""); +out: + if (file != NULL) { + (void) fclose(file); + } + return ret; +} + + +static PyMethodDef MatcherIndexes_methods[] = { + {"print_state", (PyCFunction) MatcherIndexes_print_state, + METH_VARARGS, "Low-level debug method"}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject MatcherIndexesType = { + // clang-format off + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "_tsinfer.MatcherIndexes", + .tp_basicsize = sizeof(MatcherIndexes), + .tp_dealloc = (destructor) MatcherIndexes_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_doc = "MatcherIndexes objects", + .tp_methods = MatcherIndexes_methods, + .tp_init = (initproc) MatcherIndexes_init, + .tp_new = PyType_GenericNew, + // clang-format on +}; + + +/*=================================================================== + * AncestorMatcher2 + *=================================================================== + */ + +static int +AncestorMatcher2_check_state(AncestorMatcher2 *self) +{ + int ret = 0; + if (self->ancestor_matcher == NULL) { + PyErr_SetString(PyExc_SystemError, "AncestorMatcher2 not initialised"); + ret = -1; + } + return ret; +} + +static void +AncestorMatcher2_dealloc(AncestorMatcher2* self) +{ + if (self->ancestor_matcher != NULL) { + ancestor_matcher2_free(self->ancestor_matcher); + PyMem_Free(self->ancestor_matcher); + self->ancestor_matcher = NULL; + } + Py_XDECREF(self->matcher_indexes); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +AncestorMatcher2_init(AncestorMatcher2 *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + int extended_checks = 0; + static char *kwlist[] = {"matcher_indexes", "recombination", + "mismatch", "precision", "extended_checks", NULL}; + MatcherIndexes *matcher_indexes = NULL; + PyObject *recombination = NULL; + PyObject *mismatch = NULL; + PyArrayObject *recombination_array = NULL; + PyArrayObject *mismatch_array = NULL; + npy_intp *shape; + unsigned int precision = 22; + int flags = 0; + + self->ancestor_matcher = NULL; + self->matcher_indexes = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|Ii", kwlist, + &MatcherIndexesType, &matcher_indexes, + &recombination, &mismatch, &precision, + &extended_checks)) { + goto out; + } + self->matcher_indexes = matcher_indexes; + Py_INCREF(self->matcher_indexes); + if (MatcherIndexes_check_state(self->matcher_indexes) != 0) { + goto out; + } + + recombination_array = (PyArrayObject *) PyArray_FromAny(recombination, + PyArray_DescrFromType(NPY_FLOAT64), 1, 1, + NPY_ARRAY_IN_ARRAY, NULL); + if (recombination_array == NULL) { + goto out; + } + shape = PyArray_DIMS(recombination_array); + if (shape[0] != (npy_intp) matcher_indexes->matcher_indexes->num_sites) { + PyErr_SetString(PyExc_ValueError, + "Size of recombination array must be num_sites"); + goto out; + } + mismatch_array = (PyArrayObject *) PyArray_FromAny(mismatch, + PyArray_DescrFromType(NPY_FLOAT64), 1, 1, + NPY_ARRAY_IN_ARRAY, NULL); + if (mismatch_array == NULL) { + goto out; + } + shape = PyArray_DIMS(mismatch_array); + if (shape[0] != (npy_intp) matcher_indexes->matcher_indexes->num_sites) { + PyErr_SetString(PyExc_ValueError, "Size of mismatch array must be num_sites"); + goto out; + } + + self->ancestor_matcher = PyMem_Malloc(sizeof(ancestor_matcher2_t)); + if (self->ancestor_matcher == NULL) { + PyErr_NoMemory(); + goto out; + } + if (extended_checks) { + flags = TSI_EXTENDED_CHECKS; + } + err = ancestor_matcher2_alloc(self->ancestor_matcher, + self->matcher_indexes->matcher_indexes, + PyArray_DATA(recombination_array), + PyArray_DATA(mismatch_array), + precision, flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(recombination_array); + Py_XDECREF(mismatch_array); + return ret; +} + +static PyObject * +AncestorMatcher2_find_path(AncestorMatcher2 *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + static char *kwlist[] = {"haplotype", "start", "end", NULL}; + PyObject *haplotype = NULL; + PyArrayObject *haplotype_array = NULL; + npy_intp *shape; + size_t num_edges; + int start, end; + PyArrayObject *left = NULL; + PyArrayObject *right = NULL; + PyArrayObject *parent = NULL; + PyArrayObject *match = NULL; + npy_intp dims[1]; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "Oii", kwlist, + &haplotype, &start, &end)) { + goto out; + } + haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_INT8, + NPY_ARRAY_IN_ARRAY); + if (haplotype_array == NULL) { + goto out; + } + if (PyArray_NDIM(haplotype_array) != 1) { + PyErr_SetString(PyExc_ValueError, "Dim != 1"); + goto out; + } + shape = PyArray_DIMS(haplotype_array); + if (shape[0] != (npy_intp) self->ancestor_matcher->num_sites) { + PyErr_SetString(PyExc_ValueError, "Incorrect size for input haplotype."); + goto out; + } + + dims[0] = self->ancestor_matcher->num_sites; + left = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_UINT32); + right = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_UINT32); + parent = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_INT32); + match = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_INT8); + if (left == NULL || right == NULL || parent == NULL || match == NULL) { + goto out; + } + + Py_BEGIN_ALLOW_THREADS + err = ancestor_matcher2_find_path(self->ancestor_matcher, + (tsk_id_t) start, (tsk_id_t) end, (allele_t *) PyArray_DATA(haplotype_array), + (allele_t *) PyArray_DATA(match), + &num_edges, PyArray_DATA(left), PyArray_DATA(right), PyArray_DATA(parent)); + Py_END_ALLOW_THREADS + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("(kOOOO)", (unsigned long) num_edges, left, right, parent, match); + if (ret == NULL) { + goto out; + } + left = NULL; + right = NULL; + parent = NULL; + match = NULL; +out: + Py_XDECREF(match); + Py_XDECREF(left); + Py_XDECREF(right); + Py_XDECREF(parent); + return ret; +} + +static PyObject * +AncestorMatcher2_get_traceback(AncestorMatcher2 *self, PyObject *args) +{ + PyObject *ret = NULL; + unsigned long site; + node_state_list_t *list; + PyObject *dict = NULL; + PyObject *key = NULL; + PyObject *value = NULL; + int j; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "k", &site)) { + goto out; + } + if (site >= self->ancestor_matcher->num_sites) { + PyErr_SetString(PyExc_ValueError, "site out of range"); + goto out; + } + dict = PyDict_New(); + if (dict == NULL) { + goto out; + } + list = &self->ancestor_matcher->traceback[site]; + for (j = 0; j < list->size; j++) { + key = Py_BuildValue("k", (unsigned long) list->node[j]); + value = Py_BuildValue("i", (int) list->recombination_required[j]); + if (key == NULL || value == NULL) { + goto out; + } + if (PyDict_SetItem(dict, key, value) != 0) { + goto out; + } + Py_DECREF(key); + key = NULL; + Py_DECREF(value); + value = NULL; + } + ret = dict; + dict = NULL; +out: + Py_XDECREF(key); + Py_XDECREF(value); + Py_XDECREF(dict); + return ret; +} + +static PyObject * +AncestorMatcher2_get_mean_traceback_size(AncestorMatcher2 *self, void *closure) +{ + PyObject *ret = NULL; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("d", ancestor_matcher2_get_mean_traceback_size( + self->ancestor_matcher)); +out: + return ret; +} + +static PyObject * +AncestorMatcher2_get_total_memory(AncestorMatcher2 *self, void *closure) +{ + PyObject *ret = NULL; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("k", (unsigned long) + ancestor_matcher2_get_total_memory(self->ancestor_matcher)); +out: + return ret; +} + + +static PyGetSetDef AncestorMatcher2_getsetters[] = { + {"mean_traceback_size", (getter) AncestorMatcher2_get_mean_traceback_size, + NULL, "The mean size of the traceback per site."}, + {"total_memory", (getter) AncestorMatcher2_get_total_memory, + NULL, "The total amount of memory used by this matcher."}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef AncestorMatcher2_methods[] = { + {"find_path", (PyCFunction) AncestorMatcher2_find_path, + METH_VARARGS|METH_KEYWORDS, + "Returns a best match path for the specified haplotype through the ancestors."}, + {"get_traceback", (PyCFunction) AncestorMatcher2_get_traceback, + METH_VARARGS, "Returns the traceback likelihood dictionary at the specified site."}, + {NULL} /* Sentinel */ +}; + + +static PyTypeObject AncestorMatcher2Type = { + // clang-format off + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "_tsinfer.AncestorMatcher2", + .tp_basicsize = sizeof(AncestorMatcher2), + .tp_dealloc = (destructor) AncestorMatcher2_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_doc = "AncestorMatcher2 objects", + .tp_methods = AncestorMatcher2_methods, + .tp_getset = AncestorMatcher2_getsetters, + .tp_init = (initproc) AncestorMatcher2_init, + .tp_new = PyType_GenericNew, + // clang-format on +}; + /*=================================================================== * Module level code. *=================================================================== @@ -1628,6 +2069,8 @@ init_tsinfer(void) /* Initialise numpy */ import_array(); + register_lwt_class(module); + /* AncestorBuilder type */ AncestorBuilderType.tp_new = PyType_GenericNew; if (PyType_Ready(&AncestorBuilderType) < 0) { @@ -1635,6 +2078,7 @@ init_tsinfer(void) } Py_INCREF(&AncestorBuilderType); PyModule_AddObject(module, "AncestorBuilder", (PyObject *) &AncestorBuilderType); + /* AncestorMatcher type */ AncestorMatcherType.tp_new = PyType_GenericNew; if (PyType_Ready(&AncestorMatcherType) < 0) { @@ -1642,6 +2086,7 @@ init_tsinfer(void) } Py_INCREF(&AncestorMatcherType); PyModule_AddObject(module, "AncestorMatcher", (PyObject *) &AncestorMatcherType); + /* TreeSequenceBuilder type */ TreeSequenceBuilderType.tp_new = PyType_GenericNew; if (PyType_Ready(&TreeSequenceBuilderType) < 0) { @@ -1650,6 +2095,22 @@ init_tsinfer(void) Py_INCREF(&TreeSequenceBuilderType); PyModule_AddObject(module, "TreeSequenceBuilder", (PyObject *) &TreeSequenceBuilderType); + /* MatcherIndexes type */ + MatcherIndexesType.tp_new = PyType_GenericNew; + if (PyType_Ready(&MatcherIndexesType) < 0) { + INITERROR; + } + Py_INCREF(&MatcherIndexesType); + PyModule_AddObject(module, "MatcherIndexes", (PyObject *) &MatcherIndexesType); + + /* AncestorMatcher2 type */ + AncestorMatcher2Type.tp_new = PyType_GenericNew; + if (PyType_Ready(&AncestorMatcher2Type) < 0) { + INITERROR; + } + Py_INCREF(&AncestorMatcher2Type); + PyModule_AddObject(module, "AncestorMatcher2", (PyObject *) &AncestorMatcher2Type); + TsinfLibraryError = PyErr_NewException("_tsinfer.LibraryError", NULL, NULL); Py_INCREF(TsinfLibraryError); PyModule_AddObject(module, "LibraryError", TsinfLibraryError); diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index ee992264..338c3401 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1,5 +1,5 @@ /* -** Copyright (C) 2018-2020 University of Oxford +** Copyright (C) 2018-2023 University of Oxford ** ** This file is part of tsinfer. ** @@ -742,6 +742,7 @@ insert_edge(edge_t edge, tsk_id_t *restrict parent, tsk_id_t *restrict left_chil { const tsk_id_t p = edge.parent; const tsk_id_t c = edge.child; + assert(right_child != NULL); const tsk_id_t u = right_child[p]; parent[c] = p; @@ -1011,3 +1012,1096 @@ ancestor_matcher_get_total_memory(ancestor_matcher_t *self) return total; } + +/* New implementation */ + +int +matcher_indexes_print_state(const matcher_indexes_t *self, FILE *out) +{ + size_t j; + mutation_list_node_t *u; + + fprintf(out, "Matcher indexes state\n"); + fprintf(out, "flags = %d\n", (int) self->flags); + fprintf(out, "num_sites = %d\n", (int) self->num_sites); + fprintf(out, "num_nodes = %d\n", (int) self->num_nodes); + fprintf(out, "num_edges = %d\n", (int) self->num_edges); + + fprintf(out, "Mutations = \n"); + fprintf(out, "site\t(node, derived_state),...\n"); + for (j = 0; j < self->num_sites; j++) { + if (self->sites.mutations[j] != NULL) { + fprintf(out, "%d\t", (int) j); + for (u = self->sites.mutations[j]; u != NULL; u = u->next) { + fprintf(out, "(%d, %d) ", u->node, u->derived_state); + } + fprintf(out, "\n"); + } + } + return 0; +} + +static int +matcher_indexes_copy_edge_indexes( + matcher_indexes_t *self, const tsk_table_collection_t *tables) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t k; + edge_t e; + const tsk_id_t *restrict I = tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = tables->indexes.edge_removal_order; + const double *restrict edges_right = tables->edges.right; + const double *restrict edges_left = tables->edges.left; + const tsk_id_t *restrict edges_child = tables->edges.child; + const tsk_id_t *restrict edges_parent = tables->edges.parent; + + for (j = 0; j < self->num_edges; j++) { + k = I[j]; + /* TODO check that the edges can be cast */ + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->left_index_edges[j] = e; + + k = O[j]; + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->right_index_edges[j] = e; + } + return ret; +} + +static int WARN_UNUSED +matcher_indexes_add_mutation( + matcher_indexes_t *self, tsk_id_t site, tsk_id_t node, allele_t derived_state) +{ + int ret = 0; + mutation_list_node_t *list_node, *tail; + + list_node = tsk_blkalloc_get(&self->allocator, sizeof(mutation_list_node_t)); + if (list_node == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + list_node->node = node; + list_node->derived_state = derived_state; + list_node->next = NULL; + if (self->sites.mutations[site] == NULL) { + self->sites.mutations[site] = list_node; + } else { + tail = self->sites.mutations[site]; + while (tail->next != NULL) { + tail = tail->next; + } + tail->next = list_node; + } + self->num_mutations++; +out: + return ret; +} + +static int +matcher_indexes_copy_mutation_data( + matcher_indexes_t *self, const tsk_table_collection_t *tables) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t site, last_site; + const double *restrict sites_position = tables->sites.position; + const tsk_id_t *restrict mutations_site = tables->mutations.site; + const tsk_id_t *restrict mutations_node = tables->mutations.node; + const tsk_size_t total_mutations = tables->mutations.num_rows; + coordinate_t *restrict converted_position = self->sites.position; + + for (j = 0; j < self->num_sites; j++) { + /* TODO check for under/overflow */ + converted_position[j] = (coordinate_t) sites_position[j]; + } + converted_position[j] = (coordinate_t) tables->sequence_length; + + last_site = -1; + for (j = 0; j < total_mutations; j++) { + site = mutations_site[j]; + if (site == last_site) { + ret = TSI_ERR_GENERIC; + goto out; + } + + self->sites.mutations[site] = NULL; + /* FIXME */ + self->sites.num_alleles[site] = 2; + ret = matcher_indexes_add_mutation(self, site, mutations_node[j], 1); + if (ret != 0) { + goto out; + } + last_site = site; + } + +out: + return ret; +} + +int +matcher_indexes_alloc( + matcher_indexes_t *self, const tsk_table_collection_t *tables, tsk_flags_t flags) +{ + int ret = 0; + + self->flags = flags; + self->num_edges = tables->edges.num_rows; + self->num_nodes = tables->nodes.num_rows; + self->num_sites = tables->sites.num_rows; + /* FIXME this is used below by the code that adds in mutations in the linked + * list, so *don't* set from the tables */ + self->num_mutations = 0; + + self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); + self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); + self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); + self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); + self->sites.position + = malloc((self->num_sites + 1) * sizeof(*self->sites.mutations)); + if (self->left_index_edges == NULL || self->right_index_edges == NULL + || self->sites.mutations == NULL || self->sites.position == NULL + || self->sites.num_alleles == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + ret = tsk_blkalloc_init(&self->allocator, 65536); + if (ret != 0) { + goto out; + } + + ret = matcher_indexes_copy_edge_indexes(self, tables); + if (ret != 0) { + goto out; + } + ret = matcher_indexes_copy_mutation_data(self, tables); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +int +matcher_indexes_free(matcher_indexes_t *self) +{ + tsk_safe_free(self->left_index_edges); + tsk_safe_free(self->right_index_edges); + tsk_safe_free(self->sites.mutations); + tsk_safe_free(self->sites.position); + tsk_safe_free(self->sites.num_alleles); + tsk_blkalloc_free(&self->allocator); + return 0; +} + +static void +ancestor_matcher2_check_state(ancestor_matcher2_t *self) +{ + int num_likelihoods; + int j; + tsk_id_t u; + + /* Check the properties of the likelihood map */ + for (j = 0; j < self->num_likelihood_nodes; j++) { + u = self->likelihood_nodes[j]; + assert(self->likelihood[u] >= 0 && self->likelihood[u] <= 2); + } + /* Make sure that there are no other non null likelihoods in the array */ + num_likelihoods = 0; + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + if (self->likelihood[u] >= 0) { + num_likelihoods++; + } + if (is_nonzero_root(u, self->parent, self->left_child)) { + assert(self->likelihood[u] == NONZERO_ROOT_LIKELIHOOD); + } + assert(self->allelic_state[u] == TSK_NULL); + } + assert(num_likelihoods == self->num_likelihood_nodes); +} + +int +ancestor_matcher2_print_state(ancestor_matcher2_t *self, FILE *out) +{ + int j, k; + tsk_id_t u; + + fprintf(out, "Ancestor matcher state\n"); + fprintf(out, "site\trecomb_rate\tmut_rate\n"); + for (j = 0; j < (int) self->num_sites; j++) { + fprintf( + out, "%d\t%f\t%f\n", j, self->recombination_rate[j], self->mismatch_rate[j]); + } + fprintf(out, "tree = \n"); + fprintf(out, "id\tparent\tlchild\trchild\tlsib\trsib\tlikelihood\n"); + for (j = 0; j < (int) self->num_nodes; j++) { + fprintf(out, "%d\t%d\t%d\t%d\t%d\t%d\t%f\n", (int) j, self->parent[j], + self->left_child[j], self->right_child[j], self->left_sib[j], + self->right_sib[j], self->likelihood[j]); + } + fprintf(out, "likelihood nodes\n"); + /* Check the properties of the likelihood map */ + for (j = 0; j < self->num_likelihood_nodes; j++) { + u = self->likelihood_nodes[j]; + fprintf(out, "\t%d -> %f\n", u, self->likelihood[u]); + } + fprintf(out, "traceback\n"); + for (j = 0; j < (int) self->num_sites; j++) { + fprintf(out, "\t%d:%d (%d)\t", (int) j, self->max_likelihood_node[j], + self->traceback[j].size); + for (k = 0; k < self->traceback[j].size; k++) { + fprintf(out, "(%d, %d)", self->traceback[j].node[k], + self->traceback[j].recombination_required[k]); + } + fprintf(out, "\n"); + } + tsk_blkalloc_print_state(&self->traceback_allocator, out); + + /* ancestor_matcher2_check_state(self); */ + return 0; +} + +int +ancestor_matcher2_alloc(ancestor_matcher2_t *self, + const matcher_indexes_t *matcher_indexes, double *recombination_rate, + double *mismatch_rate, unsigned int precision, int flags) +{ + int ret = 0; + + memset(self, 0, sizeof(ancestor_matcher2_t)); + /* All allocs for arrays related to nodes are done in expand_nodes */ + self->flags = flags; + self->precision = precision; + self->matcher_indexes = matcher_indexes; + self->num_sites = matcher_indexes->num_sites; + self->num_nodes = matcher_indexes->num_nodes; + self->recombination_rate + = malloc(self->num_sites * sizeof(*self->recombination_rate)); + self->mismatch_rate = malloc(self->num_sites * sizeof(*self->mismatch_rate)); + self->traceback = calloc(self->num_sites, sizeof(node_state_list_t)); + self->max_likelihood_node = malloc(self->num_sites * sizeof(tsk_id_t)); + + self->parent = malloc(self->num_nodes * sizeof(*self->parent)); + self->left_child = malloc(self->num_nodes * sizeof(*self->left_child)); + self->right_child = malloc(self->num_nodes * sizeof(*self->right_child)); + self->left_sib = malloc(self->num_nodes * sizeof(*self->left_sib)); + self->right_sib = malloc(self->num_nodes * sizeof(*self->right_sib)); + self->likelihood = malloc(self->num_nodes * sizeof(*self->likelihood)); + self->allelic_state = malloc(self->num_nodes * sizeof(*self->allelic_state)); + self->recombination_required + = malloc(self->num_nodes * sizeof(*self->recombination_required)); + self->likelihood_cache = malloc(self->num_nodes * sizeof(*self->likelihood_cache)); + self->likelihood_nodes = malloc(self->num_nodes * sizeof(*self->likelihood_nodes)); + self->likelihood_nodes_tmp + = malloc(self->num_nodes * sizeof(*self->likelihood_nodes_tmp)); + + if (self->recombination_rate == NULL || self->mismatch_rate == NULL + || self->traceback == NULL || self->max_likelihood_node == NULL + || self->parent == NULL || self->left_child == NULL || self->right_child == NULL + || self->left_sib == NULL || self->right_sib == NULL + || self->recombination_required == NULL || self->likelihood == NULL + || self->likelihood_cache == NULL || self->likelihood_nodes == NULL + || self->likelihood_nodes_tmp == NULL || self->allelic_state == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + /* Alloc in 64MiB blocks. */ + self->traceback_block_size = 64 * 1024 * 1024; + /* If the traceback allocator is using more than 2GiB of RAM free it, so + * that other threads can use the memory */ + self->traceback_realloc_size = 2L * 1024L * 1024L * 1024L; + ret = tsk_blkalloc_init(&self->traceback_allocator, self->traceback_block_size); + if (ret != 0) { + goto out; + } + memcpy(self->recombination_rate, recombination_rate, + self->num_sites * sizeof(*self->recombination_rate)); + memcpy(self->mismatch_rate, mismatch_rate, + self->num_sites * sizeof(*self->mismatch_rate)); +out: + return ret; +} + +int +ancestor_matcher2_free(ancestor_matcher2_t *self) +{ + tsi_safe_free(self->recombination_rate); + tsi_safe_free(self->mismatch_rate); + tsi_safe_free(self->parent); + tsi_safe_free(self->left_child); + tsi_safe_free(self->right_child); + tsi_safe_free(self->left_sib); + tsi_safe_free(self->right_sib); + tsi_safe_free(self->recombination_required); + tsi_safe_free(self->likelihood); + tsi_safe_free(self->likelihood_cache); + tsi_safe_free(self->likelihood_nodes); + tsi_safe_free(self->likelihood_nodes_tmp); + tsi_safe_free(self->allelic_state); + tsi_safe_free(self->max_likelihood_node); + tsi_safe_free(self->traceback); + tsk_blkalloc_free(&self->traceback_allocator); + return 0; +} + +static int +ancestor_matcher2_delete_likelihood( + ancestor_matcher2_t *self, const tsk_id_t node, double *restrict L) +{ + /* Remove the specified node from the list of nodes */ + int j, k; + tsk_id_t *restrict L_nodes = self->likelihood_nodes; + + k = 0; + for (j = 0; j < self->num_likelihood_nodes; j++) { + L_nodes[k] = L_nodes[j]; + if (L_nodes[j] != node) { + k++; + } + } + assert(self->num_likelihood_nodes == k + 1); + self->num_likelihood_nodes = k; + L[node] = NULL_LIKELIHOOD; + return 0; +} + +/* Store the recombination_required state in the traceback */ +static int WARN_UNUSED +ancestor_matcher2_store_traceback(ancestor_matcher2_t *self, const tsk_id_t site_id) +{ + int ret = 0; + tsk_id_t u; + int j; + int8_t *restrict list_R; + tsk_id_t *restrict list_node; + node_state_list_t *restrict list; + node_state_list_t *restrict T = self->traceback; + const tsk_id_t *restrict nodes = self->likelihood_nodes; + const int8_t *restrict R = self->recombination_required; + const int num_likelihood_nodes = self->num_likelihood_nodes; + bool match; + + /* Check to see if the previous site has the same recombination_required. If so, + * we can reuse the same list. */ + match = false; + if (site_id > 0) { + list = &T[site_id - 1]; + if (list->size == num_likelihood_nodes) { + list_node = list->node; + list_R = list->recombination_required; + match = true; + for (j = 0; j < num_likelihood_nodes; j++) { + if (list_node[j] != nodes[j] || list_R[j] != R[nodes[j]]) { + match = false; + break; + } + } + } + } + + if (match) { + T[site_id].size = T[site_id - 1].size; + T[site_id].node = T[site_id - 1].node; + T[site_id].recombination_required = T[site_id - 1].recombination_required; + } else { + list_node = tsk_blkalloc_get(&self->traceback_allocator, + (size_t) num_likelihood_nodes * sizeof(tsk_id_t)); + list_R = tsk_blkalloc_get( + &self->traceback_allocator, (size_t) num_likelihood_nodes * sizeof(int8_t)); + if (list_node == NULL || list_R == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + T[site_id].node = list_node; + T[site_id].recombination_required = list_R; + T[site_id].size = num_likelihood_nodes; + for (j = 0; j < num_likelihood_nodes; j++) { + u = nodes[j]; + list_node[j] = u; + list_R[j] = R[u]; + } + } + self->total_traceback_size += (size_t) num_likelihood_nodes; +out: + return ret; +} + +/* Sets the specified allelic state array to reflect the mutations at the + * specified site. */ +static inline void +ancestor_matcher2_set_allelic_state( + ancestor_matcher2_t *self, const tsk_id_t site, allele_t *restrict allelic_state) +{ + mutation_list_node_t *mutation; + + /* FIXME assuming that 0 is always the ancestral state */ + allelic_state[0] = 0; + + for (mutation = self->matcher_indexes->sites.mutations[site]; mutation != NULL; + mutation = mutation->next) { + allelic_state[mutation->node] = mutation->derived_state; + } +} + +/* Resets the allelic state at this site to NULL. */ +static inline void +ancestor_matcher2_unset_allelic_state( + ancestor_matcher2_t *self, const tsk_id_t site, allele_t *restrict allelic_state) +{ + mutation_list_node_t *mutation; + + allelic_state[0] = NULL_NODE; + for (mutation = self->matcher_indexes->sites.mutations[site]; mutation != NULL; + mutation = mutation->next) { + allelic_state[mutation->node] = TSK_NULL; + } +} + +static int WARN_UNUSED +ancestor_matcher2_update_site_likelihood_values(ancestor_matcher2_t *self, + const tsk_id_t site, const allele_t state, const tsk_id_t *restrict parent, + double *restrict L) +{ + int ret = 0; + const int num_likelihood_nodes = self->num_likelihood_nodes; + const tsk_id_t *restrict L_nodes = self->likelihood_nodes; + allele_t *restrict allelic_state = self->allelic_state; + int8_t *restrict recombination_required = self->recombination_required; + int j; + tsk_id_t u, v, max_L_node; + double max_L, p_last, p_no_recomb, p_recomb, p_t, p_e; + const double rho = self->recombination_rate[site]; + const double mu = self->mismatch_rate[site]; + const double n = (double) self->matcher_indexes->num_nodes; + const double num_alleles = (double) self->matcher_indexes->sites.num_alleles[site]; + + if (state >= num_alleles) { + ret = TSI_ERR_BAD_HAPLOTYPE_ALLELE; + goto out; + } + + ancestor_matcher2_set_allelic_state(self, site, allelic_state); + + max_L = -1; + max_L_node = NULL_NODE; + assert(num_likelihood_nodes > 0); + /* printf("likelihoods for node=%d, n=%d\n", mutation_node, + * self->num_likelihood_nodes); */ + for (j = 0; j < num_likelihood_nodes; j++) { + u = L_nodes[j]; + /* Get the allelic state at u. */ + /* TODO we can cache the states here to save some time. One nice way we could + * do the caching is to save the L_node index in the allelic_state array as + * we traverse upwards, and then keep an array of the L_node states which + * we then look up. This would save a second upward traversal to mark the + * array after we've found the state value. */ + v = u; + while (allelic_state[v] == TSK_NULL) { + v = parent[v]; + } + p_last = L[u]; + p_no_recomb = p_last * (1 - rho + rho / n); + p_recomb = rho / n; + recombination_required[u] = false; + if (p_no_recomb > p_recomb) { + p_t = p_no_recomb; + } else { + p_t = p_recomb; + recombination_required[u] = true; + } + p_e = mu; + if (allelic_state[v] == state || state == TSK_MISSING_DATA) { + p_e = 1 - (num_alleles - 1) * mu; + } + L[u] = p_t * p_e; + + if (L[u] > max_L) { + max_L = L[u]; + max_L_node = u; + } + } + /* ancestor_matcher2_print_state(self, stdout); */ + if (max_L <= 0) { + if (mu <= 0 || mu >= 1) { + ret = TSI_ERR_MATCH_IMPOSSIBLE_EXTREME_MUTATION_PROBA; + goto out; + } + if (rho == 0) { + ret = TSI_ERR_MATCH_IMPOSSIBLE_ZERO_RECOMB_PRECISION; + goto out; + } + ret = TSI_ERR_MATCH_IMPOSSIBLE; + goto out; + } + assert(max_L_node != NULL_NODE); + self->max_likelihood_node[site] = max_L_node; + + /* Renormalise the likelihoods. */ + for (j = 0; j < num_likelihood_nodes; j++) { + u = L_nodes[j]; + L[u] = tsk_round(L[u] / max_L, self->precision); + } + ancestor_matcher2_unset_allelic_state(self, site, allelic_state); +out: + return ret; +} + +static int WARN_UNUSED +ancestor_matcher2_coalesce_likelihoods(ancestor_matcher2_t *self, + const tsk_id_t *restrict parent, double *restrict L, double *restrict L_cache) +{ + int ret = 0; + double L_p; + tsk_id_t u, v, p; + tsk_id_t *restrict cached_paths = self->likelihood_nodes_tmp; + const int old_num_likelihood_nodes = self->num_likelihood_nodes; + tsk_id_t *restrict L_nodes = self->likelihood_nodes; + int j, num_cached_paths, num_likelihood_nodes; + + num_cached_paths = 0; + num_likelihood_nodes = 0; + assert(old_num_likelihood_nodes > 0); + for (j = 0; j < old_num_likelihood_nodes; j++) { + u = L_nodes[j]; + p = parent[u]; + if (p != NULL_NODE) { + cached_paths[num_cached_paths] = p; + num_cached_paths++; + v = p; + while ( + likely(L[v] == NULL_LIKELIHOOD) && likely(L_cache[v] == CACHE_UNSET)) { + v = parent[v]; + } + L_p = L_cache[v]; + if (unlikely(L_p == CACHE_UNSET)) { + L_p = L[v]; + } + /* Fill in the L cache */ + v = p; + while ( + likely(L[v] == NULL_LIKELIHOOD) && likely(L_cache[v] == CACHE_UNSET)) { + L_cache[v] = L_p; + v = parent[v]; + } + /* If the likelihood for the parent is equal to the child we can + * delete the child likelihood */ + if (L[u] == L_p) { + L[u] = NULL_LIKELIHOOD; + } + } + if (L[u] >= 0) { + L_nodes[num_likelihood_nodes] = L_nodes[j]; + num_likelihood_nodes++; + } + } + /* ancestor_matcher2_print_state(self, stdout); */ + assert(num_likelihood_nodes > 0); + + self->num_likelihood_nodes = num_likelihood_nodes; + /* Reset the L cache */ + for (j = 0; j < num_cached_paths; j++) { + v = cached_paths[j]; + while (likely(v != NULL_NODE) && likely(L_cache[v] != CACHE_UNSET)) { + L_cache[v] = CACHE_UNSET; + v = parent[v]; + } + } + + return ret; +} + +static int +ancestor_matcher2_update_site_state(ancestor_matcher2_t *self, const tsk_id_t site, + const allele_t state, tsk_id_t *restrict parent, double *restrict L, + double *restrict L_cache) +{ + int ret = 0; + mutation_list_node_t *mutation = self->matcher_indexes->sites.mutations[site]; + tsk_id_t u; + + assert(self->num_likelihood_nodes > 0); + + if (self->flags & TSI_EXTENDED_CHECKS) { + ancestor_matcher2_check_state(self); + } + for (mutation = self->matcher_indexes->sites.mutations[site]; mutation != NULL; + mutation = mutation->next) { + /* Insert a new L-value for the mutation node if needed */ + if (L[mutation->node] == NULL_LIKELIHOOD) { + u = mutation->node; + while (L[u] == NULL_LIKELIHOOD) { + u = parent[u]; + assert(u != NULL_NODE); + } + L[mutation->node] = L[u]; + self->likelihood_nodes[self->num_likelihood_nodes] = mutation->node; + self->num_likelihood_nodes++; + } + } + ret = ancestor_matcher2_update_site_likelihood_values(self, site, state, parent, L); + if (ret != 0) { + goto out; + } + ret = ancestor_matcher2_store_traceback(self, site); + if (ret != 0) { + goto out; + } + ret = ancestor_matcher2_coalesce_likelihoods(self, parent, L, L_cache); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +static void +ancestor_matcher2_reset_tree(ancestor_matcher2_t *self) +{ + memset(self->parent, 0xff, self->num_nodes * sizeof(*self->parent)); + memset(self->left_child, 0xff, self->num_nodes * sizeof(*self->left_child)); + memset(self->right_child, 0xff, self->num_nodes * sizeof(*self->right_child)); + memset(self->left_sib, 0xff, self->num_nodes * sizeof(*self->left_sib)); + memset(self->right_sib, 0xff, self->num_nodes * sizeof(*self->right_sib)); + memset(self->recombination_required, 0xff, + self->num_nodes * sizeof(*self->recombination_required)); +} + +static int +ancestor_matcher2_reset(ancestor_matcher2_t *self) +{ + int ret = 0; + + memset(self->allelic_state, 0xff, self->num_nodes * sizeof(*self->allelic_state)); + + if (self->traceback_allocator.total_size > self->traceback_realloc_size) { + tsk_blkalloc_free(&self->traceback_allocator); + ret = tsk_blkalloc_init(&self->traceback_allocator, self->traceback_block_size); + if (ret != 0) { + goto out; + } + } else { + ret = tsk_blkalloc_reset(&self->traceback_allocator); + if (ret != 0) { + goto out; + } + } + self->total_traceback_size = 0; + self->num_likelihood_nodes = 0; + ancestor_matcher2_reset_tree(self); +out: + return ret; +} + +/* Resets the recombination_required array from the traceback at the specified site. + */ +static inline void +ancestor_matcher2_set_recombination_required( + ancestor_matcher2_t *self, tsk_id_t site, int8_t *restrict recombination_required) +{ + int j; + const int8_t *restrict R = self->traceback[site].recombination_required; + const tsk_id_t *restrict node = self->traceback[site].node; + const int size = self->traceback[site].size; + + /* We always set recombination_required for node 0 to false for the cases + * where no recombination is needed at a particular site (which are + * encoded by a traceback of size 0) */ + recombination_required[0] = 0; + for (j = 0; j < size; j++) { + recombination_required[node[j]] = R[j]; + } +} + +/* Unsets the likelihood array from the traceback at the specified site. + */ +static inline void +ancestor_matcher2_unset_recombination_required( + ancestor_matcher2_t *self, tsk_id_t site, int8_t *restrict recombination_required) +{ + int j; + const tsk_id_t *restrict node = self->traceback[site].node; + const int size = self->traceback[site].size; + + for (j = 0; j < size; j++) { + recombination_required[node[j]] = -1; + } + recombination_required[0] = -1; +} + +static int WARN_UNUSED +ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, + const allele_t *TSK_UNUSED(haplotype), allele_t *restrict match, + size_t *path_length_out, tsk_id_t *restrict path_left, tsk_id_t *restrict path_right, + tsk_id_t *restrict path_parent) +{ + int ret = 0; + tsk_id_t site; + edge_t edge; + tsk_id_t u, v, max_likelihood_node; + coordinate_t left, right, pos, start_pos, end_pos; + tsk_id_t *restrict parent = self->parent; + allele_t *restrict allelic_state = self->allelic_state; + int8_t *restrict recombination_required = self->recombination_required; + const edge_t *restrict in = self->matcher_indexes->right_index_edges; + const edge_t *restrict out = self->matcher_indexes->left_index_edges; + const coordinate_t *restrict sites_position = self->matcher_indexes->sites.position; + const coordinate_t sequence_length = sites_position[self->num_sites]; + int_fast32_t in_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; + int_fast32_t out_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; + size_t path_length = 0; + + start_pos = start == 0 ? 0 : sites_position[start]; + end_pos = sites_position[end]; + + /* Prepare for the traceback and get the memory ready for recording + * the output edges. */ + path_right[path_length] = end_pos; + path_parent[path_length] = NULL_NODE; + + max_likelihood_node = self->max_likelihood_node[end - 1]; + assert(max_likelihood_node != NULL_NODE); + path_parent[path_length] = max_likelihood_node; + assert(path_parent[path_length] != NULL_NODE); + + /* Now go through the trees in reverse and run the traceback */ + memset(parent, 0xff, self->num_nodes * sizeof(*parent)); + memset( + recombination_required, 0xff, self->num_nodes * sizeof(*recombination_required)); + pos = sequence_length; + site = (tsk_id_t) self->num_sites - 1; + + while (pos > start_pos) { + while (out_index >= 0 && out[out_index].left == pos) { + edge = out[out_index]; + out_index--; + parent[edge.child] = NULL_NODE; + } + while (in_index >= 0 && in[in_index].right == pos) { + edge = in[in_index]; + in_index--; + parent[edge.child] = edge.parent; + } + right = pos; + left = 0; + if (out_index >= 0) { + left = TSK_MAX(left, out[out_index].left); + } + if (in_index >= 0) { + left = TSK_MAX(left, in[in_index].right); + } + pos = left; + + /* The tree is ready; perform the traceback at each site in this tree */ + assert(left < right); + for (; site >= 0 && left <= sites_position[site] && sites_position[site] < right; + site--) { + if (start_pos <= sites_position[site] && sites_position[site] < end_pos) { + + ancestor_matcher2_set_allelic_state(self, site, allelic_state); + u = path_parent[path_length]; + v = u; + while (allelic_state[v] == TSK_NULL) { + v = parent[v]; + } + match[site] = allelic_state[v]; + ancestor_matcher2_unset_allelic_state(self, site, allelic_state); + + /* Mark the traceback nodes on the tree */ + ancestor_matcher2_set_recombination_required( + self, site, recombination_required); + + /* Traverse up the tree from the current node. The first marked node that + * we meed tells us whether we need to recombine */ + while (u != 0 && recombination_required[u] == -1) { + u = parent[u]; + assert(u != NULL_NODE); + } + if (recombination_required[u] && site > start) { + max_likelihood_node = self->max_likelihood_node[site - 1]; + assert(max_likelihood_node != NULL_NODE); + path_left[path_length] = sites_position[site]; + path_length++; + /* Start the next output edge */ + path_right[path_length] = sites_position[site]; + path_parent[path_length] = max_likelihood_node; + } + /* Unset the values in the tree for the next site. */ + ancestor_matcher2_unset_recombination_required( + self, site, recombination_required); + } + } + } + + path_left[path_length] = start_pos; + path_length++; + assert(path_right[path_length - 1] != start); + *path_length_out = path_length; + return ret; +} + +static int +ancestor_matcher2_run_forwards_match( + ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, const allele_t *haplotype) +{ + int ret = 0; + tsk_id_t site; + edge_t edge; + tsk_id_t u, root, last_root; + double L_child = 0; + /* Use the restrict keyword here to try to improve performance by avoiding + * unecessary loads. We must be very careful to to ensure that all references + * to this memory for the duration of this function is through these variables. + */ + double *restrict L = self->likelihood; + double *restrict L_cache = self->likelihood_cache; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict left_child = self->left_child; + tsk_id_t *restrict right_child = self->right_child; + tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *restrict right_sib = self->right_sib; + coordinate_t pos, left, right; + const edge_t *restrict in = self->matcher_indexes->left_index_edges; + const edge_t *restrict out = self->matcher_indexes->right_index_edges; + const int_fast32_t M = (tsk_id_t) self->matcher_indexes->num_edges; + const coordinate_t *restrict sites_position = self->matcher_indexes->sites.position; + int_fast32_t in_index, out_index, l, remove_start; + const coordinate_t start_pos = sites_position[start]; + const coordinate_t end_pos = sites_position[end]; + const coordinate_t sequence_length = sites_position[self->num_sites]; + + /* Load the tree for start */ + left = 0; + pos = 0; + in_index = 0; + out_index = 0; + right = sequence_length; + if (in_index < M && start_pos < in[in_index].left) { + right = in[in_index].left; + } + + /* TODO don't add all edges trees but only insert edges that intersect + * with start_pos. Maybe a reasonable gain for short ancestral fragments */ + while (in_index < M && out_index < M && in[in_index].left <= start_pos) { + while (out_index < M && out[out_index].right == pos) { + remove_edge( + out[out_index], parent, left_child, right_child, left_sib, right_sib); + out_index++; + } + while (in_index < M && in[in_index].left == pos) { + insert_edge( + in[in_index], parent, left_child, right_child, left_sib, right_sib); + in_index++; + } + left = pos; + right = sequence_length; + if (in_index < M) { + right = TSK_MIN(right, in[in_index].left); + } + if (out_index < M) { + right = TSK_MIN(right, out[out_index].right); + } + pos = right; + } + + /* Insert the initial likelihoods. All non-zero roots are marked with a + * special value so we can identify them when the enter the tree */ + L_cache[0] = CACHE_UNSET; + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + L_cache[u] = CACHE_UNSET; + if (parent[u] != NULL_NODE) { + L[u] = NULL_LIKELIHOOD; + } else { + L[u] = NONZERO_ROOT_LIKELIHOOD; + } + } + if (self->flags & TSI_EXTENDED_CHECKS) { + ancestor_matcher2_check_state(self); + } + last_root = 0; + if (left_child[0] != NULL_NODE) { + last_root = left_child[0]; + assert(right_sib[last_root] == NULL_NODE); + } + L[last_root] = 1.0; + self->likelihood_nodes[0] = last_root; + self->num_likelihood_nodes = 1; + + for (site = 0; sites_position[site] < left; site++) + ; + + remove_start = out_index; + while (left < end_pos) { + assert(left < right); + + /* Remove the likelihoods for any nonzero roots that have just left + * the tree */ + for (l = remove_start; l < out_index; l++) { + edge = out[l]; + if (unlikely(is_nonzero_root(edge.child, parent, left_child))) { + if (L[edge.child] >= 0) { + ancestor_matcher2_delete_likelihood(self, edge.child, L); + } + L[edge.child] = NONZERO_ROOT_LIKELIHOOD; + } + if (unlikely(is_nonzero_root(edge.parent, parent, left_child))) { + if (L[edge.parent] >= 0) { + ancestor_matcher2_delete_likelihood(self, edge.parent, L); + } + L[edge.parent] = NONZERO_ROOT_LIKELIHOOD; + } + } + + root = 0; + if (left_child[0] != NULL_NODE) { + root = left_child[0]; + assert(right_sib[root] == NULL_NODE); + } + if (root != last_root) { + if (last_root == 0) { + ancestor_matcher2_delete_likelihood(self, last_root, L); + L[last_root] = NONZERO_ROOT_LIKELIHOOD; + } + if (L[root] == NONZERO_ROOT_LIKELIHOOD) { + L[root] = 0; + self->likelihood_nodes[self->num_likelihood_nodes] = root; + self->num_likelihood_nodes++; + } + last_root = root; + } + + if (self->flags & TSI_EXTENDED_CHECKS) { + ancestor_matcher2_check_state(self); + } + + while (left <= sites_position[site] + && sites_position[site] < TSK_MIN(right, end_pos)) { + ret = ancestor_matcher2_update_site_state( + self, site, haplotype[site], parent, L, L_cache); + if (ret != 0) { + goto out; + } + site++; + } + + /* Move on to the next tree */ + remove_start = out_index; + while (out_index < M && out[out_index].right == right) { + edge = out[out_index]; + out_index++; + remove_edge(edge, parent, left_child, right_child, left_sib, right_sib); + assert(L[edge.child] != NONZERO_ROOT_LIKELIHOOD); + if (L[edge.child] == NULL_LIKELIHOOD) { + u = edge.parent; + while (likely(L[u] == NULL_LIKELIHOOD) + && likely(L_cache[u] == CACHE_UNSET)) { + u = parent[u]; + } + L_child = L_cache[u]; + if (unlikely(L_child == CACHE_UNSET)) { + L_child = L[u]; + } + assert(L_child >= 0); + u = edge.parent; + /* Fill in the cache by traversing back upwards */ + while (likely(L[u] == NULL_LIKELIHOOD) + && likely(L_cache[u] == CACHE_UNSET)) { + L_cache[u] = L_child; + u = parent[u]; + } + L[edge.child] = L_child; + self->likelihood_nodes[self->num_likelihood_nodes] = edge.child; + self->num_likelihood_nodes++; + } + } + /* reset the L cache */ + for (l = remove_start; l < out_index; l++) { + edge = out[l]; + u = edge.parent; + while (likely(L_cache[u] != CACHE_UNSET)) { + L_cache[u] = CACHE_UNSET; + u = parent[u]; + } + } + + left = right; + while (in_index < M && in[in_index].left == left) { + edge = in[in_index]; + in_index++; + insert_edge(edge, parent, left_child, right_child, left_sib, right_sib); + /* Insert zero likelihoods for any nonzero roots that have entered + * the tree. Note we don't bother trying to compress the tree here + * because this will be done for the next site anyway. */ + if (unlikely( + edge.parent != 0 && L[edge.parent] == NONZERO_ROOT_LIKELIHOOD)) { + L[edge.parent] = 0; + self->likelihood_nodes[self->num_likelihood_nodes] = edge.parent; + self->num_likelihood_nodes++; + } + if (unlikely(L[edge.child] == NONZERO_ROOT_LIKELIHOOD)) { + L[edge.child] = 0; + self->likelihood_nodes[self->num_likelihood_nodes] = edge.child; + self->num_likelihood_nodes++; + } + } + right = sequence_length; + if (in_index < M) { + right = TSK_MIN(right, in[in_index].left); + } + if (out_index < M) { + right = TSK_MIN(right, out[out_index].right); + } + } +out: + return ret; +} + +int +ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, + const allele_t *haplotype, allele_t *matched_haplotype, size_t *path_length, + tsk_id_t *path_left, tsk_id_t *path_right, tsk_id_t *path_parent) +{ + int ret = 0; + + ret = ancestor_matcher2_reset(self); + if (ret != 0) { + goto out; + } + ret = ancestor_matcher2_run_forwards_match(self, start, end, haplotype); + if (ret != 0) { + goto out; + } + ret = ancestor_matcher2_run_traceback(self, start, end, haplotype, matched_haplotype, + path_length, path_left, path_right, path_parent); + if (ret != 0) { + goto out; + } + /* Reset some memory for the next call */ + memset( + self->traceback + start, 0, ((size_t)(end - start)) * sizeof(*self->traceback)); + memset(self->max_likelihood_node + start, 0xff, + ((size_t)(end - start)) * sizeof(*self->max_likelihood_node)); + +out: + return ret; +} + +double +ancestor_matcher2_get_mean_traceback_size(ancestor_matcher2_t *self) +{ + return (double) self->total_traceback_size / ((double) self->num_sites); +} + +size_t +ancestor_matcher2_get_total_memory(ancestor_matcher2_t *self) +{ + size_t total = self->traceback_allocator.total_size; + /* TODO add contributions from other objects */ + + return total; +} diff --git a/lib/test_data/multi_tree_example.trees b/lib/test_data/multi_tree_example.trees new file mode 100644 index 00000000..b783471d Binary files /dev/null and b/lib/test_data/multi_tree_example.trees differ diff --git a/lib/test_data/single_tree_example.trees b/lib/test_data/single_tree_example.trees new file mode 100644 index 00000000..991f94d0 Binary files /dev/null and b/lib/test_data/single_tree_example.trees differ diff --git a/lib/tests/tests.c b/lib/tests/tests.c index 5ded462c..72ee75b5 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -31,11 +31,39 @@ #include +/* FIXME this needs to be updated somehow to allow the tests to be run from + * different directories, i.e., with ninja -C build test + */ +#define TEST_DATA_DIR "test_data" + /* Global variables used for test in state in the test suite */ char *_tmp_file_name; FILE *_devnull; +tsk_treeseq_t _single_tree_ex_ts; +/* 3.00┊ 0 ┊ */ +/* ┊ ┃ ┊ */ +/* 2.00┊ 7 ┊ */ +/* ┊ ┏━┻━┓ ┊ */ +/* 1.00┊ 5 6 ┊ */ +/* ┊ ┏┻┓ ┏┻┓ ┊ */ +/* 0.00┊ 1 2 3 4 ┊ */ +/* 0 4 */ +tsk_treeseq_t _multi_tree_ex_ts; +/* 1.84┊ 0 ┊ 0 ┊ */ +/* ┊ ┃ ┊ ┃ ┊ */ +/* 0.84┊ 8 ┊ 8 ┊ */ +/* ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ */ +/* 0.42┊ ┃ ┃ ┊ 7 ┃ ┊ */ +/* ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ */ +/* 0.05┊ 6 ┃ ┊ ┃ ┃ ┃ ┊ */ +/* ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ */ +/* 0.04┊ ┃ 5 ┃ ┊ ┃ ┃ 5 ┊ */ +/* ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ */ +/* 0.00┊ 1 2 3 4 ┊ 1 4 2 3 ┊ */ +/* 0 2 4 */ + static void dump_tree_sequence_builder( tree_sequence_builder_t *tsb, tsk_table_collection_t *tables, tsk_flags_t options) @@ -985,6 +1013,117 @@ test_packbits_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSI_ERR_ONE_BIT_NON_BINARY); } +static int +run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, + allele_t *match, tsk_size_t *path_length, tsk_id_t *left, tsk_id_t *right, + tsk_id_t *parent) +{ + int ret; + ancestor_matcher2_t am; + matcher_indexes_t mi; + const size_t m = tsk_treeseq_get_num_sites(ts); + double *recombination_rate = calloc(m, sizeof(*recombination_rate)); + double *mutation_rate = calloc(m, sizeof(*mutation_rate)); + size_t j; + + CU_ASSERT_FATAL(recombination_rate != NULL); + CU_ASSERT_FATAL(mutation_rate != NULL); + for (j = 0; j < m; j++) { + mutation_rate[j] = mu; + recombination_rate[j] = rho; + } + + ret = matcher_indexes_alloc(&mi, ts->tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* matcher_indexes_print_state(&mi, stdout); */ + ret = ancestor_matcher2_alloc(&am, &mi, recombination_rate, mutation_rate, 14, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = ancestor_matcher2_find_path( + &am, 0, (tsk_id_t) m, h, match, path_length, left, right, parent); + CU_ASSERT_EQUAL_FATAL(ret, 0); + /* ancestor_matcher2_print_state(&am, stdout); */ + + ancestor_matcher2_free(&am); + matcher_indexes_free(&mi); + free(recombination_rate); + free(mutation_rate); + + return 0; +} + +static void +check_matching_single_site_match(const tsk_treeseq_t *ts) +{ + allele_t h[] = { 0, 0, 0, 0 }; + allele_t match[4]; + tsk_id_t j, left[4], right[4], parent[4]; + tsk_size_t path_length; + + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_sites(ts), 4); + + for (j = 0; j < 4; j++) { + memset(h, 0, sizeof(h)); + h[j] = 1; + run_match(ts, 1e-8, 0, h, match, &path_length, left, right, parent); + CU_ASSERT_EQUAL_FATAL(path_length, 1); + CU_ASSERT_EQUAL_FATAL(left[0], 0); + CU_ASSERT_EQUAL_FATAL(right[0], 4); + CU_ASSERT_EQUAL_FATAL(parent[0], j + 1); + } +} + +static void +test_matching_single_tree_single_site_match(void) +{ + check_matching_single_site_match(&_single_tree_ex_ts); +} + +static void +test_matching_multi_tree_single_site_match(void) +{ + check_matching_single_site_match(&_multi_tree_ex_ts); +} + +static void +check_matching_multi_switch(const tsk_treeseq_t *ts) +{ + allele_t h[] = { 1, 1, 1, 1 }; + allele_t match[4]; + tsk_id_t left[4], right[4], parent[4]; + tsk_size_t path_length; + + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_sites(ts), 4); + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_sequence_length(ts), 4); + + run_match(ts, 1e-8, 0, h, match, &path_length, left, right, parent); + CU_ASSERT_EQUAL_FATAL(path_length, 4); + CU_ASSERT_EQUAL_FATAL(left[3], 0); + CU_ASSERT_EQUAL_FATAL(right[3], 1); + CU_ASSERT_EQUAL_FATAL(parent[3], 1); + CU_ASSERT_EQUAL_FATAL(left[2], 1); + CU_ASSERT_EQUAL_FATAL(right[2], 2); + CU_ASSERT_EQUAL_FATAL(parent[2], 2); + CU_ASSERT_EQUAL_FATAL(left[1], 2); + CU_ASSERT_EQUAL_FATAL(right[1], 3); + CU_ASSERT_EQUAL_FATAL(parent[1], 3); + CU_ASSERT_EQUAL_FATAL(left[0], 3); + CU_ASSERT_EQUAL_FATAL(right[0], 4); + CU_ASSERT_EQUAL_FATAL(parent[0], 4); +} + +static void +test_matching_single_tree_multi_switch(void) +{ + check_matching_multi_switch(&_single_tree_ex_ts); +} + +static void +test_matching_multi_tree_multi_switch(void) +{ + check_matching_multi_switch(&_multi_tree_ex_ts); +} + static void test_strerror(void) { @@ -1004,11 +1143,13 @@ test_strerror(void) static int tsinfer_suite_init(void) { - int fd; + int ret, fd; static char template[] = "/tmp/tsi_c_test_XXXXXX"; _tmp_file_name = NULL; _devnull = NULL; + memset(&_single_tree_ex_ts, 0, sizeof(_single_tree_ex_ts)); + memset(&_multi_tree_ex_ts, 0, sizeof(_multi_tree_ex_ts)); _tmp_file_name = malloc(sizeof(template)); if (_tmp_file_name == NULL) { @@ -1024,6 +1165,18 @@ tsinfer_suite_init(void) if (_devnull == NULL) { return CUE_SINIT_FAILED; } + + ret = tsk_treeseq_load( + &_single_tree_ex_ts, TEST_DATA_DIR "/single_tree_example.trees", 0); + if (ret != 0) { + return CUE_SINIT_FAILED; + } + ret = tsk_treeseq_load( + &_multi_tree_ex_ts, TEST_DATA_DIR "/multi_tree_example.trees", 0); + if (ret != 0) { + return CUE_SINIT_FAILED; + } + return CUE_SUCCESS; } @@ -1037,6 +1190,8 @@ tsinfer_suite_cleanup(void) if (_devnull != NULL) { fclose(_devnull); } + tsk_treeseq_free(&_single_tree_ex_ts); + tsk_treeseq_free(&_multi_tree_ex_ts); return CUE_SUCCESS; } @@ -1077,6 +1232,15 @@ main(int argc, char **argv) { "test_packbits_4", test_packbits_4 }, { "test_packbits_errors", test_packbits_errors }, + { "test_matching_single_tree_single_site_match", + test_matching_single_tree_single_site_match }, + { "test_matching_multi_tree_single_site_match", + test_matching_multi_tree_single_site_match }, + { "test_matching_single_tree_multi_switch", + test_matching_single_tree_multi_switch }, + { "test_matching_multi_tree_multi_switch", + test_matching_multi_tree_multi_switch }, + { "test_strerror", test_strerror }, CU_TEST_INFO_NULL, diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 628a5519..684f17ac 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -46,10 +46,12 @@ #define TSI_NODE_IS_PC_ANCESTOR ((tsk_flags_t)(1u << 16)) typedef int8_t allele_t; +/* TODO should probably change to uint32 when we have removed the old code.*/ +typedef tsk_id_t coordinate_t; typedef struct { - tsk_id_t left; - tsk_id_t right; + coordinate_t left; + coordinate_t right; tsk_id_t parent; tsk_id_t child; } edge_t; @@ -209,6 +211,54 @@ typedef struct { } output; } ancestor_matcher_t; +typedef struct { + tsk_flags_t flags; + size_t num_sites; + size_t num_nodes; + size_t num_mutations; + size_t num_edges; + struct { + coordinate_t *position; + mutation_list_node_t **mutations; + tsk_size_t *num_alleles; + } sites; + edge_t *left_index_edges; + edge_t *right_index_edges; + tsk_blkalloc_t allocator; +} matcher_indexes_t; + +typedef struct { + int flags; + const matcher_indexes_t *matcher_indexes; + size_t num_nodes; + size_t num_sites; + /* Input LS model rates */ + unsigned int precision; + double *recombination_rate; + double *mismatch_rate; + /* The quintuply linked tree */ + tsk_id_t *parent; + tsk_id_t *left_child; + tsk_id_t *right_child; + tsk_id_t *left_sib; + tsk_id_t *right_sib; + double *likelihood; + double *likelihood_cache; + allele_t *allelic_state; + int num_likelihood_nodes; + /* At each site, record a node with the maximum likelihood. */ + tsk_id_t *max_likelihood_node; + /* Used during traceback to map nodes where recombination is required. */ + int8_t *recombination_required; + tsk_id_t *likelihood_nodes_tmp; + tsk_id_t *likelihood_nodes; + node_state_list_t *traceback; + tsk_blkalloc_t traceback_allocator; + size_t total_traceback_size; + size_t traceback_block_size; + size_t traceback_realloc_size; +} ancestor_matcher2_t; + int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples, size_t num_sites, int mmap_fd, int flags); int ancestor_builder_free(ancestor_builder_t *self); @@ -267,6 +317,24 @@ int tree_sequence_builder_dump_edges(tree_sequence_builder_t *self, tsk_id_t *le int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t *site, tsk_id_t *node, allele_t *derived_state, tsk_id_t *parent); +/* New impelementation */ + +int matcher_indexes_alloc( + matcher_indexes_t *self, const tsk_table_collection_t *tables, tsk_flags_t options); +int matcher_indexes_print_state(const matcher_indexes_t *self, FILE *out); +int matcher_indexes_free(matcher_indexes_t *self); + +int ancestor_matcher2_alloc(ancestor_matcher2_t *self, + const matcher_indexes_t *matcher_indexes, double *recombination_rate, + double *mismatch_rate, unsigned int precision, int flags); +int ancestor_matcher2_free(ancestor_matcher2_t *self); +int ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, + const allele_t *haplotype, allele_t *matched_haplotype, size_t *path_length, + tsk_id_t *path_left, tsk_id_t *path_right, tsk_id_t *path_parent); +int ancestor_matcher2_print_state(ancestor_matcher2_t *self, FILE *out); +double ancestor_matcher2_get_mean_traceback_size(ancestor_matcher2_t *self); +size_t ancestor_matcher2_get_total_memory(ancestor_matcher2_t *self); + int packbits(const allele_t *restrict source, size_t len, uint8_t *restrict dest); void unpackbits(const uint8_t *restrict source, size_t len, allele_t *restrict dest); diff --git a/lwt_interface b/lwt_interface new file mode 120000 index 00000000..30dc544d --- /dev/null +++ b/lwt_interface @@ -0,0 +1 @@ +git-submodules/tskit/python/lwt_interface/ \ No newline at end of file diff --git a/setup.py b/setup.py index f0579d28..6c0f1a31 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ tskroot = os.path.join(libdir, "subprojects", "tskit") tskdir = os.path.join(tskroot, "tskit") kasdir = os.path.join(tskroot, "subprojects", "kastore") -includes = [libdir, tskroot, tskdir, kasdir] +includes = ["lwt_interface", libdir, tskroot, tskdir, kasdir] tsi_source_files = [ "ancestor_matcher.c", @@ -24,7 +24,7 @@ ] # We're not actually using very much of tskit at the moment, so # just build the stuff we need. -tsk_source_files = ["core.c"] +tsk_source_files = ["core.c", "tables.c", "trees.c"] kas_source_files = ["kastore.c"] sources = ( diff --git a/tests/test_low_level.py b/tests/test_low_level.py index 550a8994..36fd5050 100644 --- a/tests/test_low_level.py +++ b/tests/test_low_level.py @@ -20,8 +20,10 @@ Integrity tests for the low-level module. """ import sys +import tempfile import pytest +import tskit import _tsinfer @@ -151,3 +153,30 @@ def test_add_too_many_sites(self): assert str(record.value) == msg # TODO need tester methods for the remaining methonds in the class. + + +class TestMatcherIndexes: + def test_single_tree(self): + ts = tskit.Tree.generate_balanced(4).tree_sequence + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + mi = _tsinfer.MatcherIndexes(ll_tables) + print(mi) + mi.print_state(sys.stdout) + + def test_print_state(self): + ts = tskit.Tree.generate_balanced(4).tree_sequence + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + mi = _tsinfer.MatcherIndexes(ll_tables) + with pytest.raises(TypeError): + mi.print_state() + + with tempfile.TemporaryFile("w+") as f: + mi.print_state(f) + f.seek(0) + output = f.read() + assert len(output) > 0 + assert "indexes" in output diff --git a/tests/test_matching.py b/tests/test_matching.py new file mode 100644 index 00000000..e4352c8c --- /dev/null +++ b/tests/test_matching.py @@ -0,0 +1,845 @@ +""" +Tests for the haplotype matching algorithm. +""" +import collections +import dataclasses +import io +import pickle + +import msprime +import numpy as np +import pytest +import tskit + +import _tsinfer +import tsinfer +from tsinfer import matching + + +@dataclasses.dataclass +class Edge: + left: int = dataclasses.field(default=None) + right: int = dataclasses.field(default=None) + parent: int = dataclasses.field(default=None) + child: int = dataclasses.field(default=None) + + +# Special values used to indicate compressed paths and nodes that are +# not present in the current tree. + + +def convert_edge_list(edges, order): + values = [] + for j in order: + tsk_edge = edges[j] + edge = Edge( + int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child + ) + values.append(edge) + return values + + +class MatcherIndexes: + """ + The memory that can be shared between AncestorMatcher instances. + """ + + def __init__(self, in_tables): + ts = matching.add_vestigial_root(in_tables.tree_sequence()) + tables = ts.dump_tables() + + self.sequence_length = tables.sequence_length + self.num_nodes = len(tables.nodes) + self.num_sites = len(tables.sites) + + # Store the edges in left and right order. + self.left_index = convert_edge_list( + tables.edges, tables.indexes.edge_insertion_order + ) + self.right_index = convert_edge_list( + tables.edges, tables.indexes.edge_removal_order + ) + + # TODO fixme + self.num_alleles = np.zeros(self.num_sites, dtype=int) + 2 + self.sites_position = np.zeros(ts.num_sites + 1, dtype=np.uint32) + self.sites_position[:-1] = tables.sites.position + self.sites_position[-1] = tables.sequence_length + self.mutations = collections.defaultdict(list) + last_site = -1 + for mutation in tables.mutations: + if last_site == mutation.site: + raise ValueError("Only single mutations supported for now") + # FIXME - should be allele index + self.mutations[mutation.site].append((mutation.node, 1)) + last_site = mutation.site + + +COMPRESSED = -1 +NONZERO_ROOT = -2 + + +class AncestorMatcher: + def __init__( + self, + matcher_indexes, + recombination=None, + mismatch=None, + precision=None, + extended_checks=False, + ): + self.matcher_indexes = matcher_indexes + self.num_sites = matcher_indexes.num_sites + self.num_nodes = matcher_indexes.num_nodes + self.mismatch = mismatch + self.recombination = recombination + self.precision = 22 + self.extended_checks = extended_checks + self.parent = None + self.left_child = None + self.right_sib = None + self.traceback = None + self.max_likelihood_node = None + self.likelihood = None + self.likelihood_nodes = None + self.allelic_state = None + self.total_memory = 0 + + def print_state(self): + # TODO - don't crash when self.max_likelihood_node or self.traceback == None + print("Ancestor matcher state") + print("max_L_node\ttraceback") + for site_index in range(self.num_sites): + print( + site_index, + self.max_likelihood_node[site_index], + self.traceback[site_index], + sep="\t", + ) + + def is_root(self, u): + return self.parent[u] == tskit.NULL + + def check_likelihoods(self): + assert len(set(self.likelihood_nodes)) == len(self.likelihood_nodes) + # Every value in L_nodes must be positive. + for u in self.likelihood_nodes: + assert self.likelihood[u] >= 0 + for u, v in enumerate(self.likelihood): + # Every non-negative value in L should be in L_nodes + if v >= 0: + assert u in self.likelihood_nodes + # Roots other than 0 should have v == -2 + if u != 0 and self.is_root(u) and self.left_child[u] == -1: + # print("root: u = ", u, self.parent[u], self.left_child[u]) + assert v == -2 + + def set_allelic_state(self, site): + """ + Sets the allelic state array to reflect the mutations at this site. + """ + # We know that 0 is always a root. + # FIXME assuming for now that the ancestral state is always zero. + self.allelic_state[0] = 0 + for node, state in self.matcher_indexes.mutations[site]: + self.allelic_state[node] = state + + def unset_allelic_state(self, site): + """ + Sets the allelic state values for this site back to null. + """ + # We know that 0 is always a root. + self.allelic_state[0] = -1 + for node, _ in self.matcher_indexes.mutations[site]: + self.allelic_state[node] = -1 + assert np.all(self.allelic_state == -1) + + def update_site(self, site, haplotype_state): + n = self.num_nodes + rho = self.recombination[site] + mu = self.mismatch[site] + num_alleles = self.matcher_indexes.num_alleles[site] + assert haplotype_state < num_alleles + + self.set_allelic_state(site) + + for node, _ in self.matcher_indexes.mutations[site]: + # Insert an new L-value for the mutation node if needed. + if self.likelihood[node] == COMPRESSED: + u = node + while self.likelihood[u] == COMPRESSED: + u = self.parent[u] + self.likelihood[node] = self.likelihood[u] + self.likelihood_nodes.append(node) + + max_L = -1 + max_L_node = -1 + for u in self.likelihood_nodes: + # Get the allelic_state at u. TODO we can cache these states to + # avoid some upward traversals. + v = u + while self.allelic_state[v] == -1: + v = self.parent[v] + assert v != -1 + + p_last = self.likelihood[u] + p_no_recomb = p_last * (1 - rho + rho / n) + p_recomb = rho / n + recombination_required = False + if p_no_recomb > p_recomb: + p_t = p_no_recomb + else: + p_t = p_recomb + recombination_required = True + self.traceback[site][u] = recombination_required + p_e = mu + if haplotype_state in (tskit.MISSING_DATA, self.allelic_state[v]): + p_e = 1 - (num_alleles - 1) * mu + self.likelihood[u] = p_t * p_e + + if self.likelihood[u] > max_L: + max_L = self.likelihood[u] + max_L_node = u + + if max_L == 0: + if mu == 0: + raise _tsinfer.MatchImpossible( + "Trying to match non-existent allele with zero mismatch rate" + ) + elif mu == 1: + raise _tsinfer.MatchImpossible( + "Match impossible: mismatch prob=1 & no haplotype with other allele" + ) + elif rho == 0: + raise _tsinfer.MatchImpossible( + "Matching failed with recombination=0, potentially due to " + "rounding issues. Try increasing the precision value" + ) + raise AssertionError("Unexpected matching failure") + + for u in self.likelihood_nodes: + x = self.likelihood[u] / max_L + self.likelihood[u] = round(x, self.precision) + + self.max_likelihood_node[site] = max_L_node + self.unset_allelic_state(site) + self.compress_likelihoods() + + def compress_likelihoods(self): + L_cache = np.zeros_like(self.likelihood) - 1 + cached_paths = [] + old_likelihood_nodes = list(self.likelihood_nodes) + self.likelihood_nodes.clear() + for u in old_likelihood_nodes: + # We need to find the likelihood of the parent of u. If this is + # the same as u, we can delete it. + if not self.is_root(u): + p = self.parent[u] + cached_paths.append(p) + v = p + while self.likelihood[v] == -1 and L_cache[v] == -1: + v = self.parent[v] + L_p = L_cache[v] + if L_p == -1: + L_p = self.likelihood[v] + # Fill in the L cache + v = p + while self.likelihood[v] == -1 and L_cache[v] == -1: + L_cache[v] = L_p + v = self.parent[v] + + if self.likelihood[u] == L_p: + # Delete u from the map + self.likelihood[u] = -1 + if self.likelihood[u] >= 0: + self.likelihood_nodes.append(u) + # Reset the L cache + for u in cached_paths: + v = u + while v != -1 and L_cache[v] != -1: + L_cache[v] = -1 + v = self.parent[v] + assert np.all(L_cache == -1) + + def remove_edge(self, edge): + p = edge.parent + c = edge.child + lsib = self.left_sib[c] + rsib = self.right_sib[c] + if lsib == tskit.NULL: + self.left_child[p] = rsib + else: + self.right_sib[lsib] = rsib + if rsib == tskit.NULL: + self.right_child[p] = lsib + else: + self.left_sib[rsib] = lsib + self.parent[c] = tskit.NULL + self.left_sib[c] = tskit.NULL + self.right_sib[c] = tskit.NULL + + def insert_edge(self, edge): + p = edge.parent + c = edge.child + self.parent[c] = p + u = self.right_child[p] + if u == tskit.NULL: + self.left_child[p] = c + self.left_sib[c] = tskit.NULL + self.right_sib[c] = tskit.NULL + else: + self.right_sib[u] = c + self.left_sib[c] = u + self.right_sib[c] = tskit.NULL + self.right_child[p] = c + + def is_nonzero_root(self, u): + return u != 0 and self.is_root(u) and self.left_child[u] == -1 + + def zero_sites_path(self): + path = matching.Path([0], [self.matcher_indexes.sites_position[-1]], [0]) + return matching.Match(path, [], []) + + def find_path(self, h): + if self.num_sites == 0: + return self.zero_sites_path() + Il = self.matcher_indexes.left_index + Ir = self.matcher_indexes.right_index + sequence_length = self.matcher_indexes.sequence_length + sites_position = self.matcher_indexes.sites_position + M = len(Il) + n = self.num_nodes + m = self.num_sites + self.parent = np.zeros(n, dtype=int) - 1 + self.left_child = np.zeros(n, dtype=int) - 1 + self.right_child = np.zeros(n, dtype=int) - 1 + self.left_sib = np.zeros(n, dtype=int) - 1 + self.right_sib = np.zeros(n, dtype=int) - 1 + self.traceback = [{} for _ in range(m)] + self.max_likelihood_node = np.zeros(m, dtype=int) - 1 + self.allelic_state = np.zeros(n, dtype=int) - 1 + + self.likelihood = np.full(n, NONZERO_ROOT, dtype=float) + self.likelihood_nodes = [] + L_cache = np.zeros_like(self.likelihood) - 1 + + start = 0 + while start < m and h[start] == tskit.MISSING_DATA: + start += 1 + + end = m - 1 + while end >= 0 and h[end] == tskit.MISSING_DATA: + end -= 1 + end += 1 + + # print("MATCH: start=", start, "end = ", end, "h = ", h) + j = 0 + k = 0 + left = 0 + start_pos = 0 if start == 0 else sites_position[start] + end_pos = sites_position[end] + pos = 0 + right = sequence_length + if j < M and start_pos < Il[j].left: + right = Il[j].left + while j < M and k < M and Il[j].left <= start_pos: + while Ir[k].right == pos: + self.remove_edge(Ir[k]) + k += 1 + while j < M and Il[j].left == pos: + self.insert_edge(Il[j]) + j += 1 + left = pos + right = sequence_length + if j < M: + right = min(right, Il[j].left) + if k < M: + right = min(right, Ir[k].right) + pos = right + assert left < right + + for u in range(n): + if not self.is_root(u): + self.likelihood[u] = -1 + + last_root = 0 + if self.left_child[0] != -1: + last_root = self.left_child[0] + assert self.right_sib[last_root] == -1 + self.likelihood_nodes.append(last_root) + self.likelihood[last_root] = 1 + + current_site = 0 + while sites_position[current_site] < left: + current_site += 1 + + remove_start = k + while left < end_pos: + # print("START OF TREE LOOP", left, right) + # print("L:", {u: self.likelihood[u] for u in self.likelihood_nodes}) + assert left < right + for e in range(remove_start, k): + edge = Ir[e] + for u in [edge.parent, edge.child]: + if self.is_nonzero_root(u): + self.likelihood[u] = NONZERO_ROOT + if u in self.likelihood_nodes: + self.likelihood_nodes.remove(u) + root = 0 + if self.left_child[0] != -1: + root = self.left_child[0] + assert self.right_sib[root] == -1 + + if root != last_root: + if last_root == 0: + self.likelihood[last_root] = NONZERO_ROOT + self.likelihood_nodes.remove(last_root) + if self.likelihood[root] == NONZERO_ROOT: + self.likelihood[root] = 0 + self.likelihood_nodes.append(root) + last_root = root + + if self.extended_checks: + self.check_likelihoods() + + while left <= sites_position[current_site] < min(right, end_pos): + self.update_site(current_site, h[current_site]) + current_site += 1 + + remove_start = k + while k < M and Ir[k].right == right: + edge = Ir[k] + self.remove_edge(edge) + k += 1 + if self.likelihood[edge.child] == -1: + # If the child has an L value, traverse upwards until we + # find the parent that carries it. To avoid repeated traversals + # along the same path we make a cache of the L values. + u = edge.parent + while self.likelihood[u] == -1 and L_cache[u] == -1: + u = self.parent[u] + L_child = L_cache[u] + if L_child == -1: + L_child = self.likelihood[u] + # Fill in the L_cache + u = edge.parent + while self.likelihood[u] == -1 and L_cache[u] == -1: + L_cache[u] = L_child + u = self.parent[u] + self.likelihood[edge.child] = L_child + self.likelihood_nodes.append(edge.child) + # Clear the L cache + for e in range(remove_start, k): + edge = Ir[e] + u = edge.parent + while L_cache[u] != -1: + L_cache[u] = -1 + u = self.parent[u] + assert np.all(L_cache == -1) + + left = right + while j < M and Il[j].left == left: + edge = Il[j] + self.insert_edge(edge) + j += 1 + # There's no point in compressing the likelihood tree here as we'll be + # doing it after we update the first site anyway. + for u in [edge.parent, edge.child]: + if u != 0 and self.likelihood[u] == NONZERO_ROOT: + self.likelihood[u] = 0 + self.likelihood_nodes.append(u) + right = sequence_length + if j < M: + right = min(right, Il[j].left) + if k < M: + right = min(right, Ir[k].right) + + return self.run_traceback(start, end, h) + + def run_traceback(self, start, end, query_haplotype): + Il = self.matcher_indexes.left_index + Ir = self.matcher_indexes.right_index + L = self.matcher_indexes.sequence_length + sites_position = self.matcher_indexes.sites_position + M = len(Il) + u = self.max_likelihood_node[end - 1] + output_edge = Edge(right=end, parent=u) + output_edges = [output_edge] + recombination_required = np.zeros(self.num_nodes, dtype=int) - 1 + + # Now go back through the trees. + j = M - 1 + k = M - 1 + start_pos = 0 if start == 0 else sites_position[start] + end_pos = sites_position[end] + # Construct the matched haplotype + match = np.zeros(self.num_sites, dtype=np.int8) + match[:start] = tskit.MISSING_DATA + match[end:] = tskit.MISSING_DATA + # Reset the tree. + self.parent[:] = -1 + self.left_child[:] = -1 + self.right_child[:] = -1 + self.left_sib[:] = -1 + self.right_sib[:] = -1 + + pos = L + site_index = self.num_sites - 1 + while pos > start_pos: + # print("Top of loop: pos = ", pos) + while k >= 0 and Il[k].left == pos: + edge = Il[k] + self.remove_edge(edge) + k -= 1 + while j >= 0 and Ir[j].right == pos: + edge = Ir[j] + self.insert_edge(edge) + j -= 1 + right = pos + left = 0 + if k >= 0: + left = max(left, Il[k].left) + if j >= 0: + left = max(left, Ir[j].right) + pos = left + + assert left < right + while left <= sites_position[site_index] < right: + if start_pos <= sites_position[site_index] < end_pos: + u = output_edge.parent + self.set_allelic_state(site_index) + v = u + while self.allelic_state[v] == -1: + v = self.parent[v] + match[site_index] = self.allelic_state[v] + self.unset_allelic_state(site_index) + + for u, recombine in self.traceback[site_index].items(): + # Mark the traceback nodes on the tree. + recombination_required[u] = recombine + # Now traverse up the tree from the current node. The first + # marked node we meet tells us whether we need to + # recombine. + u = output_edge.parent + while u != 0 and recombination_required[u] == -1: + u = self.parent[u] + if recombination_required[u] and site_index > start: + output_edge.left = site_index + u = self.max_likelihood_node[site_index - 1] + output_edge = Edge(right=site_index, parent=u) + output_edges.append(output_edge) + # Reset the nodes in the recombination tree. + for u in self.traceback[site_index].keys(): + recombination_required[u] = -1 + site_index -= 1 + + output_edge.left = start + + self.mean_traceback_size = sum(len(t) for t in self.traceback) / self.num_sites + + left = np.zeros(len(output_edges), dtype=np.uint32) + right = np.zeros(len(output_edges), dtype=np.uint32) + parent = np.zeros(len(output_edges), dtype=np.int32) + for j, e in enumerate(output_edges): + assert e.left >= start + assert e.right <= end + # TODO this does happen in the C code, so if it ever happends in a Python + # instance we need to pop the last edge off the list. Or, see why we're + # generating it in the first place. + assert e.left < e.right + left[j] = sites_position[e.left] + right[j] = sites_position[e.right] + parent[j] = e.parent + + # Convert the parent node IDs back to original values + parent -= 1 + path = matching.Path(left[::-1], right[::-1], parent[::-1]) + if start == 0 and path.left[0] == sites_position[0]: + path.left[0] = 0 + return matching.Match(path, query_haplotype, match) + + +def run_match(ts, h): + h = np.array(h).astype(np.int8) + assert len(h) == ts.num_sites + recombination = np.zeros(ts.num_sites) + 1e-9 + mismatch = np.zeros(ts.num_sites) + precision = 22 + matcher_indexes = MatcherIndexes(ts.tables) + matcher = AncestorMatcher( + matcher_indexes, + recombination=recombination, + mismatch=mismatch, + precision=precision, + ) + match_py = matcher.find_path(h) + + mi = tsinfer.MatcherIndexes(ts) + am = tsinfer.AncestorMatcher2( + mi, recombination=recombination, mismatch=mismatch, precision=precision + ) + match_c = am.find_match(h) + match_py.assert_equals(match_c) + return match_py + + +class TestMatchClassUtils: + def test_pickle(self): + m1 = matching.Match( + matching.Path(np.array([0]), np.array([1]), np.array([0])), + np.array([0]), + np.array([0]), + ) + m2 = pickle.loads(pickle.dumps(m1)) + m1.assert_equals(m2) + + +# TODO the tests on these two classes are the same right now, should +# refactor. + + +def add_unique_sample_mutations(ts, start=0): + """ + Adds a mutation for each of the samples at equally spaced locations + along the genome. + """ + tables = ts.dump_tables() + L = int(ts.sequence_length) + assert L % ts.num_samples == 0 + gap = L // ts.num_samples + x = start + for u in ts.samples(): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, derived_state="1", node=u) + x += gap + return tables.tree_sequence() + + +class TestSingleBalancedTreeExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + return add_unique_sample_mutations( + tskit.Tree.generate_balanced(4, span=8).tree_sequence + ) + + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): + ts = self.ts() + h = np.zeros(4) + h[j] = 1 + m = run_match(ts, h) + assert list(m.path.left) == [0] + assert list(m.path.right) == [ts.sequence_length] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + np.testing.assert_array_equal(h, m.query_haplotype) + + @pytest.mark.parametrize("j", [1, 2]) + def test_match_sample_missing_flanks(self, j): + ts = self.ts() + h = np.zeros(4) + h[0] = -1 + h[-1] = -1 + h[j] = 1 + m = run_match(ts, h) + assert list(m.path.left) == [2] + assert list(m.path.right) == [6] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + + def test_switch_each_sample(self): + ts = self.ts() + h = np.ones(4) + m = run_match(ts, h) + assert list(m.path.left) == [0, 2, 4, 6] + assert list(m.path.right) == [2, 4, 6, 8] + assert list(m.path.parent) == [0, 1, 2, 3] + np.testing.assert_array_equal(h, m.matched_haplotype) + + def test_switch_each_sample_missing_flanks(self): + ts = self.ts() + h = np.ones(4) + h[0] = -1 + h[-1] = -1 + m = run_match(ts, h) + assert list(m.path.left) == [2, 4] + assert list(m.path.right) == [4, 6] + assert list(m.path.parent) == [1, 2] + np.testing.assert_array_equal(h, m.matched_haplotype) + + +class TestSingleBalancedTreeExampleNonZeroFirstSite: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + return add_unique_sample_mutations( + tskit.Tree.generate_balanced(4, span=8).tree_sequence, start=1 + ) + + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): + ts = self.ts() + h = np.zeros(4) + h[j] = 1 + m = run_match(ts, h) + assert list(m.path.left) == [0] + assert list(m.path.right) == [ts.sequence_length] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + + def test_switch_each_sample(self): + ts = self.ts() + h = np.ones(4) + m = run_match(ts, h) + assert list(m.path.left) == [0, 3, 5, 7] + assert list(m.path.right) == [3, 5, 7, 8] + assert list(m.path.parent) == [0, 1, 2, 3] + np.testing.assert_array_equal(h, m.matched_haplotype) + + +class TestZeroSites: + @pytest.mark.parametrize("L", [1, 2, 5]) + def test_one_node_ts(self, L): + tables = tskit.TableCollection(L) + tables.nodes.add_row(time=1) + m = run_match(tables.tree_sequence(), []) + assert list(m.path.left) == [0] + assert list(m.path.right) == [L] + assert list(m.path.parent) == [0] + + +class TestMultiTreeExample: + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 2 4 + @staticmethod + def ts(): + nodes = """\ + is_sample time + 1 0.000000 + 1 0.000000 + 1 0.000000 + 1 0.000000 + 0 0.041304 + 0 0.045967 + 0 0.416719 + 0 0.838075 + """ + edges = """\ + left right parent child + 0.000000 4.000000 4 1 + 0.000000 4.000000 4 2 + 0.000000 2.000000 5 0 + 0.000000 2.000000 5 4 + 2.000000 4.000000 6 0 + 2.000000 4.000000 6 3 + 0.000000 2.000000 7 3 + 2.000000 4.000000 7 4 + 0.000000 2.000000 7 5 + 2.000000 4.000000 7 6 + """ + ts = tskit.load_text( + nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False + ) + return add_unique_sample_mutations(ts) + + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): + ts = self.ts() + h = np.zeros(4) + h[j] = 1 + m = run_match(self.ts(), h) + assert list(m.path.left) == [0] + assert list(m.path.right) == [4] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + + def test_switch_each_sample(self): + ts = self.ts() + h = np.ones(4) + m = run_match(ts, h) + assert list(m.path.left) == [0, 1, 2, 3] + assert list(m.path.right) == [1, 2, 3, 4] + assert list(m.path.parent) == [0, 1, 2, 3] + np.testing.assert_array_equal(h, m.matched_haplotype) + + def test_switch_each_sample_missing_flanks(self): + ts = self.ts() + h = np.ones(4) + h[0] = -1 + h[-1] = -1 + m = run_match(ts, h) + assert list(m.path.left) == [1, 2] + assert list(m.path.right) == [2, 3] + assert list(m.path.parent) == [1, 2] + np.testing.assert_array_equal(h, m.matched_haplotype) + + +class TestSimulationExamples: + def check_exact_sample_matches(self, ts): + H = ts.genotype_matrix().T + for u, h in zip(ts.samples(), H): + m = run_match(ts, h) + np.testing.assert_array_equal(h, m.matched_haplotype) + assert list(m.path.left) == [0] + assert list(m.path.right) == [ts.sequence_length] + assert list(m.path.parent) == [u] + + def check_switch_all_samples(self, ts): + h = np.ones(ts.num_sites, dtype=np.int8) + m = run_match(ts, h) + X = np.append(ts.sites_position, [ts.sequence_length]) + np.testing.assert_array_equal(h, m.matched_haplotype) + np.testing.assert_array_equal(m.path.left, X[:-1]) + np.testing.assert_array_equal(m.path.right, X[1:]) + np.testing.assert_array_equal(m.path.parent, ts.samples()) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_single_tree_exact_match(self, n): + ts = msprime.sim_ancestry(n, sequence_length=100, random_seed=2) + ts = add_unique_sample_mutations(ts) + self.check_exact_sample_matches(ts) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_multiple_trees_exact_match(self, n): + ts = msprime.sim_ancestry( + n, sequence_length=20, recombination_rate=0.1, random_seed=2234 + ) + assert ts.num_trees > 1 + ts = add_unique_sample_mutations(ts) + self.check_exact_sample_matches(ts) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_single_tree_switch_all_samples(self, n): + ts = msprime.sim_ancestry(n, sequence_length=100, random_seed=2345) + ts = add_unique_sample_mutations(ts) + self.check_switch_all_samples(ts) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_multiple_trees_switch_all_sample(self, n): + ts = msprime.sim_ancestry( + n, sequence_length=20, recombination_rate=0.1, random_seed=12234 + ) + assert ts.num_trees > 1 + ts = add_unique_sample_mutations(ts) + self.check_switch_all_samples(ts) diff --git a/tmp.py b/tmp.py new file mode 100644 index 00000000..14e4c992 --- /dev/null +++ b/tmp.py @@ -0,0 +1,142 @@ +import itertools + +import msprime +import numpy as np +import tskit + +import tsinfer + + +class Sequence: + def __init__(self, haplotype): + self.full_haplotype = haplotype + + +def run_matches(ts, positions, sequences): + match_indexes = tsinfer.MatcherIndexes(ts) + recombination = np.zeros(ts.num_sites) + 1e-9 + mismatch = np.zeros(ts.num_sites) + matcher = tsinfer.AncestorMatcher2( + match_indexes, recombination=recombination, mismatch=mismatch + ) + sites_index = np.searchsorted(positions, ts.sites_position) + assert np.all(positions[sites_index] == ts.sites_position) + sites_in_ts = np.zeros(len(positions), dtype=bool) + sites_in_ts[sites_index] = True + results = [] + for seq in sequences: + m = matcher.find_match(seq.full_haplotype[sites_in_ts]) + h = seq.full_haplotype.copy() + h[sites_in_ts] = 0 + focal_sites = np.where(h != 0)[0] + results.append((m, focal_sites)) + return results + + +def insert_matches(tables, time, all_positions, matches): + ts_sites_position = tables.sites.position + added_sites = {} + for m, new_sites in matches: + u = tables.nodes.add_row(time=time, flags=0) + for left, right, parent in m.path: + tables.edges.add_row(left, right, parent, u) + for site_index in new_sites: + if site_index not in added_sites: + s = tables.sites.add_row(all_positions[site_index], "0") + added_sites[site_index] = s + tables.mutations.add_row( + site=added_sites[site_index], node=u, derived_state="1" + ) + # print(tables) + # TODO check the matched haplotype for any mutations too. + # print(tables) + tables.sort() + ts = tables.tree_sequence() + return ts + + +def match_ancestors(ancestor_data): + tables = tskit.TableCollection(ancestor_data.sequence_length) + + all_positions = ancestor_data.sites_position[:] + + ancestors = ancestor_data.ancestors() + # Discard the "ultimate-ultimate ancestor" + next(ancestors) + ultimate_ancestor = next(ancestors) + assert np.all(ultimate_ancestor.full_haplotype == 0) + tables.nodes.add_row(time=ultimate_ancestor.time + 1) + tables.nodes.add_row(time=ultimate_ancestor.time) + tables.edges.add_row(0, tables.sequence_length, 0, 1) + ts = tables.tree_sequence() + + # TODO We don't want to use the focal sites, so we need to keep track + # of when each site gets new variation, or at least keep track of all + # the sites that are entirely ancestral, so we only add sites into the + # ts as we see variation at them. + + for time, group in itertools.groupby(ancestors, key=lambda a: a.time): + # print("EPOCH", time) + group = list(group) + matches = run_matches(ts, all_positions, group) + ts = insert_matches(tables, time, all_positions, matches) + # print(ts.draw_text()) + return ts + + +def match_samples(ts, sample_data): + all_positions = sample_data.sites_position[:] + sequences = [Sequence(h) for _, h in sample_data.haplotypes()] + matches = run_matches(ts, all_positions, sequences) + ts = insert_matches(ts.dump_tables(), 0, all_positions, matches) + tables = ts.dump_tables() + # We can have sites that are monomorphic for the ancestral state. + missing_sites = set(all_positions) - set(ts.sites_position) + for pos in missing_sites: + tables.sites.add_row(pos, ancestral_state="0") + tables.sort() + flags = tables.nodes.flags + flags[-len(sequences) :] = 1 + tables.nodes.flags = flags + print(tables) + return tables.tree_sequence() + + +if __name__ == "__main__": + for seed in range(1, 100): + ts = msprime.sim_ancestry( + 15, + population_size=1e4, # recombination_rate=1e-10, + sequence_length=1_000_000, + random_seed=seed, + ) + print(seed) + ts_orig = msprime.sim_mutations( + ts, rate=1e-8, random_seed=seed, model=msprime.BinaryMutationModel() + ) + print(ts_orig) + print(ts_orig.num_sites, ts_orig.num_mutations) + # assert ts_orig.num_sites == ts_orig.num_mutations + + # with tsinfer.SampleData(sequence_length=7, path="tmp.samples") as sample_data: + # for _ in range(5): + # sample_data.add_individual(time=0, ploidy=1) + # sample_data.add_site(0, [0, 1, 0, 0, 0], ["A", "T"]) + # sample_data.add_site(1, [0, 0, 0, 1, 1], ["G", "C"]) + # sample_data.add_site(2, [0, 1, 1, 0, 0], ["C", "A"]) + # sample_data.add_site(3, [0, 1, 1, 0, 0], ["G", "C"]) + # sample_data.add_site(4, [0, 0, 0, 1, 1], ["A", "C"]) + # sample_data.add_site(5, [0, 1, 0, 0, 0], ["T", "G"]) + # sample_data.add_site(6, [1, 1, 1, 1, 0], ["T", "G"]) + sample_data = tsinfer.SampleData.from_tree_sequence(ts_orig) + # print(sample_data) + + ad = tsinfer.generate_ancestors(sample_data) + + ts = match_ancestors(ad) + # print(sample_data) + ts = match_samples(ts, sample_data) + print(ts.draw_text()) + # print(ts.genotype_matrix()) + # print(ts_orig.genotype_matrix()) + np.testing.assert_array_equal(ts.genotype_matrix(), ts_orig.genotype_matrix()) diff --git a/tsinfer/__init__.py b/tsinfer/__init__.py index aa9d27ba..e96f36c1 100644 --- a/tsinfer/__init__.py +++ b/tsinfer/__init__.py @@ -39,4 +39,5 @@ from .eval_util import * # NOQA from .exceptions import * # NOQA from .constants import * # NOQA +from .matching import MatcherIndexes, AncestorMatcher2 # NOQA from .cli import get_cli_parser # NOQA diff --git a/tsinfer/matching.py b/tsinfer/matching.py new file mode 100644 index 00000000..0fd6c5aa --- /dev/null +++ b/tsinfer/matching.py @@ -0,0 +1,154 @@ +# +# Copyright (C) 2023 University of Oxford +# +# This file is part of tsinfer. +# +# tsinfer is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tsinfer is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with tsinfer. If not, see . +# +import dataclasses + +import numpy as np +import tskit + +import _tsinfer + + +def add_vestigial_root(ts): + """ + Adds the nodes and edges required by tsinfer to the specified tree sequence + and returns it. + """ + if not ts.discrete_genome: + raise ValueError("Only discrete genome coords supported") + if ts.num_nodes == 0: + raise ValueError("Emtpy trees not supported") + + base_tables = ts.dump_tables() + tables = base_tables.copy() + tables.nodes.clear() + t = max(ts.nodes_time) + tables.nodes.add_row(time=t + 1) + num_additonal_nodes = 1 + tables.mutations.node += num_additonal_nodes + tables.edges.child += num_additonal_nodes + tables.edges.parent += num_additonal_nodes + for node in base_tables.nodes: + tables.nodes.append(node) + if ts.num_edges > 0: + for tree in ts.trees(): + # if tree.num_roots > 1: + # print(ts.draw_text()) + root = tree.root + num_additonal_nodes + tables.edges.add_row( + tree.interval.left, tree.interval.right, parent=0, child=root + ) + tables.edges.squash() + # FIXME probably don't need to sort here most of the time, or at least we + # can just sort almost the end of the table. + tables.sort() + return tables.tree_sequence() + + +class MatcherIndexes(_tsinfer.MatcherIndexes): + def __init__(self, ts): + # TODO make this polymorphic to accept tables as well + # This is very wasteful, but we can do better if it all basically works. + print("FIXME!") + # This is turning out to be a bit problematic for actual tsinfer'd trees + # because we have to mark things as samples to define the roots, but then + # we get multiple roots incorrectly when we mark everything as a sample. + # It's not clear that doing this is helpful for tsinfer generated trees, + # but then when we turn it off the current generator script results in + # C-level assertion trips. Hmm. + ts = add_vestigial_root(ts) + # print(ts.draw_text()) + tables = ts.dump_tables() + # print(tables) + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + # TODO should really just reflect these from the low-level C values. + self.sequence_length = ts.sequence_length + self.num_sites = ts.num_sites + super().__init__(ll_tables) + + +@dataclasses.dataclass +class Path: + left: np.ndarray + right: np.ndarray + parent: np.ndarray + + def __iter__(self): + yield from zip(self.left, self.right, self.parent) + + def __len__(self): + return len(self.left) + + def assert_equals(self, other): + np.testing.assert_array_equal(self.left, other.left) + np.testing.assert_array_equal(self.right, other.right) + np.testing.assert_array_equal(self.parent, other.parent) + + +@dataclasses.dataclass +class Match: + path: Path + query_haplotype: np.ndarray + matched_haplotype: np.ndarray + + def assert_equals(self, other): + self.path.assert_equals(other.path) + np.testing.assert_array_equal(self.matched_haplotype, other.matched_haplotype) + np.testing.assert_array_equal(self.query_haplotype, other.query_haplotype) + + +class AncestorMatcher2(_tsinfer.AncestorMatcher2): + def __init__(self, matcher_indexes, **kwargs): + super().__init__(matcher_indexes, **kwargs) + self.sequence_length = matcher_indexes.sequence_length + self.num_sites = matcher_indexes.num_sites + + def zero_sites_path(self): + left = np.array([0], dtype=np.uint32) + right = np.array([self.sequence_length], dtype=np.uint32) + parent = np.array([0], dtype=np.uint32) + return Match(Path(left, right, parent), [], []) + + def find_match(self, h): + if self.num_sites == 0: + return self.zero_sites_path() + + # TODO compute these in C - taking a shortcut for now. + m = len(h) + + start = 0 + while start < m and h[start] == tskit.MISSING_DATA: + start += 1 + # if start == m: + # raise ValueError("All missing data") + end = m - 1 + while end >= 0 and h[end] == tskit.MISSING_DATA: + end -= 1 + end += 1 + + path_len, left, right, parent, matched_haplotype = self.find_path(h, start, end) + left = left[:path_len][::-1] + right = right[:path_len][::-1] + parent = parent[:path_len][::-1] + # We added a 0-root everywhere above, so convert node IDs back + parent -= 1 + # FIXME C code isn't setting match to missing as expected + matched_haplotype[:start] = tskit.MISSING_DATA + matched_haplotype[end:] = tskit.MISSING_DATA + return Match(Path(left, right, parent), h, matched_haplotype)