diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..c262fbee4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,22 @@ +--- +name: Bug report +about: Create a report to help us improve + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**Command** +Command line to run octopus: +```shell +$ octopus +``` + +**Desktop (please complete the following information):** + - OS: [e.g. OSX High Sierra] + - Version [e.g. v0.3.3-alpha] + - Reference [e.g. hg19] + +**Additional context** +Add any other context about the problem here. diff --git a/.gitignore b/.gitignore index b052d9fd4..03d515639 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ CTestTestfile.cmake build/src build/test +resources/forests ## Core latex/pdflatex auxiliary files: *.aux diff --git a/.travis.yml b/.travis.yml index 1fd1482c0..15f6e41dc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -177,7 +177,7 @@ install: - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then sudo apt-get install python3 -qy; else - brew install python3; + brew upgrade python; fi ############################################################################ @@ -187,7 +187,7 @@ install: git clone https://github.com/samtools/htslib.git; cd htslib && autoheader && autoconf && ./configure && make && sudo make install; else - brew tap homebrew/science && brew install htslib; + brew install htslib; fi before_script: @@ -197,7 +197,7 @@ before_script: - echo "BOOST_ROOT = " ${BOOST_ROOT}; script: - - ./install.py --cxx_compiler=${COMPILER} + - ./scripts/install.py --cxx_compiler=${COMPILER} notifications: email: false \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index f85700586..1069a5360 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 3.9) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/build/cmake/modules/") + include(CheckIPOSupported) project(octopus) @@ -23,6 +25,17 @@ else() message(WARNING "You are using an unsupported compiler! Compilation has only been tested with Clang and GCC.") endif() +set(default_build_type "Release") + +if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + message(STATUS "Setting build type to '${default_build_type}' as none was specified.") + set(CMAKE_BUILD_TYPE "${default_build_type}" CACHE + STRING "Choose the type of build." FORCE) + # Set the possible values of build type for cmake-gui + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS + "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + message("-- Build type is " ${CMAKE_BUILD_TYPE}) # for the main octopus executable diff --git a/Dockerfile b/Dockerfile index 9a490b36a..de3d41b1f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,5 +61,6 @@ WORKDIR /tmp RUN git clone -b master https://github.com/luntergroup/octopus.git WORKDIR /tmp/octopus RUN ./install.py --root --threads=2 +RUN ldconfig WORKDIR /home diff --git a/README.md b/README.md index 0bdfb2c38..66f5982ca 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![Build Status](https://travis-ci.org/luntergroup/octopus.svg?branch=master)](https://travis-ci.org/luntergroup/octopus) [![MIT license](http://img.shields.io/badge/license-MIT-brightgreen.svg)](http://opensource.org/licenses/MIT) [![Gitter](https://badges.gitter.im/octopus-caller/Lobby.svg)](https://gitter.im/octopus-caller/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +[![Anaconda-Server Badge](https://anaconda.org/bioconda/octopus/badges/installer/conda.svg)](https://conda.anaconda.org/bioconda) --- @@ -10,7 +11,17 @@ --- -Octopus is a mapping-based variant caller that implements several calling models within a unified haplotype-aware framework. Octopus explicitly stores allele phasing infomation which allows haplotypes to be dynamically excluded and extended. Primarily this means octopus can jointly consider allele sets far exceeding the cardinality of other approaches, but perhaps more importantly, it allows *marginalisation* over posterior distributions in haplotype space at specific loci. In practise this means octopus can achieve far greater allelic genotyping accuracy than other methods, but can also infer conditional or unconditional phase probabilities directly from genotype probability distributions. This allows octopus to report consistent allele event level variant calls *and* independent phase information. +Octopus is a mapping-based variant caller that implements several calling models within a unified haplotype-aware framework. Octopus takes inspiration from particle filtering by constructing a tree of haplotypes and dynamically pruning and extending the tree based on haplotype posterior probabilities in a sequential manner. This allows octopus to implicitly consider all possible haplotypes at a given loci in reasonable time. + +There are currently five calling models implemented: + +- **individual**: call germline variants in a single healthy individual. +- **population**: jointly call germline variants in small cohorts. +- **cancer**: call germline and somatic mutations tumour samples. +- **trio**: call germline and _de novo_ mutations in a parent-offspring trio. +- **polyclone**: call variants in samples with an unknown mixture of haploid clones, such a bacteria or viral samples. + +Octopus is currently able to call SNVs, small-medium sized indels, small complex rearrangements, and micro-inversions. ## Requirements * A C++14 compiler with SSE2 support @@ -22,7 +33,7 @@ Octopus is a mapping-based variant caller that implements several calling models * Optional: * Python3 or greater -#### *Obtaining requirements on OS X* +#### Obtaining requirements on OS X On OS X, Clang is recommended. All requirements can be installed using the package manager [Homebrew](http://brew.sh/index.html): @@ -39,7 +50,7 @@ $ brew install python3 Note if you already have any of these packages installed via Homebrew on your system the command will fail, but you can update to the latest version using `brew upgrade`. -#### *Obtaining requirements on Ubuntu* +#### Obtaining requirements on Ubuntu Depending on your Ubuntu distribution, some requirements can be installed with `apt-get`. It may be preferable to use GCC as this will simplify installing Boost: @@ -61,67 +72,70 @@ These instructions are replicated in the [user documentation](https://github.com ## Installation -Octopus can be built and installed on a wide range of operating systems including most Unix based systems (Linux, OS X) and Windows (once MSVC is C++14 feature complete). +Octopus can be built and installed on most Unix based systems (Linux, OS X). Windows has not been tested, but should be compatible. -#### *Quick installation with Python3* +#### Conda package -Installing octopus first requires obtaining a copy the source code. In the command line, move to an appropriate install directory and execute: +Octopus is available [pre-built for Linux](https://anaconda.org/bioconda/octopus) as part of [Bioconda](https://bioconda.github.io/). To [install in an isolated environment](https://bioconda.github.io/#using-bioconda): -```shell -$ git clone https://github.com/luntergroup/octopus.git && cd octopus -``` + wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh + bash Miniconda3-latest-Linux-x86_64.sh -b -p venv + venv/bin/conda install -c conda-forge -c bioconda octopus + venv/bin/octopus -h + +A package will also be available for OSX once conda-forge and bioconda move to newer versions of gcc and boost. -The default branch is develop, which is not always stable. You may prefer to switch to the master branch which always has the latest release: +#### Quick installation with Python3 + +First clone the git repository in your preferred directory: ```shell -$ git checkout master +$ git clone -b master https://github.com/luntergroup/octopus.git && cd octopus ``` -Installation is easy using the Python3 install script. If your default compiler satisfies the minimum requirements just execute: +The easiest way to install octopus from source is with the Python3 install script. If your default compiler satisfies the minimum requirements just execute: ```shell -$ ./install.py +$ ./scripts/install.py ``` otherwise explicitly specify the compiler to use: ```shell -$ ./install.py --cxx_compiler /path/to/cpp/compiler # or just the compiler name if on your PATH +$ ./scripts/install.py --cxx_compiler /path/to/cpp/compiler # or just the compiler name if on your PATH ``` For example, if the requirement instructions above were used: ```shell -$ ./install.py --cxx_compiler clang++-4.0 +$ ./scripts/install.py --cxx_compiler clang++-4.0 ``` On some systems, you may also need to specify a C compiler which is the same version as your C++ compiler, otherwise you'll get lots of link errors. This can be done with the `--c_compiler` option: ```shell -$ ./install.py --cxx_compiler g++-7 --c_compiler gcc-7 +$ ./scripts/install.py -cxx g++-7 -c gcc-7 ``` By default this installs to `/bin` relative to where you installed octopus. To install to a root directory (e.g. `/usr/local/bin`) use: ```shell -$ ./install.py --root +$ ./scripts/install.py --root ``` -this may prompt you to enter a `sudo` password. - -If anything goes wrong with the build process and you need to start again, be sure to add the command `--clean`! +If anything goes wrong with the build process and you need to start again, be sure to add the command `--clean`. -#### *Installing with CMake* +#### Installing with CMake If Python3 isn't available, the binaries can be installed manually with [CMake](https://cmake.org): ```shell -$ git clone https://github.com/luntergroup/octopus.git +$ git clone -b master https://github.com/luntergroup/octopus.git $ cd octopus/build $ cmake .. && make install ``` -By default this installs to the `/bin` directory where octopus was installed. To install to root (e.g. `/usr/local/bin`) use the `-D` option: +To install to root (e.g. `/usr/local/bin`) use the `-D` option: ```shell $ cmake -DINSTALL_ROOT=ON .. @@ -136,7 +150,7 @@ $ cmake -D CMAKE_C_COMPILER=clang-4.0 -D CMAKE_CXX_COMPILER=clang++-4.0 .. You can check installation was successful by executing the command: ```shell -$ octopus --help +$ octopus -h ``` ## Running Tests @@ -149,11 +163,11 @@ $ test/install.py ## Examples -Here are some common use-cases to get started. These examples are by no means exhaustive, please consult the documentation for explanations of all options, algorithms, and further examples. For a more in depth example, refer to the [whole genome germline calling case study](https://github.com/luntergroup/octopus/blob/master/doc/octopus_wgs_case_study.md). +Here are some common use-cases to get started. These examples are by no means exhaustive, please consult the documentation for explanations of all options, algorithms, and further examples. For more in depth examples, refer to the [case studies](https://github.com/luntergroup/octopus/wiki/Case-studies). Note by default octopus will output all calls in VCF format to standard output, in order to write calls to a file (`.vcf`, `.vcf.gz`, and `.bcf` are supported), use the command line option `--output` (`-o`). -#### *Calling germline variants in an individual* +#### Calling germline variants in an individual This is the simplest case, if the file `NA12878.bam` contains a single sample, octopus will default to its individual calling model: @@ -173,21 +187,23 @@ By default, octopus automatically detects and calls all samples contained in the $ octopus -R hs37d5.fa -I multi-sample.bam -S NA12878 ``` -#### *Targeted calling* +#### Targeted calling -By default, octopus will call all possible regions (as specified in the reference FASTA). In order to select a set of target regions, use the `--regions` (`-T`) option: +By default, octopus will call all regions specified in the reference index. In order to restrict calling to a subset of regions, either provide a list of zero-indexed regions in the format `chr:start-end` (`--regions`; `-T`), or a file containing a list of regions in either standard format or BED format (`--regions-file`; `-t`): ```shell $ octopus -R hs37d5.fa -I NA12878.bam -T 1 2:30,000,000- 3:10,000,000-20,000,000 +$ octopus -R hs37d5.fa -I NA12878.bam -t regions.bed ``` -Or conversely a set of regions to *exclude* can be given with `--skip-regions` (`-K`): +Conversely a set of regions to *exclude* can be given explictely (`--skip-regions`;`-K`), or with a file (`--skip-regions-file`; `-k`): ```shell $ octopus -R hs37d5.fa -I NA12878.bam -K 1 2:30,000,000- 3:10,000,000-20,000,000 +$ octopus -R hs37d5.fa -I NA12878.bam -k skip-regions.bed ``` -#### *Calling de novo mutations in a trio* +#### Calling de novo mutations in a trio To call germline and de novo mutations in a trio, either specify both maternal (`--maternal-sample`; `-M`) and paternal (`--paternal-sample`; `-F`) samples: @@ -195,13 +211,13 @@ To call germline and de novo mutations in a trio, either specify both maternal ( $ octopus -R hs37d5.fa -I NA12878.bam NA12891.bam NA12892.bam -M NA12892 -F NA12891 ``` -The trio can also be specified with a PED file: +or provide a PED file which defines the trio: ```shell $ octopus -R hs37d5.fa -I NA12878.bam NA12891.bam NA12892.bam --pedigree ceu_trio.ped ``` -#### *Calling somatic mutations in tumours* +#### Calling somatic mutations in tumours To call germline and somatic mutations in a paired tumour-normal sample, just specify which sample is the normal (`--normal-sample`; `-N`): @@ -212,7 +228,7 @@ $ octopus -R hs37d5.fa -I normal.bam tumour.bam --normal-sample NORMAL It is also possible to genotype multiple tumours from the same individual jointly: ```shell -$ octopus -R hs37d5.fa -I normal.bam tumourA.bam tumourB --normal-sample NORMAL +$ octopus -R hs37d5.fa -I normal.bam tumourA.bam tumourB.bam --normal-sample NORMAL ``` If a normal sample is not present the cancer calling model must be invoked explicitly: @@ -221,9 +237,9 @@ If a normal sample is not present the cancer calling model must be invoked expli $ octopus -R hs37d5.fa -I tumour1.bam tumour2.bam -C cancer ``` -Note however, that without a normal sample, somatic mutation classification power is significantly reduced. +Be aware that without a normal sample, somatic mutation classification power is significantly reduced. -#### *Joint variant calling (in development)* +#### Joint variant calling (experimental) Multiple samples from the same population, without pedigree information, can be called jointly: @@ -233,17 +249,27 @@ $ octopus -R hs37d5.fa -I NA12878.bam NA12891.bam NA12892.bam Joint calling samples may increase calling power, especially for low coverage sequencing. -#### *HLA genotyping* +#### Calling variants in mixed haploid samples (experimental) + +If your sample contains an unknown mix of haploid clones (e.g. some bacteria or viral samples), use the `polyclone` calling model: + +```shell +$ octopus -R H37Rv.fa -I mycobacterium_tuberculosis.bam -C polyclone +``` + +This model will automatically detect the number of subclones in your sample (up to the maximum given by `--max-clones`). + +#### HLA genotyping To call phased HLA genotypes, increase the default phase level: ```shell -$ octopus -R human.fa -I NA12878.bam -t hla-regions.txt -l aggressive +$ octopus -R hs37d5.fa -I NA12878.bam -t hla-regions.bed -l aggressive ``` -#### *Multithreaded calling* +#### Multithreaded calling -Octopus has built in multithreading capacbilities, just add the `--threads` command: +Octopus has built in multithreading capabilities, just add the `--threads` command: ```shell $ octopus -R hs37d5.fa -I NA12878.bam --threads @@ -252,12 +278,12 @@ $ octopus -R hs37d5.fa -I NA12878.bam --threads This will let octopus automatically decide how many threads to use, and is the recommended approach as octopus can dynamically juggle thread usage at an algorithm level. However, a strict upper limit on the number of threads can also be used: ```shell -$ octopus -R hs37d5.fa -I NA12878.bam --threads=4 +$ octopus -R hs37d5.fa -I NA12878.bam --threads 4 ``` -#### *Fast calling* +#### Fast calling -By default, octopus is geared towards more accurate variant calling which requires the use of complex (slow) algorithms. However, to acheive faster runtimes (at the cost of decreased calling accuray) many of these features can be disabled. There are two helper commands that setup octopus for faster variant calling, `--fast` and `--very-fast`, e.g.: +By default, octopus is geared towards more accurate variant calling which requires the use of complex (slow) algorithms. However, to achieve faster runtimes (at the cost of decreased calling accuracy) many of these features can be disabled. There are two helper commands that setup octopus for faster variant calling, `--fast` and `--very-fast`, e.g.: ```shell $ octopus -R hs37d5.fa -I NA12878.bam --fast @@ -265,6 +291,22 @@ $ octopus -R hs37d5.fa -I NA12878.bam --fast Note this does not turn on multithreading or increase buffer sizes. +#### Making evidence BAMs + +Octopus can generate 'evidence' BAMs for single sample calling. To generate a single BAM file containing realigned reads supporting called variants use the `--bamout` option: + +```shell +$ octopus -R hs37d5.fa -I NA12878.bam -o octopus.vcf --bamout octopus.bam +``` + +To generate split BAM files (one for each called haplotype) use the `--bamout` option, but specify only the file prefix: + +```shell +$ octopus -R hs37d5.fa -I NA12878.bam -o octopus.vcf --bamout octopus +``` + +Octopus will generate BAM files (`octopus1.bam`, `octopus2.bam`, ...) for the number of haplotypes in the sample. Note that although each split BAM is haploid, the variants in each are only phased according to the phase sets called in the output VCF. + ## Output format Octopus outputs variants using a simple but rich VCF format (see [user documentation](https://github.com/luntergroup/octopus/blob/develop/doc/manuals/user/octopus-user-manual.pdf) for full details). For example, two overlapping deletions are represented like: diff --git a/doc/manuals/user/octopus-user-manual.pdf b/doc/manuals/user/octopus-user-manual.pdf index afc066d64..9999b9cd0 100644 Binary files a/doc/manuals/user/octopus-user-manual.pdf and b/doc/manuals/user/octopus-user-manual.pdf differ diff --git a/doc/octopus_wgs_case_study.md b/doc/octopus_wgs_case_study.md deleted file mode 100644 index 02ebd7b78..000000000 --- a/doc/octopus_wgs_case_study.md +++ /dev/null @@ -1,94 +0,0 @@ -# Whole genome germline case study - -Here we will work through a real whole genome calling case study, from FASTQ to VCF evaluation. In addition to octopus, we make use of the following software tools: - -* [samtools](http://samtools.sourceforge.net) (version 1.6) -* [BWA](http://bio-bwa.sourceforge.net) (versio 0.7.17-r1188) -* [RTG Tools](https://www.realtimegenomics.com/products/rtg-tools) (version 3.8.4) - -## Download data files - -First download raw reads from the Illumina platinum genomes project for individual NA12878: - -``` -mkdir ~/data/fastq && cd ~/data/fastq -wget https://storage.googleapis.com/genomics-public-data/platinum-genomes/fastq/ERR194147_1.fastq.gz -wget https://storage.googleapis.com/genomics-public-data/platinum-genomes/fastq/ERR194147_2.fastq.gz -``` - -Next download a copy of the human reference sequence. In this example we use GRCh37 plus a decoy contig (recommended). If you prefer to use GRCh38, be sure to get a copy without alternative contigs or patches (but with a decoy contig), such as the one available [here](ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/000/001/405/GCA_000001405.15_GRCh38/seqs_for_alignment_pipelines.ucsc_ids/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.gz). - -``` -mkdir ~/data/reference && cd ~/data/reference -wget ftp://ftp-trace.ncbi.nih.gov/1000genomes/ftp/technical/reference/phase2_reference_assembly_sequence/hs37d5.fa.gz && gzip -d hs37d5.fa.gz -``` - -To evaluate our calls we need a truth set. We use the Genome In a Bottle (GIAB) version 3.3.2 high confidence calls for NA12878 (HG001): - -``` -mkdir ~/data/vcf/giab && cd ~/data/vcf/giab -wget ftp://ftp-trace.ncbi.nlm.nih.gov//giab/ftp/release/NA12878_HG001/NISTv3.3.2/GRCh37 -``` - -## Map reads to reference genome - -First we need to index the reference sequence: - -``` -cd ~/data/reference -samtools faidx hs37d5.fa -bwa index hs37d5.fa -``` - -Then map our reads to the reference: - -``` -mkdir ~/data/bam -bwa mem -t 15 -R "@RG\tID:NA12878\tSM:NA12878\tLB:platinum\tPU:illumina" \ - ~/data/reference/hs37d5.fa \ - ~/data/fastq/ERR194147_1.fastq.gz ~/data/fastq/ERR194147_2.fastq.gz \ - | samtools view -bh > ~/data/bam/NA12878.platinum.b37.unsorted.bam -samtools sort -@ 15 -o ~/data/bam/NA12878.platinum.b37.bam ~/data/bam/NA12878.platinum.b37.unsorted.bam -samtools index ~/data/bam/NA12878.platinum.b37.bam -rm ~/data/bam/NA12878.platinum.b37.unsorted.bam -``` - -## Call variants - -We do not recommend pre-processing the raw BWA alignments (e.g. duplicate marking, or base quality score recalibration) as we do not find this provides consistent improvements in accuracy, and tends to slow down calling as pre-processed reads files are often considerably larger than the originals. As this is human data, the default arguments for octopus should work well. We restrict calling to the autosomes plus X as these are the only contigs present in the validation sets. We also request a 'legacy' VCF file to use for benchmarking (see section on octopus's default VCF format). - -``` -octopus -R ~/data/reference/hs37d5.fa \ - -I ~/data/bam/NA12878.platinum.b37.bam \ - -o ~/data/vcf/NA12878.platinum.b37.octopus.vcf.gz \ - -T 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 X \ - --threads 15 --legacy -``` - -## Evaluate variant calls - -Finally, we will evaluate our calls with RTG Tools `vcfeval`. This command requires the reference sequence to be preprocessed: - -``` -~/tools/rtgtools/rtg format -o ~/data/reference/hs37d5_sdf ~/data/reference/hs37d5.fa -``` - -Then run vcfeval: - -``` -rtg vcfeval -t ~/data/reference/hs37d5_sdf \ - -b ~/data/vcf/giab/HG001_GRCh37_truth.vcf.gz \ - --evaluation-regions ~/data/vcf/giab/HG001_GRCh37_hiconf.bed \ - -c ~/data/vcf/NA12878.platinum.b37.octopus.legacy.vcf.gz \ - -o ~/benchmarks/NA12878.platinum.b37.octopus.eval \ - --ref-overlap -f QUAL -``` - -We see the following results: - -``` -Threshold True-pos-baseline True-pos-call False-pos False-neg Precision Sensitivity F-measure ----------------------------------------------------------------------------------------------------- - 19.470 3678469 3699288 7611 12392 0.9979 0.9966 0.9973 - None 3679800 3700790 9280 11061 0.9975 0.9970 0.9973 -``` \ No newline at end of file diff --git a/install.py b/install.py deleted file mode 100755 index c9eebc0c8..000000000 --- a/install.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -from subprocess import call -import platform -import argparse -from shutil import move, rmtree -import multiprocessing - -def is_unix(): - system = platform.system() - return system == "Darwin" or system == "Linux" - -parser = argparse.ArgumentParser() -parser.add_argument('--clean', help='Do a clean install', action='store_true') -parser.add_argument('--root', help='Install into /usr/local/bin', action='store_true') -parser.add_argument('-c', '--c_compiler', help='C compiler path to use') -parser.add_argument('-cxx', '--cxx_compiler', help='C++ compiler path to use') -parser.add_argument('--keep_cache', help='Do not refresh CMake cache', action='store_true') -parser.add_argument('--debug', help='Builds in debug mode', action='store_true') -parser.add_argument('--sanitize', help='Builds in release mode with sanitize flags', action='store_true') -parser.add_argument('--static', help='Builds using static libraries', action='store_true') -parser.add_argument('--threads', help='The number of threads to use for building', type=int) -parser.add_argument('--boost', help='The Boost library root') -parser.add_argument('--verbose', help='Ouput verbose make information', action='store_true') -args = vars(parser.parse_args()) - -octopus_dir = os.path.dirname(os.path.realpath(__file__)) -root_cmake = octopus_dir + "/CMakeLists.txt" - -if not os.path.exists(root_cmake): - print("octopus source directory corrupted: root CMakeLists.txt is missing. Please re-download source code.") - exit(1) - -octopus_build_dir = octopus_dir + "/build" - -if not os.path.exists(octopus_build_dir): - print("octopus source directory corrupted: build directory is missing. Please re-download source code.") - exit(1) - -bin_dir = octopus_dir + "/bin" - -if not os.path.exists(bin_dir): - print("No bin directory found, making one") - os.makedirs(bin_dir) - -if args["clean"]: - print("Cleaning build directory") - move(octopus_build_dir + "/cmake", octopus_dir + "/cmake") - rmtree(octopus_build_dir) - os.makedirs(octopus_build_dir) - move(octopus_dir + "/cmake", octopus_build_dir + "/cmake") - -cmake_cache_file = "CMakeCache.txt" -os.chdir(octopus_build_dir) # so cmake doesn't pollute root directory - -if not args["keep_cache"] and os.path.exists(cmake_cache_file): - os.remove(cmake_cache_file) - -cmake_options = [] -if args["root"]: - cmake_options.extend(["-DINSTALL_ROOT=ON", octopus_dir]) -if args["c_compiler"]: - cmake_options.append("-DCMAKE_C_COMPILER=" + args["c_compiler"]) -if args["cxx_compiler"]: - cmake_options.append("-DCMAKE_CXX_COMPILER=" + args["cxx_compiler"]) -if args["debug"]: - cmake_options.append("-DCMAKE_BUILD_TYPE=Debug") -elif args["sanitize"]: - cmake_options.append("-DCMAKE_BUILD_TYPE=RelWithDebInfo") -else: - cmake_options.append("-DCMAKE_BUILD_TYPE=Release") -if args["static"]: - cmake_options.append("-DBUILD_SHARED_LIBS=OFF") -if args["boost"]: - cmake_options.append("-DBOOST_ROOT=" + args["boost"]) -if args["verbose"]: - cmake_options.append("CMAKE_VERBOSE_MAKEFILE:BOOL=ON") - -ret = call(["cmake"] + cmake_options + [".."]) - -if ret == 0: - make_options = [] - if args["threads"]: - if (args["threads"] > 1): - make_options.append("-j" + str(args["threads"])) - else: - make_options.append("-j" + str(multiprocessing.cpu_count())) - - if is_unix(): - ret = call(["make", "install"] + make_options) - else: - print("Windows make files not supported. Build files have been written to " + octopus_build_dir) - -sys.exit(ret) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index df31d8a19..05d77a976 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(tandem) \ No newline at end of file +add_subdirectory(tandem) +add_subdirectory(ranger) \ No newline at end of file diff --git a/lib/ranger/CMakeLists.txt b/lib/ranger/CMakeLists.txt new file mode 100644 index 000000000..f98cfbc40 --- /dev/null +++ b/lib/ranger/CMakeLists.txt @@ -0,0 +1,40 @@ +set(RANGER_SOURCES + Data.h + Data.cpp + DataChar.h + DataChar.cpp + DataDouble.h + DataFloat.h + DataFloat.cpp + Forest.h + Forest.cpp + ForestClassification.h + ForestClassification.cpp + ForestProbability.h + ForestProbability.cpp + ForestRegression.h + ForestRegression.cpp + ForestSurvival.h + ForestSurvival.cpp + globals.h + Tree.h + Tree.cpp + TreeClassification.h + TreeClassification.cpp + TreeProbability.h + TreeProbability.cpp + TreeRegression.h + TreeRegression.cpp + TreeSurvival.h + TreeSurvival.cpp + utility.h + utility.cpp) + +add_library(ranger STATIC ${RANGER_SOURCES}) + +set(WarningIgnores + -Wno-unused-parameter + -Wno-unused-function + -Wno-missing-braces) + +add_compile_options(-Wall -Wextra -Werror ${WarningIgnores}) diff --git a/lib/ranger/Data.cpp b/lib/ranger/Data.cpp new file mode 100644 index 000000000..16cba5197 --- /dev/null +++ b/lib/ranger/Data.cpp @@ -0,0 +1,272 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include + +#include "Data.h" +#include "utility.h" + +namespace ranger { + +Data::Data() : + num_rows(0), num_rows_rounded(0), num_cols(0), snp_data(0), num_cols_no_snp(0), externalData(true), index_data(0), max_num_unique_values( + 0), order_snps(false) { +} + +size_t Data::getVariableID(const std::string& variable_name) const { + auto it = std::find(variable_names.cbegin(), variable_names.cend(), variable_name); + if (it == variable_names.cend()) { + throw std::runtime_error("Variable " + variable_name + " not found."); + } + return (std::distance(variable_names.cbegin(), it)); +} + +void Data::addSnpData(unsigned char* snp_data, size_t num_cols_snp) { + num_cols = num_cols_no_snp + num_cols_snp; + num_rows_rounded = roundToNextMultiple(num_rows, 4); + this->snp_data = snp_data; +} + +// #nocov start +bool Data::loadFromFile(std::string filename) { + + bool result; + + // Open input file + std::ifstream input_file; + input_file.open(filename); + if (!input_file.good()) { + throw std::runtime_error("Could not open input file."); + } + + // Count number of rows + size_t line_count = 0; + std::string line; + while (getline(input_file, line)) { + ++line_count; + } + num_rows = line_count - 1; + input_file.close(); + input_file.open(filename); + + // Check if comma, semicolon or whitespace seperated + std::string header_line; + getline(input_file, header_line); + + // Find out if comma, semicolon or whitespace seperated and call appropriate method + if (header_line.find(",") != std::string::npos) { + result = loadFromFileOther(input_file, header_line, ','); + } else if (header_line.find(";") != std::string::npos) { + result = loadFromFileOther(input_file, header_line, ';'); + } else { + result = loadFromFileWhitespace(input_file, header_line); + } + + externalData = false; + input_file.close(); + return result; +} + +bool Data::loadFromFileWhitespace(std::ifstream& input_file, std::string header_line) { + + // Read header + std::string header_token; + std::stringstream header_line_stream(header_line); + while (header_line_stream >> header_token) { + variable_names.push_back(header_token); + } + num_cols = variable_names.size(); + num_cols_no_snp = num_cols; + + // Read body + reserveMemory(); + bool error = false; + std::string line; + size_t row = 0; + while (getline(input_file, line)) { + double token; + std::stringstream line_stream(line); + size_t column = 0; + while (line_stream >> token) { + set(column, row, token, error); + ++column; + } + if (column > num_cols) { + throw std::runtime_error("Could not open input file. Too many columns in a row."); + } else if (column < num_cols) { + throw std::runtime_error("Could not open input file. Too few columns in a row. Are all values numeric?"); + } + ++row; + } + num_rows = row; + return error; +} + +bool Data::loadFromFileOther(std::ifstream& input_file, std::string header_line, char seperator) { + + // Read header + std::string header_token; + std::stringstream header_line_stream(header_line); + while (getline(header_line_stream, header_token, seperator)) { + variable_names.push_back(header_token); + } + num_cols = variable_names.size(); + num_cols_no_snp = num_cols; + + // Read body + reserveMemory(); + bool error = false; + std::string line; + size_t row = 0; + while (getline(input_file, line)) { + std::string token_string; + double token; + std::stringstream line_stream(line); + size_t column = 0; + while (getline(line_stream, token_string, seperator)) { + std::stringstream token_stream(token_string); + token_stream >> token; + set(column, row, token, error); + ++column; + } + ++row; + } + num_rows = row; + return error; +} +// #nocov end + +void Data::getAllValues(std::vector& all_values, std::vector& sampleIDs, size_t varID) const { + + // All values for varID (no duplicates) for given sampleIDs + if (getUnpermutedVarID(varID) < num_cols_no_snp) { + + all_values.reserve(sampleIDs.size()); + for (size_t i = 0; i < sampleIDs.size(); ++i) { + all_values.push_back(get(sampleIDs[i], varID)); + } + std::sort(all_values.begin(), all_values.end()); + all_values.erase(std::unique(all_values.begin(), all_values.end()), all_values.end()); + } else { + // If GWA data just use 0, 1, 2 + all_values = std::vector( { 0, 1, 2 }); + } +} + +void Data::getMinMaxValues(double& min, double&max, std::vector& sampleIDs, size_t varID) const { + if (sampleIDs.size() > 0) { + min = get(sampleIDs[0], varID); + max = min; + } + for (size_t i = 1; i < sampleIDs.size(); ++i) { + double value = get(sampleIDs[i], varID); + if (value < min) { + min = value; + } + if (value > max) { + max = value; + } + } +} + +void Data::sort() { + + // Reserve memory + index_data.resize(num_cols_no_snp * num_rows); + + // For all columns, get unique values and save index for each observation + for (size_t col = 0; col < num_cols_no_snp; ++col) { + + // Get all unique values + std::vector unique_values(num_rows); + for (size_t row = 0; row < num_rows; ++row) { + unique_values[row] = get(row, col); + } + std::sort(unique_values.begin(), unique_values.end()); + unique_values.erase(unique(unique_values.begin(), unique_values.end()), unique_values.end()); + + // Get index of unique value + for (size_t row = 0; row < num_rows; ++row) { + size_t idx = std::lower_bound(unique_values.begin(), unique_values.end(), get(row, col)) - unique_values.begin(); + index_data[col * num_rows + row] = idx; + } + + // Save unique values + unique_data_values.push_back(unique_values); + if (unique_values.size() > max_num_unique_values) { + max_num_unique_values = unique_values.size(); + } + } +} + +// TODO: Implement ordering for multiclass and survival +void Data::orderSnpLevels(std::string dependent_variable_name, bool corrected_importance) { + // Stop if now SNP data + if (snp_data == 0) { + return; + } + + size_t dependent_varID = getVariableID(dependent_variable_name); + size_t num_snps; + if (corrected_importance) { + num_snps = 2 * (num_cols - num_cols_no_snp); + } else { + num_snps = num_cols - num_cols_no_snp; + } + + // Reserve space + snp_order.resize(num_snps, std::vector(3)); + + // For each SNP + for (size_t i = 0; i < num_snps; ++i) { + size_t col = i; + if (i >= (num_cols - num_cols_no_snp)) { + // Get unpermuted SNP ID + col = i - num_cols + num_cols_no_snp; + } + + // Order by mean response + std::vector means(3, 0); + std::vector counts(3, 0); + for (size_t row = 0; row < num_rows; ++row) { + size_t row_permuted = row; + if (i >= (num_cols - num_cols_no_snp)) { + row_permuted = getPermutedSampleID(row); + } + size_t idx = col * num_rows_rounded + row_permuted; + size_t value = (((snp_data[idx / 4] & mask[idx % 4]) >> offset[idx % 4]) - 1); + + // TODO: Better way to treat missing values? + if (value > 2) { + value = 0; + } + + means[value] += get(row, dependent_varID); + ++counts[value]; + } + + for (size_t value = 0; value < 3; ++value) { + means[value] /= counts[value]; + } + + // Save order + snp_order[i] = order(means, false); + } + + order_snps = true; +} + +} // namespace ranger + diff --git a/lib/ranger/Data.h b/lib/ranger/Data.h new file mode 100644 index 000000000..73e923e6b --- /dev/null +++ b/lib/ranger/Data.h @@ -0,0 +1,235 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef DATA_H_ +#define DATA_H_ + +#include +#include +#include +#include +#include + +#include "globals.h" + +namespace ranger { + +class Data { +public: + Data(); + + Data(const Data&) = delete; + Data& operator=(const Data&) = delete; + + virtual ~Data() = default; + + virtual double get(size_t row, size_t col) const = 0; + + size_t getVariableID(const std::string& variable_name) const; + + virtual void reserveMemory() = 0; + virtual void set(size_t col, size_t row, double value, bool& error) = 0; + + void addSnpData(unsigned char* snp_data, size_t num_cols_snp); + + bool loadFromFile(std::string filename); + bool loadFromFileWhitespace(std::ifstream& input_file, std::string header_line); + bool loadFromFileOther(std::ifstream& input_file, std::string header_line, char seperator); + + void getAllValues(std::vector& all_values, std::vector& sampleIDs, size_t varID) const; + + void getMinMaxValues(double& min, double&max, std::vector& sampleIDs, size_t varID) const; + + size_t getIndex(size_t row, size_t col) const { + // Use permuted data for corrected impurity importance + size_t col_permuted = col; + if (col >= num_cols) { + col = getUnpermutedVarID(col); + row = getPermutedSampleID(row); + } + + if (col < num_cols_no_snp) { + return index_data[col * num_rows + row]; + } else { + return getSnp(row, col, col_permuted); + } + } + + size_t getSnp(size_t row, size_t col, size_t col_permuted) const { + // Get data out of snp storage. -1 because of GenABEL coding. + size_t idx = (col - num_cols_no_snp) * num_rows_rounded + row; + size_t result = ((snp_data[idx / 4] & mask[idx % 4]) >> offset[idx % 4]) - 1; + + // TODO: Better way to treat missing values? + if (result > 2) { + result = 0; + } + + // Order SNPs + if (order_snps) { + if (col_permuted >= num_cols) { + result = snp_order[col_permuted + no_split_variables.size() - 2 * num_cols_no_snp][result]; + } else { + result = snp_order[col - num_cols_no_snp][result]; + } + } + return result; + } + + double getUniqueDataValue(size_t varID, size_t index) const { + // Use permuted data for corrected impurity importance + if (varID >= num_cols) { + varID = getUnpermutedVarID(varID); + } + + if (varID < num_cols_no_snp) { + return unique_data_values[varID][index]; + } else { + // For GWAS data the index is the value + return (index); + } + } + + size_t getNumUniqueDataValues(size_t varID) const { + // Use permuted data for corrected impurity importance + if (varID >= num_cols) { + varID = getUnpermutedVarID(varID); + } + + if (varID < num_cols_no_snp) { + return unique_data_values[varID].size(); + } else { + // For GWAS data 0,1,2 + return (3); + } + } + + void sort(); + + void orderSnpLevels(std::string dependent_variable_name, bool corrected_importance); + + const std::vector& getVariableNames() const { + return variable_names; + } + size_t getNumCols() const { + return num_cols; + } + size_t getNumRows() const { + return num_rows; + } + + size_t getMaxNumUniqueValues() const { + if (snp_data == 0 || max_num_unique_values > 3) { + // If no snp data or one variable with more than 3 unique values, return that value + return max_num_unique_values; + } else { + // If snp data and no variable with more than 3 unique values, return 3 + return 3; + } + } + + const std::vector& getNoSplitVariables() const noexcept { + return no_split_variables; + } + + void addNoSplitVariable(size_t varID) { + no_split_variables.push_back(varID); + std::sort(no_split_variables.begin(), no_split_variables.end()); + } + + std::vector& getIsOrderedVariable() noexcept { + return is_ordered_variable; + } + + void setIsOrderedVariable(const std::vector& unordered_variable_names) { + is_ordered_variable.resize(num_cols, true); + for (auto& variable_name : unordered_variable_names) { + size_t varID = getVariableID(variable_name); + is_ordered_variable[varID] = false; + } + } + + void setIsOrderedVariable(std::vector& is_ordered_variable) { + this->is_ordered_variable = is_ordered_variable; + } + + bool isOrderedVariable(size_t varID) const { + // Use permuted data for corrected impurity importance + if (varID >= num_cols) { + varID = getUnpermutedVarID(varID); + } + return is_ordered_variable[varID]; + } + + void permuteSampleIDs(std::mt19937_64 random_number_generator) { + permuted_sampleIDs.resize(num_rows); + std::iota(permuted_sampleIDs.begin(), permuted_sampleIDs.end(), 0); + std::shuffle(permuted_sampleIDs.begin(), permuted_sampleIDs.end(), random_number_generator); + } + + size_t getPermutedSampleID(size_t sampleID) const { + return permuted_sampleIDs[sampleID]; + } + + size_t getUnpermutedVarID(size_t varID) const { + if (varID >= num_cols) { + varID -= num_cols; + + for (auto& skip : no_split_variables) { + if (varID >= skip) { + ++varID; + } + } + } + return varID; + } + + const std::vector>& getSnpOrder() const { + return snp_order; + } + + void setSnpOrder(std::vector>& snp_order) { + this->snp_order = snp_order; + order_snps = true; + } + +protected: + std::vector variable_names; + size_t num_rows; + size_t num_rows_rounded; + size_t num_cols; + + unsigned char* snp_data; + size_t num_cols_no_snp; + + bool externalData; + + std::vector index_data; + std::vector> unique_data_values; + size_t max_num_unique_values; + + // Variable to not split at (only dependent_varID for non-survival trees) + std::vector no_split_variables; + + // For each varID true if ordered + std::vector is_ordered_variable; + + // Permuted samples for corrected impurity importance + std::vector permuted_sampleIDs; + + // Order of 0/1/2 for ordered splitting + std::vector> snp_order; + bool order_snps; +}; + +} // namespace ranger + +#endif /* DATA_H_ */ diff --git a/lib/ranger/DataChar.cpp b/lib/ranger/DataChar.cpp new file mode 100644 index 000000000..24169c880 --- /dev/null +++ b/lib/ranger/DataChar.cpp @@ -0,0 +1,48 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +// Ignore in coverage report (not used in R package) +// #nocov start +#include +#include +#include + +#include "DataChar.h" + +namespace ranger { + +DataChar::DataChar(double* data_double, std::vector variable_names, size_t num_rows, size_t num_cols, + bool& error) { + this->variable_names = variable_names; + this->num_rows = num_rows; + this->num_cols = num_cols; + this->num_cols_no_snp = num_cols; + + reserveMemory(); + + // Save data and report errors + for (size_t i = 0; i < num_cols; ++i) { + for (size_t j = 0; j < num_rows; ++j) { + double value = data_double[i * num_rows + j]; + if (value > CHAR_MAX || value < CHAR_MIN) { + error = true; + } + if (floor(value) != ceil(value)) { + error = true; + } + data[i * num_rows + j] = value; + } + } +} + +} // namespace ranger + +// #nocov end diff --git a/lib/ranger/DataChar.h b/lib/ranger/DataChar.h new file mode 100644 index 000000000..b0f4e0dae --- /dev/null +++ b/lib/ranger/DataChar.h @@ -0,0 +1,71 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +// Ignore in coverage report (not used in R package) +// #nocov start +#ifndef DATACHAR_H_ +#define DATACHAR_H_ + +#include +#include + +#include "globals.h" +#include "Data.h" + +namespace ranger { + +class DataChar: public Data { +public: + DataChar() = default; + DataChar(double* data_double, std::vector variable_names, size_t num_rows, size_t num_cols, bool& error); + + DataChar(const DataChar&) = delete; + DataChar& operator=(const DataChar&) = delete; + + virtual ~DataChar() override = default; + + double get(size_t row, size_t col) const override { + // Use permuted data for corrected impurity importance + size_t col_permuted = col; + if (col >= num_cols) { + col = getUnpermutedVarID(col); + row = getPermutedSampleID(row); + } + + if (col < num_cols_no_snp) { + return data[col * num_rows + row]; + } else { + return getSnp(row, col, col_permuted); + } + } + + void reserveMemory() override { + data.resize(num_cols * num_rows); + } + + void set(size_t col, size_t row, double value, bool& error) override { + if (value > CHAR_MAX || value < CHAR_MIN) { + error = true; + } + if (floor(value) != ceil(value)) { + error = true; + } + data[col * num_rows + row] = value; + } + +private: + std::vector data; +}; + +} // namespace ranger + +#endif /* DATACHAR_H_ */ +// #nocov end diff --git a/lib/ranger/DataDouble.h b/lib/ranger/DataDouble.h new file mode 100644 index 000000000..aa8713053 --- /dev/null +++ b/lib/ranger/DataDouble.h @@ -0,0 +1,69 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef DATADOUBLE_H_ +#define DATADOUBLE_H_ + +#include +#include + +#include "globals.h" +#include "utility.h" +#include "Data.h" + +namespace ranger { + +class DataDouble: public Data { +public: + DataDouble() = default; + DataDouble(std::vector data, std::vector variable_names, size_t num_rows, size_t num_cols) : + data { std::move(data) } { + this->variable_names = variable_names; + this->num_rows = num_rows; + this->num_cols = num_cols; + this->num_cols_no_snp = num_cols; + } + + DataDouble(const DataDouble&) = delete; + DataDouble& operator=(const DataDouble&) = delete; + + virtual ~DataDouble() override = default; + + double get(size_t row, size_t col) const override { + // Use permuted data for corrected impurity importance + size_t col_permuted = col; + if (col >= num_cols) { + col = getUnpermutedVarID(col); + row = getPermutedSampleID(row); + } + + if (col < num_cols_no_snp) { + return data[col * num_rows + row]; + } else { + return getSnp(row, col, col_permuted); + } + } + + void reserveMemory() override { + data.resize(num_cols * num_rows); + } + + void set(size_t col, size_t row, double value, bool& error) override { + data[col * num_rows + row] = value; + } + +private: + std::vector data; +}; + +} // namespace ranger + +#endif /* DATADOUBLE_H_ */ diff --git a/lib/ranger/DataFloat.cpp b/lib/ranger/DataFloat.cpp new file mode 100644 index 000000000..a8027abbb --- /dev/null +++ b/lib/ranger/DataFloat.cpp @@ -0,0 +1,36 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +// Ignore in coverage report (not used in R package) +// #nocov start +#include + +#include "DataFloat.h" + +namespace ranger { + +DataFloat::DataFloat(double* data_double, std::vector variable_names, size_t num_rows, size_t num_cols) { + this->variable_names = variable_names; + this->num_rows = num_rows; + this->num_cols = num_cols; + this->num_cols_no_snp = num_cols; + + reserveMemory(); + for (size_t i = 0; i < num_cols; ++i) { + for (size_t j = 0; j < num_rows; ++j) { + data[i * num_rows + j] = (float) data_double[i * num_rows + j]; + } + } +} + +// #nocov end + +}// namespace ranger diff --git a/lib/ranger/DataFloat.h b/lib/ranger/DataFloat.h new file mode 100644 index 000000000..a3c4da4f2 --- /dev/null +++ b/lib/ranger/DataFloat.h @@ -0,0 +1,65 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +// Ignore in coverage report (not used in R package) +// #nocov start +#ifndef DATAFLOAT_H_ +#define DATAFLOAT_H_ + +#include + +#include "globals.h" +#include "Data.h" + +namespace ranger { + +class DataFloat: public Data { +public: + DataFloat() = default; + DataFloat(double* data_double, std::vector variable_names, size_t num_rows, size_t num_cols); + + DataFloat(const DataFloat&) = delete; + DataFloat& operator=(const DataFloat&) = delete; + + virtual ~DataFloat() override = default; + + double get(size_t row, size_t col) const override { + // Use permuted data for corrected impurity importance + size_t col_permuted = col; + if (col >= num_cols) { + col = getUnpermutedVarID(col); + row = getPermutedSampleID(row); + } + + if (col < num_cols_no_snp) { + return data[col * num_rows + row]; + } else { + return getSnp(row, col, col_permuted); + } + } + + void reserveMemory() override { + data.resize(num_cols * num_rows); + } + + void set(size_t col, size_t row, double value, bool& error) override { + data[col * num_rows + row] = value; + } + +private: + std::vector data; +}; + +} // namespace ranger + +#endif /* DATAFLOAT_H_ */ +// #nocov end + diff --git a/lib/ranger/Forest.cpp b/lib/ranger/Forest.cpp new file mode 100644 index 000000000..1b109df64 --- /dev/null +++ b/lib/ranger/Forest.cpp @@ -0,0 +1,948 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include +#include +#ifndef OLD_WIN_R_BUILD +#include +#include +#endif + +#include "utility.h" +#include "Forest.h" +#include "DataChar.h" +#include "DataDouble.h" +#include "DataFloat.h" + +namespace ranger { + +Forest::Forest() : + verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size(0), num_variables(0), num_independent_variables( + 0), seed(0), dependent_varID(0), num_samples(0), prediction_mode(false), memory_mode(MEM_DOUBLE), sample_with_replacement( + true), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction( + { 1 }), holdout(false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), alpha( + DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_threads(DEFAULT_NUM_THREADS), data { }, overall_prediction_error( + 0), importance_mode(DEFAULT_IMPORTANCE_MODE), progress(0) { +} + +// #nocov start +std::unique_ptr load_data_from_file(const std::string& data_path, const MemoryMode memory_mode, + std::ostream* verbose_out = nullptr) { + std::unique_ptr result { }; + switch (memory_mode) { + case MEM_DOUBLE: + result = std::make_unique(); + break; + case MEM_FLOAT: + result = std::make_unique(); + break; + case MEM_CHAR: + result = std::make_unique(); + break; + } + + if (verbose_out) + *verbose_out << "Loading input file: " << data_path << "." << std::endl; + bool found_rounding_error = result->loadFromFile(data_path); + if (found_rounding_error && verbose_out) { + *verbose_out << "Warning: Rounding or Integer overflow occurred. Use FLOAT or DOUBLE precision to avoid this." + << std::endl; + } + return result; +} + +void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode, std::string input_file, uint mtry, + std::string output_prefix, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads, + std::string load_forest_filename, ImportanceMode importance_mode, uint min_node_size, + std::string split_select_weights_file, const std::vector& always_split_variable_names, + std::string status_variable_name, bool sample_with_replacement, + const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, + std::string case_weights_file, bool predict_all, double sample_fraction, double alpha, double minprop, bool holdout, + PredictionType prediction_type, uint num_random_splits) { + + this->verbose_out = verbose_out; + + // Set prediction mode + bool prediction_mode = false; + if (!load_forest_filename.empty()) { + prediction_mode = true; + } + + // Sample fraction to vector + std::vector sample_fraction_vector = { sample_fraction }; + + // Call other init function + init(dependent_variable_name, memory_mode, load_data_from_file(input_file, memory_mode, verbose_out), mtry, + output_prefix, num_trees, seed, num_threads, importance_mode, min_node_size, status_variable_name, + prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting, splitrule, + predict_all, sample_fraction_vector, alpha, minprop, holdout, prediction_type, num_random_splits, false); + + if (prediction_mode) { + loadFromFile(load_forest_filename); + } + // Set variables to be always considered for splitting + if (!always_split_variable_names.empty()) { + setAlwaysSplitVariables(always_split_variable_names); + } + + // TODO: Read 2d weights for tree-wise split select weights + // Load split select weights from file + if (!split_select_weights_file.empty()) { + std::vector> split_select_weights; + split_select_weights.resize(1); + loadDoubleVectorFromFile(split_select_weights[0], split_select_weights_file); + if (split_select_weights[0].size() != num_variables - 1) { + throw std::runtime_error("Number of split select weights is not equal to number of independent variables."); + } + setSplitWeightVector(split_select_weights); + } + + // Load case weights from file + if (!case_weights_file.empty()) { + loadDoubleVectorFromFile(case_weights, case_weights_file); + if (case_weights.size() != num_samples) { + throw std::runtime_error("Number of case weights is not equal to number of samples."); + } + } + + // Sample from non-zero weights in holdout mode + if (holdout && !case_weights.empty()) { + size_t nonzero_weights = 0; + for (auto& weight : case_weights) { + if (weight > 0) { + ++nonzero_weights; + } + } + this->sample_fraction[0] = this->sample_fraction[0] * ((double) nonzero_weights / (double) num_samples); + } + + // Check if all catvars are coded in integers starting at 1 + if (!unordered_variable_names.empty()) { + std::string error_message = checkUnorderedVariables(*data, unordered_variable_names); + if (!error_message.empty()) { + throw std::runtime_error(error_message); + } + } +} +// #nocov end + +void Forest::initR(std::string dependent_variable_name, std::unique_ptr input_data, uint mtry, uint num_trees, + std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, + std::vector>& split_select_weights, const std::vector& always_split_variable_names, + std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, + const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, + std::vector& case_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, + double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, + bool order_snps) { + + this->verbose_out = verbose_out; + + // Call other init function + init(dependent_variable_name, MEM_DOUBLE, std::move(input_data), mtry, "", num_trees, seed, num_threads, + importance_mode, min_node_size, status_variable_name, prediction_mode, sample_with_replacement, + unordered_variable_names, memory_saving_splitting, splitrule, predict_all, sample_fraction, alpha, minprop, + holdout, prediction_type, num_random_splits, order_snps); + + // Set variables to be always considered for splitting + if (!always_split_variable_names.empty()) { + setAlwaysSplitVariables(always_split_variable_names); + } + + // Set split select weights + if (!split_select_weights.empty()) { + setSplitWeightVector(split_select_weights); + } + + // Set case weights + if (!case_weights.empty()) { + if (case_weights.size() != num_samples) { + throw std::runtime_error("Number of case weights not equal to number of samples."); + } + this->case_weights = case_weights; + } + + // Keep inbag counts + this->keep_inbag = keep_inbag; +} + +void Forest::init(std::string dependent_variable_name, MemoryMode memory_mode, std::unique_ptr input_data, + uint mtry, std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, + uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, + const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, + bool predict_all, std::vector& sample_fraction, double alpha, double minprop, bool holdout, + PredictionType prediction_type, uint num_random_splits, bool order_snps) { + + // Initialize data with memmode + this->data = std::move(input_data); + + // Initialize random number generator and set seed + if (seed == 0) { + std::random_device random_device; + random_number_generator.seed(random_device()); + } else { + random_number_generator.seed(seed); + } + + // Set number of threads + if (num_threads == DEFAULT_NUM_THREADS) { +#ifdef OLD_WIN_R_BUILD + this->num_threads = 1; +#else + this->num_threads = std::thread::hardware_concurrency(); +#endif + } else { + this->num_threads = num_threads; + } + + // Set member variables + this->num_trees = num_trees; + this->mtry = mtry; + this->seed = seed; + this->output_prefix = output_prefix; + this->importance_mode = importance_mode; + this->min_node_size = min_node_size; + this->memory_mode = memory_mode; + this->prediction_mode = prediction_mode; + this->sample_with_replacement = sample_with_replacement; + this->memory_saving_splitting = memory_saving_splitting; + this->splitrule = splitrule; + this->predict_all = predict_all; + this->sample_fraction = sample_fraction; + this->holdout = holdout; + this->alpha = alpha; + this->minprop = minprop; + this->prediction_type = prediction_type; + this->num_random_splits = num_random_splits; + + // Set number of samples and variables + num_samples = data->getNumRows(); + num_variables = data->getNumCols(); + + // Convert dependent variable name to ID + if (!prediction_mode && !dependent_variable_name.empty()) { + dependent_varID = data->getVariableID(dependent_variable_name); + } + + // Set unordered factor variables + if (!prediction_mode) { + data->setIsOrderedVariable(unordered_variable_names); + } + + data->addNoSplitVariable(dependent_varID); + + initInternal(status_variable_name); + + num_independent_variables = num_variables - data->getNoSplitVariables().size(); + + // Init split select weights + split_select_weights.push_back(std::vector()); + + // Check if mtry is in valid range + if (this->mtry > num_variables - 1) { + throw std::runtime_error("mtry can not be larger than number of variables in data."); + } + + // Check if any observations samples + if ((size_t) num_samples * sample_fraction[0] < 1) { + throw std::runtime_error("sample_fraction too small, no observations sampled."); + } + + // Permute samples for corrected Gini importance + if (importance_mode == IMP_GINI_CORRECTED) { + data->permuteSampleIDs(random_number_generator); + } + + // Order SNP levels if in "order" splitting + if (!prediction_mode && order_snps) { + data->orderSnpLevels(dependent_variable_name, (importance_mode == IMP_GINI_CORRECTED)); + } +} + +void Forest::run(bool verbose) { + + if (prediction_mode) { + if (verbose && verbose_out) { + *verbose_out << "Predicting .." << std::endl; + } + predict(); + } else { + if (verbose && verbose_out) { + *verbose_out << "Growing trees .." << std::endl; + } + + grow(); + + if (verbose && verbose_out) { + *verbose_out << "Computing prediction error .." << std::endl; + } + computePredictionError(); + + if (importance_mode == IMP_PERM_BREIMAN || importance_mode == IMP_PERM_LIAW || importance_mode == IMP_PERM_RAW) { + if (verbose && verbose_out) { + *verbose_out << "Computing permutation variable importance .." << std::endl; + } + computePermutationImportance(); + } + } +} + +// #nocov start +void Forest::writeOutput() { + + if (verbose_out) + *verbose_out << std::endl; + writeOutputInternal(); + if (verbose_out) { + *verbose_out << "Dependent variable name: " << data->getVariableNames()[dependent_varID] << std::endl; + *verbose_out << "Dependent variable ID: " << dependent_varID << std::endl; + *verbose_out << "Number of trees: " << num_trees << std::endl; + *verbose_out << "Sample size: " << num_samples << std::endl; + *verbose_out << "Number of independent variables: " << num_independent_variables << std::endl; + *verbose_out << "Mtry: " << mtry << std::endl; + *verbose_out << "Target node size: " << min_node_size << std::endl; + *verbose_out << "Variable importance mode: " << importance_mode << std::endl; + *verbose_out << "Memory mode: " << memory_mode << std::endl; + *verbose_out << "Seed: " << seed << std::endl; + *verbose_out << "Number of threads: " << num_threads << std::endl; + *verbose_out << std::endl; + } + + if (prediction_mode) { + writePredictionFile(); + } else { + if (verbose_out) { + *verbose_out << "Overall OOB prediction error: " << overall_prediction_error << std::endl; + *verbose_out << std::endl; + } + + if (!split_select_weights.empty() & !split_select_weights[0].empty()) { + if (verbose_out) { + *verbose_out + << "Warning: Split select weights used. Variable importance measures are only comparable for variables with equal weights." + << std::endl; + } + } + + if (importance_mode != IMP_NONE) { + writeImportanceFile(); + } + + writeConfusionFile(); + } +} + +void Forest::writeImportanceFile() { + + // Open importance file for writing + std::string filename = output_prefix + ".importance"; + std::ofstream importance_file; + importance_file.open(filename, std::ios::out); + if (!importance_file.good()) { + throw std::runtime_error("Could not write to importance file: " + filename + "."); + } + + // Write importance to file + for (size_t i = 0; i < variable_importance.size(); ++i) { + size_t varID = i; + for (auto& skip : data->getNoSplitVariables()) { + if (varID >= skip) { + ++varID; + } + } + std::string variable_name = data->getVariableNames()[varID]; + importance_file << variable_name << ": " << variable_importance[i] << std::endl; + } + + importance_file.close(); + if (verbose_out) + *verbose_out << "Saved variable importance to file " << filename << "." << std::endl; +} + +void Forest::saveToFile() { + + // Open file for writing + std::string filename = output_prefix + ".forest"; + std::ofstream outfile; + outfile.open(filename, std::ios::binary); + if (!outfile.good()) { + throw std::runtime_error("Could not write to output file: " + filename + "."); + } + + // Write dependent_varID + outfile.write((char*) &dependent_varID, sizeof(dependent_varID)); + + // Write num_trees + outfile.write((char*) &num_trees, sizeof(num_trees)); + + // Write is_ordered_variable + saveVector1D(data->getIsOrderedVariable(), outfile); + + saveToFileInternal(outfile); + + // Write tree data for each tree + for (auto& tree : trees) { + tree->appendToFile(outfile); + } + + // Close file + outfile.close(); + if (verbose_out) + *verbose_out << "Saved forest to file " << filename << "." << std::endl; +} +// #nocov end + +void Forest::grow() { + + // Create thread ranges + equalSplit(thread_ranges, 0, num_trees - 1, num_threads); + + // Call special grow functions of subclasses. There trees must be created. + growInternal(); + + // Init trees, create a seed for each tree, based on main seed + std::uniform_int_distribution udist; + for (size_t i = 0; i < num_trees; ++i) { + uint tree_seed; + if (seed == 0) { + tree_seed = udist(random_number_generator); + } else { + tree_seed = (i + 1) * seed; + } + + // Get split select weights for tree + std::vector* tree_split_select_weights; + if (split_select_weights.size() > 1) { + tree_split_select_weights = &split_select_weights[i]; + } else { + tree_split_select_weights = &split_select_weights[0]; + } + + trees[i]->init(data.get(), mtry, dependent_varID, num_samples, tree_seed, &deterministic_varIDs, + &split_select_varIDs, tree_split_select_weights, importance_mode, min_node_size, sample_with_replacement, + memory_saving_splitting, splitrule, &case_weights, keep_inbag, &sample_fraction, alpha, minprop, holdout, + num_random_splits); + } + +// Init variable importance + variable_importance.resize(num_independent_variables, 0); + +// Grow trees in multiple threads +#ifdef OLD_WIN_R_BUILD + progress = 0; + clock_t start_time = clock(); + clock_t lap_time = clock(); + for (size_t i = 0; i < num_trees; ++i) { + trees[i]->grow(&variable_importance); + progress++; + showProgress("Growing trees..", start_time, lap_time); + } +#else + progress = 0; +#ifdef R_BUILD + aborted = false; + aborted_threads = 0; +#endif + + std::vector threads; + threads.reserve(num_threads); + +// Initailize importance per thread + std::vector> variable_importance_threads(num_threads); + + for (uint i = 0; i < num_threads; ++i) { + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + variable_importance_threads[i].resize(num_independent_variables, 0); + } + threads.emplace_back(&Forest::growTreesInThread, this, i, &(variable_importance_threads[i])); + } + showProgress("Growing trees..", num_trees); + for (auto &thread : threads) { + thread.join(); + } + +#ifdef R_BUILD + if (aborted_threads > 0) { + throw std::runtime_error("User interrupt."); + } +#endif + + // Sum thread importances + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + variable_importance.resize(num_independent_variables, 0); + for (size_t i = 0; i < num_independent_variables; ++i) { + for (uint j = 0; j < num_threads; ++j) { + variable_importance[i] += variable_importance_threads[j][i]; + } + } + variable_importance_threads.clear(); + } + +#endif + +// Divide importance by number of trees + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + for (auto& v : variable_importance) { + v /= num_trees; + } + } +} + +void Forest::predict() { + +// Predict trees in multiple threads and join the threads with the main thread +#ifdef OLD_WIN_R_BUILD + progress = 0; + clock_t start_time = clock(); + clock_t lap_time = clock(); + for (size_t i = 0; i < num_trees; ++i) { + trees[i]->predict(data.get(), false); + progress++; + showProgress("Predicting..", start_time, lap_time); + } + + // For all samples get tree predictions + allocatePredictMemory(); + for (size_t sample_idx = 0; sample_idx < data->getNumRows(); ++sample_idx) { + predictInternal(sample_idx); + } +#else + progress = 0; +#ifdef R_BUILD + aborted = false; + aborted_threads = 0; +#endif + + // Predict + std::vector threads; + threads.reserve(num_threads); + for (uint i = 0; i < num_threads; ++i) { + threads.emplace_back(&Forest::predictTreesInThread, this, i, data.get(), false); + } + showProgress("Predicting..", num_trees); + for (auto &thread : threads) { + thread.join(); + } + + // Aggregate predictions + allocatePredictMemory(); + threads.clear(); + threads.reserve(num_threads); + progress = 0; + for (uint i = 0; i < num_threads; ++i) { + threads.emplace_back(&Forest::predictInternalInThread, this, i); + } + showProgress("Aggregating predictions..", num_samples); + for (auto &thread : threads) { + thread.join(); + } + +#ifdef R_BUILD + if (aborted_threads > 0) { + throw std::runtime_error("User interrupt."); + } +#endif +#endif +} + +void Forest::computePredictionError() { + +// Predict trees in multiple threads +#ifdef OLD_WIN_R_BUILD + progress = 0; + clock_t start_time = clock(); + clock_t lap_time = clock(); + for (size_t i = 0; i < num_trees; ++i) { + trees[i]->predict(data.get(), true); + progress++; + showProgress("Predicting..", start_time, lap_time); + } +#else + std::vector threads; + threads.reserve(num_threads); + progress = 0; + for (uint i = 0; i < num_threads; ++i) { + threads.emplace_back(&Forest::predictTreesInThread, this, i, data.get(), true); + } + showProgress("Computing prediction error..", num_trees); + for (auto &thread : threads) { + thread.join(); + } + +#ifdef R_BUILD + if (aborted_threads > 0) { + throw std::runtime_error("User interrupt."); + } +#endif +#endif + + // Call special function for subclasses + computePredictionErrorInternal(); +} + +void Forest::computePermutationImportance() { + +// Compute tree permutation importance in multiple threads +#ifdef OLD_WIN_R_BUILD + progress = 0; + clock_t start_time = clock(); + clock_t lap_time = clock(); + +// Initailize importance and variance + variable_importance.resize(num_independent_variables, 0); + std::vector variance; + if (importance_mode == IMP_PERM_BREIMAN || importance_mode == IMP_PERM_LIAW) { + variance.resize(num_independent_variables, 0); + } + +// Compute importance + for (size_t i = 0; i < num_trees; ++i) { + trees[i]->computePermutationImportance(variable_importance, variance); + progress++; + showProgress("Computing permutation importance..", start_time, lap_time); + } +#else + progress = 0; +#ifdef R_BUILD + aborted = false; + aborted_threads = 0; +#endif + + std::vector threads; + threads.reserve(num_threads); + +// Initailize importance and variance + std::vector> variable_importance_threads(num_threads); + std::vector> variance_threads(num_threads); + +// Compute importance + for (uint i = 0; i < num_threads; ++i) { + variable_importance_threads[i].resize(num_independent_variables, 0); + if (importance_mode == IMP_PERM_BREIMAN || importance_mode == IMP_PERM_LIAW) { + variance_threads[i].resize(num_independent_variables, 0); + } + threads.emplace_back(&Forest::computeTreePermutationImportanceInThread, this, i, + std::ref(variable_importance_threads[i]), std::ref(variance_threads[i])); + } + showProgress("Computing permutation importance..", num_trees); + for (auto &thread : threads) { + thread.join(); + } + +#ifdef R_BUILD + if (aborted_threads > 0) { + throw std::runtime_error("User interrupt."); + } +#endif + +// Sum thread importances + variable_importance.resize(num_independent_variables, 0); + for (size_t i = 0; i < num_independent_variables; ++i) { + for (uint j = 0; j < num_threads; ++j) { + variable_importance[i] += variable_importance_threads[j][i]; + } + } + variable_importance_threads.clear(); + +// Sum thread variances + std::vector variance(num_independent_variables, 0); + if (importance_mode == IMP_PERM_BREIMAN || importance_mode == IMP_PERM_LIAW) { + for (size_t i = 0; i < num_independent_variables; ++i) { + for (uint j = 0; j < num_threads; ++j) { + variance[i] += variance_threads[j][i]; + } + } + variance_threads.clear(); + } +#endif + + for (size_t i = 0; i < variable_importance.size(); ++i) { + variable_importance[i] /= num_trees; + + // Normalize by variance for scaled permutation importance + if (importance_mode == IMP_PERM_BREIMAN || importance_mode == IMP_PERM_LIAW) { + if (variance[i] != 0) { + variance[i] = variance[i] / num_trees - variable_importance[i] * variable_importance[i]; + variable_importance[i] /= sqrt(variance[i] / num_trees); + } + } + } +} + +#ifndef OLD_WIN_R_BUILD +void Forest::growTreesInThread(uint thread_idx, std::vector* variable_importance) { + if (thread_ranges.size() > thread_idx + 1) { + for (size_t i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i) { + trees[i]->grow(variable_importance); + + // Check for user interrupt +#ifdef R_BUILD + if (aborted) { + std::unique_lock lock(mutex); + ++aborted_threads; + condition_variable.notify_one(); + return; + } +#endif + + // Increase progress by 1 tree + std::unique_lock lock(mutex); + ++progress; + condition_variable.notify_one(); + } + } +} + +void Forest::predictTreesInThread(uint thread_idx, const Data* prediction_data, bool oob_prediction) { + if (thread_ranges.size() > thread_idx + 1) { + for (size_t i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i) { + trees[i]->predict(prediction_data, oob_prediction); + + // Check for user interrupt +#ifdef R_BUILD + if (aborted) { + std::unique_lock lock(mutex); + ++aborted_threads; + condition_variable.notify_one(); + return; + } +#endif + + // Increase progress by 1 tree + std::unique_lock lock(mutex); + ++progress; + condition_variable.notify_one(); + } + } +} + +void Forest::predictInternalInThread(uint thread_idx) { + // Create thread ranges + std::vector predict_ranges; + equalSplit(predict_ranges, 0, num_samples - 1, num_threads); + + if (predict_ranges.size() > thread_idx + 1) { + for (size_t i = predict_ranges[thread_idx]; i < predict_ranges[thread_idx + 1]; ++i) { + predictInternal(i); + + // Check for user interrupt +#ifdef R_BUILD + if (aborted) { + std::unique_lock lock(mutex); + ++aborted_threads; + condition_variable.notify_one(); + return; + } +#endif + + // Increase progress by 1 tree + std::unique_lock lock(mutex); + ++progress; + condition_variable.notify_one(); + } + } +} + +void Forest::computeTreePermutationImportanceInThread(uint thread_idx, std::vector& importance, + std::vector& variance) { + if (thread_ranges.size() > thread_idx + 1) { + for (size_t i = thread_ranges[thread_idx]; i < thread_ranges[thread_idx + 1]; ++i) { + trees[i]->computePermutationImportance(importance, variance); + + // Check for user interrupt +#ifdef R_BUILD + if (aborted) { + std::unique_lock lock(mutex); + ++aborted_threads; + condition_variable.notify_one(); + return; + } +#endif + + // Increase progress by 1 tree + std::unique_lock lock(mutex); + ++progress; + condition_variable.notify_one(); + } + } +} +#endif + +// #nocov start +void Forest::loadFromFile(std::string filename) { + if (verbose_out) + *verbose_out << "Loading forest from file " << filename << "." << std::endl; + +// Open file for reading + std::ifstream infile; + infile.open(filename, std::ios::binary); + if (!infile.good()) { + throw std::runtime_error("Could not read from input file: " + filename + "."); + } + +// Read dependent_varID and num_trees + infile.read((char*) &dependent_varID, sizeof(dependent_varID)); + infile.read((char*) &num_trees, sizeof(num_trees)); + +// Read is_ordered_variable + readVector1D(data->getIsOrderedVariable(), infile); + +// Read tree data. This is different for tree types -> virtual function + loadFromFileInternal(infile); + + infile.close(); + +// Create thread ranges + equalSplit(thread_ranges, 0, num_trees - 1, num_threads); +} +// #nocov end + +void Forest::setSplitWeightVector(std::vector>& split_select_weights) { + +// Size should be 1 x num_independent_variables or num_trees x num_independent_variables + if (split_select_weights.size() != 1 && split_select_weights.size() != num_trees) { + throw std::runtime_error("Size of split select weights not equal to 1 or number of trees."); + } + +// Reserve space + if (split_select_weights.size() == 1) { + this->split_select_weights[0].resize(num_independent_variables); + } else { + this->split_select_weights.clear(); + this->split_select_weights.resize(num_trees, std::vector(num_independent_variables)); + } + this->split_select_varIDs.resize(num_independent_variables); + deterministic_varIDs.reserve(num_independent_variables); + +// Split up in deterministic and weighted variables, ignore zero weights + for (size_t i = 0; i < split_select_weights.size(); ++i) { + + // Size should be 1 x num_independent_variables or num_trees x num_independent_variables + if (split_select_weights[i].size() != num_independent_variables) { + throw std::runtime_error("Number of split select weights not equal to number of independent variables."); + } + + for (size_t j = 0; j < split_select_weights[i].size(); ++j) { + double weight = split_select_weights[i][j]; + + if (i == 0) { + size_t varID = j; + for (auto& skip : data->getNoSplitVariables()) { + if (varID >= skip) { + ++varID; + } + } + + if (weight == 1) { + deterministic_varIDs.push_back(varID); + } else if (weight < 1 && weight > 0) { + this->split_select_varIDs[j] = varID; + this->split_select_weights[i][j] = weight; + } else if (weight < 0 || weight > 1) { + throw std::runtime_error("One or more split select weights not in range [0,1]."); + } + + } else { + if (weight < 1 && weight > 0) { + this->split_select_weights[i][j] = weight; + } else if (weight < 0 || weight > 1) { + throw std::runtime_error("One or more split select weights not in range [0,1]."); + } + } + } + } + + if (deterministic_varIDs.size() > this->mtry) { + throw std::runtime_error("Number of ones in split select weights cannot be larger than mtry."); + } + if (deterministic_varIDs.size() + split_select_varIDs.size() < mtry) { + throw std::runtime_error("Too many zeros in split select weights. Need at least mtry variables to split at."); + } +} + +void Forest::setAlwaysSplitVariables(const std::vector& always_split_variable_names) { + + deterministic_varIDs.reserve(num_independent_variables); + + for (auto& variable_name : always_split_variable_names) { + size_t varID = data->getVariableID(variable_name); + deterministic_varIDs.push_back(varID); + } + + if (deterministic_varIDs.size() + this->mtry > num_independent_variables) { + throw std::runtime_error( + "Number of variables to be always considered for splitting plus mtry cannot be larger than number of independent variables."); + } +} + +#ifdef OLD_WIN_R_BUILD +void Forest::showProgress(std::string operation, clock_t start_time, clock_t& lap_time) { + +// Check for user interrupt + if (checkInterrupt()) { + throw std::runtime_error("User interrupt."); + } + + double elapsed_time = (clock() - lap_time) / CLOCKS_PER_SEC; + if (elapsed_time > STATUS_INTERVAL) { + double relative_progress = (double) progress / (double) num_trees; + double time_from_start = (clock() - start_time) / CLOCKS_PER_SEC; + uint remaining_time = (1 / relative_progress - 1) * time_from_start; + if (verbose_out) { + *verbose_out << operation << " Progress: " << round(100 * relative_progress) + << "%. Estimated remaining time: " << beautifyTime(remaining_time) << "." << std::endl; + } + lap_time = clock(); + } +} +#else +void Forest::showProgress(std::string operation, size_t max_progress) { + using std::chrono::steady_clock; + using std::chrono::duration_cast; + using std::chrono::seconds; + + steady_clock::time_point start_time = steady_clock::now(); + steady_clock::time_point last_time = steady_clock::now(); + std::unique_lock lock(mutex); + +// Wait for message from threads and show output if enough time elapsed + while (progress < max_progress) { + condition_variable.wait(lock); + seconds elapsed_time = duration_cast(steady_clock::now() - last_time); + + // Check for user interrupt +#ifdef R_BUILD + if (!aborted && checkInterrupt()) { + aborted = true; + } + if (aborted && aborted_threads >= num_threads) { + return; + } +#endif + + if (progress > 0 && elapsed_time.count() > STATUS_INTERVAL) { + double relative_progress = (double) progress / (double) max_progress; + seconds time_from_start = duration_cast(steady_clock::now() - start_time); + uint remaining_time = (1 / relative_progress - 1) * time_from_start.count(); + if (verbose_out) { + *verbose_out << operation << " Progress: " << round(100 * relative_progress) << "%. Estimated remaining time: " + << beautifyTime(remaining_time) << "." << std::endl; + } + last_time = steady_clock::now(); + } + } +} +#endif + +} // namespace ranger diff --git a/lib/ranger/Forest.h b/lib/ranger/Forest.h new file mode 100644 index 000000000..b06f418ed --- /dev/null +++ b/lib/ranger/Forest.h @@ -0,0 +1,248 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef FOREST_H_ +#define FOREST_H_ + +#include +#include +#include +#include +#include +#ifndef OLD_WIN_R_BUILD +#include +#include +#include +#include +#endif + +#include "globals.h" +#include "Tree.h" +#include "Data.h" + +namespace ranger { + +class Forest { +public: + Forest(); + + Forest(const Forest&) = delete; + Forest& operator=(const Forest&) = delete; + + virtual ~Forest() = default; + + // Init from c++ main or Rcpp from R + void initCpp(std::string dependent_variable_name, MemoryMode memory_mode, std::string input_file, uint mtry, + std::string output_prefix, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads, + std::string load_forest_filename, ImportanceMode importance_mode, uint min_node_size, + std::string split_select_weights_file, const std::vector& always_split_variable_names, + std::string status_variable_name, bool sample_with_replacement, + const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, + std::string case_weights_file, bool predict_all, double sample_fraction, double alpha, double minprop, + bool holdout, PredictionType prediction_type, uint num_random_splits); + void initR(std::string dependent_variable_name, std::unique_ptr input_data, uint mtry, uint num_trees, + std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, + std::vector>& split_select_weights, + const std::vector& always_split_variable_names, std::string status_variable_name, + bool prediction_mode, bool sample_with_replacement, const std::vector& unordered_variable_names, + bool memory_saving_splitting, SplitRule splitrule, std::vector& case_weights, bool predict_all, + bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, + PredictionType prediction_type, uint num_random_splits, bool order_snps); + void init(std::string dependent_variable_name, MemoryMode memory_mode, std::unique_ptr input_data, uint mtry, + std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, + uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, + const std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, + bool predict_all, std::vector& sample_fraction, double alpha, double minprop, bool holdout, + PredictionType prediction_type, uint num_random_splits, bool order_snps); + virtual void initInternal(std::string status_variable_name) = 0; + + // Grow or predict + void run(bool verbose); + + // Write results to output files + void writeOutput(); + virtual void writeOutputInternal() = 0; + virtual void writeConfusionFile() = 0; + virtual void writePredictionFile() = 0; + void writeImportanceFile(); + + // Save forest to file + void saveToFile(); + virtual void saveToFileInternal(std::ofstream& outfile) = 0; + + std::vector>> getChildNodeIDs() { + std::vector>> result; + for (auto& tree : trees) { + result.push_back(tree->getChildNodeIDs()); + } + return result; + } + std::vector> getSplitVarIDs() { + std::vector> result; + for (auto& tree : trees) { + result.push_back(tree->getSplitVarIDs()); + } + return result; + } + std::vector> getSplitValues() { + std::vector> result; + for (auto& tree : trees) { + result.push_back(tree->getSplitValues()); + } + return result; + } + const std::vector& getVariableImportance() const { + return variable_importance; + } + double getOverallPredictionError() const { + return overall_prediction_error; + } + const std::vector>>& getPredictions() const { + return predictions; + } + size_t getDependentVarId() const { + return dependent_varID; + } + size_t getNumTrees() const { + return num_trees; + } + uint getMtry() const { + return mtry; + } + uint getMinNodeSize() const { + return min_node_size; + } + size_t getNumIndependentVariables() const { + return num_independent_variables; + } + + const std::vector& getIsOrderedVariable() const { + return data->getIsOrderedVariable(); + } + + std::vector> getInbagCounts() const { + std::vector> result; + for (auto& tree : trees) { + result.push_back(tree->getInbagCounts()); + } + return result; + } + + const std::vector>& getSnpOrder() const { + return data->getSnpOrder(); + } + +protected: + void grow(); + virtual void growInternal() = 0; + + // Predict using existing tree from file and data as prediction data + void predict(); + virtual void allocatePredictMemory() = 0; + virtual void predictInternal(size_t sample_idx) = 0; + + void computePredictionError(); + virtual void computePredictionErrorInternal() = 0; + + void computePermutationImportance(); + + // Multithreading methods for growing/prediction/importance, called by each thread + void growTreesInThread(uint thread_idx, std::vector* variable_importance); + void predictTreesInThread(uint thread_idx, const Data* prediction_data, bool oob_prediction); + void predictInternalInThread(uint thread_idx); + void computeTreePermutationImportanceInThread(uint thread_idx, std::vector& importance, + std::vector& variance); + + // Load forest from file + void loadFromFile(std::string filename); + virtual void loadFromFileInternal(std::ifstream& infile) = 0; + + // Set split select weights and variables to be always considered for splitting + void setSplitWeightVector(std::vector>& split_select_weights); + void setAlwaysSplitVariables(const std::vector& always_split_variable_names); + + // Show progress every few seconds +#ifdef OLD_WIN_R_BUILD + void showProgress(std::string operation, clock_t start_time, clock_t& lap_time); +#else + void showProgress(std::string operation, size_t max_progress); +#endif + + // Verbose output stream, cout if verbose==true, logfile if not + std::ostream* verbose_out; + + size_t num_trees; + uint mtry; + uint min_node_size; + size_t num_variables; + size_t num_independent_variables; + uint seed; + size_t dependent_varID; + size_t num_samples; + bool prediction_mode; + MemoryMode memory_mode; + bool sample_with_replacement; + bool memory_saving_splitting; + SplitRule splitrule; + bool predict_all; + bool keep_inbag; + std::vector sample_fraction; + bool holdout; + PredictionType prediction_type; + uint num_random_splits; + + // MAXSTAT splitrule + double alpha; + double minprop; + + // Multithreading + uint num_threads; + std::vector thread_ranges; +#ifndef OLD_WIN_R_BUILD + std::mutex mutex; + std::condition_variable condition_variable; +#endif + + std::vector> trees; + std::unique_ptr data; + + std::vector>> predictions; + double overall_prediction_error; + + // Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable + // Deterministic variables are always selected + std::vector deterministic_varIDs; + std::vector split_select_varIDs; + std::vector> split_select_weights; + + // Bootstrap weights + std::vector case_weights; + + // Random number generator + std::mt19937_64 random_number_generator; + + std::string output_prefix; + ImportanceMode importance_mode; + + // Variable importance for all variables in forest + std::vector variable_importance; + + // Computation progress (finished trees) + size_t progress; +#ifdef R_BUILD + size_t aborted_threads; + bool aborted; +#endif +}; + +} // namespace ranger + +#endif /* FOREST_H_ */ diff --git a/lib/ranger/ForestClassification.cpp b/lib/ranger/ForestClassification.cpp new file mode 100644 index 000000000..d69fb94ed --- /dev/null +++ b/lib/ranger/ForestClassification.cpp @@ -0,0 +1,334 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "utility.h" +#include "ForestClassification.h" +#include "TreeClassification.h" +#include "Data.h" + +namespace ranger { + +void ForestClassification::loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + std::vector& class_values, std::vector& is_ordered_variable) { + + this->dependent_varID = dependent_varID; + this->num_trees = num_trees; + this->class_values = class_values; + data->setIsOrderedVariable(is_ordered_variable); + + // Create trees + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back( + std::make_unique(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i], + &this->class_values, &response_classIDs)); + } + + // Create thread ranges + equalSplit(thread_ranges, 0, num_trees - 1, num_threads); +} + +void ForestClassification::initInternal(std::string status_variable_name) { + + // If mtry not set, use floored square root of number of independent variables. + if (mtry == 0) { + unsigned long temp = sqrt((double) (num_variables - 1)); + mtry = std::max((unsigned long) 1, temp); + } + + // Set minimal node size + if (min_node_size == 0) { + min_node_size = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION; + } + + // Create class_values and response_classIDs + if (!prediction_mode) { + for (size_t i = 0; i < num_samples; ++i) { + double value = data->get(i, dependent_varID); + + // If classID is already in class_values, use ID. Else create a new one. + uint classID = find(class_values.begin(), class_values.end(), value) - class_values.begin(); + if (classID == class_values.size()) { + class_values.push_back(value); + } + response_classIDs.push_back(classID); + } + } + + // Create sampleIDs_per_class if required + if (sample_fraction.size() > 1) { + sampleIDs_per_class.resize(sample_fraction.size()); + for (auto& v : sampleIDs_per_class) { + v.reserve(num_samples); + } + for (size_t i = 0; i < num_samples; ++i) { + size_t classID = response_classIDs[i]; + sampleIDs_per_class[classID].push_back(i); + } + } + + // Set class weights all to 1 + class_weights = std::vector(class_values.size(), 1.0); + + // Sort data if memory saving mode + if (!memory_saving_splitting) { + data->sort(); + } +} + +void ForestClassification::growInternal() { + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back( + std::make_unique(&class_values, &response_classIDs, &sampleIDs_per_class, &class_weights)); + } +} + +void ForestClassification::allocatePredictMemory() { + size_t num_prediction_samples = data->getNumRows(); + if (predict_all || prediction_type == TERMINALNODES) { + predictions = std::vector>>(1, + std::vector>(num_prediction_samples, std::vector(num_trees))); + } else { + predictions = std::vector>>(1, + std::vector>(1, std::vector(num_prediction_samples))); + } +} + +void ForestClassification::predictInternal(size_t sample_idx) { + if (predict_all || prediction_type == TERMINALNODES) { + // Get all tree predictions + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + if (prediction_type == TERMINALNODES) { + predictions[0][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx); + } else { + predictions[0][sample_idx][tree_idx] = getTreePrediction(tree_idx, sample_idx); + } + } + } else { + // Count classes over trees and save class with maximum count + std::unordered_map class_count; + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + ++class_count[getTreePrediction(tree_idx, sample_idx)]; + } + predictions[0][0][sample_idx] = mostFrequentValue(class_count, random_number_generator); + } +} + +void ForestClassification::computePredictionErrorInternal() { + + // Class counts for samples + std::vector> class_counts; + class_counts.reserve(num_samples); + for (size_t i = 0; i < num_samples; ++i) { + class_counts.push_back(std::unordered_map()); + } + + // For each tree loop over OOB samples and count classes + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + for (size_t sample_idx = 0; sample_idx < trees[tree_idx]->getNumSamplesOob(); ++sample_idx) { + size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx]; + ++class_counts[sampleID][getTreePrediction(tree_idx, sample_idx)]; + } + } + + // Compute majority vote for each sample + predictions = std::vector>>(1, + std::vector>(1, std::vector(num_samples))); + for (size_t i = 0; i < num_samples; ++i) { + if (!class_counts[i].empty()) { + predictions[0][0][i] = mostFrequentValue(class_counts[i], random_number_generator); + } else { + predictions[0][0][i] = NAN; + } + } + + // Compare predictions with true data + size_t num_missclassifications = 0; + size_t num_predictions = 0; + for (size_t i = 0; i < predictions[0][0].size(); ++i) { + double predicted_value = predictions[0][0][i]; + if (!std::isnan(predicted_value)) { + ++num_predictions; + double real_value = data->get(i, dependent_varID); + if (predicted_value != real_value) { + ++num_missclassifications; + } + ++classification_table[std::make_pair(real_value, predicted_value)]; + } + } + overall_prediction_error = (double) num_missclassifications / (double) num_predictions; +} + +// #nocov start +void ForestClassification::writeOutputInternal() { + if (verbose_out) { + *verbose_out << "Tree type: " << "Classification" << std::endl; + } +} + +void ForestClassification::writeConfusionFile() { + + // Open confusion file for writing + std::string filename = output_prefix + ".confusion"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to confusion file: " + filename + "."); + } + + // Write confusion to file + outfile << "Overall OOB prediction error (Fraction missclassified): " << overall_prediction_error << std::endl; + outfile << std::endl; + outfile << "Class specific prediction errors:" << std::endl; + outfile << " "; + for (auto& class_value : class_values) { + outfile << " " << class_value; + } + outfile << std::endl; + for (auto& predicted_value : class_values) { + outfile << "predicted " << predicted_value << " "; + for (auto& real_value : class_values) { + size_t value = classification_table[std::make_pair(real_value, predicted_value)]; + outfile << value; + if (value < 10) { + outfile << " "; + } else if (value < 100) { + outfile << " "; + } else if (value < 1000) { + outfile << " "; + } else if (value < 10000) { + outfile << " "; + } else if (value < 100000) { + outfile << " "; + } + } + outfile << std::endl; + } + + outfile.close(); + if (verbose_out) + *verbose_out << "Saved confusion matrix to file " << filename << "." << std::endl; +} + +void ForestClassification::writePredictionFile() { + + // Open prediction file for writing + std::string filename = output_prefix + ".prediction"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to prediction file: " + filename + "."); + } + + // Write + outfile << "Predictions: " << std::endl; + if (predict_all) { + for (size_t k = 0; k < num_trees; ++k) { + outfile << "Tree " << k << ":" << std::endl; + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + outfile << predictions[i][j][k] << std::endl; + } + } + outfile << std::endl; + } + } else { + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + for (size_t k = 0; k < predictions[i][j].size(); ++k) { + outfile << predictions[i][j][k] << std::endl; + } + } + } + } + + if (verbose_out) + *verbose_out << "Saved predictions to file " << filename << "." << std::endl; +} + +void ForestClassification::saveToFileInternal(std::ofstream& outfile) { + + // Write num_variables + outfile.write((char*) &num_variables, sizeof(num_variables)); + + // Write treetype + TreeType treetype = TREE_CLASSIFICATION; + outfile.write((char*) &treetype, sizeof(treetype)); + + // Write class_values + saveVector1D(class_values, outfile); +} + +void ForestClassification::loadFromFileInternal(std::ifstream& infile) { + + // Read number of variables + size_t num_variables_saved; + infile.read((char*) &num_variables_saved, sizeof(num_variables_saved)); + + // Read treetype + TreeType treetype; + infile.read((char*) &treetype, sizeof(treetype)); + if (treetype != TREE_CLASSIFICATION) { + throw std::runtime_error("Wrong treetype. Loaded file is not a classification forest."); + } + + // Read class_values + readVector1D(class_values, infile); + + for (size_t i = 0; i < num_trees; ++i) { + + // Read data + std::vector> child_nodeIDs; + readVector2D(child_nodeIDs, infile); + std::vector split_varIDs; + readVector1D(split_varIDs, infile); + std::vector split_values; + readVector1D(split_values, infile); + + // If dependent variable not in test data, change variable IDs accordingly + if (num_variables_saved > num_variables) { + for (auto& varID : split_varIDs) { + if (varID >= dependent_varID) { + --varID; + } + } + } + + // Create tree + trees.push_back( + std::make_unique(child_nodeIDs, split_varIDs, split_values, &class_values, &response_classIDs)); + } +} + +double ForestClassification::getTreePrediction(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPrediction(sample_idx); +} + +size_t ForestClassification::getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPredictionTerminalNodeID(sample_idx); +} + +// #nocov end + +}// namespace ranger diff --git a/lib/ranger/ForestClassification.h b/lib/ranger/ForestClassification.h new file mode 100644 index 000000000..fc23c70ea --- /dev/null +++ b/lib/ranger/ForestClassification.h @@ -0,0 +1,77 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef FORESTCLASSIFICATION_H_ +#define FORESTCLASSIFICATION_H_ + +#include +#include +#include +#include + +#include "globals.h" +#include "Forest.h" + +namespace ranger { + +class ForestClassification: public Forest { +public: + ForestClassification() = default; + + ForestClassification(const ForestClassification&) = delete; + ForestClassification& operator=(const ForestClassification&) = delete; + + virtual ~ForestClassification() override = default; + + void loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + std::vector& class_values, std::vector& is_ordered_variable); + + const std::vector& getClassValues() const { + return class_values; + } + + void setClassWeights(std::vector& class_weights) { + this->class_weights = class_weights; + } + +protected: + void initInternal(std::string status_variable_name) override; + void growInternal() override; + void allocatePredictMemory() override; + void predictInternal(size_t sample_idx) override; + void computePredictionErrorInternal() override; + void writeOutputInternal() override; + void writeConfusionFile() override; + void writePredictionFile() override; + void saveToFileInternal(std::ofstream& outfile) override; + void loadFromFileInternal(std::ifstream& infile) override; + + // Classes of the dependent variable and classIDs for responses + std::vector class_values; + std::vector response_classIDs; + std::vector> sampleIDs_per_class; + + // Splitting weights + std::vector class_weights; + + // Table with classifications and true classes + std::map, size_t> classification_table; + +private: + double getTreePrediction(size_t tree_idx, size_t sample_idx) const; + size_t getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const; +}; + +} // namespace ranger + +#endif /* FORESTCLASSIFICATION_H_ */ diff --git a/lib/ranger/ForestProbability.cpp b/lib/ranger/ForestProbability.cpp new file mode 100644 index 000000000..ab5eac13c --- /dev/null +++ b/lib/ranger/ForestProbability.cpp @@ -0,0 +1,341 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include + +#include "utility.h" +#include "ForestProbability.h" +#include "TreeProbability.h" +#include "Data.h" + +namespace ranger { + +void ForestProbability::loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + std::vector& class_values, std::vector>>& forest_terminal_class_counts, + std::vector& is_ordered_variable) { + + this->dependent_varID = dependent_varID; + this->num_trees = num_trees; + this->class_values = class_values; + data->setIsOrderedVariable(is_ordered_variable); + + // Create trees + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back( + std::make_unique(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i], + &this->class_values, &response_classIDs, forest_terminal_class_counts[i])); + } + + // Create thread ranges + equalSplit(thread_ranges, 0, num_trees - 1, num_threads); +} + +std::vector>> ForestProbability::getTerminalClassCounts() const { + std::vector>> result; + result.reserve(num_trees); + for (const auto& tree : trees) { + const auto& temp = dynamic_cast(*tree); + result.push_back(temp.getTerminalClassCounts()); + } + return result; +} + +void ForestProbability::initInternal(std::string status_variable_name) { + + // If mtry not set, use floored square root of number of independent variables. + if (mtry == 0) { + unsigned long temp = sqrt((double) (num_variables - 1)); + mtry = std::max((unsigned long) 1, temp); + } + + // Set minimal node size + if (min_node_size == 0) { + min_node_size = DEFAULT_MIN_NODE_SIZE_PROBABILITY; + } + + // Create class_values and response_classIDs + if (!prediction_mode) { + for (size_t i = 0; i < num_samples; ++i) { + double value = data->get(i, dependent_varID); + + // If classID is already in class_values, use ID. Else create a new one. + uint classID = find(class_values.begin(), class_values.end(), value) - class_values.begin(); + if (classID == class_values.size()) { + class_values.push_back(value); + } + response_classIDs.push_back(classID); + } + } + + // Create sampleIDs_per_class if required + if (sample_fraction.size() > 1) { + sampleIDs_per_class.resize(sample_fraction.size()); + for (auto& v : sampleIDs_per_class) { + v.reserve(num_samples); + } + for (size_t i = 0; i < num_samples; ++i) { + size_t classID = response_classIDs[i]; + sampleIDs_per_class[classID].push_back(i); + } + } + + // Set class weights all to 1 + class_weights = std::vector(class_values.size(), 1.0); + + // Sort data if memory saving mode + if (!memory_saving_splitting) { + data->sort(); + } +} + +void ForestProbability::growInternal() { + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back( + std::make_unique(&class_values, &response_classIDs, &sampleIDs_per_class, &class_weights)); + } +} + +void ForestProbability::allocatePredictMemory() { + size_t num_prediction_samples = data->getNumRows(); + if (predict_all) { + predictions = std::vector>>(num_prediction_samples, + std::vector>(class_values.size(), std::vector(num_trees, 0))); + } else if (prediction_type == TERMINALNODES) { + predictions = std::vector>>(1, + std::vector>(num_prediction_samples, std::vector(num_trees, 0))); + } else { + predictions = std::vector>>(1, + std::vector>(num_prediction_samples, std::vector(class_values.size(), 0))); + } +} + +void ForestProbability::predictInternal(size_t sample_idx) { + // For each sample compute proportions in each tree + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + if (predict_all) { + std::vector counts = getTreePrediction(tree_idx, sample_idx); + + for (size_t class_idx = 0; class_idx < counts.size(); ++class_idx) { + predictions[sample_idx][class_idx][tree_idx] += counts[class_idx]; + } + } else if (prediction_type == TERMINALNODES) { + predictions[0][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx); + } else { + std::vector counts = getTreePrediction(tree_idx, sample_idx); + + for (size_t class_idx = 0; class_idx < counts.size(); ++class_idx) { + predictions[0][sample_idx][class_idx] += counts[class_idx]; + } + } + } + + // Average over trees + if (!predict_all && prediction_type != TERMINALNODES) { + for (size_t class_idx = 0; class_idx < predictions[0][sample_idx].size(); ++class_idx) { + predictions[0][sample_idx][class_idx] /= num_trees; + } + } +} + +void ForestProbability::computePredictionErrorInternal() { + +// For each sample sum over trees where sample is OOB + std::vector samples_oob_count; + samples_oob_count.resize(num_samples, 0); + predictions = std::vector>>(1, + std::vector>(num_samples, std::vector(class_values.size(), 0))); + + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + for (size_t sample_idx = 0; sample_idx < trees[tree_idx]->getNumSamplesOob(); ++sample_idx) { + size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx]; + std::vector counts = getTreePrediction(tree_idx, sample_idx); + + for (size_t class_idx = 0; class_idx < counts.size(); ++class_idx) { + predictions[0][sampleID][class_idx] += counts[class_idx]; + } + ++samples_oob_count[sampleID]; + } + } + +// MSE with predicted probability and true data + size_t num_predictions = 0; + for (size_t i = 0; i < predictions[0].size(); ++i) { + if (samples_oob_count[i] > 0) { + ++num_predictions; + for (size_t j = 0; j < predictions[0][i].size(); ++j) { + predictions[0][i][j] /= (double) samples_oob_count[i]; + } + size_t real_classID = response_classIDs[i]; + double predicted_value = predictions[0][i][real_classID]; + overall_prediction_error += (1 - predicted_value) * (1 - predicted_value); + } else { + for (size_t j = 0; j < predictions[0][i].size(); ++j) { + predictions[0][i][j] = NAN; + } + } + } + + overall_prediction_error /= (double) num_predictions; +} + +// #nocov start +void ForestProbability::writeOutputInternal() { + if (verbose_out) { + *verbose_out << "Tree type: " << "Probability estimation" << std::endl; + } +} + +void ForestProbability::writeConfusionFile() { + +// Open confusion file for writing + std::string filename = output_prefix + ".confusion"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to confusion file: " + filename + "."); + } + +// Write confusion to file + outfile << "Overall OOB prediction error (MSE): " << overall_prediction_error << std::endl; + + outfile.close(); + if (verbose_out) + *verbose_out << "Saved prediction error to file " << filename << "." << std::endl; +} + +void ForestProbability::writePredictionFile() { + + // Open prediction file for writing + std::string filename = output_prefix + ".prediction"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to prediction file: " + filename + "."); + } + + // Write + outfile << "Class predictions, one sample per row." << std::endl; + for (auto& class_value : class_values) { + outfile << class_value << " "; + } + outfile << std::endl << std::endl; + + if (predict_all) { + for (size_t k = 0; k < num_trees; ++k) { + outfile << "Tree " << k << ":" << std::endl; + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + outfile << predictions[i][j][k] << " "; + } + outfile << std::endl; + } + outfile << std::endl; + } + } else { + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + for (size_t k = 0; k < predictions[i][j].size(); ++k) { + outfile << predictions[i][j][k] << " "; + } + outfile << std::endl; + } + } + } + + if (verbose_out) + *verbose_out << "Saved predictions to file " << filename << "." << std::endl; +} + +void ForestProbability::saveToFileInternal(std::ofstream& outfile) { + +// Write num_variables + outfile.write((char*) &num_variables, sizeof(num_variables)); + +// Write treetype + TreeType treetype = TREE_PROBABILITY; + outfile.write((char*) &treetype, sizeof(treetype)); + +// Write class_values + saveVector1D(class_values, outfile); +} + +void ForestProbability::loadFromFileInternal(std::ifstream& infile) { + +// Read number of variables + size_t num_variables_saved; + infile.read((char*) &num_variables_saved, sizeof(num_variables_saved)); + +// Read treetype + TreeType treetype; + infile.read((char*) &treetype, sizeof(treetype)); + if (treetype != TREE_PROBABILITY) { + throw std::runtime_error("Wrong treetype. Loaded file is not a probability estimation forest."); + } + +// Read class_values + readVector1D(class_values, infile); + + for (size_t i = 0; i < num_trees; ++i) { + + // Read data + std::vector> child_nodeIDs; + readVector2D(child_nodeIDs, infile); + std::vector split_varIDs; + readVector1D(split_varIDs, infile); + std::vector split_values; + readVector1D(split_values, infile); + + // Read Terminal node class counts + std::vector terminal_nodes; + readVector1D(terminal_nodes, infile); + std::vector> terminal_class_counts_vector; + readVector2D(terminal_class_counts_vector, infile); + + // Convert Terminal node class counts to vector with empty elemtents for non-terminal nodes + std::vector> terminal_class_counts; + terminal_class_counts.resize(child_nodeIDs[0].size(), std::vector()); + for (size_t j = 0; j < terminal_nodes.size(); ++j) { + terminal_class_counts[terminal_nodes[j]] = terminal_class_counts_vector[j]; + } + + // If dependent variable not in test data, change variable IDs accordingly + if (num_variables_saved > num_variables) { + for (auto& varID : split_varIDs) { + if (varID >= dependent_varID) { + --varID; + } + } + } + + // Create tree + trees.push_back( + std::make_unique(child_nodeIDs, split_varIDs, split_values, &class_values, &response_classIDs, + terminal_class_counts)); + } +} + +const std::vector& ForestProbability::getTreePrediction(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPrediction(sample_idx); +} + +size_t ForestProbability::getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPredictionTerminalNodeID(sample_idx); +} + +// #nocov end + +}// namespace ranger diff --git a/lib/ranger/ForestProbability.h b/lib/ranger/ForestProbability.h new file mode 100644 index 000000000..aa42c893f --- /dev/null +++ b/lib/ranger/ForestProbability.h @@ -0,0 +1,77 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef FORESTPROBABILITY_H_ +#define FORESTPROBABILITY_H_ + +#include +#include +#include + +#include "globals.h" +#include "Forest.h" +#include "TreeProbability.h" + +namespace ranger { + +class ForestProbability: public Forest { +public: + ForestProbability() = default; + + ForestProbability(const ForestProbability&) = delete; + ForestProbability& operator=(const ForestProbability&) = delete; + + virtual ~ForestProbability() override = default; + + void loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + std::vector& class_values, std::vector>>& forest_terminal_class_counts, + std::vector& is_ordered_variable); + + std::vector>> getTerminalClassCounts() const; + + const std::vector& getClassValues() const { + return class_values; + } + + void setClassWeights(std::vector& class_weights) { + this->class_weights = class_weights; + } + +protected: + void initInternal(std::string status_variable_name) override; + void growInternal() override; + void allocatePredictMemory() override; + void predictInternal(size_t sample_idx) override; + void computePredictionErrorInternal() override; + void writeOutputInternal() override; + void writeConfusionFile() override; + void writePredictionFile() override; + void saveToFileInternal(std::ofstream& outfile) override; + void loadFromFileInternal(std::ifstream& infile) override; + + // Classes of the dependent variable and classIDs for responses + std::vector class_values; + std::vector response_classIDs; + std::vector> sampleIDs_per_class; + + // Splitting weights + std::vector class_weights; + +private: + const std::vector& getTreePrediction(size_t tree_idx, size_t sample_idx) const; + size_t getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const; +}; + +} // namespace ranger + +#endif /* FORESTPROBABILITY_H_ */ diff --git a/lib/ranger/ForestRegression.cpp b/lib/ranger/ForestRegression.cpp new file mode 100644 index 000000000..a5f986bab --- /dev/null +++ b/lib/ranger/ForestRegression.cpp @@ -0,0 +1,254 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include + +#include "utility.h" +#include "ForestRegression.h" +#include "TreeRegression.h" +#include "Data.h" + +namespace ranger { + +void ForestRegression::loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + std::vector& is_ordered_variable) { + + this->dependent_varID = dependent_varID; + this->num_trees = num_trees; + data->setIsOrderedVariable(is_ordered_variable); + + // Create trees + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back( + std::make_unique(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i])); + } + + // Create thread ranges + equalSplit(thread_ranges, 0, num_trees - 1, num_threads); +} + +void ForestRegression::initInternal(std::string status_variable_name) { + + // If mtry not set, use floored square root of number of independent variables + if (mtry == 0) { + unsigned long temp = sqrt((double) (num_variables - 1)); + mtry = std::max((unsigned long) 1, temp); + } + + // Set minimal node size + if (min_node_size == 0) { + min_node_size = DEFAULT_MIN_NODE_SIZE_REGRESSION; + } + + // Sort data if memory saving mode + if (!memory_saving_splitting) { + data->sort(); + } +} + +void ForestRegression::growInternal() { + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back(std::make_unique()); + } +} + +void ForestRegression::allocatePredictMemory() { + size_t num_prediction_samples = data->getNumRows(); + if (predict_all || prediction_type == TERMINALNODES) { + predictions = std::vector>>(1, + std::vector>(num_prediction_samples, std::vector(num_trees))); + } else { + predictions = std::vector>>(1, + std::vector>(1, std::vector(num_prediction_samples))); + } +} + +void ForestRegression::predictInternal(size_t sample_idx) { + if (predict_all || prediction_type == TERMINALNODES) { + // Get all tree predictions + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + if (prediction_type == TERMINALNODES) { + predictions[0][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx); + } else { + predictions[0][sample_idx][tree_idx] = getTreePrediction(tree_idx, sample_idx); + } + } + } else { + // Mean over trees + double prediction_sum = 0; + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + prediction_sum += getTreePrediction(tree_idx, sample_idx); + } + predictions[0][0][sample_idx] = prediction_sum / num_trees; + } +} + +void ForestRegression::computePredictionErrorInternal() { + +// For each sample sum over trees where sample is OOB + std::vector samples_oob_count; + predictions = std::vector>>(1, + std::vector>(1, std::vector(num_samples, 0))); + samples_oob_count.resize(num_samples, 0); + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + for (size_t sample_idx = 0; sample_idx < trees[tree_idx]->getNumSamplesOob(); ++sample_idx) { + size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx]; + double value = getTreePrediction(tree_idx, sample_idx); + + predictions[0][0][sampleID] += value; + ++samples_oob_count[sampleID]; + } + } + +// MSE with predictions and true data + size_t num_predictions = 0; + for (size_t i = 0; i < predictions[0][0].size(); ++i) { + if (samples_oob_count[i] > 0) { + ++num_predictions; + predictions[0][0][i] /= (double) samples_oob_count[i]; + double predicted_value = predictions[0][0][i]; + double real_value = data->get(i, dependent_varID); + overall_prediction_error += (predicted_value - real_value) * (predicted_value - real_value); + } else { + predictions[0][0][i] = NAN; + } + } + + overall_prediction_error /= (double) num_predictions; +} + +// #nocov start +void ForestRegression::writeOutputInternal() { + if (verbose_out) { + *verbose_out << "Tree type: " << "Regression" << std::endl; + } +} + +void ForestRegression::writeConfusionFile() { + +// Open confusion file for writing + std::string filename = output_prefix + ".confusion"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to confusion file: " + filename + "."); + } + +// Write confusion to file + outfile << "Overall OOB prediction error (MSE): " << overall_prediction_error << std::endl; + + outfile.close(); + if (verbose_out) + *verbose_out << "Saved prediction error to file " << filename << "." << std::endl; +} + +void ForestRegression::writePredictionFile() { + +// Open prediction file for writing + std::string filename = output_prefix + ".prediction"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to prediction file: " + filename + "."); + } + + // Write + outfile << "Predictions: " << std::endl; + if (predict_all) { + for (size_t k = 0; k < num_trees; ++k) { + outfile << "Tree " << k << ":" << std::endl; + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + outfile << predictions[i][j][k] << std::endl; + } + } + outfile << std::endl; + } + } else { + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + for (size_t k = 0; k < predictions[i][j].size(); ++k) { + outfile << predictions[i][j][k] << std::endl; + } + } + } + } + + if (verbose_out) + *verbose_out << "Saved predictions to file " << filename << "." << std::endl; +} + +void ForestRegression::saveToFileInternal(std::ofstream& outfile) { + +// Write num_variables + outfile.write((char*) &num_variables, sizeof(num_variables)); + +// Write treetype + TreeType treetype = TREE_REGRESSION; + outfile.write((char*) &treetype, sizeof(treetype)); +} + +void ForestRegression::loadFromFileInternal(std::ifstream& infile) { + +// Read number of variables + size_t num_variables_saved; + infile.read((char*) &num_variables_saved, sizeof(num_variables_saved)); + +// Read treetype + TreeType treetype; + infile.read((char*) &treetype, sizeof(treetype)); + if (treetype != TREE_REGRESSION) { + throw std::runtime_error("Wrong treetype. Loaded file is not a regression forest."); + } + + for (size_t i = 0; i < num_trees; ++i) { + + // Read data + std::vector> child_nodeIDs; + readVector2D(child_nodeIDs, infile); + std::vector split_varIDs; + readVector1D(split_varIDs, infile); + std::vector split_values; + readVector1D(split_values, infile); + + // If dependent variable not in test data, change variable IDs accordingly + if (num_variables_saved > num_variables) { + for (auto& varID : split_varIDs) { + if (varID >= dependent_varID) { + --varID; + } + } + } + + // Create tree + trees.push_back(std::make_unique(child_nodeIDs, split_varIDs, split_values)); + } +} + +double ForestRegression::getTreePrediction(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPrediction(sample_idx); +} + +size_t ForestRegression::getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPredictionTerminalNodeID(sample_idx); +} + +// #nocov end + +}// namespace ranger diff --git a/lib/ranger/ForestRegression.h b/lib/ranger/ForestRegression.h new file mode 100644 index 000000000..bd59d22dc --- /dev/null +++ b/lib/ranger/ForestRegression.h @@ -0,0 +1,56 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef FORESTREGRESSION_H_ +#define FORESTREGRESSION_H_ + +#include +#include + +#include "globals.h" +#include "Forest.h" + +namespace ranger { + +class ForestRegression: public Forest { +public: + ForestRegression() = default; + + ForestRegression(const ForestRegression&) = delete; + ForestRegression& operator=(const ForestRegression&) = delete; + + virtual ~ForestRegression() override = default; + + void loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + std::vector& is_ordered_variable); + +private: + void initInternal(std::string status_variable_name) override; + void growInternal() override; + void allocatePredictMemory() override; + void predictInternal(size_t sample_idx) override; + void computePredictionErrorInternal() override; + void writeOutputInternal() override; + void writeConfusionFile() override; + void writePredictionFile() override; + void saveToFileInternal(std::ofstream& outfile) override; + void loadFromFileInternal(std::ifstream& infile) override; + +private: + double getTreePrediction(size_t tree_idx, size_t sample_idx) const; + size_t getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const; +}; + +} // namespace ranger + +#endif /* FORESTREGRESSION_H_ */ diff --git a/lib/ranger/ForestSurvival.cpp b/lib/ranger/ForestSurvival.cpp new file mode 100644 index 000000000..97bcf72e4 --- /dev/null +++ b/lib/ranger/ForestSurvival.cpp @@ -0,0 +1,362 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include + +#include "utility.h" +#include "ForestSurvival.h" +#include "Data.h" + +namespace ranger { + +void ForestSurvival::loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + size_t status_varID, std::vector> >& forest_chf, + std::vector& unique_timepoints, std::vector& is_ordered_variable) { + + this->dependent_varID = dependent_varID; + this->status_varID = status_varID; + this->num_trees = num_trees; + this->unique_timepoints = unique_timepoints; + data->setIsOrderedVariable(is_ordered_variable); + + // Create trees + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back( + std::make_unique(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i], + forest_chf[i], &this->unique_timepoints, &response_timepointIDs)); + } + + // Create thread ranges + equalSplit(thread_ranges, 0, num_trees - 1, num_threads); +} + +std::vector>> ForestSurvival::getChf() const { + std::vector>> result; + result.reserve(num_trees); + for (const auto& tree : trees) { + const auto& temp = dynamic_cast(*tree); + result.push_back(temp.getChf()); + } + return result; +} + +void ForestSurvival::initInternal(std::string status_variable_name) { + + // Convert status variable name to ID + if (!prediction_mode && !status_variable_name.empty()) { + status_varID = data->getVariableID(status_variable_name); + } + + data->addNoSplitVariable(status_varID); + + // If mtry not set, use floored square root of number of independent variables. + if (mtry == 0) { + unsigned long temp = ceil(sqrt((double) (num_variables - 2))); + mtry = std::max((unsigned long) 1, temp); + } + + // Set minimal node size + if (min_node_size == 0) { + min_node_size = DEFAULT_MIN_NODE_SIZE_SURVIVAL; + } + + // Create unique timepoints + std::set unique_timepoint_set; + for (size_t i = 0; i < num_samples; ++i) { + unique_timepoint_set.insert(data->get(i, dependent_varID)); + } + unique_timepoints.reserve(unique_timepoint_set.size()); + for (auto& t : unique_timepoint_set) { + unique_timepoints.push_back(t); + } + + // Create response_timepointIDs + if (!prediction_mode) { + for (size_t i = 0; i < num_samples; ++i) { + double value = data->get(i, dependent_varID); + + // If timepoint is already in unique_timepoints, use ID. Else create a new one. + uint timepointID = find(unique_timepoints.begin(), unique_timepoints.end(), value) - unique_timepoints.begin(); + response_timepointIDs.push_back(timepointID); + } + } + + // Sort data if extratrees and not memory saving mode + if (splitrule == EXTRATREES && !memory_saving_splitting) { + data->sort(); + } +} + +void ForestSurvival::growInternal() { + trees.reserve(num_trees); + for (size_t i = 0; i < num_trees; ++i) { + trees.push_back(std::make_unique(&unique_timepoints, status_varID, &response_timepointIDs)); + } +} + +void ForestSurvival::allocatePredictMemory() { + size_t num_prediction_samples = data->getNumRows(); + size_t num_timepoints = unique_timepoints.size(); + if (predict_all) { + predictions = std::vector>>(num_prediction_samples, + std::vector>(num_timepoints, std::vector(num_trees, 0))); + } else if (prediction_type == TERMINALNODES) { + predictions = std::vector>>(1, + std::vector>(num_prediction_samples, std::vector(num_trees, 0))); + } else { + predictions = std::vector>>(1, + std::vector>(num_prediction_samples, std::vector(num_timepoints, 0))); + } +} + +void ForestSurvival::predictInternal(size_t sample_idx) { + // For each timepoint sum over trees + if (predict_all) { + for (size_t j = 0; j < unique_timepoints.size(); ++j) { + for (size_t k = 0; k < num_trees; ++k) { + predictions[sample_idx][j][k] = getTreePrediction(k, sample_idx)[j]; + } + } + } else if (prediction_type == TERMINALNODES) { + for (size_t k = 0; k < num_trees; ++k) { + predictions[0][sample_idx][k] = getTreePredictionTerminalNodeID(k, sample_idx); + } + } else { + for (size_t j = 0; j < unique_timepoints.size(); ++j) { + double sample_time_prediction = 0; + for (size_t k = 0; k < num_trees; ++k) { + sample_time_prediction += getTreePrediction(k, sample_idx)[j]; + } + predictions[0][sample_idx][j] = sample_time_prediction / num_trees; + } + } +} + +void ForestSurvival::computePredictionErrorInternal() { + + size_t num_timepoints = unique_timepoints.size(); + + // For each sample sum over trees where sample is OOB + std::vector samples_oob_count; + samples_oob_count.resize(num_samples, 0); + predictions = std::vector>>(1, + std::vector>(num_samples, std::vector(num_timepoints, 0))); + + for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) { + for (size_t sample_idx = 0; sample_idx < trees[tree_idx]->getNumSamplesOob(); ++sample_idx) { + size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx]; + std::vector tree_sample_chf = getTreePrediction(tree_idx, sample_idx); + + for (size_t time_idx = 0; time_idx < tree_sample_chf.size(); ++time_idx) { + predictions[0][sampleID][time_idx] += tree_sample_chf[time_idx]; + } + ++samples_oob_count[sampleID]; + } + } + + // Divide sample predictions by number of trees where sample is oob and compute summed chf for samples + std::vector sum_chf; + sum_chf.reserve(predictions[0].size()); + std::vector oob_sampleIDs; + oob_sampleIDs.reserve(predictions[0].size()); + for (size_t i = 0; i < predictions[0].size(); ++i) { + if (samples_oob_count[i] > 0) { + double sum = 0; + for (size_t j = 0; j < predictions[0][i].size(); ++j) { + predictions[0][i][j] /= samples_oob_count[i]; + sum += predictions[0][i][j]; + } + sum_chf.push_back(sum); + oob_sampleIDs.push_back(i); + } + } + + // Use all samples which are OOB at least once + overall_prediction_error = 1 - computeConcordanceIndex(*data, sum_chf, dependent_varID, status_varID, oob_sampleIDs); +} + +// #nocov start +void ForestSurvival::writeOutputInternal() { + if (verbose_out) { + *verbose_out << "Tree type: " << "Survival" << std::endl; + *verbose_out << "Status variable name: " << data->getVariableNames()[status_varID] << std::endl; + *verbose_out << "Status variable ID: " << status_varID << std::endl; + } +} + +void ForestSurvival::writeConfusionFile() { + + // Open confusion file for writing + std::string filename = output_prefix + ".confusion"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to confusion file: " + filename + "."); + } + + // Write confusion to file + outfile << "Overall OOB prediction error (1 - C): " << overall_prediction_error << std::endl; + + outfile.close(); + if (verbose_out) + *verbose_out << "Saved prediction error to file " << filename << "." << std::endl; + +} + +void ForestSurvival::writePredictionFile() { + + // Open prediction file for writing + std::string filename = output_prefix + ".prediction"; + std::ofstream outfile; + outfile.open(filename, std::ios::out); + if (!outfile.good()) { + throw std::runtime_error("Could not write to prediction file: " + filename + "."); + } + + // Write + outfile << "Unique timepoints: " << std::endl; + for (auto& timepoint : unique_timepoints) { + outfile << timepoint << " "; + } + outfile << std::endl << std::endl; + + outfile << "Cumulative hazard function, one row per sample: " << std::endl; + if (predict_all) { + for (size_t k = 0; k < num_trees; ++k) { + outfile << "Tree " << k << ":" << std::endl; + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + outfile << predictions[i][j][k] << " "; + } + outfile << std::endl; + } + outfile << std::endl; + } + } else { + for (size_t i = 0; i < predictions.size(); ++i) { + for (size_t j = 0; j < predictions[i].size(); ++j) { + for (size_t k = 0; k < predictions[i][j].size(); ++k) { + outfile << predictions[i][j][k] << " "; + } + outfile << std::endl; + } + } + } + + if (verbose_out) + *verbose_out << "Saved predictions to file " << filename << "." << std::endl; +} + +void ForestSurvival::saveToFileInternal(std::ofstream& outfile) { + + // Write num_variables + outfile.write((char*) &num_variables, sizeof(num_variables)); + + // Write treetype + TreeType treetype = TREE_SURVIVAL; + outfile.write((char*) &treetype, sizeof(treetype)); + + // Write status_varID + outfile.write((char*) &status_varID, sizeof(status_varID)); + + // Write unique timepoints + saveVector1D(unique_timepoints, outfile); +} + +void ForestSurvival::loadFromFileInternal(std::ifstream& infile) { + + // Read number of variables + size_t num_variables_saved; + infile.read((char*) &num_variables_saved, sizeof(num_variables_saved)); + + // Read treetype + TreeType treetype; + infile.read((char*) &treetype, sizeof(treetype)); + if (treetype != TREE_SURVIVAL) { + throw std::runtime_error("Wrong treetype. Loaded file is not a survival forest."); + } + + // Read status_varID + infile.read((char*) &status_varID, sizeof(status_varID)); + + // Read unique timepoints + unique_timepoints.clear(); + readVector1D(unique_timepoints, infile); + + for (size_t i = 0; i < num_trees; ++i) { + + // Read data + std::vector> child_nodeIDs; + readVector2D(child_nodeIDs, infile); + std::vector split_varIDs; + readVector1D(split_varIDs, infile); + std::vector split_values; + readVector1D(split_values, infile); + + // Read chf + std::vector terminal_nodes; + readVector1D(terminal_nodes, infile); + std::vector> chf_vector; + readVector2D(chf_vector, infile); + + // Convert chf to vector with empty elements for non-terminal nodes + std::vector> chf; + chf.resize(child_nodeIDs[0].size(), std::vector()); +// for (size_t i = 0; i < child_nodeIDs.size(); ++i) { +// chf.push_back(std::vector()); +// } + for (size_t j = 0; j < terminal_nodes.size(); ++j) { + chf[terminal_nodes[j]] = chf_vector[j]; + } + + // If dependent variable not in test data, change variable IDs accordingly + if (num_variables_saved > num_variables) { + for (auto& varID : split_varIDs) { + if (varID >= dependent_varID) { + --varID; + } + } + } + if (num_variables_saved > num_variables + 1) { + for (auto& varID : split_varIDs) { + if (varID >= status_varID) { + --varID; + } + } + } + + // Create tree + trees.push_back( + std::make_unique(child_nodeIDs, split_varIDs, split_values, chf, &unique_timepoints, + &response_timepointIDs)); + } +} + +const std::vector& ForestSurvival::getTreePrediction(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPrediction(sample_idx); +} + +size_t ForestSurvival::getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const { + const auto& tree = dynamic_cast(*trees[tree_idx]); + return tree.getPredictionTerminalNodeID(sample_idx); +} + +// #nocov end + +}// namespace ranger diff --git a/lib/ranger/ForestSurvival.h b/lib/ranger/ForestSurvival.h new file mode 100644 index 000000000..4fc69171f --- /dev/null +++ b/lib/ranger/ForestSurvival.h @@ -0,0 +1,71 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef FORESTSURVIVAL_H_ +#define FORESTSURVIVAL_H_ + +#include +#include + +#include "globals.h" +#include "Forest.h" +#include "TreeSurvival.h" + +namespace ranger { + +class ForestSurvival: public Forest { +public: + ForestSurvival() = default; + + ForestSurvival(const ForestSurvival&) = delete; + ForestSurvival& operator=(const ForestSurvival&) = delete; + + virtual ~ForestSurvival() override = default; + + void loadForest(size_t dependent_varID, size_t num_trees, + std::vector> >& forest_child_nodeIDs, + std::vector>& forest_split_varIDs, std::vector>& forest_split_values, + size_t status_varID, std::vector> >& forest_chf, + std::vector& unique_timepoints, std::vector& is_ordered_variable); + + std::vector>> getChf() const; + + size_t getStatusVarId() const { + return status_varID; + } + const std::vector& getUniqueTimepoints() const { + return unique_timepoints; + } + +private: + void initInternal(std::string status_variable_name) override; + void growInternal() override; + void allocatePredictMemory() override; + void predictInternal(size_t sample_idx) override; + void computePredictionErrorInternal() override; + void writeOutputInternal() override; + void writeConfusionFile() override; + void writePredictionFile() override; + void saveToFileInternal(std::ofstream& outfile) override; + void loadFromFileInternal(std::ifstream& infile) override; + + size_t status_varID; + std::vector unique_timepoints; + std::vector response_timepointIDs; + +private: + const std::vector& getTreePrediction(size_t tree_idx, size_t sample_idx) const; + size_t getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const; +}; + +} // namespace ranger + +#endif /* FORESTSURVIVAL_H_ */ diff --git a/lib/ranger/Tree.cpp b/lib/ranger/Tree.cpp new file mode 100644 index 000000000..84275f039 --- /dev/null +++ b/lib/ranger/Tree.cpp @@ -0,0 +1,510 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include + +#include "Tree.h" +#include "utility.h" + +namespace ranger { + +Tree::Tree() : + dependent_varID(0), mtry(0), num_samples(0), num_samples_oob(0), min_node_size(0), deterministic_varIDs(0), split_select_varIDs( + 0), split_select_weights(0), case_weights(0), oob_sampleIDs(0), holdout(false), keep_inbag(false), data(0), variable_importance( + 0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(true), sample_fraction(0), memory_saving_splitting( + false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_random_splits( + DEFAULT_NUM_RANDOM_SPLITS) { +} + +Tree::Tree(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values) : + dependent_varID(0), mtry(0), num_samples(0), num_samples_oob(0), min_node_size(0), deterministic_varIDs(0), split_select_varIDs( + 0), split_select_weights(0), case_weights(0), split_varIDs(split_varIDs), split_values(split_values), child_nodeIDs( + child_nodeIDs), oob_sampleIDs(0), holdout(false), keep_inbag(false), data(0), variable_importance(0), importance_mode( + DEFAULT_IMPORTANCE_MODE), sample_with_replacement(true), sample_fraction(0), memory_saving_splitting(false), splitrule( + DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS) { +} + +void Tree::init(const Data* data, uint mtry, size_t dependent_varID, size_t num_samples, uint seed, + std::vector* deterministic_varIDs, std::vector* split_select_varIDs, + std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, + bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector* case_weights, + bool keep_inbag, std::vector* sample_fraction, double alpha, double minprop, bool holdout, + uint num_random_splits) { + + this->data = data; + this->mtry = mtry; + this->dependent_varID = dependent_varID; + this->num_samples = num_samples; + this->memory_saving_splitting = memory_saving_splitting; + + // Create root node, assign bootstrap sample and oob samples + child_nodeIDs.push_back(std::vector()); + child_nodeIDs.push_back(std::vector()); + createEmptyNode(); + + // Initialize random number generator and set seed + random_number_generator.seed(seed); + + this->deterministic_varIDs = deterministic_varIDs; + this->split_select_varIDs = split_select_varIDs; + this->split_select_weights = split_select_weights; + this->importance_mode = importance_mode; + this->min_node_size = min_node_size; + this->sample_with_replacement = sample_with_replacement; + this->splitrule = splitrule; + this->case_weights = case_weights; + this->keep_inbag = keep_inbag; + this->sample_fraction = sample_fraction; + this->holdout = holdout; + this->alpha = alpha; + this->minprop = minprop; + this->num_random_splits = num_random_splits; +} + +void Tree::grow(std::vector* variable_importance) { + // Allocate memory for tree growing + allocateMemory(); + + this->variable_importance = variable_importance; + +// Bootstrap, dependent if weighted or not and with or without replacement + if (!case_weights->empty()) { + if (sample_with_replacement) { + bootstrapWeighted(); + } else { + bootstrapWithoutReplacementWeighted(); + } + } else if (sample_fraction->size() > 1) { + if (sample_with_replacement) { + bootstrapClassWise(); + } else { + bootstrapWithoutReplacementClassWise(); + } + } else { + if (sample_with_replacement) { + bootstrap(); + } else { + bootstrapWithoutReplacement(); + } + } + +// While not all nodes terminal, split next node + size_t num_open_nodes = 1; + size_t i = 0; + while (num_open_nodes > 0) { + bool is_terminal_node = splitNode(i); + if (is_terminal_node) { + --num_open_nodes; + } else { + ++num_open_nodes; + } + ++i; + } + + // Delete sampleID vector to save memory + sampleIDs.clear(); + sampleIDs.shrink_to_fit(); + cleanUpInternal(); +} + +void Tree::predict(const Data* prediction_data, bool oob_prediction) { + + size_t num_samples_predict; + if (oob_prediction) { + num_samples_predict = num_samples_oob; + } else { + num_samples_predict = prediction_data->getNumRows(); + } + + prediction_terminal_nodeIDs.resize(num_samples_predict, 0); + +// For each sample start in root, drop down the tree and return final value + for (size_t i = 0; i < num_samples_predict; ++i) { + size_t sample_idx; + if (oob_prediction) { + sample_idx = oob_sampleIDs[i]; + } else { + sample_idx = i; + } + size_t nodeID = 0; + while (1) { + + // Break if terminal node + if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0) { + break; + } + + // Move to child + size_t split_varID = split_varIDs[nodeID]; + + double value = prediction_data->get(sample_idx, split_varID); + if (prediction_data->isOrderedVariable(split_varID)) { + if (value <= split_values[nodeID]) { + // Move to left child + nodeID = child_nodeIDs[0][nodeID]; + } else { + // Move to right child + nodeID = child_nodeIDs[1][nodeID]; + } + } else { + size_t factorID = floor(value) - 1; + size_t splitID = floor(split_values[nodeID]); + + // Left if 0 found at position factorID + if (!(splitID & (1 << factorID))) { + // Move to left child + nodeID = child_nodeIDs[0][nodeID]; + } else { + // Move to right child + nodeID = child_nodeIDs[1][nodeID]; + } + } + } + + prediction_terminal_nodeIDs[i] = nodeID; + } +} + +void Tree::computePermutationImportance(std::vector& forest_importance, std::vector& forest_variance) { + + size_t num_independent_variables = data->getNumCols() - data->getNoSplitVariables().size(); + +// Compute normal prediction accuracy for each tree. Predictions already computed.. + double accuracy_normal = computePredictionAccuracyInternal(); + + prediction_terminal_nodeIDs.clear(); + prediction_terminal_nodeIDs.resize(num_samples_oob, 0); + +// Reserve space for permutations, initialize with oob_sampleIDs + std::vector permutations(oob_sampleIDs); + +// Randomly permute for all independent variables + for (size_t i = 0; i < num_independent_variables; ++i) { + + // Skip no split variables + size_t varID = i; + for (auto& skip : data->getNoSplitVariables()) { + if (varID >= skip) { + ++varID; + } + } + + // Permute and compute prediction accuracy again for this permutation and save difference + permuteAndPredictOobSamples(varID, permutations); + double accuracy_permuted = computePredictionAccuracyInternal(); + double accuracy_difference = accuracy_normal - accuracy_permuted; + forest_importance[i] += accuracy_difference; + + // Compute variance + if (importance_mode == IMP_PERM_BREIMAN) { + forest_variance[i] += accuracy_difference * accuracy_difference; + } else if (importance_mode == IMP_PERM_LIAW) { + forest_variance[i] += accuracy_difference * accuracy_difference * num_samples_oob; + } + } +} + +void Tree::appendToFile(std::ofstream& file) { + +// Save general fields + saveVector2D(child_nodeIDs, file); + saveVector1D(split_varIDs, file); + saveVector1D(split_values, file); + +// Call special functions for subclasses to save special fields. + appendToFileInternal(file); +} + +void Tree::createPossibleSplitVarSubset(std::vector& result) { + + size_t num_vars = data->getNumCols(); + + // For corrected Gini importance add dummy variables + if (importance_mode == IMP_GINI_CORRECTED) { + num_vars += data->getNumCols() - data->getNoSplitVariables().size(); + } + + // Always use deterministic variables + std::copy(deterministic_varIDs->begin(), deterministic_varIDs->end(), std::inserter(result, result.end())); + + // Randomly add non-deterministic variables (according to weights if needed) + if (split_select_weights->empty()) { + drawWithoutReplacementSkip(result, random_number_generator, num_vars, data->getNoSplitVariables(), mtry); + } else { + // No corrected Gini importance supported for weighted splitting + size_t num_draws = mtry - result.size(); + drawWithoutReplacementWeighted(result, random_number_generator, *split_select_varIDs, num_draws, + *split_select_weights); + } +} + +bool Tree::splitNode(size_t nodeID) { + + // Select random subset of variables to possibly split at + std::vector possible_split_varIDs; + createPossibleSplitVarSubset(possible_split_varIDs); + + // Call subclass method, sets split_varIDs and split_values + bool stop = splitNodeInternal(nodeID, possible_split_varIDs); + if (stop) { + // Terminal node + return true; + } + + size_t split_varID = split_varIDs[nodeID]; + double split_value = split_values[nodeID]; + + // Save non-permuted variable for prediction + split_varIDs[nodeID] = data->getUnpermutedVarID(split_varID); + + // Create child nodes + size_t left_child_nodeID = sampleIDs.size(); + child_nodeIDs[0][nodeID] = left_child_nodeID; + createEmptyNode(); + + size_t right_child_nodeID = sampleIDs.size(); + child_nodeIDs[1][nodeID] = right_child_nodeID; + createEmptyNode(); + + // For each sample in node, assign to left or right child + if (data->isOrderedVariable(split_varID)) { + // Ordered: left is <= splitval and right is > splitval + for (auto& sampleID : sampleIDs[nodeID]) { + if (data->get(sampleID, split_varID) <= split_value) { + sampleIDs[left_child_nodeID].push_back(sampleID); + } else { + sampleIDs[right_child_nodeID].push_back(sampleID); + } + } + } else { + // Unordered: If bit at position is 1 -> right, 0 -> left + for (auto& sampleID : sampleIDs[nodeID]) { + + double level = data->get(sampleID, split_varID); + size_t factorID = floor(level) - 1; + size_t splitID = floor(split_value); + + // Left if 0 found at position factorID + if (!(splitID & (1 << factorID))) { + sampleIDs[left_child_nodeID].push_back(sampleID); + } else { + sampleIDs[right_child_nodeID].push_back(sampleID); + } + } + } + + // No terminal node + return false; +} + +void Tree::createEmptyNode() { + split_varIDs.push_back(0); + split_values.push_back(0); + child_nodeIDs[0].push_back(0); + child_nodeIDs[1].push_back(0); + sampleIDs.push_back(std::vector()); + + createEmptyNodeInternal(); +} + +size_t Tree::dropDownSamplePermuted(size_t permuted_varID, size_t sampleID, size_t permuted_sampleID) { + +// Start in root and drop down + size_t nodeID = 0; + while (child_nodeIDs[0][nodeID] != 0 || child_nodeIDs[1][nodeID] != 0) { + + // Permute if variable is permutation variable + size_t split_varID = split_varIDs[nodeID]; + size_t sampleID_final = sampleID; + if (split_varID == permuted_varID) { + sampleID_final = permuted_sampleID; + } + + // Move to child + double value = data->get(sampleID_final, split_varID); + if (data->isOrderedVariable(split_varID)) { + if (value <= split_values[nodeID]) { + // Move to left child + nodeID = child_nodeIDs[0][nodeID]; + } else { + // Move to right child + nodeID = child_nodeIDs[1][nodeID]; + } + } else { + size_t factorID = floor(value) - 1; + size_t splitID = floor(split_values[nodeID]); + + // Left if 0 found at position factorID + if (!(splitID & (1 << factorID))) { + // Move to left child + nodeID = child_nodeIDs[0][nodeID]; + } else { + // Move to right child + nodeID = child_nodeIDs[1][nodeID]; + } + } + + } + return nodeID; +} + +void Tree::permuteAndPredictOobSamples(size_t permuted_varID, std::vector& permutations) { + +// Permute OOB sample +//std::vector permutations(oob_sampleIDs); + std::shuffle(permutations.begin(), permutations.end(), random_number_generator); + +// For each sample, drop down the tree and add prediction + for (size_t i = 0; i < num_samples_oob; ++i) { + size_t nodeID = dropDownSamplePermuted(permuted_varID, oob_sampleIDs[i], permutations[i]); + prediction_terminal_nodeIDs[i] = nodeID; + } +} + +void Tree::bootstrap() { + +// Use fraction (default 63.21%) of the samples + size_t num_samples_inbag = (size_t) num_samples * (*sample_fraction)[0]; + +// Reserve space, reserve a little more to be save) + sampleIDs[0].reserve(num_samples_inbag); + oob_sampleIDs.reserve(num_samples * (exp(-(*sample_fraction)[0]) + 0.1)); + + std::uniform_int_distribution unif_dist(0, num_samples - 1); + +// Start with all samples OOB + inbag_counts.resize(num_samples, 0); + +// Draw num_samples samples with replacement (num_samples_inbag out of n) as inbag and mark as not OOB + for (size_t s = 0; s < num_samples_inbag; ++s) { + size_t draw = unif_dist(random_number_generator); + sampleIDs[0].push_back(draw); + ++inbag_counts[draw]; + } + +// Save OOB samples + for (size_t s = 0; s < inbag_counts.size(); ++s) { + if (inbag_counts[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + num_samples_oob = oob_sampleIDs.size(); + + if (!keep_inbag) { + inbag_counts.clear(); + inbag_counts.shrink_to_fit(); + } +} + +void Tree::bootstrapWeighted() { + +// Use fraction (default 63.21%) of the samples + size_t num_samples_inbag = (size_t) num_samples * (*sample_fraction)[0]; + +// Reserve space, reserve a little more to be save) + sampleIDs[0].reserve(num_samples_inbag); + oob_sampleIDs.reserve(num_samples * (exp(-(*sample_fraction)[0]) + 0.1)); + + std::discrete_distribution<> weighted_dist(case_weights->begin(), case_weights->end()); + +// Start with all samples OOB + inbag_counts.resize(num_samples, 0); + +// Draw num_samples samples with replacement (n out of n) as inbag and mark as not OOB + for (size_t s = 0; s < num_samples_inbag; ++s) { + size_t draw = weighted_dist(random_number_generator); + sampleIDs[0].push_back(draw); + ++inbag_counts[draw]; + } + + // Save OOB samples. In holdout mode these are the cases with 0 weight. + if (holdout) { + for (size_t s = 0; s < (*case_weights).size(); ++s) { + if ((*case_weights)[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + } else { + for (size_t s = 0; s < inbag_counts.size(); ++s) { + if (inbag_counts[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + } + num_samples_oob = oob_sampleIDs.size(); + + if (!keep_inbag) { + inbag_counts.clear(); + inbag_counts.shrink_to_fit(); + } +} + +void Tree::bootstrapWithoutReplacement() { + +// Use fraction (default 63.21%) of the samples + size_t num_samples_inbag = (size_t) num_samples * (*sample_fraction)[0]; + shuffleAndSplit(sampleIDs[0], oob_sampleIDs, num_samples, num_samples_inbag, random_number_generator); + num_samples_oob = oob_sampleIDs.size(); + + if (keep_inbag) { + // All observation are 0 or 1 times inbag + inbag_counts.resize(num_samples, 1); + for (size_t i = 0; i < oob_sampleIDs.size(); i++) { + inbag_counts[oob_sampleIDs[i]] = 0; + } + } +} + +void Tree::bootstrapWithoutReplacementWeighted() { + +// Use fraction (default 63.21%) of the samples + size_t num_samples_inbag = (size_t) num_samples * (*sample_fraction)[0]; + drawWithoutReplacementWeighted(sampleIDs[0], random_number_generator, num_samples - 1, num_samples_inbag, + *case_weights); + +// All observation are 0 or 1 times inbag + inbag_counts.resize(num_samples, 0); + for (auto& sampleID : sampleIDs[0]) { + inbag_counts[sampleID] = 1; + } + +// Save OOB samples. In holdout mode these are the cases with 0 weight. + if (holdout) { + for (size_t s = 0; s < (*case_weights).size(); ++s) { + if ((*case_weights)[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + } else { + for (size_t s = 0; s < inbag_counts.size(); ++s) { + if (inbag_counts[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + } + num_samples_oob = oob_sampleIDs.size(); + + if (!keep_inbag) { + inbag_counts.clear(); + inbag_counts.shrink_to_fit(); + } +} + +void Tree::bootstrapClassWise() { + // Empty on purpose (virtual function only implemented in classification and probability) +} + +void Tree::bootstrapWithoutReplacementClassWise() { + // Empty on purpose (virtual function only implemented in classification and probability) +} + +} // namespace ranger diff --git a/lib/ranger/Tree.h b/lib/ranger/Tree.h new file mode 100644 index 000000000..7f4b2433a --- /dev/null +++ b/lib/ranger/Tree.h @@ -0,0 +1,172 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef TREE_H_ +#define TREE_H_ + +#include +#include +#include +#include + +#include "globals.h" +#include "Data.h" + +namespace ranger { + +class Tree { +public: + Tree(); + + // Create from loaded forest + Tree(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values); + + virtual ~Tree() = default; + + Tree(const Tree&) = delete; + Tree& operator=(const Tree&) = delete; + + void init(const Data* data, uint mtry, size_t dependent_varID, size_t num_samples, uint seed, + std::vector* deterministic_varIDs, std::vector* split_select_varIDs, + std::vector* split_select_weights, ImportanceMode importance_mode, uint min_node_size, + bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, + std::vector* case_weights, bool keep_inbag, std::vector* sample_fraction, double alpha, + double minprop, bool holdout, uint num_random_splits); + + virtual void allocateMemory() = 0; + + void grow(std::vector* variable_importance); + + void predict(const Data* prediction_data, bool oob_prediction); + + void computePermutationImportance(std::vector& forest_importance, std::vector& forest_variance); + + void appendToFile(std::ofstream& file); + virtual void appendToFileInternal(std::ofstream& file) = 0; + + const std::vector>& getChildNodeIDs() const { + return child_nodeIDs; + } + const std::vector& getSplitValues() const { + return split_values; + } + const std::vector& getSplitVarIDs() const { + return split_varIDs; + } + + const std::vector& getOobSampleIDs() const { + return oob_sampleIDs; + } + size_t getNumSamplesOob() const { + return num_samples_oob; + } + + const std::vector& getInbagCounts() const { + return inbag_counts; + } + +protected: + void createPossibleSplitVarSubset(std::vector& result); + + bool splitNode(size_t nodeID); + virtual bool splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) = 0; + + void createEmptyNode(); + virtual void createEmptyNodeInternal() = 0; + + size_t dropDownSamplePermuted(size_t permuted_varID, size_t sampleID, size_t permuted_sampleID); + void permuteAndPredictOobSamples(size_t permuted_varID, std::vector& permutations); + + virtual double computePredictionAccuracyInternal() = 0; + + void bootstrap(); + void bootstrapWithoutReplacement(); + + void bootstrapWeighted(); + void bootstrapWithoutReplacementWeighted(); + + virtual void bootstrapClassWise(); + virtual void bootstrapWithoutReplacementClassWise(); + + virtual void cleanUpInternal() = 0; + + size_t dependent_varID; + uint mtry; + + // Number of samples (all samples, not only inbag for this tree) + size_t num_samples; + + // Number of OOB samples + size_t num_samples_oob; + + // Minimum node size to split, like in original RF nodes of smaller size can be produced + uint min_node_size; + + // Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable + // Deterministic variables are always selected + const std::vector* deterministic_varIDs; + const std::vector* split_select_varIDs; + const std::vector* split_select_weights; + + // Bootstrap weights + const std::vector* case_weights; + + // Splitting variable for each node + std::vector split_varIDs; + + // Value to split at for each node, for now only binary split + // For terminal nodes the prediction value is saved here + std::vector split_values; + + // Vector of left and right child node IDs, 0 for no child + std::vector> child_nodeIDs; + + // For each node a vector with IDs of samples in node + std::vector> sampleIDs; + + // IDs of OOB individuals, sorted + std::vector oob_sampleIDs; + + // Holdout mode + bool holdout; + + // Inbag counts + bool keep_inbag; + std::vector inbag_counts; + + // Random number generator + std::mt19937_64 random_number_generator; + + // Pointer to original data + const Data* data; + + // Variable importance for all variables + std::vector* variable_importance; + ImportanceMode importance_mode; + + // When growing here the OOB set is used + // Terminal nodeIDs for prediction samples + std::vector prediction_terminal_nodeIDs; + + bool sample_with_replacement; + const std::vector* sample_fraction; + + bool memory_saving_splitting; + SplitRule splitrule; + double alpha; + double minprop; + uint num_random_splits; +}; + +} // namespace ranger + +#endif /* TREE_H_ */ diff --git a/lib/ranger/TreeClassification.cpp b/lib/ranger/TreeClassification.cpp new file mode 100644 index 000000000..d607d4b3b --- /dev/null +++ b/lib/ranger/TreeClassification.cpp @@ -0,0 +1,751 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include + +#include "TreeClassification.h" +#include "utility.h" +#include "Data.h" + +namespace ranger { + +TreeClassification::TreeClassification(std::vector* class_values, std::vector* response_classIDs, + std::vector>* sampleIDs_per_class, std::vector* class_weights) : + class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class(sampleIDs_per_class), class_weights( + class_weights), counter(0), counter_per_class(0) { +} + +TreeClassification::TreeClassification(std::vector>& child_nodeIDs, + std::vector& split_varIDs, std::vector& split_values, std::vector* class_values, + std::vector* response_classIDs) : + Tree(child_nodeIDs, split_varIDs, split_values), class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class( + 0), class_weights(0), counter { }, counter_per_class { } { +} + +void TreeClassification::allocateMemory() { + // Init counters if not in memory efficient mode + if (!memory_saving_splitting) { + size_t num_classes = class_values->size(); + size_t max_num_splits = data->getMaxNumUniqueValues(); + + // Use number of random splits for extratrees + if (splitrule == EXTRATREES && num_random_splits > max_num_splits) { + max_num_splits = num_random_splits; + } + + counter.resize(max_num_splits); + counter_per_class.resize(num_classes * max_num_splits); + } +} + +double TreeClassification::estimate(size_t nodeID) { + + // Count classes over samples in node and return class with maximum count + std::vector class_count = std::vector(class_values->size(), 0.0); + + for (size_t i = 0; i < sampleIDs[nodeID].size(); ++i) { + size_t value = (*response_classIDs)[sampleIDs[nodeID][i]]; + class_count[value] += (*class_weights)[value]; + } + + if (sampleIDs[nodeID].size() > 0) { + size_t result_classID = mostFrequentClass(class_count, random_number_generator); + return ((*class_values)[result_classID]); + } else { + throw std::runtime_error("Error: Empty node."); + } + +} + +void TreeClassification::appendToFileInternal(std::ofstream& file) { // #nocov start + // Empty on purpose +} // #nocov end + +bool TreeClassification::splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) { + + // Check node size, stop if maximum reached + if (sampleIDs[nodeID].size() <= min_node_size) { + split_values[nodeID] = estimate(nodeID); + return true; + } + + // Check if node is pure and set split_value to estimate and stop if pure + bool pure = true; + double pure_value = 0; + for (size_t i = 0; i < sampleIDs[nodeID].size(); ++i) { + double value = data->get(sampleIDs[nodeID][i], dependent_varID); + if (i != 0 && value != pure_value) { + pure = false; + break; + } + pure_value = value; + } + if (pure) { + split_values[nodeID] = pure_value; + return true; + } + + // Find best split, stop if no decrease of impurity + bool stop; + if (splitrule == EXTRATREES) { + stop = findBestSplitExtraTrees(nodeID, possible_split_varIDs); + } else { + stop = findBestSplit(nodeID, possible_split_varIDs); + } + + if (stop) { + split_values[nodeID] = estimate(nodeID); + return true; + } + + return false; +} + +void TreeClassification::createEmptyNodeInternal() { + // Empty on purpose +} + +double TreeClassification::computePredictionAccuracyInternal() { + + size_t num_predictions = prediction_terminal_nodeIDs.size(); + size_t num_missclassifications = 0; + for (size_t i = 0; i < num_predictions; ++i) { + size_t terminal_nodeID = prediction_terminal_nodeIDs[i]; + double predicted_value = split_values[terminal_nodeID]; + double real_value = data->get(oob_sampleIDs[i], dependent_varID); + if (predicted_value != real_value) { + ++num_missclassifications; + } + } + return (1.0 - (double) num_missclassifications / (double) num_predictions); +} + +bool TreeClassification::findBestSplit(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + size_t num_classes = class_values->size(); + double best_decrease = -1; + size_t best_varID = 0; + double best_value = 0; + + std::vector class_counts(num_classes); + // Compute overall class counts + for (size_t i = 0; i < num_samples_node; ++i) { + size_t sampleID = sampleIDs[nodeID][i]; + uint sample_classID = (*response_classIDs)[sampleID]; + ++class_counts[sample_classID]; + } + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + + // Use memory saving method if option set + if (memory_saving_splitting) { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + // Use faster method for both cases + double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); + if (q < Q_THRESHOLD) { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } + } + } else { + findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } + } + + // Stop if no good split found + if (best_decrease < 0) { + return true; + } + + // Save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute gini index for this node and to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addGiniImportance(nodeID, best_varID, best_decrease); + } + return false; +} + +void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + +// Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + // -1 because no split possible at largest value + const size_t num_splits = possible_split_values.size() - 1; + if (memory_saving_splitting) { + std::vector class_counts_right(num_splits * num_classes), n_right(num_splits); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, class_counts_right, n_right); + } else { + std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, counter_per_class, counter); + } +} + +void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right) { + const size_t num_splits = possible_split_values.size() - 1; + + // Count samples in right child per class and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + uint sample_classID = (*response_classIDs)[sampleID]; + + // Count samples until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + if (value > possible_split_values[i]) { + ++n_right[i]; + ++class_counts_right[i * num_classes + sample_classID]; + } else { + break; + } + } + } + + // Compute decrease of impurity for each possible split + for (size_t i = 0; i < num_splits; ++i) { + + // Stop if one child empty + size_t n_left = num_samples_node - n_right[i]; + if (n_left == 0 || n_right[i] == 0) { + continue; + } + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[i * num_classes + j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_decrease = decrease; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Set counters to 0 + size_t num_unique = data->getNumUniqueDataValues(varID); + std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0); + std::fill_n(counter.begin(), num_unique, 0); + + // Count values + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } + + size_t n_left = 0; + std::vector class_counts_left(num_classes); + + // Compute decrease of impurity for each split + for (size_t i = 0; i < num_unique - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - n_left; + if (n_right == 0) { + break; + } + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + class_counts_left[j] += counter_per_class[i * num_classes + j]; + size_t class_count_right = class_counts[j] - class_counts_left[j]; + + sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + } + + // Decrease of impurity + double decrease = sum_right / (double) n_right + sum_left / (double) n_left; + + // If better than before, use this + if (decrease > best_decrease) { + // Find next value in this node + size_t j = i + 1; + while (j < num_unique && counter[j] == 0) { + ++j; + } + + // Use mid-point split + best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; + best_varID = varID; + best_decrease = decrease; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == data->getUniqueDataValue(varID, j)) { + best_value = data->getUniqueDataValue(varID, i); + } + } + } +} + +void TreeClassification::findBestSplitValueUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Create possible split values + std::vector factor_levels; + data->getAllValues(factor_levels, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (factor_levels.size() < 2) { + return; + } + + // Number of possible splits is 2^num_levels + size_t num_splits = (1 << factor_levels.size()); + + // Compute decrease of impurity for each possible split + // Split where all left (0) or all right (1) are excluded + // The second half of numbers is just left/right switched the first half -> Exclude second half + for (size_t local_splitID = 1; local_splitID < num_splits / 2; ++local_splitID) { + + // Compute overall splitID by shifting local factorIDs to global positions + size_t splitID = 0; + for (size_t j = 0; j < factor_levels.size(); ++j) { + if ((local_splitID & (1 << j))) { + double level = factor_levels[j]; + size_t factorID = floor(level) - 1; + splitID = splitID | (1 << factorID); + } + } + + // Initialize + std::vector class_counts_right(num_classes); + size_t n_right = 0; + + // Count classes in left and right child + for (auto& sampleID : sampleIDs[nodeID]) { + uint sample_classID = (*response_classIDs)[sampleID]; + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++n_right; + ++class_counts_right[sample_classID]; + } + } + size_t n_left = num_samples_node - n_right; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = splitID; + best_varID = varID; + best_decrease = decrease; + } + } +} + +bool TreeClassification::findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + size_t num_classes = class_values->size(); + double best_decrease = -1; + size_t best_varID = 0; + double best_value = 0; + + std::vector class_counts(num_classes); + // Compute overall class counts + for (size_t i = 0; i < num_samples_node; ++i) { + size_t sampleID = sampleIDs[nodeID][i]; + uint sample_classID = (*response_classIDs)[sampleID]; + ++class_counts[sample_classID]; + } + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, + best_varID, best_decrease); + } + } + + // Stop if no good split found + if (best_decrease < 0) { + return true; + } + + // Save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute gini index for this node and to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addGiniImportance(nodeID, best_varID, best_decrease); + } + return false; +} + +void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Get min/max values of covariate in node + double min; + double max; + data->getMinMaxValues(min, max, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (min == max) { + return; + } + + // Create possible split values: Draw randomly between min and max + std::vector possible_split_values; + std::uniform_real_distribution udist(min, max); + possible_split_values.reserve(num_random_splits); + for (size_t i = 0; i < num_random_splits; ++i) { + possible_split_values.push_back(udist(random_number_generator)); + } + + const size_t num_splits = possible_split_values.size(); + if (memory_saving_splitting) { + std::vector class_counts_right(num_splits * num_classes), n_right(num_splits); + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, class_counts_right, n_right); + } else { + std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, counter_per_class, counter); + } +} + +void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right) { + const size_t num_splits = possible_split_values.size(); + + // Count samples in right child per class and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + uint sample_classID = (*response_classIDs)[sampleID]; + + // Count samples until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + if (value > possible_split_values[i]) { + ++n_right[i]; + ++class_counts_right[i * num_classes + sample_classID]; + } else { + break; + } + } + } + + // Compute decrease of impurity for each possible split + for (size_t i = 0; i < num_splits; ++i) { + + // Stop if one child empty + size_t n_left = num_samples_node - n_right[i]; + if (n_left == 0 || n_right[i] == 0) { + continue; + } + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[i * num_classes + j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = possible_split_values[i]; + best_varID = varID; + best_decrease = decrease; + } + } +} + +void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + size_t num_unique_values = data->getNumUniqueDataValues(varID); + + // Get all factor indices in node + std::vector factor_in_node(num_unique_values, false); + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + factor_in_node[index] = true; + } + + // Vector of indices in and out of node + std::vector indices_in_node; + std::vector indices_out_node; + indices_in_node.reserve(num_unique_values); + indices_out_node.reserve(num_unique_values); + for (size_t i = 0; i < num_unique_values; ++i) { + if (factor_in_node[i]) { + indices_in_node.push_back(i); + } else { + indices_out_node.push_back(i); + } + } + + // Generate num_random_splits splits + for (size_t i = 0; i < num_random_splits; ++i) { + std::vector split_subset; + split_subset.reserve(num_unique_values); + + // Draw random subsets, sample all partitions with equal probability + if (indices_in_node.size() > 1) { + size_t num_partitions = (2 << (indices_in_node.size() - 1)) - 2; // 2^n-2 (don't allow full or empty) + std::uniform_int_distribution udist(1, num_partitions); + size_t splitID_in_node = udist(random_number_generator); + for (size_t j = 0; j < indices_in_node.size(); ++j) { + if ((splitID_in_node & (1 << j)) > 0) { + split_subset.push_back(indices_in_node[j]); + } + } + } + if (indices_out_node.size() > 1) { + size_t num_partitions = (2 << (indices_out_node.size() - 1)) - 1; // 2^n-1 (allow full or empty) + std::uniform_int_distribution udist(0, num_partitions); + size_t splitID_out_node = udist(random_number_generator); + for (size_t j = 0; j < indices_out_node.size(); ++j) { + if ((splitID_out_node & (1 << j)) > 0) { + split_subset.push_back(indices_out_node[j]); + } + } + } + + // Assign union of the two subsets to right child + size_t splitID = 0; + for (auto& idx : split_subset) { + splitID |= 1 << idx; + } + + // Initialize + std::vector class_counts_right(num_classes); + size_t n_right = 0; + + // Count classes in left and right child + for (auto& sampleID : sampleIDs[nodeID]) { + uint sample_classID = (*response_classIDs)[sampleID]; + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++n_right; + ++class_counts_right[sample_classID]; + } + } + size_t n_left = num_samples_node - n_right; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = splitID; + best_varID = varID; + best_decrease = decrease; + } + } +} + +void TreeClassification::addGiniImportance(size_t nodeID, size_t varID, double decrease) { + + std::vector class_counts; + class_counts.resize(class_values->size(), 0); + + for (auto& sampleID : sampleIDs[nodeID]) { + uint sample_classID = (*response_classIDs)[sampleID]; + class_counts[sample_classID]++; + } + double sum_node = 0; + for (auto& class_count : class_counts) { + sum_node += class_count * class_count; + } + double best_gini = decrease - sum_node / (double) sampleIDs[nodeID].size(); + + // No variable importance for no split variables + size_t tempvarID = data->getUnpermutedVarID(varID); + for (auto& skip : data->getNoSplitVariables()) { + if (tempvarID >= skip) { + --tempvarID; + } + } + + // Subtract if corrected importance and permuted variable, else add + if (importance_mode == IMP_GINI_CORRECTED && varID >= data->getNumCols()) { + (*variable_importance)[tempvarID] -= best_gini; + } else { + (*variable_importance)[tempvarID] += best_gini; + } +} + +void TreeClassification::bootstrapClassWise() { + // Number of samples is sum of sample fraction * number of samples + size_t num_samples_inbag = 0; + double sum_sample_fraction = 0; + for (auto& s : *sample_fraction) { + num_samples_inbag += (size_t) num_samples * s; + sum_sample_fraction += s; + } + + // Reserve space, reserve a little more to be save) + sampleIDs[0].reserve(num_samples_inbag); + oob_sampleIDs.reserve(num_samples * (exp(-sum_sample_fraction) + 0.1)); + + // Start with all samples OOB + inbag_counts.resize(num_samples, 0); + + // Draw samples for each class + for (size_t i = 0; i < sample_fraction->size(); ++i) { + // Draw samples of class with replacement as inbag and mark as not OOB + size_t num_samples_class = (*sampleIDs_per_class)[i].size(); + size_t num_samples_inbag_class = round(num_samples * (*sample_fraction)[i]); + std::uniform_int_distribution unif_dist(0, num_samples_class - 1); + for (size_t s = 0; s < num_samples_inbag_class; ++s) { + size_t draw = (*sampleIDs_per_class)[i][unif_dist(random_number_generator)]; + sampleIDs[0].push_back(draw); + ++inbag_counts[draw]; + } + } + + // Save OOB samples + for (size_t s = 0; s < inbag_counts.size(); ++s) { + if (inbag_counts[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + num_samples_oob = oob_sampleIDs.size(); + + if (!keep_inbag) { + inbag_counts.clear(); + inbag_counts.shrink_to_fit(); + } +} + +void TreeClassification::bootstrapWithoutReplacementClassWise() { + // Draw samples for each class + for (size_t i = 0; i < sample_fraction->size(); ++i) { + size_t num_samples_class = (*sampleIDs_per_class)[i].size(); + size_t num_samples_inbag_class = round(num_samples * (*sample_fraction)[i]); + + shuffleAndSplitAppend(sampleIDs[0], oob_sampleIDs, num_samples_class, num_samples_inbag_class, + (*sampleIDs_per_class)[i], random_number_generator); + } + + if (keep_inbag) { + // All observation are 0 or 1 times inbag + inbag_counts.resize(num_samples, 1); + for (size_t i = 0; i < oob_sampleIDs.size(); i++) { + inbag_counts[oob_sampleIDs[i]] = 0; + } + } +} + +} // namespace ranger diff --git a/lib/ranger/TreeClassification.h b/lib/ranger/TreeClassification.h new file mode 100644 index 000000000..5df1803d6 --- /dev/null +++ b/lib/ranger/TreeClassification.h @@ -0,0 +1,111 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef TREECLASSIFICATION_H_ +#define TREECLASSIFICATION_H_ + +#include + +#include "globals.h" +#include "Tree.h" + +namespace ranger { + +class TreeClassification: public Tree { +public: + TreeClassification(std::vector* class_values, std::vector* response_classIDs, + std::vector>* sampleIDs_per_class, std::vector* class_weights); + + // Create from loaded forest + TreeClassification(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values, std::vector* class_values, std::vector* response_classIDs); + + TreeClassification(const TreeClassification&) = delete; + TreeClassification& operator=(const TreeClassification&) = delete; + + virtual ~TreeClassification() override = default; + + void allocateMemory() override; + + double estimate(size_t nodeID); + void computePermutationImportanceInternal(std::vector>* permutations); + void appendToFileInternal(std::ofstream& file) override; + + double getPrediction(size_t sampleID) const { + size_t terminal_nodeID = prediction_terminal_nodeIDs[sampleID]; + return split_values[terminal_nodeID]; + } + + size_t getPredictionTerminalNodeID(size_t sampleID) const { + return prediction_terminal_nodeIDs[sampleID]; + } + +private: + bool splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) override; + void createEmptyNodeInternal() override; + + double computePredictionAccuracyInternal() override; + + // Called by splitNodeInternal(). Sets split_varIDs and split_values. + bool findBestSplit(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right); + void findBestSplitValueLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + + bool findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right); + void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + + void addGiniImportance(size_t nodeID, size_t varID, double decrease); + + void bootstrapClassWise() override; + void bootstrapWithoutReplacementClassWise() override; + + void cleanUpInternal() override { + counter.clear(); + counter.shrink_to_fit(); + counter_per_class.clear(); + counter_per_class.shrink_to_fit(); + } + + // Classes of the dependent variable and classIDs for responses + const std::vector* class_values; + const std::vector* response_classIDs; + const std::vector>* sampleIDs_per_class; + + // Splitting weights + const std::vector* class_weights; + + std::vector counter; + std::vector counter_per_class; +}; + +} // namespace ranger + +#endif /* TREECLASSIFICATION_H_ */ diff --git a/lib/ranger/TreeProbability.cpp b/lib/ranger/TreeProbability.cpp new file mode 100644 index 000000000..935ea8162 --- /dev/null +++ b/lib/ranger/TreeProbability.cpp @@ -0,0 +1,756 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include "TreeProbability.h" +#include "utility.h" +#include "Data.h" + +namespace ranger { + +TreeProbability::TreeProbability(std::vector* class_values, std::vector* response_classIDs, + std::vector>* sampleIDs_per_class, std::vector* class_weights) : + class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class(sampleIDs_per_class), class_weights( + class_weights), counter(0), counter_per_class(0) { +} + +TreeProbability::TreeProbability(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values, std::vector* class_values, std::vector* response_classIDs, + std::vector>& terminal_class_counts) : + Tree(child_nodeIDs, split_varIDs, split_values), class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class( + 0), terminal_class_counts(terminal_class_counts), class_weights(0), counter(0), counter_per_class(0) { +} + +void TreeProbability::allocateMemory() { + // Init counters if not in memory efficient mode + if (!memory_saving_splitting) { + size_t num_classes = class_values->size(); + size_t max_num_splits = data->getMaxNumUniqueValues(); + + // Use number of random splits for extratrees + if (splitrule == EXTRATREES && num_random_splits > max_num_splits) { + max_num_splits = num_random_splits; + } + + counter.resize(max_num_splits); + counter_per_class.resize(num_classes * max_num_splits); + } +} + +void TreeProbability::addToTerminalNodes(size_t nodeID) { + + size_t num_samples_in_node = sampleIDs[nodeID].size(); + terminal_class_counts[nodeID].resize(class_values->size(), 0); + + // Compute counts + for (size_t i = 0; i < num_samples_in_node; ++i) { + size_t node_sampleID = sampleIDs[nodeID][i]; + size_t classID = (*response_classIDs)[node_sampleID]; + ++terminal_class_counts[nodeID][classID]; + } + + // Compute fractions + for (size_t i = 0; i < terminal_class_counts[nodeID].size(); ++i) { + terminal_class_counts[nodeID][i] /= num_samples_in_node; + } +} + +void TreeProbability::appendToFileInternal(std::ofstream& file) { // #nocov start + + // Add Terminal node class counts + // Convert to vector without empty elements and save + std::vector terminal_nodes; + std::vector> terminal_class_counts_vector; + for (size_t i = 0; i < terminal_class_counts.size(); ++i) { + if (!terminal_class_counts[i].empty()) { + terminal_nodes.push_back(i); + terminal_class_counts_vector.push_back(terminal_class_counts[i]); + } + } + saveVector1D(terminal_nodes, file); + saveVector2D(terminal_class_counts_vector, file); +} // #nocov end + +bool TreeProbability::splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) { + + // Check node size, stop if maximum reached + if (sampleIDs[nodeID].size() <= min_node_size) { + addToTerminalNodes(nodeID); + return true; + } + + // Check if node is pure and set split_value to estimate and stop if pure + bool pure = true; + double pure_value = 0; + for (size_t i = 0; i < sampleIDs[nodeID].size(); ++i) { + double value = data->get(sampleIDs[nodeID][i], dependent_varID); + if (i != 0 && value != pure_value) { + pure = false; + break; + } + pure_value = value; + } + if (pure) { + addToTerminalNodes(nodeID); + return true; + } + + // Find best split, stop if no decrease of impurity + bool stop; + if (splitrule == EXTRATREES) { + stop = findBestSplitExtraTrees(nodeID, possible_split_varIDs); + } else { + stop = findBestSplit(nodeID, possible_split_varIDs); + } + + if (stop) { + addToTerminalNodes(nodeID); + return true; + } + + return false; +} + +void TreeProbability::createEmptyNodeInternal() { + terminal_class_counts.push_back(std::vector()); +} + +double TreeProbability::computePredictionAccuracyInternal() { + + size_t num_predictions = prediction_terminal_nodeIDs.size(); + double sum_of_squares = 0; + for (size_t i = 0; i < num_predictions; ++i) { + size_t sampleID = oob_sampleIDs[i]; + size_t real_classID = (*response_classIDs)[sampleID]; + size_t terminal_nodeID = prediction_terminal_nodeIDs[i]; + double predicted_value = terminal_class_counts[terminal_nodeID][real_classID]; + sum_of_squares += (1 - predicted_value) * (1 - predicted_value); + } + return (1.0 - sum_of_squares / (double) num_predictions); +} + +bool TreeProbability::findBestSplit(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + size_t num_classes = class_values->size(); + double best_decrease = -1; + size_t best_varID = 0; + double best_value = 0; + + std::vector class_counts(num_classes); + // Compute overall class counts + for (size_t i = 0; i < num_samples_node; ++i) { + size_t sampleID = sampleIDs[nodeID][i]; + uint sample_classID = (*response_classIDs)[sampleID]; + ++class_counts[sample_classID]; + } + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + + // Use memory saving method if option set + if (memory_saving_splitting) { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + // Use faster method for both cases + double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); + if (q < Q_THRESHOLD) { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } + } + } else { + findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } + } + + // Stop if no good split found + if (best_decrease < 0) { + return true; + } + + // Save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_decrease); + } + return false; +} + +void TreeProbability::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + // -1 because no split possible at largest value + const size_t num_splits = possible_split_values.size() - 1; + if (memory_saving_splitting) { + std::vector class_counts_right(num_splits * num_classes), n_right(num_splits); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, class_counts_right, n_right); + } else { + std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, counter_per_class, counter); + } +} + +void TreeProbability::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right) { + // -1 because no split possible at largest value + const size_t num_splits = possible_split_values.size() - 1; + + // Count samples in right child per class and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + uint sample_classID = (*response_classIDs)[sampleID]; + + // Count samples until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + if (value > possible_split_values[i]) { + ++n_right[i]; + ++class_counts_right[i * num_classes + sample_classID]; + } else { + break; + } + } + } + + // Compute decrease of impurity for each possible split + for (size_t i = 0; i < num_splits; ++i) { + + // Stop if one child empty + size_t n_left = num_samples_node - n_right[i]; + if (n_left == 0 || n_right[i] == 0) { + continue; + } + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[i * num_classes + j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_decrease = decrease; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeProbability::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Set counters to 0 + size_t num_unique = data->getNumUniqueDataValues(varID); + std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0); + std::fill_n(counter.begin(), num_unique, 0); + + // Count values + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } + + size_t n_left = 0; + std::vector class_counts_left(num_classes); + + // Compute decrease of impurity for each split + for (size_t i = 0; i < num_unique - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - n_left; + if (n_right == 0) { + break; + } + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + class_counts_left[j] += counter_per_class[i * num_classes + j]; + size_t class_count_right = class_counts[j] - class_counts_left[j]; + + sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + } + + // Decrease of impurity + double decrease = sum_right / (double) n_right + sum_left / (double) n_left; + + // If better than before, use this + if (decrease > best_decrease) { + // Find next value in this node + size_t j = i + 1; + while (j < num_unique && counter[j] == 0) { + ++j; + } + + // Use mid-point split + best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; + best_varID = varID; + best_decrease = decrease; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == data->getUniqueDataValue(varID, j)) { + best_value = data->getUniqueDataValue(varID, i); + } + } + } +} + +void TreeProbability::findBestSplitValueUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Create possible split values + std::vector factor_levels; + data->getAllValues(factor_levels, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (factor_levels.size() < 2) { + return; + } + + // Number of possible splits is 2^num_levels + size_t num_splits = (1 << factor_levels.size()); + + // Compute decrease of impurity for each possible split + // Split where all left (0) or all right (1) are excluded + // The second half of numbers is just left/right switched the first half -> Exclude second half + for (size_t local_splitID = 1; local_splitID < num_splits / 2; ++local_splitID) { + + // Compute overall splitID by shifting local factorIDs to global positions + size_t splitID = 0; + for (size_t j = 0; j < factor_levels.size(); ++j) { + if ((local_splitID & (1 << j))) { + double level = factor_levels[j]; + size_t factorID = floor(level) - 1; + splitID = splitID | (1 << factorID); + } + } + + // Initialize + std::vector class_counts_right(num_classes); + size_t n_right = 0; + + // Count classes in left and right child + for (auto& sampleID : sampleIDs[nodeID]) { + uint sample_classID = (*response_classIDs)[sampleID]; + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++n_right; + ++class_counts_right[sample_classID]; + } + } + size_t n_left = num_samples_node - n_right; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = splitID; + best_varID = varID; + best_decrease = decrease; + } + } +} + +bool TreeProbability::findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + size_t num_classes = class_values->size(); + double best_decrease = -1; + size_t best_varID = 0; + double best_value = 0; + + std::vector class_counts(num_classes); + // Compute overall class counts + for (size_t i = 0; i < num_samples_node; ++i) { + size_t sampleID = sampleIDs[nodeID][i]; + uint sample_classID = (*response_classIDs)[sampleID]; + ++class_counts[sample_classID]; + } + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, + best_varID, best_decrease); + } + } + + // Stop if no good split found + if (best_decrease < 0) { + return true; + } + + // Save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_decrease); + } + return false; +} + +void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Get min/max values of covariate in node + double min; + double max; + data->getMinMaxValues(min, max, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (min == max) { + return; + } + + // Create possible split values: Draw randomly between min and max + std::vector possible_split_values; + std::uniform_real_distribution udist(min, max); + possible_split_values.reserve(num_random_splits); + for (size_t i = 0; i < num_random_splits; ++i) { + possible_split_values.push_back(udist(random_number_generator)); + } + + const size_t num_splits = possible_split_values.size(); + if (memory_saving_splitting) { + std::vector class_counts_right(num_splits * num_classes), n_right(num_splits); + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, class_counts_right, n_right); + } else { + std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, counter_per_class, counter); + } +} + +void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right) { + const size_t num_splits = possible_split_values.size(); + + // Count samples in right child per class and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + uint sample_classID = (*response_classIDs)[sampleID]; + + // Count samples until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + if (value > possible_split_values[i]) { + ++n_right[i]; + ++class_counts_right[i * num_classes + sample_classID]; + } else { + break; + } + } + } + + // Compute decrease of impurity for each possible split + for (size_t i = 0; i < num_splits; ++i) { + + // Stop if one child empty + size_t n_left = num_samples_node - n_right[i]; + if (n_left == 0 || n_right[i] == 0) { + continue; + } + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[i * num_classes + j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right[i]; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = possible_split_values[i]; + best_varID = varID; + best_decrease = decrease; + } + } +} + +void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + size_t num_unique_values = data->getNumUniqueDataValues(varID); + + // Get all factor indices in node + std::vector factor_in_node(num_unique_values, false); + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + factor_in_node[index] = true; + } + + // Vector of indices in and out of node + std::vector indices_in_node; + std::vector indices_out_node; + indices_in_node.reserve(num_unique_values); + indices_out_node.reserve(num_unique_values); + for (size_t i = 0; i < num_unique_values; ++i) { + if (factor_in_node[i]) { + indices_in_node.push_back(i); + } else { + indices_out_node.push_back(i); + } + } + + // Generate num_random_splits splits + for (size_t i = 0; i < num_random_splits; ++i) { + std::vector split_subset; + split_subset.reserve(num_unique_values); + + // Draw random subsets, sample all partitions with equal probability + if (indices_in_node.size() > 1) { + size_t num_partitions = (2 << (indices_in_node.size() - 1)) - 2; // 2^n-2 (don't allow full or empty) + std::uniform_int_distribution udist(1, num_partitions); + size_t splitID_in_node = udist(random_number_generator); + for (size_t j = 0; j < indices_in_node.size(); ++j) { + if ((splitID_in_node & (1 << j)) > 0) { + split_subset.push_back(indices_in_node[j]); + } + } + } + if (indices_out_node.size() > 1) { + size_t num_partitions = (2 << (indices_out_node.size() - 1)) - 1; // 2^n-1 (allow full or empty) + std::uniform_int_distribution udist(0, num_partitions); + size_t splitID_out_node = udist(random_number_generator); + for (size_t j = 0; j < indices_out_node.size(); ++j) { + if ((splitID_out_node & (1 << j)) > 0) { + split_subset.push_back(indices_out_node[j]); + } + } + } + + // Assign union of the two subsets to right child + size_t splitID = 0; + for (auto& idx : split_subset) { + splitID |= 1 << idx; + } + + // Initialize + std::vector class_counts_right(num_classes); + size_t n_right = 0; + + // Count classes in left and right child + for (auto& sampleID : sampleIDs[nodeID]) { + uint sample_classID = (*response_classIDs)[sampleID]; + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++n_right; + ++class_counts_right[sample_classID]; + } + } + size_t n_left = num_samples_node - n_right; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts_right[j]; + size_t class_count_left = class_counts[j] - class_count_right; + + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + sum_left += (*class_weights)[j] * class_count_left * class_count_left; + } + + // Decrease of impurity + double decrease = sum_left / (double) n_left + sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = splitID; + best_varID = varID; + best_decrease = decrease; + } + } +} + +void TreeProbability::addImpurityImportance(size_t nodeID, size_t varID, double decrease) { + + std::vector class_counts; + class_counts.resize(class_values->size(), 0); + + for (auto& sampleID : sampleIDs[nodeID]) { + uint sample_classID = (*response_classIDs)[sampleID]; + class_counts[sample_classID]++; + } + double sum_node = 0; + for (auto& class_count : class_counts) { + sum_node += class_count * class_count; + } + double best_gini = decrease - sum_node / (double) sampleIDs[nodeID].size(); + + // No variable importance for no split variables + size_t tempvarID = data->getUnpermutedVarID(varID); + for (auto& skip : data->getNoSplitVariables()) { + if (tempvarID >= skip) { + --tempvarID; + } + } + + // Subtract if corrected importance and permuted variable, else add + if (importance_mode == IMP_GINI_CORRECTED && varID >= data->getNumCols()) { + (*variable_importance)[tempvarID] -= best_gini; + } else { + (*variable_importance)[tempvarID] += best_gini; + } +} + +void TreeProbability::bootstrapClassWise() { + // Number of samples is sum of sample fraction * number of samples + size_t num_samples_inbag = 0; + double sum_sample_fraction = 0; + for (auto& s : *sample_fraction) { + num_samples_inbag += (size_t) num_samples * s; + sum_sample_fraction += s; + } + + // Reserve space, reserve a little more to be save) + sampleIDs[0].reserve(num_samples_inbag); + oob_sampleIDs.reserve(num_samples * (exp(-sum_sample_fraction) + 0.1)); + + // Start with all samples OOB + inbag_counts.resize(num_samples, 0); + + // Draw samples for each class + for (size_t i = 0; i < sample_fraction->size(); ++i) { + // Draw samples of class with replacement as inbag and mark as not OOB + size_t num_samples_class = (*sampleIDs_per_class)[i].size(); + size_t num_samples_inbag_class = round(num_samples * (*sample_fraction)[i]); + std::uniform_int_distribution unif_dist(0, num_samples_class - 1); + for (size_t s = 0; s < num_samples_inbag_class; ++s) { + size_t draw = (*sampleIDs_per_class)[i][unif_dist(random_number_generator)]; + sampleIDs[0].push_back(draw); + ++inbag_counts[draw]; + } + } + + // Save OOB samples + for (size_t s = 0; s < inbag_counts.size(); ++s) { + if (inbag_counts[s] == 0) { + oob_sampleIDs.push_back(s); + } + } + num_samples_oob = oob_sampleIDs.size(); + + if (!keep_inbag) { + inbag_counts.clear(); + inbag_counts.shrink_to_fit(); + } +} + +void TreeProbability::bootstrapWithoutReplacementClassWise() { + // Draw samples for each class + for (size_t i = 0; i < sample_fraction->size(); ++i) { + size_t num_samples_class = (*sampleIDs_per_class)[i].size(); + size_t num_samples_inbag_class = round(num_samples * (*sample_fraction)[i]); + + shuffleAndSplitAppend(sampleIDs[0], oob_sampleIDs, num_samples_class, num_samples_inbag_class, + (*sampleIDs_per_class)[i], random_number_generator); + } + + if (keep_inbag) { + // All observation are 0 or 1 times inbag + inbag_counts.resize(num_samples, 1); + for (size_t i = 0; i < oob_sampleIDs.size(); i++) { + inbag_counts[oob_sampleIDs[i]] = 0; + } + } +} + +} // namespace ranger diff --git a/lib/ranger/TreeProbability.h b/lib/ranger/TreeProbability.h new file mode 100644 index 000000000..188497072 --- /dev/null +++ b/lib/ranger/TreeProbability.h @@ -0,0 +1,120 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef TREEPROBABILITY_H_ +#define TREEPROBABILITY_H_ + +#include +#include + +#include "globals.h" +#include "Tree.h" + +namespace ranger { + +class TreeProbability: public Tree { +public: + TreeProbability(std::vector* class_values, std::vector* response_classIDs, + std::vector>* sampleIDs_per_class, std::vector* class_weights); + + // Create from loaded forest + TreeProbability(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values, std::vector* class_values, std::vector* response_classIDs, + std::vector>& terminal_class_counts); + + TreeProbability(const TreeProbability&) = delete; + TreeProbability& operator=(const TreeProbability&) = delete; + + virtual ~TreeProbability() override = default; + + void allocateMemory() override; + + void addToTerminalNodes(size_t nodeID); + void computePermutationImportanceInternal(std::vector>* permutations); + void appendToFileInternal(std::ofstream& file) override; + + const std::vector& getPrediction(size_t sampleID) const { + size_t terminal_nodeID = prediction_terminal_nodeIDs[sampleID]; + return terminal_class_counts[terminal_nodeID]; + } + + size_t getPredictionTerminalNodeID(size_t sampleID) const { + return prediction_terminal_nodeIDs[sampleID]; + } + + const std::vector>& getTerminalClassCounts() const { + return terminal_class_counts; + } + +private: + bool splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) override; + void createEmptyNodeInternal() override; + + double computePredictionAccuracyInternal() override; + + // Called by splitNodeInternal(). Sets split_varIDs and split_values. + bool findBestSplit(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right); + void findBestSplitValueLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + + bool findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& class_counts_right, + std::vector& n_right); + void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + + void addImpurityImportance(size_t nodeID, size_t varID, double decrease); + + void bootstrapClassWise() override; + void bootstrapWithoutReplacementClassWise() override; + + void cleanUpInternal() override { + counter.clear(); + counter.shrink_to_fit(); + counter_per_class.clear(); + counter_per_class.shrink_to_fit(); + } + + // Classes of the dependent variable and classIDs for responses + const std::vector* class_values; + const std::vector* response_classIDs; + const std::vector>* sampleIDs_per_class; + + // Class counts in terminal nodes. Empty for non-terminal nodes. + std::vector> terminal_class_counts; + + // Splitting weights + const std::vector* class_weights; + + std::vector counter; + std::vector counter_per_class; +}; + +} // namespace ranger + +#endif /* TREEPROBABILITY_H_ */ diff --git a/lib/ranger/TreeRegression.cpp b/lib/ranger/TreeRegression.cpp new file mode 100644 index 000000000..da9d9243b --- /dev/null +++ b/lib/ranger/TreeRegression.cpp @@ -0,0 +1,706 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include + +#include + +#include "utility.h" +#include "TreeRegression.h" +#include "Data.h" + +namespace ranger { + +TreeRegression::TreeRegression(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values) : + Tree(child_nodeIDs, split_varIDs, split_values), counter(0), sums(0) { +} + +void TreeRegression::allocateMemory() { + // Init counters if not in memory efficient mode + if (!memory_saving_splitting) { + size_t max_num_splits = data->getMaxNumUniqueValues(); + + // Use number of random splits for extratrees + if (splitrule == EXTRATREES && num_random_splits > max_num_splits) { + max_num_splits = num_random_splits; + } + + counter.resize(max_num_splits); + sums.resize(max_num_splits); + } +} + +double TreeRegression::estimate(size_t nodeID) { + +// Mean of responses of samples in node + double sum_responses_in_node = 0; + size_t num_samples_in_node = sampleIDs[nodeID].size(); + for (size_t i = 0; i < sampleIDs[nodeID].size(); ++i) { + sum_responses_in_node += data->get(sampleIDs[nodeID][i], dependent_varID); + } + return (sum_responses_in_node / (double) num_samples_in_node); +} + +void TreeRegression::appendToFileInternal(std::ofstream& file) { // #nocov start +// Empty on purpose +} // #nocov end + +bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) { + + // Check node size, stop if maximum reached + if (sampleIDs[nodeID].size() <= min_node_size) { + split_values[nodeID] = estimate(nodeID); + return true; + } + + // Check if node is pure and set split_value to estimate and stop if pure + bool pure = true; + double pure_value = 0; + for (size_t i = 0; i < sampleIDs[nodeID].size(); ++i) { + double value = data->get(sampleIDs[nodeID][i], dependent_varID); + if (i != 0 && value != pure_value) { + pure = false; + break; + } + pure_value = value; + } + if (pure) { + split_values[nodeID] = pure_value; + return true; + } + + // Find best split, stop if no decrease of impurity + bool stop; + if (splitrule == MAXSTAT) { + stop = findBestSplitMaxstat(nodeID, possible_split_varIDs); + } else if (splitrule == EXTRATREES) { + stop = findBestSplitExtraTrees(nodeID, possible_split_varIDs); + } else { + stop = findBestSplit(nodeID, possible_split_varIDs); + } + + if (stop) { + split_values[nodeID] = estimate(nodeID); + return true; + } + + return false; +} + +void TreeRegression::createEmptyNodeInternal() { +// Empty on purpose +} + +double TreeRegression::computePredictionAccuracyInternal() { + + size_t num_predictions = prediction_terminal_nodeIDs.size(); + double sum_of_squares = 0; + for (size_t i = 0; i < num_predictions; ++i) { + size_t terminal_nodeID = prediction_terminal_nodeIDs[i]; + double predicted_value = split_values[terminal_nodeID]; + double real_value = data->get(oob_sampleIDs[i], dependent_varID); + if (predicted_value != real_value) { + sum_of_squares += (predicted_value - real_value) * (predicted_value - real_value); + } + } + return (1.0 - sum_of_squares / (double) num_predictions); +} + +bool TreeRegression::findBestSplit(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + double best_decrease = -1; + size_t best_varID = 0; + double best_value = 0; + + // Compute sum of responses in node + double sum_node = 0; + for (auto& sampleID : sampleIDs[nodeID]) { + sum_node += data->get(sampleID, dependent_varID); + } + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + + // Use memory saving method if option set + if (memory_saving_splitting) { + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } else { + // Use faster method for both cases + double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); + if (q < Q_THRESHOLD) { + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } else { + findBestSplitValueLargeQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } + } + } else { + findBestSplitValueUnordered(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } + } + +// Stop if no good split found + if (best_decrease < 0) { + return true; + } + +// Save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + +// Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_decrease); + } + return false; +} + +void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + // -1 because no split possible at largest value + const size_t num_splits = possible_split_values.size() - 1; + if (memory_saving_splitting) { + std::vector sums_right(num_splits); + std::vector n_right(num_splits); + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease, + possible_split_values, sums_right, n_right); + } else { + std::fill_n(sums.begin(), num_splits, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease, + possible_split_values, sums, counter); + } +} + +void TreeRegression::findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, + std::vector& sums_right, std::vector& n_right) { + // -1 because no split possible at largest value + const size_t num_splits = possible_split_values.size() - 1; + + // Sum in right child and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + double response = data->get(sampleID, dependent_varID); + + // Count samples until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + if (value > possible_split_values[i]) { + ++n_right[i]; + sums_right[i] += response; + } else { + break; + } + } + } + + // Compute decrease of impurity for each possible split + for (size_t i = 0; i < num_splits; ++i) { + + // Stop if one child empty + size_t n_left = num_samples_node - n_right[i]; + if (n_left == 0 || n_right[i] == 0) { + continue; + } + + double sum_right = sums_right[i]; + double sum_left = sum_node - sum_right; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right[i]; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_decrease = decrease; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeRegression::findBestSplitValueLargeQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease) { + + // Set counters to 0 + size_t num_unique = data->getNumUniqueDataValues(varID); + std::fill_n(counter.begin(), num_unique, 0); + std::fill_n(sums.begin(), num_unique, 0); + + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + + sums[index] += data->get(sampleID, dependent_varID); + ++counter[index]; + } + + size_t n_left = 0; + double sum_left = 0; + + // Compute decrease of impurity for each split + for (size_t i = 0; i < num_unique - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + sum_left += sums[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - n_left; + if (n_right == 0) { + break; + } + + double sum_right = sum_node - sum_left; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + // Find next value in this node + size_t j = i + 1; + while (j < num_unique && counter[j] == 0) { + ++j; + } + + // Use mid-point split + best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; + best_varID = varID; + best_decrease = decrease; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == data->getUniqueDataValue(varID, j)) { + best_value = data->getUniqueDataValue(varID, i); + } + } + } +} + +void TreeRegression::findBestSplitValueUnordered(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease) { + +// Create possible split values + std::vector factor_levels; + data->getAllValues(factor_levels, sampleIDs[nodeID], varID); + +// Try next variable if all equal for this + if (factor_levels.size() < 2) { + return; + } + +// Number of possible splits is 2^num_levels + size_t num_splits = (1 << factor_levels.size()); + +// Compute decrease of impurity for each possible split +// Split where all left (0) or all right (1) are excluded +// The second half of numbers is just left/right switched the first half -> Exclude second half + for (size_t local_splitID = 1; local_splitID < num_splits / 2; ++local_splitID) { + + // Compute overall splitID by shifting local factorIDs to global positions + size_t splitID = 0; + for (size_t j = 0; j < factor_levels.size(); ++j) { + if ((local_splitID & (1 << j))) { + double level = factor_levels[j]; + size_t factorID = floor(level) - 1; + splitID = splitID | (1 << factorID); + } + } + + // Initialize + double sum_right = 0; + size_t n_right = 0; + + // Sum in right child + for (auto& sampleID : sampleIDs[nodeID]) { + double response = data->get(sampleID, dependent_varID); + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++n_right; + sum_right += response; + } + } + size_t n_left = num_samples_node - n_right; + + // Sum of squares + double sum_left = sum_node - sum_right; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = splitID; + best_varID = varID; + best_decrease = decrease; + } + } +} + +bool TreeRegression::findBestSplitMaxstat(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + + // Compute ranks + std::vector response; + response.reserve(num_samples_node); + for (auto& sampleID : sampleIDs[nodeID]) { + response.push_back(data->get(sampleID, dependent_varID)); + } + std::vector ranks = rank(response); + + // Save split stats + std::vector pvalues; + pvalues.reserve(possible_split_varIDs.size()); + std::vector values; + values.reserve(possible_split_varIDs.size()); + std::vector candidate_varIDs; + candidate_varIDs.reserve(possible_split_varIDs.size()); + std::vector test_statistics; + test_statistics.reserve(possible_split_varIDs.size()); + + // Compute p-values + for (auto& varID : possible_split_varIDs) { + + // Get all observations + std::vector x; + x.reserve(num_samples_node); + for (auto& sampleID : sampleIDs[nodeID]) { + x.push_back(data->get(sampleID, varID)); + } + + // Order by x + std::vector indices = order(x, false); + //std::vector indices = orderInData(data, sampleIDs[nodeID], varID, false); + + // Compute maximally selected rank statistics + double best_maxstat; + double best_split_value; + maxstat(ranks, x, indices, best_maxstat, best_split_value, minprop, 1 - minprop); + //maxstatInData(scores, data, sampleIDs[nodeID], varID, indices, best_maxstat, best_split_value, minprop, 1 - minprop); + + if (best_maxstat > -1) { + // Compute number of samples left of cutpoints + std::vector num_samples_left = numSamplesLeftOfCutpoint(x, indices); + //std::vector num_samples_left = numSamplesLeftOfCutpointInData(data, sampleIDs[nodeID], varID, indices); + + // Compute p-values + double pvalue_lau92 = maxstatPValueLau92(best_maxstat, minprop, 1 - minprop); + double pvalue_lau94 = maxstatPValueLau94(best_maxstat, minprop, 1 - minprop, num_samples_node, num_samples_left); + + // Use minimum of Lau92 and Lau94 + double pvalue = std::min(pvalue_lau92, pvalue_lau94); + + // Save split stats + pvalues.push_back(pvalue); + values.push_back(best_split_value); + candidate_varIDs.push_back(varID); + test_statistics.push_back(best_maxstat); + } + } + + double adjusted_best_pvalue = std::numeric_limits::max(); + size_t best_varID = 0; + double best_value = 0; + double best_maxstat = 0; + + if (pvalues.size() > 0) { + // Adjust p-values with Benjamini/Hochberg + std::vector adjusted_pvalues = adjustPvalues(pvalues); + + // Use smallest p-value + double min_pvalue = std::numeric_limits::max(); + for (size_t i = 0; i < pvalues.size(); ++i) { + if (pvalues[i] < min_pvalue) { + min_pvalue = pvalues[i]; + best_varID = candidate_varIDs[i]; + best_value = values[i]; + adjusted_best_pvalue = adjusted_pvalues[i]; + best_maxstat = test_statistics[i]; + } + } + } + + // Stop if no good split found (this is terminal node). + if (adjusted_best_pvalue > alpha) { + return true; + } else { + // If not terminal node save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_maxstat); + } + + return false; + } +} + +bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + double best_decrease = -1; + size_t best_varID = 0; + double best_value = 0; + + // Compute sum of responses in node + double sum_node = 0; + for (auto& sampleID : sampleIDs[nodeID]) { + sum_node += data->get(sampleID, dependent_varID); + } + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + findBestSplitValueExtraTrees(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } else { + findBestSplitValueExtraTreesUnordered(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, + best_decrease); + } + } + + // Stop if no good split found + if (best_decrease < 0) { + return true; + } + + // Save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_decrease); + } + return false; +} + +void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease) { + + // Get min/max values of covariate in node + double min; + double max; + data->getMinMaxValues(min, max, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (min == max) { + return; + } + + // Create possible split values: Draw randomly between min and max + std::vector possible_split_values; + std::uniform_real_distribution udist(min, max); + possible_split_values.reserve(num_random_splits); + for (size_t i = 0; i < num_random_splits; ++i) { + possible_split_values.push_back(udist(random_number_generator)); + } + + const size_t num_splits = possible_split_values.size(); + if (memory_saving_splitting) { + std::vector sums_right(num_splits); + std::vector n_right(num_splits); + findBestSplitValueExtraTrees(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease, + possible_split_values, sums_right, n_right); + } else { + std::fill_n(sums.begin(), num_splits, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueExtraTrees(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease, + possible_split_values, sums, counter); + } +} + +void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, + std::vector& sums_right, std::vector& n_right) { + const size_t num_splits = possible_split_values.size(); + + // Sum in right child and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + double response = data->get(sampleID, dependent_varID); + + // Count samples until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + if (value > possible_split_values[i]) { + ++n_right[i]; + sums_right[i] += response; + } else { + break; + } + } + } + + // Compute decrease of impurity for each possible split + for (size_t i = 0; i < num_splits; ++i) { + + // Stop if one child empty + size_t n_left = num_samples_node - n_right[i]; + if (n_left == 0 || n_right[i] == 0) { + continue; + } + + double sum_right = sums_right[i]; + double sum_left = sum_node - sum_right; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right[i]; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = possible_split_values[i]; + best_varID = varID; + best_decrease = decrease; + } + } +} + +void TreeRegression::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double sum_node, + size_t num_samples_node, double& best_value, size_t& best_varID, double& best_decrease) { + + size_t num_unique_values = data->getNumUniqueDataValues(varID); + + // Get all factor indices in node + std::vector factor_in_node(num_unique_values, false); + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + factor_in_node[index] = true; + } + + // Vector of indices in and out of node + std::vector indices_in_node; + std::vector indices_out_node; + indices_in_node.reserve(num_unique_values); + indices_out_node.reserve(num_unique_values); + for (size_t i = 0; i < num_unique_values; ++i) { + if (factor_in_node[i]) { + indices_in_node.push_back(i); + } else { + indices_out_node.push_back(i); + } + } + + // Generate num_random_splits splits + for (size_t i = 0; i < num_random_splits; ++i) { + std::vector split_subset; + split_subset.reserve(num_unique_values); + + // Draw random subsets, sample all partitions with equal probability + if (indices_in_node.size() > 1) { + size_t num_partitions = (2 << (indices_in_node.size() - 1)) - 2; // 2^n-2 (don't allow full or empty) + std::uniform_int_distribution udist(1, num_partitions); + size_t splitID_in_node = udist(random_number_generator); + for (size_t j = 0; j < indices_in_node.size(); ++j) { + if ((splitID_in_node & (1 << j)) > 0) { + split_subset.push_back(indices_in_node[j]); + } + } + } + if (indices_out_node.size() > 1) { + size_t num_partitions = (2 << (indices_out_node.size() - 1)) - 1; // 2^n-1 (allow full or empty) + std::uniform_int_distribution udist(0, num_partitions); + size_t splitID_out_node = udist(random_number_generator); + for (size_t j = 0; j < indices_out_node.size(); ++j) { + if ((splitID_out_node & (1 << j)) > 0) { + split_subset.push_back(indices_out_node[j]); + } + } + } + + // Assign union of the two subsets to right child + size_t splitID = 0; + for (auto& idx : split_subset) { + splitID |= 1 << idx; + } + + // Initialize + double sum_right = 0; + size_t n_right = 0; + + // Sum in right child + for (auto& sampleID : sampleIDs[nodeID]) { + double response = data->get(sampleID, dependent_varID); + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++n_right; + sum_right += response; + } + } + size_t n_left = num_samples_node - n_right; + + // Sum of squares + double sum_left = sum_node - sum_right; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right; + + // If better than before, use this + if (decrease > best_decrease) { + best_value = splitID; + best_varID = varID; + best_decrease = decrease; + } + } +} + +void TreeRegression::addImpurityImportance(size_t nodeID, size_t varID, double decrease) { + + double best_decrease = decrease; + if (splitrule != MAXSTAT) { + double sum_node = 0; + for (auto& sampleID : sampleIDs[nodeID]) { + sum_node += data->get(sampleID, dependent_varID); + } + best_decrease = decrease - sum_node * sum_node / (double) sampleIDs[nodeID].size(); + } + + // No variable importance for no split variables + size_t tempvarID = data->getUnpermutedVarID(varID); + for (auto& skip : data->getNoSplitVariables()) { + if (tempvarID >= skip) { + --tempvarID; + } + } + + // Subtract if corrected importance and permuted variable, else add + if (importance_mode == IMP_GINI_CORRECTED && varID >= data->getNumCols()) { + (*variable_importance)[tempvarID] -= best_decrease; + } else { + (*variable_importance)[tempvarID] += best_decrease; + } +} + +} // namespace ranger diff --git a/lib/ranger/TreeRegression.h b/lib/ranger/TreeRegression.h new file mode 100644 index 000000000..244191786 --- /dev/null +++ b/lib/ranger/TreeRegression.h @@ -0,0 +1,96 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef TREEREGRESSION_H_ +#define TREEREGRESSION_H_ + +#include + +#include "globals.h" +#include "Tree.h" + +namespace ranger { + +class TreeRegression: public Tree { +public: + TreeRegression() = default; + + // Create from loaded forest + TreeRegression(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values); + + TreeRegression(const TreeRegression&) = delete; + TreeRegression& operator=(const TreeRegression&) = delete; + + virtual ~TreeRegression() override = default; + + void allocateMemory() override; + + double estimate(size_t nodeID); + void computePermutationImportanceInternal(std::vector>* permutations); + void appendToFileInternal(std::ofstream& file) override; + + double getPrediction(size_t sampleID) const { + size_t terminal_nodeID = prediction_terminal_nodeIDs[sampleID]; + return (split_values[terminal_nodeID]); + } + + size_t getPredictionTerminalNodeID(size_t sampleID) const { + return prediction_terminal_nodeIDs[sampleID]; + } + +private: + bool splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) override; + void createEmptyNodeInternal() override; + + double computePredictionAccuracyInternal() override; + + // Called by splitNodeInternal(). Sets split_varIDs and split_values. + bool findBestSplit(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); + void findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, + std::vector& sums_right, std::vector& n_right); + void findBestSplitValueLargeQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); + void findBestSplitValueUnordered(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); + + bool findBestSplitMaxstat(size_t nodeID, std::vector& possible_split_varIDs); + + bool findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, + std::vector& sums_right, std::vector& n_right); + void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); + + void addImpurityImportance(size_t nodeID, size_t varID, double decrease); + + double computePredictionMSE(); + + void cleanUpInternal() override { + counter.clear(); + counter.shrink_to_fit(); + sums.clear(); + sums.shrink_to_fit(); + } + + std::vector counter; + std::vector sums; +}; + +} // namespace ranger + +#endif /* TREEREGRESSION_H_ */ diff --git a/lib/ranger/TreeSurvival.cpp b/lib/ranger/TreeSurvival.cpp new file mode 100644 index 000000000..f0da9b9c4 --- /dev/null +++ b/lib/ranger/TreeSurvival.cpp @@ -0,0 +1,884 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include + +#include "utility.h" +#include "TreeSurvival.h" +#include "Data.h" + +namespace ranger { + +TreeSurvival::TreeSurvival(std::vector* unique_timepoints, size_t status_varID, + std::vector* response_timepointIDs) : + status_varID(status_varID), unique_timepoints(unique_timepoints), response_timepointIDs(response_timepointIDs), num_deaths( + 0), num_samples_at_risk(0) { + this->num_timepoints = unique_timepoints->size(); +} + +TreeSurvival::TreeSurvival(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values, std::vector> chf, std::vector* unique_timepoints, + std::vector* response_timepointIDs) : + Tree(child_nodeIDs, split_varIDs, split_values), status_varID(0), unique_timepoints(unique_timepoints), response_timepointIDs( + response_timepointIDs), chf(chf), num_deaths(0), num_samples_at_risk(0) { + this->num_timepoints = unique_timepoints->size(); +} + +void TreeSurvival::allocateMemory() { + // Number of deaths and samples at risk for each timepoint + num_deaths.resize(num_timepoints); + num_samples_at_risk.resize(num_timepoints); +} + +void TreeSurvival::appendToFileInternal(std::ofstream& file) { // #nocov start + + // Convert to vector without empty elements and save + std::vector terminal_nodes; + std::vector> chf_vector; + for (size_t i = 0; i < chf.size(); ++i) { + if (!chf[i].empty()) { + terminal_nodes.push_back(i); + chf_vector.push_back(chf[i]); + } + } + saveVector1D(terminal_nodes, file); + saveVector2D(chf_vector, file); +} // #nocov end + +void TreeSurvival::createEmptyNodeInternal() { + chf.push_back(std::vector()); +} + +void TreeSurvival::computeSurvival(size_t nodeID) { + std::vector chf_temp; + chf_temp.reserve(num_timepoints); + double chf_value = 0; + for (size_t i = 0; i < num_timepoints; ++i) { + if (num_samples_at_risk[i] != 0) { + chf_value += (double) num_deaths[i] / (double) num_samples_at_risk[i]; + } + chf_temp.push_back(chf_value); + } + chf[nodeID] = chf_temp; +} + +double TreeSurvival::computePredictionAccuracyInternal() { + + // Compute summed chf for samples + std::vector sum_chf; + for (size_t i = 0; i < prediction_terminal_nodeIDs.size(); ++i) { + size_t terminal_nodeID = prediction_terminal_nodeIDs[i]; + sum_chf.push_back(std::accumulate(chf[terminal_nodeID].begin(), chf[terminal_nodeID].end(), 0.0)); + } + + // Return concordance index + return computeConcordanceIndex(*data, sum_chf, dependent_varID, status_varID, oob_sampleIDs); +} + +bool TreeSurvival::splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) { + + if (splitrule == MAXSTAT) { + return findBestSplitMaxstat(nodeID, possible_split_varIDs); + } else if (splitrule == EXTRATREES) { + return findBestSplitExtraTrees(nodeID, possible_split_varIDs); + } else { + return findBestSplit(nodeID, possible_split_varIDs); + } +} + +bool TreeSurvival::findBestSplit(size_t nodeID, std::vector& possible_split_varIDs) { + + double best_decrease = -1; + size_t num_samples_node = sampleIDs[nodeID].size(); + size_t best_varID = 0; + double best_value = 0; + + computeDeathCounts(nodeID); + + // Stop early if no split posssible + if (num_samples_node >= 2 * min_node_size) { + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + if (splitrule == LOGRANK) { + findBestSplitValueLogRank(nodeID, varID, best_value, best_varID, best_decrease); + } else if (splitrule == AUC || splitrule == AUC_IGNORE_TIES) { + findBestSplitValueAUC(nodeID, varID, best_value, best_varID, best_decrease); + } + } else { + findBestSplitValueLogRankUnordered(nodeID, varID, best_value, best_varID, best_decrease); + } + + } + } + + // Stop and save CHF if no good split found (this is terminal node). + if (best_decrease < 0) { + computeSurvival(nodeID); + return true; + } else { + // If not terminal node save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_decrease); + } + + return false; + } +} + +bool TreeSurvival::findBestSplitMaxstat(size_t nodeID, std::vector& possible_split_varIDs) { + + size_t num_samples_node = sampleIDs[nodeID].size(); + + // Check node size, stop if maximum reached + if (num_samples_node <= min_node_size) { + computeDeathCounts(nodeID); + computeSurvival(nodeID); + return true; + } + + // Compute scores + std::vector time; + time.reserve(num_samples_node); + std::vector status; + status.reserve(num_samples_node); + for (auto& sampleID : sampleIDs[nodeID]) { + time.push_back(data->get(sampleID, dependent_varID)); + status.push_back(data->get(sampleID, status_varID)); + } + std::vector scores = logrankScores(time, status); + //std::vector scores = logrankScoresData(data, dependent_varID, status_varID, sampleIDs[nodeID]); + + // Save split stats + std::vector pvalues; + pvalues.reserve(possible_split_varIDs.size()); + std::vector values; + values.reserve(possible_split_varIDs.size()); + std::vector candidate_varIDs; + candidate_varIDs.reserve(possible_split_varIDs.size()); + std::vector test_statistics; + test_statistics.reserve(possible_split_varIDs.size()); + + // Compute p-values + for (auto& varID : possible_split_varIDs) { + + // Get all observations + std::vector x; + x.reserve(num_samples_node); + for (auto& sampleID : sampleIDs[nodeID]) { + x.push_back(data->get(sampleID, varID)); + } + + // Order by x + std::vector indices = order(x, false); + //std::vector indices = orderInData(data, sampleIDs[nodeID], varID, false); + + // Compute maximally selected rank statistics + double best_maxstat; + double best_split_value; + maxstat(scores, x, indices, best_maxstat, best_split_value, minprop, 1 - minprop); + //maxstatInData(scores, data, sampleIDs[nodeID], varID, indices, best_maxstat, best_split_value, minprop, 1 - minprop); + + if (best_maxstat > -1) { + // Compute number of samples left of cutpoints + std::vector num_samples_left = numSamplesLeftOfCutpoint(x, indices); + //std::vector num_samples_left = numSamplesLeftOfCutpointInData(data, sampleIDs[nodeID], varID, indices); + + // Remove largest cutpoint (all observations left) + num_samples_left.pop_back(); + + // Use unadjusted p-value if only 1 split point + double pvalue; + if (num_samples_left.size() == 1) { + pvalue = maxstatPValueUnadjusted(best_maxstat); + } else { + // Compute p-values + double pvalue_lau92 = maxstatPValueLau92(best_maxstat, minprop, 1 - minprop); + double pvalue_lau94 = maxstatPValueLau94(best_maxstat, minprop, 1 - minprop, num_samples_node, + num_samples_left); + + // Use minimum of Lau92 and Lau94 + pvalue = std::min(pvalue_lau92, pvalue_lau94); + } + + // Save split stats + pvalues.push_back(pvalue); + values.push_back(best_split_value); + candidate_varIDs.push_back(varID); + test_statistics.push_back(best_maxstat); + } + } + + double adjusted_best_pvalue = std::numeric_limits::max(); + size_t best_varID = 0; + double best_value = 0; + double best_maxstat = 0; + + if (pvalues.size() > 0) { + // Adjust p-values with Benjamini/Hochberg + std::vector adjusted_pvalues = adjustPvalues(pvalues); + + double min_pvalue = std::numeric_limits::max(); + for (size_t i = 0; i < pvalues.size(); ++i) { + if (pvalues[i] < min_pvalue) { + min_pvalue = pvalues[i]; + best_varID = candidate_varIDs[i]; + best_value = values[i]; + adjusted_best_pvalue = adjusted_pvalues[i]; + best_maxstat = test_statistics[i]; + } + } + } + + // Stop and save CHF if no good split found (this is terminal node). + if (adjusted_best_pvalue > alpha) { + computeDeathCounts(nodeID); + computeSurvival(nodeID); + return true; + } else { + // If not terminal node save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_maxstat); + } + + return false; + } +} + +void TreeSurvival::computeDeathCounts(size_t nodeID) { + + // Initialize + for (size_t i = 0; i < num_timepoints; ++i) { + num_deaths[i] = 0; + num_samples_at_risk[i] = 0; + } + + for (auto& sampleID : sampleIDs[nodeID]) { + double survival_time = data->get(sampleID, dependent_varID); + + size_t t = 0; + while (t < num_timepoints && (*unique_timepoints)[t] < survival_time) { + ++num_samples_at_risk[t]; + ++t; + } + + // Now t is the survival time, add to at risk and to death if death + if (t < num_timepoints) { + ++num_samples_at_risk[t]; + if (data->get(sampleID, status_varID) == 1) { + ++num_deaths[t]; + } + } + } +} + +void TreeSurvival::computeChildDeathCounts(size_t nodeID, size_t varID, std::vector& possible_split_values, + std::vector& num_samples_right_child, std::vector& delta_samples_at_risk_right_child, + std::vector& num_deaths_right_child, size_t num_splits) { + + // Count deaths in right child per timepoint and possbile split + for (auto& sampleID : sampleIDs[nodeID]) { + double value = data->get(sampleID, varID); + size_t survival_timeID = (*response_timepointIDs)[sampleID]; + + // Count deaths until split_value reached + for (size_t i = 0; i < num_splits; ++i) { + + if (value > possible_split_values[i]) { + ++num_samples_right_child[i]; + ++delta_samples_at_risk_right_child[i * num_timepoints + survival_timeID]; + if (data->get(sampleID, status_varID) == 1) { + ++num_deaths_right_child[i * num_timepoints + survival_timeID]; + } + } else { + break; + } + } + } +} + +void TreeSurvival::findBestSplitValueLogRank(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_logrank) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + // -1 because no split possible at largest value + size_t num_splits = possible_split_values.size() - 1; + + // Initialize + std::vector num_deaths_right_child(num_splits * num_timepoints); + std::vector delta_samples_at_risk_right_child(num_splits * num_timepoints); + std::vector num_samples_right_child(num_splits); + + computeChildDeathCounts(nodeID, varID, possible_split_values, num_samples_right_child, + delta_samples_at_risk_right_child, num_deaths_right_child, num_splits); + + // Compute logrank test for all splits and use best + for (size_t i = 0; i < num_splits; ++i) { + double numerator = 0; + double denominator_squared = 0; + + // Stop if minimal node size reached + size_t num_samples_left_child = sampleIDs[nodeID].size() - num_samples_right_child[i]; + if (num_samples_right_child[i] < min_node_size || num_samples_left_child < min_node_size) { + continue; + } + + // Compute logrank test statistic for this split + size_t num_samples_at_risk_right_child = num_samples_right_child[i]; + for (size_t t = 0; t < num_timepoints; ++t) { + if (num_samples_at_risk[t] < 2 || num_samples_at_risk_right_child < 1) { + break; + } + + if (num_deaths[t] > 0) { + // Numerator and demoninator for log-rank test, notation from Ishwaran et al. + double di = (double) num_deaths[t]; + double di1 = (double) num_deaths_right_child[i * num_timepoints + t]; + double Yi = (double) num_samples_at_risk[t]; + double Yi1 = (double) num_samples_at_risk_right_child; + numerator += di1 - Yi1 * (di / Yi); + denominator_squared += (Yi1 / Yi) * (1.0 - Yi1 / Yi) * ((Yi - di) / (Yi - 1)) * di; + } + + // Reduce number of samples at risk for next timepoint + num_samples_at_risk_right_child -= delta_samples_at_risk_right_child[i * num_timepoints + t]; + + } + double logrank = -1; + if (denominator_squared != 0) { + logrank = fabs(numerator / sqrt(denominator_squared)); + } + + if (logrank > best_logrank) { + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_logrank = logrank; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeSurvival::findBestSplitValueLogRankUnordered(size_t nodeID, size_t varID, double& best_value, + size_t& best_varID, double& best_logrank) { + + // Create possible split values + std::vector factor_levels; + data->getAllValues(factor_levels, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (factor_levels.size() < 2) { + return; + } + + // Number of possible splits is 2^num_levels + size_t num_splits = (1 << factor_levels.size()); + + // Compute logrank test statistic for each possible split + // Split where all left (0) or all right (1) are excluded + // The second half of numbers is just left/right switched the first half -> Exclude second half + for (size_t local_splitID = 1; local_splitID < num_splits / 2; ++local_splitID) { + + // Compute overall splitID by shifting local factorIDs to global positions + size_t splitID = 0; + for (size_t j = 0; j < factor_levels.size(); ++j) { + if ((local_splitID & (1 << j))) { + double level = factor_levels[j]; + size_t factorID = floor(level) - 1; + splitID = splitID | (1 << factorID); + } + } + + // Initialize + std::vector num_deaths_right_child(num_timepoints); + std::vector delta_samples_at_risk_right_child(num_timepoints); + size_t num_samples_right_child = 0; + double numerator = 0; + double denominator_squared = 0; + + // Count deaths in right child per timepoint + for (auto& sampleID : sampleIDs[nodeID]) { + size_t survival_timeID = (*response_timepointIDs)[sampleID]; + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++num_samples_right_child; + ++delta_samples_at_risk_right_child[survival_timeID]; + if (data->get(sampleID, status_varID) == 1) { + ++num_deaths_right_child[survival_timeID]; + } + } + + } + + // Stop if minimal node size reached + size_t num_samples_left_child = sampleIDs[nodeID].size() - num_samples_right_child; + if (num_samples_right_child < min_node_size || num_samples_left_child < min_node_size) { + continue; + } + + // Compute logrank test statistic for this split + size_t num_samples_at_risk_right_child = num_samples_right_child; + for (size_t t = 0; t < num_timepoints; ++t) { + if (num_samples_at_risk[t] < 2 || num_samples_at_risk_right_child < 1) { + break; + } + + if (num_deaths[t] > 0) { + // Numerator and demoninator for log-rank test, notation from Ishwaran et al. + double di = (double) num_deaths[t]; + double di1 = (double) num_deaths_right_child[t]; + double Yi = (double) num_samples_at_risk[t]; + double Yi1 = (double) num_samples_at_risk_right_child; + numerator += di1 - Yi1 * (di / Yi); + denominator_squared += (Yi1 / Yi) * (1.0 - Yi1 / Yi) * ((Yi - di) / (Yi - 1)) * di; + } + + // Reduce number of samples at risk for next timepoint + num_samples_at_risk_right_child -= delta_samples_at_risk_right_child[t]; + } + double logrank = -1; + if (denominator_squared != 0) { + logrank = fabs(numerator / sqrt(denominator_squared)); + } + + if (logrank > best_logrank) { + best_value = splitID; + best_varID = varID; + best_logrank = logrank; + } + } +} + +void TreeSurvival::findBestSplitValueAUC(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_auc) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + size_t num_node_samples = sampleIDs[nodeID].size(); + size_t num_splits = possible_split_values.size() - 1; + size_t num_possible_pairs = num_node_samples * (num_node_samples - 1) / 2; + + // Initialize + std::vector num_count(num_splits, num_possible_pairs); + std::vector num_total(num_splits, num_possible_pairs); + std::vector num_samples_left_child(num_splits); + + // For all pairs + for (size_t k = 0; k < num_node_samples; ++k) { + size_t sample_k = sampleIDs[nodeID][k]; + double time_k = data->get(sample_k, dependent_varID); + double status_k = data->get(sample_k, status_varID); + double value_k = data->get(sample_k, varID); + + // Count samples in left node + for (size_t i = 0; i < num_splits; ++i) { + double split_value = possible_split_values[i]; + if (value_k <= split_value) { + ++num_samples_left_child[i]; + } + } + + for (size_t l = k + 1; l < num_node_samples; ++l) { + size_t sample_l = sampleIDs[nodeID][l]; + double time_l = data->get(sample_l, dependent_varID); + double status_l = data->get(sample_l, status_varID); + double value_l = data->get(sample_l, varID); + + // Compute split + computeAucSplit(time_k, time_l, status_k, status_l, value_k, value_l, num_splits, possible_split_values, + num_count, num_total); + } + } + + for (size_t i = 0; i < num_splits; ++i) { + // Do not consider this split point if fewer than min_node_size samples in one node + size_t num_samples_right_child = num_node_samples - num_samples_left_child[i]; + if (num_samples_left_child[i] < min_node_size || num_samples_right_child < min_node_size) { + continue; + } else { + double auc = fabs((num_count[i] / 2) / num_total[i] - 0.5); + if (auc > best_auc) { + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_auc = auc; + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } + } +} + +void TreeSurvival::computeAucSplit(double time_k, double time_l, double status_k, double status_l, double value_k, + double value_l, size_t num_splits, std::vector& possible_split_values, std::vector& num_count, + std::vector& num_total) { + + bool ignore_pair = false; + bool do_nothing = false; + + double value_smaller = 0; + double value_larger = 0; + double status_smaller = 0; + + if (time_k < time_l) { + value_smaller = value_k; + value_larger = value_l; + status_smaller = status_k; + } else if (time_l < time_k) { + value_smaller = value_l; + value_larger = value_k; + status_smaller = status_l; + } else { + // Tie in survival time + if (status_k == 0 || status_l == 0) { + ignore_pair = true; + } else { + if (splitrule == AUC_IGNORE_TIES) { + ignore_pair = true; + } else { + if (value_k == value_l) { + // Tie in survival time and in covariate + ignore_pair = true; + } else { + // Tie in survival time in covariate + do_nothing = true; + } + } + } + } + + // Do not count if smaller time censored + if (status_smaller == 0) { + ignore_pair = true; + } + + if (ignore_pair) { + for (size_t i = 0; i < num_splits; ++i) { + --num_count[i]; + --num_total[i]; + } + } else if (do_nothing) { + // Do nothing + } else { + for (size_t i = 0; i < num_splits; ++i) { + double split_value = possible_split_values[i]; + + if (value_smaller <= split_value && value_larger > split_value) { + ++num_count[i]; + } else if (value_smaller > split_value && value_larger <= split_value) { + --num_count[i]; + } else if (value_smaller <= split_value && value_larger <= split_value) { + break; + } + } + } + +} + +bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs) { + + double best_decrease = -1; + size_t num_samples_node = sampleIDs[nodeID].size(); + size_t best_varID = 0; + double best_value = 0; + + computeDeathCounts(nodeID); + + // Stop early if no split posssible + if (num_samples_node >= 2 * min_node_size) { + + // For all possible split variables + for (auto& varID : possible_split_varIDs) { + + // Find best split value, if ordered consider all values as split values, else all 2-partitions + if (data->isOrderedVariable(varID)) { + findBestSplitValueExtraTrees(nodeID, varID, best_value, best_varID, best_decrease); + } else { + findBestSplitValueExtraTreesUnordered(nodeID, varID, best_value, best_varID, best_decrease); + } + + } + } + + // Stop and save CHF if no good split found (this is terminal node). + if (best_decrease < 0) { + computeSurvival(nodeID); + return true; + } else { + // If not terminal node save best values + split_varIDs[nodeID] = best_varID; + split_values[nodeID] = best_value; + + // Compute decrease of impurity for this node and add to variable importance if needed + if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { + addImpurityImportance(nodeID, best_varID, best_decrease); + } + + return false; + } +} + +void TreeSurvival::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_logrank) { + + // Get min/max values of covariate in node + double min; + double max; + data->getMinMaxValues(min, max, sampleIDs[nodeID], varID); + + // Try next variable if all equal for this + if (min == max) { + return; + } + + // Create possible split values: Draw randomly between min and max + std::vector possible_split_values; + std::uniform_real_distribution udist(min, max); + possible_split_values.reserve(num_random_splits); + for (size_t i = 0; i < num_random_splits; ++i) { + possible_split_values.push_back(udist(random_number_generator)); + } + + size_t num_splits = possible_split_values.size(); + + // Initialize + std::vector num_deaths_right_child(num_splits * num_timepoints); + std::vector delta_samples_at_risk_right_child(num_splits * num_timepoints); + std::vector num_samples_right_child(num_splits); + + computeChildDeathCounts(nodeID, varID, possible_split_values, num_samples_right_child, + delta_samples_at_risk_right_child, num_deaths_right_child, num_splits); + + // Compute logrank test for all splits and use best + for (size_t i = 0; i < num_splits; ++i) { + double numerator = 0; + double denominator_squared = 0; + + // Stop if minimal node size reached + size_t num_samples_left_child = sampleIDs[nodeID].size() - num_samples_right_child[i]; + if (num_samples_right_child[i] < min_node_size || num_samples_left_child < min_node_size) { + continue; + } + + // Compute logrank test statistic for this split + size_t num_samples_at_risk_right_child = num_samples_right_child[i]; + for (size_t t = 0; t < num_timepoints; ++t) { + if (num_samples_at_risk[t] < 2 || num_samples_at_risk_right_child < 1) { + break; + } + + if (num_deaths[t] > 0) { + // Numerator and demoninator for log-rank test, notation from Ishwaran et al. + double di = (double) num_deaths[t]; + double di1 = (double) num_deaths_right_child[i * num_timepoints + t]; + double Yi = (double) num_samples_at_risk[t]; + double Yi1 = (double) num_samples_at_risk_right_child; + numerator += di1 - Yi1 * (di / Yi); + denominator_squared += (Yi1 / Yi) * (1.0 - Yi1 / Yi) * ((Yi - di) / (Yi - 1)) * di; + } + + // Reduce number of samples at risk for next timepoint + num_samples_at_risk_right_child -= delta_samples_at_risk_right_child[i * num_timepoints + t]; + + } + double logrank = -1; + if (denominator_squared != 0) { + logrank = fabs(numerator / sqrt(denominator_squared)); + } + + if (logrank > best_logrank) { + best_value = possible_split_values[i]; + best_varID = varID; + best_logrank = logrank; + } + } +} + +void TreeSurvival::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double& best_value, + size_t& best_varID, double& best_logrank) { + + size_t num_unique_values = data->getNumUniqueDataValues(varID); + + // Get all factor indices in node + std::vector factor_in_node(num_unique_values, false); + for (auto& sampleID : sampleIDs[nodeID]) { + size_t index = data->getIndex(sampleID, varID); + factor_in_node[index] = true; + } + + // Vector of indices in and out of node + std::vector indices_in_node; + std::vector indices_out_node; + indices_in_node.reserve(num_unique_values); + indices_out_node.reserve(num_unique_values); + for (size_t i = 0; i < num_unique_values; ++i) { + if (factor_in_node[i]) { + indices_in_node.push_back(i); + } else { + indices_out_node.push_back(i); + } + } + + // Generate num_random_splits splits + for (size_t i = 0; i < num_random_splits; ++i) { + std::vector split_subset; + split_subset.reserve(num_unique_values); + + // Draw random subsets, sample all partitions with equal probability + if (indices_in_node.size() > 1) { + size_t num_partitions = (2 << (indices_in_node.size() - 1)) - 2; // 2^n-2 (don't allow full or empty) + std::uniform_int_distribution udist(1, num_partitions); + size_t splitID_in_node = udist(random_number_generator); + for (size_t j = 0; j < indices_in_node.size(); ++j) { + if ((splitID_in_node & (1 << j)) > 0) { + split_subset.push_back(indices_in_node[j]); + } + } + } + if (indices_out_node.size() > 1) { + size_t num_partitions = (2 << (indices_out_node.size() - 1)) - 1; // 2^n-1 (allow full or empty) + std::uniform_int_distribution udist(0, num_partitions); + size_t splitID_out_node = udist(random_number_generator); + for (size_t j = 0; j < indices_out_node.size(); ++j) { + if ((splitID_out_node & (1 << j)) > 0) { + split_subset.push_back(indices_out_node[j]); + } + } + } + + // Assign union of the two subsets to right child + size_t splitID = 0; + for (auto& idx : split_subset) { + splitID |= 1 << idx; + } + + // Initialize + std::vector num_deaths_right_child(num_timepoints); + std::vector delta_samples_at_risk_right_child(num_timepoints); + size_t num_samples_right_child = 0; + double numerator = 0; + double denominator_squared = 0; + + // Count deaths in right child per timepoint + for (auto& sampleID : sampleIDs[nodeID]) { + size_t survival_timeID = (*response_timepointIDs)[sampleID]; + double value = data->get(sampleID, varID); + size_t factorID = floor(value) - 1; + + // If in right child, count + // In right child, if bitwise splitID at position factorID is 1 + if ((splitID & (1 << factorID))) { + ++num_samples_right_child; + ++delta_samples_at_risk_right_child[survival_timeID]; + if (data->get(sampleID, status_varID) == 1) { + ++num_deaths_right_child[survival_timeID]; + } + } + + } + + // Stop if minimal node size reached + size_t num_samples_left_child = sampleIDs[nodeID].size() - num_samples_right_child; + if (num_samples_right_child < min_node_size || num_samples_left_child < min_node_size) { + continue; + } + + // Compute logrank test statistic for this split + size_t num_samples_at_risk_right_child = num_samples_right_child; + for (size_t t = 0; t < num_timepoints; ++t) { + if (num_samples_at_risk[t] < 2 || num_samples_at_risk_right_child < 1) { + break; + } + + if (num_deaths[t] > 0) { + // Numerator and demoninator for log-rank test, notation from Ishwaran et al. + double di = (double) num_deaths[t]; + double di1 = (double) num_deaths_right_child[t]; + double Yi = (double) num_samples_at_risk[t]; + double Yi1 = (double) num_samples_at_risk_right_child; + numerator += di1 - Yi1 * (di / Yi); + denominator_squared += (Yi1 / Yi) * (1.0 - Yi1 / Yi) * ((Yi - di) / (Yi - 1)) * di; + } + + // Reduce number of samples at risk for next timepoint + num_samples_at_risk_right_child -= delta_samples_at_risk_right_child[t]; + } + double logrank = -1; + if (denominator_squared != 0) { + logrank = fabs(numerator / sqrt(denominator_squared)); + } + + if (logrank > best_logrank) { + best_value = splitID; + best_varID = varID; + best_logrank = logrank; + } + } +} + +void TreeSurvival::addImpurityImportance(size_t nodeID, size_t varID, double decrease) { + + // No variable importance for no split variables + size_t tempvarID = data->getUnpermutedVarID(varID); + for (auto& skip : data->getNoSplitVariables()) { + if (tempvarID >= skip) { + --tempvarID; + } + } + + // Subtract if corrected importance and permuted variable, else add + if (importance_mode == IMP_GINI_CORRECTED && varID >= data->getNumCols()) { + (*variable_importance)[tempvarID] -= decrease; + } else { + (*variable_importance)[tempvarID] += decrease; + } +} + +} // namespace ranger diff --git a/lib/ranger/TreeSurvival.h b/lib/ranger/TreeSurvival.h new file mode 100644 index 000000000..613b476ff --- /dev/null +++ b/lib/ranger/TreeSurvival.h @@ -0,0 +1,117 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef TREESURVIVAL_H_ +#define TREESURVIVAL_H_ + +#include + +#include "globals.h" +#include "Tree.h" + +namespace ranger { + +class TreeSurvival: public Tree { +public: + TreeSurvival(std::vector* unique_timepoints, size_t status_varID, std::vector* response_timepointIDs); + + // Create from loaded forest + TreeSurvival(std::vector>& child_nodeIDs, std::vector& split_varIDs, + std::vector& split_values, std::vector> chf, std::vector* unique_timepoints, + std::vector* response_timepointIDs); + + TreeSurvival(const TreeSurvival&) = delete; + TreeSurvival& operator=(const TreeSurvival&) = delete; + + virtual ~TreeSurvival() override = default; + + void allocateMemory() override; + + void appendToFileInternal(std::ofstream& file) override; + void computePermutationImportanceInternal(std::vector>* permutations); + + const std::vector>& getChf() const { + return chf; + } + + const std::vector& getPrediction(size_t sampleID) const { + size_t terminal_nodeID = prediction_terminal_nodeIDs[sampleID]; + return chf[terminal_nodeID]; + } + + size_t getPredictionTerminalNodeID(size_t sampleID) const { + return prediction_terminal_nodeIDs[sampleID]; + } + +private: + + void createEmptyNodeInternal() override; + void computeSurvival(size_t nodeID); + double computePredictionAccuracyInternal() override; + + bool splitNodeInternal(size_t nodeID, std::vector& possible_split_varIDs) override; + + bool findBestSplit(size_t nodeID, std::vector& possible_split_varIDs); + bool findBestSplitMaxstat(size_t nodeID, std::vector& possible_split_varIDs); + + void findBestSplitValueLogRank(size_t nodeID, size_t varID, std::vector& possible_split_values, + double& best_value, size_t& best_varID, double& best_logrank); + void findBestSplitValueLogRankUnordered(size_t nodeID, size_t varID, std::vector& factor_levels, + double& best_value, size_t& best_varID, double& best_logrank); + void findBestSplitValueAUC(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, double& best_auc); + + void computeDeathCounts(size_t nodeID); + void computeChildDeathCounts(size_t nodeID, size_t varID, std::vector& possible_split_values, + std::vector& num_samples_right_child, std::vector& num_samples_at_risk_right_child, + std::vector& num_deaths_right_child, size_t num_splits); + + void computeAucSplit(double time_k, double time_l, double status_k, double status_l, double value_k, double value_l, + size_t num_splits, std::vector& possible_split_values, std::vector& num_count, + std::vector& num_total); + + void findBestSplitValueLogRank(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_logrank); + void findBestSplitValueLogRankUnordered(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_logrank); + + bool findBestSplitExtraTrees(size_t nodeID, std::vector& possible_split_varIDs); + void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_logrank); + void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, + double& best_logrank); + + void addImpurityImportance(size_t nodeID, size_t varID, double decrease); + + void cleanUpInternal() override { + num_deaths.clear(); + num_deaths.shrink_to_fit(); + num_samples_at_risk.clear(); + num_samples_at_risk.shrink_to_fit(); + } + + size_t status_varID; + + // Unique time points for all individuals (not only this bootstrap), sorted + const std::vector* unique_timepoints; + size_t num_timepoints; + const std::vector* response_timepointIDs; + + // For all terminal nodes CHF for all unique timepoints. For other nodes empty vector. + std::vector> chf; + + // Fields to save to while tree growing + std::vector num_deaths; + std::vector num_samples_at_risk; +}; + +} // namespace ranger + +#endif /* TREESURVIVAL_H_ */ diff --git a/lib/ranger/globals.h b/lib/ranger/globals.h new file mode 100644 index 000000000..b844c18cb --- /dev/null +++ b/lib/ranger/globals.h @@ -0,0 +1,105 @@ +/*------------------------------------------------------------------------------- +This file is part of ranger. + +Copyright (c) [2014-2018] [Marvin N. Wright] + +This software may be modified and distributed under the terms of the MIT license. + +Please note that the C++ core of ranger is distributed under MIT license and the +R package "ranger" under GPL3 license. +#-------------------------------------------------------------------------------*/ + +#ifndef GLOBALS_H_ +#define GLOBALS_H_ + +namespace ranger { + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +// Old/new Win build +#ifdef WIN_R_BUILD + #if __cplusplus < 201103L + #define OLD_WIN_R_BUILD + #else + #define NEW_WIN_R_BUILD + #endif +#endif + +typedef unsigned int uint; + +// Tree types, probability is not selected by ID +enum TreeType { + TREE_CLASSIFICATION = 1, + TREE_REGRESSION = 3, + TREE_SURVIVAL = 5, + TREE_PROBABILITY = 9 +}; + +// Memory modes +enum MemoryMode { + MEM_DOUBLE = 0, + MEM_FLOAT = 1, + MEM_CHAR = 2 +}; +const uint MAX_MEM_MODE = 2; + +// Mask and Offset to store 2 bit values in bytes +static const int mask[4] = {192,48,12,3}; +static const int offset[4] = {6,4,2,0}; + +// Variable importance +enum ImportanceMode { + IMP_NONE = 0, + IMP_GINI = 1, + IMP_PERM_BREIMAN = 2, + IMP_PERM_LIAW = 4, + IMP_PERM_RAW = 3, + IMP_GINI_CORRECTED = 5 +}; +const uint MAX_IMP_MODE = 5; + +// Split mode +enum SplitRule { + LOGRANK = 1, + AUC = 2, + AUC_IGNORE_TIES = 3, + MAXSTAT = 4, + EXTRATREES = 5 +}; + +// Prediction type +enum PredictionType { + RESPONSE = 1, + TERMINALNODES = 2 +}; + +// Default values +const uint DEFAULT_NUM_TREE = 500; +const uint DEFAULT_NUM_THREADS = 0; +const ImportanceMode DEFAULT_IMPORTANCE_MODE = IMP_NONE; + +const uint DEFAULT_MIN_NODE_SIZE_CLASSIFICATION = 1; +const uint DEFAULT_MIN_NODE_SIZE_REGRESSION = 5; +const uint DEFAULT_MIN_NODE_SIZE_SURVIVAL = 3; +const uint DEFAULT_MIN_NODE_SIZE_PROBABILITY = 10; + +const SplitRule DEFAULT_SPLITRULE = LOGRANK; +const double DEFAULT_ALPHA = 0.5; +const double DEFAULT_MINPROP = 0.1; + +const PredictionType DEFAULT_PREDICTIONTYPE = RESPONSE; +const uint DEFAULT_NUM_RANDOM_SPLITS = 1; + +//const std::vector DEFAULT_SAMPLE_FRACTION = std::vector({1}); + +// Interval to print progress in seconds +const double STATUS_INTERVAL = 30.0; + +// Threshold for q value split method switch +const double Q_THRESHOLD = 0.02; + +} // namespace ranger + +#endif /* GLOBALS_H_ */ diff --git a/lib/ranger/utility.cpp b/lib/ranger/utility.cpp new file mode 100644 index 000000000..9ca02a1de --- /dev/null +++ b/lib/ranger/utility.cpp @@ -0,0 +1,582 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utility.h" +#include "globals.h" +#include "Data.h" + +namespace ranger { + +void equalSplit(std::vector& result, uint start, uint end, uint num_parts) { + + result.reserve(num_parts + 1); + + // Return range if only 1 part + if (num_parts == 1) { + result.push_back(start); + result.push_back(end + 1); + return; + } + + // Return vector from start to end+1 if more parts than elements + if (num_parts > end - start + 1) { + for (uint i = start; i <= end + 1; ++i) { + result.push_back(i); + } + return; + } + + uint length = (end - start + 1); + uint part_length_short = length / num_parts; + uint part_length_long = (uint) ceil(length / ((double) num_parts)); + uint cut_pos = length % num_parts; + + // Add long ranges + for (uint i = start; i < start + cut_pos * part_length_long; i = i + part_length_long) { + result.push_back(i); + } + + // Add short ranges + for (uint i = start + cut_pos * part_length_long; i <= end + 1; i = i + part_length_short) { + result.push_back(i); + } +} + +void loadDoubleVectorFromFile(std::vector& result, std::string filename) { // #nocov start + + // Open input file + std::ifstream input_file; + input_file.open(filename); + if (!input_file.good()) { + throw std::runtime_error("Could not open file: " + filename); + } + + // Read the first line, ignore the rest + std::string line; + getline(input_file, line); + std::stringstream line_stream(line); + double token; + while (line_stream >> token) { + result.push_back(token); + } +} // #nocov end + +void drawWithoutReplacementSkip(std::vector& result, std::mt19937_64& random_number_generator, size_t max, + const std::vector& skip, size_t num_samples) { + if (num_samples < max / 10) { + drawWithoutReplacementSimple(result, random_number_generator, max, skip, num_samples); + } else { + //drawWithoutReplacementKnuth(result, random_number_generator, max, skip, num_samples); + drawWithoutReplacementFisherYates(result, random_number_generator, max, skip, num_samples); + } +} + +void drawWithoutReplacementSimple(std::vector& result, std::mt19937_64& random_number_generator, size_t max, + const std::vector& skip, size_t num_samples) { + + result.reserve(num_samples); + + // Set all to not selected + std::vector temp; + temp.resize(max, false); + + std::uniform_int_distribution unif_dist(0, max - 1 - skip.size()); + for (size_t i = 0; i < num_samples; ++i) { + size_t draw; + do { + draw = unif_dist(random_number_generator); + for (auto& skip_value : skip) { + if (draw >= skip_value) { + ++draw; + } + } + } while (temp[draw]); + temp[draw] = true; + result.push_back(draw); + } +} + +void drawWithoutReplacementFisherYates(std::vector& result, std::mt19937_64& random_number_generator, + size_t max, const std::vector& skip, size_t num_samples) { + + // Create indices + result.resize(max); + std::iota(result.begin(), result.end(), 0); + + // Skip indices + for (size_t i = 0; i < skip.size(); ++i) { + result.erase(result.begin() + skip[skip.size() - 1 - i]); + } + + // Draw without replacement using Fisher Yates algorithm + std::uniform_real_distribution distribution(0.0, 1.0); + for (size_t i = 0; i < num_samples; ++i) { + size_t j = i + distribution(random_number_generator) * (max - skip.size() - i); + std::swap(result[i], result[j]); + } + + result.resize(num_samples); +} + +void drawWithoutReplacementWeighted(std::vector& result, std::mt19937_64& random_number_generator, + const std::vector& indices, size_t num_samples, const std::vector& weights) { + + result.reserve(num_samples); + + // Set all to not selected + std::vector temp; + temp.resize(indices.size(), false); + + std::discrete_distribution<> weighted_dist(weights.begin(), weights.end()); + for (size_t i = 0; i < num_samples; ++i) { + size_t draw; + do { + draw = weighted_dist(random_number_generator); + } while (temp[draw]); + temp[draw] = true; + result.push_back(indices[draw]); + } +} + +void drawWithoutReplacementWeighted(std::vector& result, std::mt19937_64& random_number_generator, + size_t max_index, size_t num_samples, const std::vector& weights) { + + result.reserve(num_samples); + + // Set all to not selected + std::vector temp; + temp.resize(max_index + 1, false); + + std::discrete_distribution<> weighted_dist(weights.begin(), weights.end()); + for (size_t i = 0; i < num_samples; ++i) { + size_t draw; + do { + draw = weighted_dist(random_number_generator); + } while (temp[draw]); + temp[draw] = true; + result.push_back(draw); + } +} + +double mostFrequentValue(const std::unordered_map& class_count, + std::mt19937_64 random_number_generator) { + std::vector major_classes; + + // Find maximum count + size_t max_count = 0; + for (auto& class_value : class_count) { + if (class_value.second > max_count) { + max_count = class_value.second; + major_classes.clear(); + major_classes.push_back(class_value.first); + } else if (class_value.second == max_count) { + major_classes.push_back(class_value.first); + } + } + + if (major_classes.size() == 1) { + return major_classes[0]; + } else { + // Choose randomly + std::uniform_int_distribution unif_dist(0, major_classes.size() - 1); + return major_classes[unif_dist(random_number_generator)]; + } +} + +double computeConcordanceIndex(const Data& data, const std::vector& sum_chf, size_t dependent_varID, + size_t status_varID, const std::vector& sample_IDs) { + + // Compute concordance index + double concordance = 0; + double permissible = 0; + for (size_t i = 0; i < sum_chf.size(); ++i) { + size_t sample_i = i; + if (!sample_IDs.empty()) { + sample_i = sample_IDs[i]; + } + double time_i = data.get(sample_i, dependent_varID); + double status_i = data.get(sample_i, status_varID); + + for (size_t j = i + 1; j < sum_chf.size(); ++j) { + size_t sample_j = j; + if (!sample_IDs.empty()) { + sample_j = sample_IDs[j]; + } + double time_j = data.get(sample_j, dependent_varID); + double status_j = data.get(sample_j, status_varID); + + if (time_i < time_j && status_i == 0) { + continue; + } + if (time_j < time_i && status_j == 0) { + continue; + } + if (time_i == time_j && status_i == status_j) { + continue; + } + + permissible += 1; + + if (time_i < time_j && sum_chf[i] > sum_chf[j]) { + concordance += 1; + } else if (time_j < time_i && sum_chf[j] > sum_chf[i]) { + concordance += 1; + } else if (sum_chf[i] == sum_chf[j]) { + concordance += 0.5; + } + + } + } + + return (concordance / permissible); + +} + +std::string uintToString(uint number) { +#if WIN_R_BUILD == 1 + std::stringstream temp; + temp << number; + return temp.str(); +#else + return std::to_string(number); +#endif +} + +std::string beautifyTime(uint seconds) { // #nocov start + std::string result; + + // Add seconds, minutes, hours, days if larger than zero + uint out_seconds = (uint) seconds % 60; + result = uintToString(out_seconds) + " seconds"; + uint out_minutes = (seconds / 60) % 60; + if (seconds / 60 == 0) { + return result; + } else if (out_minutes == 1) { + result = "1 minute, " + result; + } else { + result = uintToString(out_minutes) + " minutes, " + result; + } + uint out_hours = (seconds / 3600) % 24; + if (seconds / 3600 == 0) { + return result; + } else if (out_hours == 1) { + result = "1 hour, " + result; + } else { + result = uintToString(out_hours) + " hours, " + result; + } + uint out_days = (seconds / 86400); + if (out_days == 0) { + return result; + } else if (out_days == 1) { + result = "1 day, " + result; + } else { + result = uintToString(out_days) + " days, " + result; + } + return result; +} // #nocov end + +size_t roundToNextMultiple(size_t value, uint multiple) { + + if (multiple == 0) { + return value; + } + + size_t remainder = value % multiple; + if (remainder == 0) { + return value; + } + + return value + multiple - remainder; +} + +void splitString(std::vector& result, const std::string& input, char split_char) { // #nocov start + + std::istringstream ss(input); + std::string token; + + while (std::getline(ss, token, split_char)) { + result.push_back(token); + } +} // #nocov end + +void shuffleAndSplit(std::vector& first_part, std::vector& second_part, size_t n_all, size_t n_first, + std::mt19937_64 random_number_generator) { + + // Reserve space + first_part.resize(n_all); + + // Fill with 0..n_all-1 and shuffle + std::iota(first_part.begin(), first_part.end(), 0); + std::shuffle(first_part.begin(), first_part.end(), random_number_generator); + + // Copy to second part + second_part.resize(n_all - n_first); + std::copy(first_part.begin() + n_first, first_part.end(), second_part.begin()); + + // Resize first part + first_part.resize(n_first); +} + +void shuffleAndSplitAppend(std::vector& first_part, std::vector& second_part, size_t n_all, + size_t n_first, const std::vector& mapping, std::mt19937_64 random_number_generator) { + // Old end is start position for new data + size_t first_old_size = first_part.size(); + size_t second_old_size = second_part.size(); + + // Reserve space + first_part.resize(first_old_size + n_all); + std::vector::iterator first_start_pos = first_part.begin() + first_old_size; + + // Fill with 0..n_all-1 and shuffle + std::iota(first_start_pos, first_part.end(), 0); + std::shuffle(first_start_pos, first_part.end(), random_number_generator); + + // Mapping + for (std::vector::iterator j = first_start_pos; j != first_part.end(); ++j) { + *j = mapping[*j]; + } + + // Copy to second part + second_part.resize(second_part.size() + n_all - n_first); + std::vector::iterator second_start_pos = second_part.begin() + second_old_size; + std::copy(first_start_pos + n_first, first_part.end(), second_start_pos); + + // Resize first part + first_part.resize(first_old_size + n_first); +} + +std::string checkUnorderedVariables(const Data& data, const std::vector& unordered_variable_names) { // #nocov start + size_t num_rows = data.getNumRows(); + std::vector sampleIDs(num_rows); + std::iota(sampleIDs.begin(), sampleIDs.end(), 0); + + // Check for all unordered variables + for (auto& variable_name : unordered_variable_names) { + size_t varID = data.getVariableID(variable_name); + std::vector all_values; + data.getAllValues(all_values, sampleIDs, varID); + + // Check level count + size_t max_level_count = 8 * sizeof(size_t) - 1; + if (all_values.size() > max_level_count) { + return "Too many levels in unordered categorical variable " + variable_name + ". Only " + + uintToString(max_level_count) + " levels allowed on this system."; + } + + // Check positive integers + if (!checkPositiveIntegers(all_values)) { + return "Not all values in unordered categorical variable " + variable_name + " are positive integers."; + } + } + return ""; +} // #nocov end + +bool checkPositiveIntegers(const std::vector& all_values) { // #nocov start + for (auto& value : all_values) { + if (value < 1 || !(floor(value) == value)) { + return false; + } + } + return true; +} // #nocov end + +double maxstatPValueLau92(double b, double minprop, double maxprop) { + + if (b < 1) { + return 1.0; + } + + // Compute only once (minprop/maxprop don't change during runtime) + static double logprop = log((maxprop * (1 - minprop)) / ((1 - maxprop) * minprop)); + + double db = dstdnorm(b); + double p = 4 * db / b + db * (b - 1 / b) * logprop; + + if (p > 0) { + return p; + } else { + return 0; + } +} + +double maxstatPValueLau94(double b, double minprop, double maxprop, size_t N, const std::vector& m) { + + double D = 0; + for (size_t i = 0; i < m.size() - 1; ++i) { + double m1 = m[i]; + double m2 = m[i + 1]; + + double t = sqrt(1.0 - m1 * (N - m2) / ((N - m1) * m2)); + D += 1 / M_PI * exp(-b * b / 2) * (t - (b * b / 4 - 1) * (t * t * t) / 6); + } + + return 2 * (1 - pstdnorm(b)) + D; +} + +double maxstatPValueUnadjusted(double b) { + return 2 * pstdnorm(-b); +} + +double dstdnorm(double x) { + return exp(-0.5 * x * x) / sqrt(2 * M_PI); +} + +double pstdnorm(double x) { + return 0.5 * (1 + erf(x / sqrt(2.0))); +} + +std::vector adjustPvalues(std::vector& unadjusted_pvalues) { + size_t num_pvalues = unadjusted_pvalues.size(); + std::vector adjusted_pvalues(num_pvalues, 0); + + // Get order of p-values + std::vector indices = order(unadjusted_pvalues, true); + + // Compute adjusted p-values + adjusted_pvalues[indices[0]] = unadjusted_pvalues[indices[0]]; + for (size_t i = 1; i < indices.size(); ++i) { + size_t idx = indices[i]; + size_t idx_last = indices[i - 1]; + + adjusted_pvalues[idx] = std::min(adjusted_pvalues[idx_last], + (double) num_pvalues / (double) (num_pvalues - i) * unadjusted_pvalues[idx]); + } + return adjusted_pvalues; +} + +std::vector logrankScores(const std::vector& time, const std::vector& status) { + size_t n = time.size(); + std::vector scores(n); + + // Get order of timepoints + std::vector indices = order(time, false); + + // Compute scores + double cumsum = 0; + size_t last_unique = -1; + for (size_t i = 0; i < n; ++i) { + + // Continue if next value is the same + if (i < n - 1 && time[indices[i]] == time[indices[i + 1]]) { + continue; + } + + // Compute sum and scores for all non-unique values in a row + for (size_t j = last_unique + 1; j <= i; ++j) { + cumsum += status[indices[j]] / (n - i); + } + for (size_t j = last_unique + 1; j <= i; ++j) { + scores[indices[j]] = status[indices[j]] - cumsum; + } + + // Save last computed value + last_unique = i; + } + + return scores; +} + +void maxstat(const std::vector& scores, const std::vector& x, const std::vector& indices, + double& best_maxstat, double& best_split_value, double minprop, double maxprop) { + size_t n = x.size(); + + double sum_all_scores = 0; + for (size_t i = 0; i < n; ++i) { + sum_all_scores += scores[indices[i]]; + } + + // Compute sum of differences from mean for variance + double mean_scores = sum_all_scores / n; + double sum_mean_diff = 0; + for (size_t i = 0; i < n; ++i) { + sum_mean_diff += (scores[i] - mean_scores) * (scores[i] - mean_scores); + } + + // Get smallest and largest split to consider, -1 for compatibility with R maxstat + size_t minsplit = 0; + if (n * minprop > 1) { + minsplit = n * minprop - 1; + } + size_t maxsplit = n * maxprop - 1; + + // For all unique x-values + best_maxstat = -1; + best_split_value = -1; + double sum_scores = 0; + size_t n_left = 0; + for (size_t i = 0; i <= maxsplit; ++i) { + + sum_scores += scores[indices[i]]; + n_left++; + + // Dont consider splits smaller than minsplit for splitting (but count) + if (i < minsplit) { + continue; + } + + // Consider only unique values + if (i < n - 1 && x[indices[i]] == x[indices[i + 1]]) { + continue; + } + + // If value is largest possible value, stop + if (x[indices[i]] == x[indices[n - 1]]) { + break; + } + + double S = sum_scores; + double E = (double) n_left / (double) n * sum_all_scores; + double V = (double) n_left * (double) (n - n_left) / (double) (n * (n - 1)) * sum_mean_diff; + double T = fabs((S - E) / sqrt(V)); + + if (T > best_maxstat) { + best_maxstat = T; + + // Use mid-point split if possible + if (i < n - 1) { + best_split_value = (x[indices[i]] + x[indices[i + 1]]) / 2; + } else { + best_split_value = x[indices[i]]; + } + } + } +} + +std::vector numSamplesLeftOfCutpoint(std::vector& x, const std::vector& indices) { + std::vector num_samples_left; + num_samples_left.reserve(x.size()); + + for (size_t i = 0; i < x.size(); ++i) { + if (i == 0) { + num_samples_left.push_back(1); + } else if (x[indices[i]] == x[indices[i - 1]]) { + ++num_samples_left[num_samples_left.size() - 1]; + } else { + num_samples_left.push_back(num_samples_left[num_samples_left.size() - 1] + 1); + } + } + + return num_samples_left; +} + +} // namespace ranger diff --git a/lib/ranger/utility.h b/lib/ranger/utility.h new file mode 100644 index 000000000..578372bc4 --- /dev/null +++ b/lib/ranger/utility.h @@ -0,0 +1,497 @@ +/*------------------------------------------------------------------------------- + This file is part of ranger. + + Copyright (c) [2014-2018] [Marvin N. Wright] + + This software may be modified and distributed under the terms of the MIT license. + + Please note that the C++ core of ranger is distributed under MIT license and the + R package "ranger" under GPL3 license. + #-------------------------------------------------------------------------------*/ + +#ifndef UTILITY_H_ +#define UTILITY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef R_BUILD +#include +#endif + +#include "globals.h" +#include "Data.h" + +namespace ranger { + +/** + * Split sequence start..end in num_parts parts with sizes as equal as possible. + * @param result Result vector of size num_parts+1. Ranges for the parts are then result[0]..result[1]-1, result[1]..result[2]-1, .. + * @param start minimum value + * @param end maximum value + * @param num_parts number of parts + */ +void equalSplit(std::vector& result, uint start, uint end, uint num_parts); + +// #nocov start +/** + * Write a 1d vector to filestream. First the size is written as size_t, then all vector elements. + * @param vector Vector with elements of type T to write to file. + * @param file ofstream object to write to. + */ + +/** + * Write a 1d vector to filestream. First the size is written, then all vector elements. + * @param vector Vector of type T to save + * @param file ofstream object to write to. + */ +template +inline void saveVector1D(const std::vector& vector, std::ofstream& file) { + // Save length + size_t length = vector.size(); + file.write((char*) &length, sizeof(length)); + file.write((char*) vector.data(), length * sizeof(T)); +} + +template<> +inline void saveVector1D(const std::vector& vector, std::ofstream& file) { + // Save length + size_t length = vector.size(); + file.write((char*) &length, sizeof(length)); + + // Save vector + for (size_t i = 0; i < vector.size(); ++i) { + bool v = vector[i]; + file.write((char*) &v, sizeof(v)); + } +} + +/** + * Read a 1d vector written by saveVector1D() from filestream. + * @param result Result vector with elements of type T. + * @param file ifstream object to read from. + */ +template +inline void readVector1D(std::vector& result, std::ifstream& file) { + // Read length + size_t length; + file.read((char*) &length, sizeof(length)); + result.resize(length); + file.read((char*) result.data(), length * sizeof(T)); +} + +template<> +inline void readVector1D(std::vector& result, std::ifstream& file) { + // Read length + size_t length; + file.read((char*) &length, sizeof(length)); + + // Read vector. + for (size_t i = 0; i < length; ++i) { + bool temp; + file.read((char*) &temp, sizeof(temp)); + result.push_back(temp); + } +} + +/** + * Write a 2d vector to filestream. First the size of the first dim is written as size_t, then for all inner vectors the size and elements. + * @param vector Vector of vectors of type T to write to file. + * @param file ofstream object to write to. + */ +template +inline void saveVector2D(const std::vector>& vector, std::ofstream& file) { + // Save length of first dim + size_t length = vector.size(); + file.write((char*) &length, sizeof(length)); + + // Save outer vector + for (auto& inner_vector : vector) { + // Save inner vector + saveVector1D(inner_vector, file); + } +} + +/** + * Read a 2d vector written by saveVector2D() from filestream. + * @param result Result vector of vectors with elements of type T. + * @param file ifstream object to read from. + */ +template +inline void readVector2D(std::vector>& result, std::ifstream& file) { + // Read length of first dim + size_t length; + file.read((char*) &length, sizeof(length)); + result.resize(length); + + // Read outer vector + for (size_t i = 0; i < length; ++i) { + // Read inner vector + readVector1D(result[i], file); + } +} + +/** + * Read a double vector from text file. Reads only the first line. + * @param result Result vector of doubles with contents + * @param filename filename of input file + */ +void loadDoubleVectorFromFile(std::vector& result, std::string filename); +// #nocov end + +/** + * Draw random numbers in a range without replacement and skip values. + * @param result Vector to add results to. Will not be cleaned before filling. + * @param random_number_generator Random number generator + * @param range_length Length of range. Interval to draw from: 0..max-1 + * @param skip Values to skip + * @param num_samples Number of samples to draw + */ +void drawWithoutReplacementSkip(std::vector& result, std::mt19937_64& random_number_generator, + size_t range_length, const std::vector& skip, size_t num_samples); + +/** + * Simple algorithm for sampling without replacement, faster for smaller num_samples + * @param result Vector to add results to. Will not be cleaned before filling. + * @param random_number_generator Random number generator + * @param range_length Length of range. Interval to draw from: 0..max-1 + * @param skip Values to skip + * @param num_samples Number of samples to draw + */ +void drawWithoutReplacementSimple(std::vector& result, std::mt19937_64& random_number_generator, size_t max, + const std::vector& skip, size_t num_samples); + +/** + * Fisher Yates algorithm for sampling without replacement. + * @param result Vector to add results to. Will not be cleaned before filling. + * @param random_number_generator Random number generator + * @param max Length of range. Interval to draw from: 0..max-1 + * @param skip Values to skip + * @param num_samples Number of samples to draw + */ +void drawWithoutReplacementFisherYates(std::vector& result, std::mt19937_64& random_number_generator, + size_t max, const std::vector& skip, size_t num_samples); + +/** + * Draw random numers without replacement and with weighted probabilites from vector of indices. + * @param result Vector to add results to. Will not be cleaned before filling. + * @param random_number_generator Random number generator + * @param indices Vector with numbers to draw + * @param num_samples Number of samples to draw + * @param weights A weight for each element of indices + */ +void drawWithoutReplacementWeighted(std::vector& result, std::mt19937_64& random_number_generator, + const std::vector& indices, size_t num_samples, const std::vector& weights); + +/** + * Draw random numers without replacement and with weighted probabilites from 0..n-1. + * @param result Vector to add results to. Will not be cleaned before filling. + * @param random_number_generator Random number generator + * @param max_index Maximum index to draw + * @param num_samples Number of samples to draw + * @param weights A weight for each element of indices + */ +void drawWithoutReplacementWeighted(std::vector& result, std::mt19937_64& random_number_generator, + size_t max_index, size_t num_samples, const std::vector& weights); + +/** + * Draw random numbers of a vector without replacement. + * @param result Vector to add results to. Will not be cleaned before filling. + * @param input Vector to draw values from. + * @param random_number_generator Random number generator + * @param num_samples Number of samples to draw + */ +template +void drawWithoutReplacementFromVector(std::vector& result, const std::vector& input, + std::mt19937_64& random_number_generator, size_t num_samples) { + + // Draw random indices + std::vector result_idx; + result_idx.reserve(num_samples); + std::vector skip; // Empty vector (no skip) + drawWithoutReplacementSkip(result_idx, random_number_generator, input.size(), skip, num_samples); + + // Add vector values to result + for (auto& idx : result_idx) { + result.push_back(input[idx]); + } +} + +/** + * Returns the most frequent class index of a vector with counts for the classes. Returns a random class if counts are equal. + * @param class_count Vector with class counts + * @param random_number_generator Random number generator + * @return Most frequent class index. Out of range index if all 0. + */ +template +size_t mostFrequentClass(const std::vector& class_count, std::mt19937_64 random_number_generator) { + std::vector major_classes; + +// Find maximum count + T max_count = 0; + for (size_t i = 0; i < class_count.size(); ++i) { + T count = class_count[i]; + if (count > max_count) { + max_count = count; + major_classes.clear(); + major_classes.push_back(i); + } else if (count == max_count) { + major_classes.push_back(i); + } + } + + if (max_count == 0) { + return class_count.size(); + } else if (major_classes.size() == 1) { + return major_classes[0]; + } else { + // Choose randomly + std::uniform_int_distribution unif_dist(0, major_classes.size() - 1); + return major_classes[unif_dist(random_number_generator)]; + } +} + +/** + * Returns the most frequent value of a map with counts for the values. Returns a random class if counts are equal. + * @param class_count Map with classes and counts + * @param random_number_generator Random number generator + * @return Most frequent value + */ +double mostFrequentValue(const std::unordered_map& class_count, + std::mt19937_64 random_number_generator); + +/** + * Compute concordance index for given data and summed cumulative hazard function/estimate + * @param data Reference to Data object + * @param sum_chf Summed chf over timepoints for each sample + * @param dependent_varID ID of dependent variable + * @param status_varID ID of status variable + * @param sample_IDs IDs of samples, for example OOB samples + * @return concordance index + */ +double computeConcordanceIndex(const Data& data, const std::vector& sum_chf, size_t dependent_varID, + size_t status_varID, const std::vector& sample_IDs); + +/** + * Convert a unsigned integer to string + * @param number Number to convert + * @return Converted number as string + */ +std::string uintToString(uint number); + +/** + * Beautify output of time. + * @param seconds Time in seconds + * @return Time in days, hours, minutes and seconds as string + */ +std::string beautifyTime(uint seconds); + +/** + * Round up to next multiple of a number. + * @param value Value to be rounded up. + * @param multiple Number to multiply. + * @return Rounded number + */ +size_t roundToNextMultiple(size_t value, uint multiple); + +/** + * Split string in parts separated by character. + * @param result Splitted string + * @param input String to be splitted + * @param split_char Char to separate parts + */ +void splitString(std::vector& result, const std::string& input, char split_char); + +/** + * Create numbers from 0 to n_all-1, shuffle and split in two parts. + * @param first_part First part + * @param second_part Second part + * @param n_all Number elements + * @param n_first Number of elements of first part + * @param random_number_generator Random number generator + */ +void shuffleAndSplit(std::vector& first_part, std::vector& second_part, size_t n_all, size_t n_first, + std::mt19937_64 random_number_generator); + +/** + * Create numbers from 0 to n_all-1, shuffle and split in two parts. Append to existing data. + * @param first_part First part + * @param second_part Second part + * @param n_all Number elements + * @param n_first Number of elements of first part + * @param mapping Values to use instead of 0...n-1 + * @param random_number_generator Random number generator + */ +void shuffleAndSplitAppend(std::vector& first_part, std::vector& second_part, size_t n_all, + size_t n_first, const std::vector& mapping, std::mt19937_64 random_number_generator); + +/** + * Check if not too many factor levels and all values in unordered categorical variables are positive integers. + * @param data Reference to data object + * @param unordered_variable_names Names of unordered variables + * @return Error message, empty if no problem occured + */ +std::string checkUnorderedVariables(const Data& data, const std::vector& unordered_variable_names); + +/** + * Check if all values in double vector are positive integers. + * @param all_values Double vector to check + * @return True if all values are positive integers + */ +bool checkPositiveIntegers(const std::vector& all_values); + +/** + * Compute p-value for maximally selected rank statistics using Lau92 approximation + * See Lausen, B. & Schumacher, M. (1992). Biometrics 48, 73-85. + * @param b Quantile + * @param minprop Minimal proportion of observations left of cutpoint + * @param maxprop Maximal proportion of observations left of cutpoint + * @return p-value for quantile b + */ +double maxstatPValueLau92(double b, double minprop, double maxprop); + +/** + * Compute p-value for maximally selected rank statistics using Lau94 approximation + * See Lausen, B., Sauerbrei, W. & Schumacher, M. (1994). Computational Statistics. 483-496. + * @param b Quantile + * @param minprop Minimal proportion of observations left of cutpoint + * @param maxprop Maximal proportion of observations left of cutpoint + * @param N Number of observations + * @param m Vector with number of observations smaller or equal than cutpoint, sorted, only for unique cutpoints + * @return p-value for quantile b + */ +double maxstatPValueLau94(double b, double minprop, double maxprop, size_t N, const std::vector& m); + +/** + * Compute unadjusted p-value for rank statistics + * @param b Quantile + * @return p-value for quantile b + */ +double maxstatPValueUnadjusted(double b); + +/** + * Standard normal density + * @param x Quantile + * @return Standard normal density at quantile x + */ +double dstdnorm(double x); + +/** + * Standard normal distribution + * @param x Quantile + * @return Standard normal distribution at quantile x + */ +double pstdnorm(double x); + +/** + * Adjust p-values with Benjamini/Hochberg + * @param unadjusted_pvalues Unadjusted p-values (input) + * @param adjusted_pvalues Adjusted p-values (result) + */ +std::vector adjustPvalues(std::vector& unadjusted_pvalues); + +/** + * Get indices of sorted values + * @param values Values to sort + * @param decreasing Order decreasing + * @return Indices of sorted values + */ +template +std::vector order(const std::vector& values, bool decreasing) { +// Create index vector + std::vector indices(values.size()); + std::iota(indices.begin(), indices.end(), 0); + +// Sort index vector based on value vector + if (decreasing) { + std::sort(std::begin(indices), std::end(indices), [&](size_t i1, size_t i2) {return values[i1] > values[i2];}); + } else { + std::sort(std::begin(indices), std::end(indices), [&](size_t i1, size_t i2) {return values[i1] < values[i2];}); + } + return indices; +} + +/** + * Sample ranks starting from 1. Ties are given the average rank. + * @param values Values to rank + * @return Ranks of input values + */ +template +std::vector rank(const std::vector& values) { + size_t num_values = values.size(); + +// Order + std::vector indices = order(values, false); + +// Compute ranks, start at 1 + std::vector ranks(num_values); + size_t reps = 1; + for (size_t i = 0; i < num_values; i += reps) { + + // Find number of replications + reps = 1; + while (i + reps < num_values && values[indices[i]] == values[indices[i + reps]]) { + ++reps; + } + + // Assign rank to all replications + for (size_t j = 0; j < reps; ++j) + ranks[indices[i + j]] = (2 * (double) i + (double) reps - 1) / 2 + 1; + } + + return ranks; +} + +/** + * Compute Logrank scores for survival times + * @param time Survival time + * @param status Censoring indicator + * @return Logrank scores + */ +std::vector logrankScores(const std::vector& time, const std::vector& status); + +/** + * Compute maximally selected rank statistics + * @param scores Scores for dependent variable (y) + * @param x Independent variable + * @param indices Ordering of x values + * @param best_maxstat Maximally selected statistic (output) + * @param best_split_value Split value for maximally selected statistic (output) + * @param minprop Minimal proportion of observations left of cutpoint + * @param maxprop Maximal proportion of observations left of cutpoint + */ +void maxstat(const std::vector& scores, const std::vector& x, const std::vector& indices, + double& best_maxstat, double& best_split_value, double minprop, double maxprop); + +/** + * Compute number of samples smaller or equal than each unique value in x + * @param x Value vector + * @param indices Ordering of x + * @return Vector of number of samples smaller or equal than each unique value in x + */ +std::vector numSamplesLeftOfCutpoint(std::vector& x, const std::vector& indices); + +// User interrupt from R +#ifdef R_BUILD +static void chkIntFn(void *dummy) { + R_CheckUserInterrupt(); +} + +inline bool checkInterrupt() { + return (R_ToplevelExec(chkIntFn, NULL) == FALSE); +} +#endif + +} + // namespace ranger + +#endif /* UTILITY_H_ */ diff --git a/lib/tandem/tandem.cpp b/lib/tandem/tandem.cpp index 8f206d51e..ac41a77dc 100755 --- a/lib/tandem/tandem.cpp +++ b/lib/tandem/tandem.cpp @@ -1,24 +1,10 @@ -/* MIT License +/* tandem.cpp - Copyright (c) 2017 Daniel Cooke + Copyright (C) 2017-2018 University of Oxford. - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: + Author: Daniel Cooke - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. */ + Use of this source code is governed by the MIT license that can be found in the LICENSE file. */ #include "tandem.hpp" @@ -115,37 +101,38 @@ make_lpf_and_prev_occ_arrays(std::vector sa, std::vector> +get_init_buckets(const std::size_t n, const LMRVector& lmrs) { - std::vector> - get_init_buckets(const std::size_t n, const std::deque& lmrs) - { - std::vector counts(n, 0); - for (const auto& run : lmrs) { - ++counts[run.pos + run.length - 1]; - } - std::vector> result(n, std::vector {}); - for (std::size_t i {0}; i < n; ++i) { - result[i].reserve(counts[i]); - } - return result; + std::vector counts(n, 0); + for (const auto& run : lmrs) { + ++counts[run.pos + run.length - 1]; } - - std::vector> - get_init_buckets(const std::size_t n, const std::vector>& end_buckets) - { - std::vector counts(n, 0); - for (const auto& bucket : end_buckets) { - for (const auto& run : bucket) { - ++counts[run.pos]; - } - } - std::vector> result(n, std::vector {}); - for (std::size_t i {0}; i < n; ++i) { - result[i].reserve(counts[i]); + std::vector> result(n, std::vector {}); + for (std::size_t i {0}; i < n; ++i) { + result[i].reserve(counts[i]); + } + return result; +} + +std::vector> +get_init_buckets(const std::size_t n, const std::vector>& end_buckets) +{ + std::vector counts(n, 0); + for (const auto& bucket : end_buckets) { + for (const auto& run : bucket) { + ++counts[run.pos]; } - return result; } + std::vector> result(n, std::vector {}); + for (std::size_t i {0}; i < n; ++i) { + result[i].reserve(counts[i]); + } + return result; +} + } // namespace detail void rebase(std::vector& runs, const std::map& shift_map) diff --git a/lib/tandem/tandem.hpp b/lib/tandem/tandem.hpp index 3562d4706..1277494db 100755 --- a/lib/tandem/tandem.hpp +++ b/lib/tandem/tandem.hpp @@ -1,6 +1,6 @@ /* MIT License - Copyright (c) 2017 Daniel Cooke + Copyright (c) 2017-2018 Daniel Cooke Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -148,11 +148,8 @@ make_lcp_array(const T& str, const std::vector& suffix_array, const size_t extra_capacity = 0) { static_assert(std::is_integral::value, "Integer required"); - const auto rank = make_rank_array(suffix_array, extra_capacity); - std::vector result(suffix_array.size() + extra_capacity); - for (I i {0}, h {0}; i < (suffix_array.size() - extra_capacity); ++i) { if (rank[i] > 0) { h += detail::forward_lce(str, i + h, suffix_array[rank[i] - 1] + h); @@ -160,7 +157,6 @@ make_lcp_array(const T& str, const std::vector& suffix_array, if (h > 0) --h; } } - return result; } @@ -198,45 +194,39 @@ template std::vector lempel_ziv_factorisation(const T& str) { if (str.empty()) return {}; - const auto lpf = make_lpf_array(str); - std::vector result {}; result.reserve(str.size()); // max possible blocks std::uint32_t end {1}; // start at 1 because the first element of lpf is sentinel result.emplace_back(0, end); - while (end < str.size()) { const auto m = std::max(std::uint32_t {1}, lpf[end]); result.emplace_back(end, m); end += m; } - - result.shrink_to_fit(); - return result; } // Implementation of algorithm found in Crochemore et al. (2008) template -std::pair, std::vector> +std::pair, std::vector> lempel_ziv_factorisation_with_prev_block_occurences(const T& str) { if (str.empty()) return {{}, {}}; - std::vector lpf, prev_occ; + std::vector lpf, prev_occ; std::tie(lpf, prev_occ) = make_lpf_and_prev_occ_arrays(str); std::vector lz_blocks {}; lz_blocks.reserve(str.size()); // max possible blocks - std::vector prev_lz_block_occurrence {}; + std::vector prev_lz_block_occurrence {}; prev_lz_block_occurrence.reserve(str.size()); std::uint32_t end {1}; // start at 1 because the first element of lpf is sentinel lz_blocks.emplace_back(0, end); - prev_lz_block_occurrence.emplace_back(-1); + prev_lz_block_occurrence.push_back(std::numeric_limits::max()); while (end < str.size()) { - const auto m = std::max(uint32_t {1}, lpf[end]); + const auto m = std::max(std::uint32_t {1}, lpf[end]); lz_blocks.emplace_back(end, m); prev_lz_block_occurrence.emplace_back(prev_occ[end]); end += m; @@ -250,43 +240,52 @@ lempel_ziv_factorisation_with_prev_block_occurences(const T& str) namespace detail { +using LMRVector = std::deque; + +template +void add_maximal_periodicities(const T& str, const LZBlock& prev_block, const LZBlock& block, + const std::uint32_t min_period, const std::uint32_t max_period, + LMRVector& result) +{ + const auto u = block.pos; + const auto n = block.length; + const auto m = std::min(u, 2 * prev_block.length + n); + const auto t = u - m; + const auto end = u + n; + // rightmax periodicities + for (auto j = min_period; j <= std::min(n, max_period); ++j) { + const auto ls = backward_lce(str, u - 1, u + j - 1, t); + const auto lp = forward_lce(str, u + j, u, end); + if (ls + lp >= j && j + lp < n) { + result.emplace_back(u - ls, j + lp + ls, j); + } + } + // leftmax periodicities + for (auto j = min_period; j < std::min(m, max_period); ++j) { + const auto ls = backward_lce(str, u - j - 1, u - 1, t); + const auto lp = forward_lce(str, u, u - j, end); + if (ls + lp >= j) { + result.emplace_back(u - (ls + j), j + lp + ls, j); + } + } +} + // Implements Mains algorithm found in Main (1989). Obscure notation as in paper. template -std::deque +auto find_leftmost_maximal_repetitions(const T& str, const std::vector& lz_blocks, - const std::uint32_t min_period = 1, const std::uint32_t max_period = -1) + const std::uint32_t min_period = 1, + const std::uint32_t max_period = std::numeric_limits::max()) { - std::deque result {}; - + LMRVector result {}; for (std::size_t h {1}; h < lz_blocks.size(); ++h) { - const auto u = lz_blocks[h].pos; - const auto n = lz_blocks[h].length; - const auto m = std::min(u, 2 * lz_blocks[h - 1].length + n); - const auto t = u - m; - const auto end = u + n; - // rightmax periodicities - for (auto j = min_period; j <= std::min(n, max_period); ++j) { - const auto ls = detail::backward_lce(str, u - 1, u + j - 1, t); - const auto lp = detail::forward_lce(str, u + j, u, end); - if (ls > 0 && ls + lp >= j && j + lp < n) { - result.emplace_back(u - ls, j + lp + ls, j); - } - } - // leftmax periodicities - for (auto j = min_period; j < std::min(m, max_period); ++j) { - const auto ls = detail::backward_lce(str, u - j - 1, u - 1, t); - const auto lp = detail::forward_lce(str, u, u - j, end); - if (ls + lp >= j) { - result.emplace_back(u - (ls + j), j + lp + ls, j); - } - } + add_maximal_periodicities(str, lz_blocks[h - 1], lz_blocks[h], min_period, max_period, result); } - return result; } // just reserves enough space to avoid reallocations -std::vector> get_init_buckets(std::size_t n, const std::deque& lmrs); +std::vector> get_init_buckets(std::size_t n, const LMRVector& lmrs); template std::vector> @@ -333,18 +332,21 @@ extract_maximal_repetitions(const T& str, const std::uint32_t min_period, const std::vector lz_blocks; std::vector prev_lz_block_occurrence; std::tie(lz_blocks, prev_lz_block_occurrence) = lempel_ziv_factorisation_with_prev_block_occurences(str); - auto sorted_buckets = get_sorted_buckets(str, lz_blocks, min_period, max_period); - for (std::size_t k {0}; k < lz_blocks.size(); ++k) { const auto& block = lz_blocks[k]; const auto block_end = block.pos + block.length; static constexpr auto sentinal = std::numeric_limits::max(); const auto delta = block.pos - ((prev_lz_block_occurrence[k] != sentinal) ? prev_lz_block_occurrence[k] : 0); - const auto v = block_end - delta; + const auto max_target_end = block_end - delta; for (auto j = block.pos; j < block_end; ++j) { - const auto& target = sorted_buckets[j - delta]; - const auto last_target_itr = std::lower_bound(std::cbegin(target), std::cend(target), v, + const auto target_start = j - delta; + const auto& target = sorted_buckets[target_start]; + auto target_end = max_target_end; + if (!sorted_buckets[j].empty()) { + target_end = std::min(target_start + sorted_buckets[j].front().length, max_target_end); + } + const auto last_target_itr = std::lower_bound(std::cbegin(target), std::cend(target), target_end, [] (const auto& run, const auto val) noexcept { return run.pos + run.length < val; }); @@ -360,7 +362,6 @@ extract_maximal_repetitions(const T& str, const std::uint32_t min_period, const } } } - return sorted_buckets; } @@ -373,13 +374,27 @@ auto count_runs(const std::vector& buckets) noexcept }); } +template +std::vector +extract_exact_tandem_repeats_lz(const T& str, const std::uint32_t min_period, const std::uint32_t max_period) +{ + auto sorted_buckets = detail::extract_maximal_repetitions(str, min_period, max_period); + std::vector result {}; + result.reserve(detail::count_runs(sorted_buckets)); + for (auto& bucket : sorted_buckets) { + result.insert(std::end(result), std::cbegin(bucket), std::cend(bucket)); + bucket.clear(); + bucket.shrink_to_fit(); + } + return result; +} + template std::vector extract_homopolymers(const ForwardIt first, const ForwardIt last, const std::size_t reserve_hint = 0) { std::vector result {}; result.reserve(reserve_hint); - for (auto curr = first; curr != last; ) { const auto it = std::adjacent_find(curr, last); if (it == last) break; @@ -390,9 +405,6 @@ extract_homopolymers(const ForwardIt first, const ForwardIt last, const std::siz std::uint32_t {1}); curr = it2; } - - result.shrink_to_fit(); - return result; } @@ -401,13 +413,11 @@ std::vector extract_exact_tandem_repeats(const ForwardIt first, const ForwardIt last) { std::vector result {}; - const auto length = static_cast(std::distance(first, last)); if (length < 2 * N) return result; auto it1 = std::adjacent_find(first, last, std::not_equal_to {}); if (it1 == last) return result; result.reserve(std::min(length / N, std::size_t {1024})); - for (auto it2 = std::next(it1, N); it2 < last; ) { const auto p = std::mismatch(it2, last, it1); if (p.second >= it2) { @@ -422,9 +432,6 @@ extract_exact_tandem_repeats(const ForwardIt first, const ForwardIt last) if (it1 == last) break; it2 = std::next(it1, N); } - - result.shrink_to_fit(); - return result; } @@ -458,53 +465,51 @@ void merge(Container2&& src, Container1& dst) }); } +template +std::vector +extract_exact_tandem_repeats_naive(const T& str, std::uint32_t min_period, const std::uint32_t max_period) +{ + assert(max_period <= 3); + if (min_period == max_period) { + switch(min_period) { + case 1: return detail::extract_homopolymers(str); + case 2: return detail::extract_exact_dinucleotide_tandem_repeats(str); + case 3: return detail::extract_exact_trinucleotide_tandem_repeats(str); + } + } + using detail::merge; + if (min_period == 1) { // known max_period >= 2 + auto result = detail::extract_homopolymers(str); + merge(detail::extract_exact_dinucleotide_tandem_repeats(str), result); + if (max_period == 3) { + merge(detail::extract_exact_trinucleotide_tandem_repeats(str), result); + } + return result; + } else { // min_period == 2 && max_period == 3 + auto result = detail::extract_exact_dinucleotide_tandem_repeats(str); + merge(detail::extract_exact_trinucleotide_tandem_repeats(str), result); + return result; + } +} + } // namespace detail template std::vector -extract_exact_tandem_repeats(const T& str, std::uint32_t min_period = 1, const std::uint32_t max_period = -1) +extract_exact_tandem_repeats(const T& str, + std::uint32_t min_period = 1, + const std::uint32_t max_period = std::numeric_limits::max()) { if (min_period == 0) ++min_period; if (str.empty() || str.size() < min_period) return {}; if (min_period > max_period) { throw std::domain_error {"find_maximal_repetitions: given unsatisfiable condition min_period > max_period"}; } - if (max_period <= 3) { // The naive algorithm is faster in these cases - if (min_period == max_period) { - switch(min_period) { - case 1: return detail::extract_homopolymers(str); - case 2: return detail::extract_exact_dinucleotide_tandem_repeats(str); - case 3: return detail::extract_exact_trinucleotide_tandem_repeats(str); - } - } - using detail::merge; - if (min_period == 1) { // known max_period >= 2 - auto result = detail::extract_homopolymers(str); - merge(detail::extract_exact_dinucleotide_tandem_repeats(str), result); - if (max_period == 3) { - merge(detail::extract_exact_trinucleotide_tandem_repeats(str), result); - } - return result; - } else { // min_period == 2 && max_period == 3 - auto result = detail::extract_exact_dinucleotide_tandem_repeats(str); - merge(detail::extract_exact_trinucleotide_tandem_repeats(str), result); - return result; - } - } - - auto sorted_buckets = detail::extract_maximal_repetitions(str, min_period, max_period); - - std::vector result {}; - result.reserve(detail::count_runs(sorted_buckets)); - - for (auto& bucket : sorted_buckets) { - result.insert(std::end(result), std::cbegin(bucket), std::cend(bucket)); - bucket.clear(); - bucket.shrink_to_fit(); + return detail::extract_exact_tandem_repeats_naive(str, min_period, max_period); + } else { + return detail::extract_exact_tandem_repeats_lz(str, min_period, max_period); } - - return result; } /** @@ -526,7 +531,6 @@ std::map collapse(SequenceType& sequence, const char c std::map result {}; const auto last = std::end(sequence); std::size_t position {0}, num_removed {0}; - for (auto first = std::begin(sequence); first != last;) { const auto it1 = std::adjacent_find(first, last, [c] (const char lhs, const char rhs) noexcept { @@ -539,7 +543,6 @@ std::map collapse(SequenceType& sequence, const char c result.emplace(position, num_removed); first = it2; } - if (!result.empty()) { sequence.erase(std::unique(std::next(std::begin(sequence), std::cbegin(result)->first), last, [c] (const char lhs, const char rhs) noexcept { @@ -547,7 +550,6 @@ std::map collapse(SequenceType& sequence, const char c }), last); } - return result; } diff --git a/scripts/install.py b/scripts/install.py new file mode 100755 index 000000000..fe2bff8a7 --- /dev/null +++ b/scripts/install.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 + +import os +import os.path +import sys +from subprocess import call +import platform +import argparse +from shutil import move, rmtree +import multiprocessing +import urllib.request + +google_cloud_octopus_base = "https://storage.googleapis.com/luntergroup/octopus" +forest_url_base = os.path.join(google_cloud_octopus_base, "forests") +forests = ['germline', 'somatic'] + +def is_unix(): + system = platform.system() + return system == "Darwin" or system == "Linux" + +def download_file(url, file_name): + urllib.request.urlretrieve(url, file_name) + +def download_forests(forest_dir): + if not os.path.exists(forest_dir): + print("No forest directory found, making one") + os.makedirs(forest_dir) + for forest in forests: + forest_name = forest + '.forest' + forest_url = os.path.join(forest_url_base, forest_name) + forest_file = os.path.join(forest_dir, forest_name) + try: + print("Downloading " + forest_url + " to " + forest_file) + download_file(forest_url, forest_file) + except: + print("Failed to download forest " + forest_name) + +def main(args): + script_dir = os.path.dirname(os.path.realpath(__file__)) + octopus_dir = os.path.dirname(script_dir) + root_cmake = os.path.join(octopus_dir, "CMakeLists.txt") + + if not os.path.exists(root_cmake): + print("octopus source directory corrupted: root CMakeLists.txt is missing. Please re-download source code.") + exit(1) + + octopus_build_dir = os.path.join(octopus_dir, "build") + + if not os.path.exists(octopus_build_dir): + print("octopus source directory corrupted: build directory is missing. Please re-download source code.") + exit(1) + + bin_dir = os.path.join(octopus_dir, "bin") + + if not os.path.exists(bin_dir): + print("No bin directory found, making one") + os.makedirs(bin_dir) + + if args["clean"]: + print("Cleaning build directory") + move(os.path.join(octopus_build_dir, "cmake"), os.path.join(octopus_dir, "cmake")) + rmtree(octopus_build_dir) + os.makedirs(octopus_build_dir) + move(os.path.join(octopus_dir, "cmake"), os.path.join(octopus_build_dir, "cmake")) + + cmake_cache_file = "CMakeCache.txt" + os.chdir(octopus_build_dir) # so cmake doesn't pollute root directory + + if not args["keep_cache"] and os.path.exists(cmake_cache_file): + os.remove(cmake_cache_file) + + cmake_options = [] + if args["root"]: + cmake_options.extend(["-DINSTALL_ROOT=ON", octopus_dir]) + if args["c_compiler"]: + cmake_options.append("-DCMAKE_C_COMPILER=" + args["c_compiler"]) + if args["cxx_compiler"]: + cmake_options.append("-DCMAKE_CXX_COMPILER=" + args["cxx_compiler"]) + if args["debug"]: + cmake_options.append("-DCMAKE_BUILD_TYPE=Debug") + elif args["sanitize"]: + cmake_options.append("-DCMAKE_BUILD_TYPE=RelWithDebInfo") + else: + cmake_options.append("-DCMAKE_BUILD_TYPE=Release") + if args["static"]: + cmake_options.append("-DBUILD_SHARED_LIBS=OFF") + if args["boost"]: + cmake_options.append("-DBOOST_ROOT=" + args["boost"]) + if args["verbose"]: + cmake_options.append("CMAKE_VERBOSE_MAKEFILE:BOOL=ON") + + ret = call(["cmake"] + cmake_options + [".."]) + + if ret == 0: + make_options = [] + if args["threads"]: + if (args["threads"] > 1): + make_options.append("-j" + str(args["threads"])) + else: + make_options.append("-j" + str(multiprocessing.cpu_count())) + + if is_unix(): + ret = call(["make", "install"] + make_options) + else: + print("Windows make files not supported. Build files have been written to " + octopus_build_dir) + + if args["download"]: + if len(forests) > 0: + forest_dir = os.path.join(octopus_dir, "resources/forests") + download_forests(forest_dir) + + sys.exit(ret) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--clean', + help='Do a clean install', + action='store_true') + parser.add_argument('--root', + help='Install into /usr/local/bin', + action='store_true') + parser.add_argument('-c', '--c_compiler', + help='C compiler path to use') + parser.add_argument('-cxx', '--cxx_compiler', + help='C++ compiler path to use') + parser.add_argument('--keep_cache', + help='Do not refresh CMake cache', + action='store_true') + parser.add_argument('--debug', + help='Builds in debug mode', + action='store_true') + parser.add_argument('--sanitize', + help='Builds in release mode with sanitize flags', + action='store_true') + parser.add_argument('--static', + help='Builds using static libraries', + action='store_true') + parser.add_argument('--threads', + help='The number of threads to use for building', + type=int) + parser.add_argument('--boost', + help='The Boost library root') + parser.add_argument('--download', + help='Try to download octopus classifiers', + action='store_true') + parser.add_argument('--verbose', + help='Ouput verbose make information', + action='store_true') + args = vars(parser.parse_args()) + main(args) diff --git a/scripts/train_random_forest.py b/scripts/train_random_forest.py new file mode 100755 index 000000000..5efc0d41f --- /dev/null +++ b/scripts/train_random_forest.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 + +import argparse +from os import makedirs, remove +from os.path import join, basename, exists +from subprocess import call +import csv +from pysam import VariantFile +import random +import numpy as np + +def run_octopus(octopus, ref_path, bam_path, regions_bed, measures, threads, out_path): + call([octopus, '-R', ref_path, '-I', bam_path, '-t', regions_bed, '-o', out_path, '--threads', str(threads), + '--legacy', '--csr-train'] + measures) + +def get_reference_id(ref_path): + return basename(ref_path).replace(".fasta", "") + +def get_bam_id(bam_path): + return basename(bam_path).replace(".bam", "") + +def call_variants(octopus, ref_path, bam_path, regions_bed, measures, threads, out_dir): + ref_id = get_reference_id(ref_path) + bam_id = get_bam_id(bam_path) + out_vcf = join(out_dir, "octopus." + bam_id + "." + ref_id + ".vcf.gz") + run_octopus(octopus, ref_path, bam_path, regions_bed, measures, threads, out_vcf) + legacy_vcf = out_vcf.replace(".vcf.gz", ".legacy.vcf.gz") + return legacy_vcf + +def run_rtg(rtg, rtg_ref_path, truth_vcf_path, confident_bed_path, octopus_vcf_path, out_dir): + call([rtg, 'vcfeval', '-b', truth_vcf_path, '-t', rtg_ref_path, '--evaluation-regions', confident_bed_path, + '--ref-overlap', '-c', octopus_vcf_path, '-o', out_dir]) + +def eval_octopus(octopus, ref_path, bam_path, regions_bed, measures, threads, + rtg, rtg_ref_path, truth_vcf_path, confident_bed_path, out_dir): + octopus_vcf = call_variants(octopus, ref_path, bam_path, regions_bed, measures, threads, out_dir) + rtf_eval_dir = join(out_dir, basename(octopus_vcf).replace(".legacy.vcf.gz", ".eval")) + run_rtg(rtg, rtg_ref_path, truth_vcf_path, confident_bed_path, octopus_vcf, rtf_eval_dir) + return rtf_eval_dir + +def is_missing(x): + if x == '.': + return True + x = float(x) + return np.isnan(x) + +def to_str(x, missing_value): + if is_missing(x): + return str(missing_value) + else: + return str(x) + +def get_field(field, rec, missing_value): + if field == 'QUAL': + return to_str(rec.qual, missing_value) + elif field == 'GQ': + return to_str(rec.samples[0]['GQ'], missing_value) + else: + val = rec.info[field] + if type(val) == tuple: + val = val[0] + return to_str(val, missing_value) + +def subset(vcf_in_path, vcf_out_path, bed_regions): + call(['bcftools', 'view', '-R', bed_regions, '-O', 'z', '-o', vcf_out_path, vcf_in_path]) + +def make_ranger_data(octopus_vcf_path, measures, is_tp, out, missing_value): + vcf = VariantFile(octopus_vcf_path) + with open(out, 'w') as ranger_dat: + datwriter = csv.writer(ranger_dat, delimiter=' ') + for rec in vcf: + row = [get_field(measure, rec, missing_value) for measure in measures] + if is_tp: + row.append('1') + else: + row.append('0') + datwriter.writerow(row) + +def concat(filenames, outpath): + with open(outpath, 'w') as outfile: + for fname in filenames: + with open(fname) as infile: + for line in infile: + outfile.write(line) + +def shuffle(fname): + lines = open(fname).readlines() + random.shuffle(lines) + open(fname, 'w').writelines(lines) + +def add_header(fname, header): + lines = open(fname).readlines() + with open(fname, 'w') as f: + f.write(header + '\n') + f.writelines(lines) + +def run_ranger_training(ranger, data_path, n_trees, min_node_size, threads, out): + call([ranger, '--file', data_path, '--depvarname', 'TP', '--probability', + '--ntree', str(n_trees), '--targetpartitionsize', str(min_node_size), + '--nthreads', str(threads), '--outprefix', out, '--write', '--verbose']) + +def main(options): + if not exists(options.out): + makedirs(options.out) + rtg_eval_dirs = [] + for bam_path in options.reads: + rtg_eval_dirs.append(eval_octopus(options.octopus, options.reference, bam_path, options.regions, options.measures, + options.threads, options.rtg, options.sdf, options.truth, options.confident, + options.out)) + data_paths = [] + tmp_paths = [] + for rtg_eval in rtg_eval_dirs: + tp_vcf_path = join(rtg_eval, "tp.vcf.gz") + tp_train_vcf_path = tp_vcf_path.replace("tp.vcf", "tp.train.vcf") + subset(tp_vcf_path, tp_train_vcf_path, options.regions) + tp_data_path = tp_train_vcf_path.replace(".vcf.gz", ".dat") + make_ranger_data(tp_train_vcf_path, options.measures, True, tp_data_path, options.missing_value) + data_paths.append(tp_data_path) + fp_vcf_path = join(rtg_eval, "fp.vcf.gz") + fp_train_vcf_path = fp_vcf_path.replace("fp.vcf", "fp.train.vcf") + subset(fp_vcf_path, fp_train_vcf_path, options.regions) + fp_data_path = fp_train_vcf_path.replace(".vcf.gz", ".dat") + make_ranger_data(fp_train_vcf_path, options.measures, False, fp_data_path, options.missing_value) + data_paths.append(fp_data_path) + tmp_paths += [tp_train_vcf_path, fp_train_vcf_path] + master_data_path = join(options.out, options.prefix + ".dat") + concat(data_paths, master_data_path) + for path in data_paths: + remove(path) + for path in tmp_paths: + remove(path) + shuffle(master_data_path) + ranger_header = ' '.join(options.measures + ['TP']) + add_header(master_data_path, ranger_header) + ranger_out_prefix = join(options.out, options.prefix) + run_ranger_training(options.ranger, master_data_path, options.trees, options.min_node_size, options.threads, ranger_out_prefix) + remove(ranger_out_prefix + ".confusion") + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-R', '--reference', + type=str, + required=True, + help='Reference to use for calling') + parser.add_argument('-I', '--reads', + nargs='+', + type=str, + required=True, + help='Input BAM files') + parser.add_argument('-T', '--regions', + type=str, + required=True, + help='BED files containing regions to call') + parser.add_argument('--measures', + type=str, + nargs='+', + required=True, + help='Measures to use for training') + parser.add_argument('--truth', + type=str, + required=True, + help='Truth VCF file') + parser.add_argument('--confident', + type=str, + required=True, + help='BED files containing high confidence truth regions') + parser.add_argument('--octopus', + type=str, + required=True, + help='Octopus binary') + parser.add_argument('--rtg', + type=str, + required=True, + help='RTG Tools binary') + parser.add_argument('--sdf', + type=str, + required=True, + help='RTG Tools SDF reference index') + parser.add_argument('--ranger', + type=str, + required=True, + help='Ranger binary') + parser.add_argument('--trees', + type=int, + default=300, + help='Number of trees to use in the random forest') + parser.add_argument('--min_node_size', + type=int, + default=20, + help='Node size to stop growing trees, implicitly limiting tree depth') + parser.add_argument('-o', '--out', + type=str, + help='Output directory') + parser.add_argument('--prefix', + type=str, + default='ranger_octopus', + help='Output files prefix') + parser.add_argument('-t', '--threads', + type=int, + default=1, + help='Number of threads for octopus') + parser.add_argument('--missing_value', + type=float, + default=-1, + help='Value for missing measures') + parsed, unparsed = parser.parse_known_args() + main(parsed) diff --git a/scripts/train_somatic_random_forest.py b/scripts/train_somatic_random_forest.py new file mode 100755 index 000000000..63a8e0140 --- /dev/null +++ b/scripts/train_somatic_random_forest.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +import argparse +from os import makedirs, remove +from os.path import join, basename, exists +from subprocess import call +import csv +from pysam import VariantFile +import random +import numpy as np + +def is_homozygous(gt): + return all(a == 1 for a in gt) + +def to_float(val): + if val == '.': + return np.nan + else: + try: + return float(val) + except ValueError: + return val + +def get_value(rec, feature): + if feature == "POS": + return rec.pos + elif feature == "REF": + return rec.ref + elif feature == "ALT": + return list(rec.alts) + elif feature == "QUAL": + return rec.qual + elif feature == 'AS': + return max([len(allele) for allele in rec.alleles]) + elif feature in rec.info: + val = rec.info[feature] + if type(val) == tuple: + return to_float(val[0]), to_float(val[1]) + else: + return to_float(val) + elif feature in rec.samples[0]: + if feature == 'GT': + return int(is_homozygous(rec.samples[0][feature])) + else: + return int(rec.samples[0][feature]) + else: + return np.nan + +def is_somatic(rec): + return any(get_value(rec, 'SOMATIC')) + +def filter_somatic(in_vcf_path, out_vcf_path): + in_vcf = VariantFile(in_vcf_path) + out_vcf = VariantFile(out_vcf_path, 'w', header=in_vcf.header) + num_skipped_records = 0 + for rec in in_vcf: + if is_somatic(rec): + try: + out_vcf.write(rec) + except OSError: + num_skipped_records += 1 + print("Skipped " + str(num_skipped_records) + " bad records") + in_vcf.close() + out_vcf.close() + +def classify_calls(somatic_vcf, truth_vcf, temp_dir, regions_bed=None): + tp_cmd = ['bcftools', 'isec'] + if regions_bed is not None: + tp_cmd += ['-T', regions_bed] + tp_cmd += ['-n', '=2', '-w', '1'] + tp_vcf = join(temp_dir, 'tp.vcf.gz') + tp_cmd += ['-Oz', '-o', tp_vcf] + tp_cmd += [somatic_vcf, truth_vcf] + call(tp_cmd) + fp_cmd = ['bcftools', 'isec', '-C'] + if regions_bed is not None: + fp_cmd += ['-T', regions_bed] + fp_cmd += ['-n', '=1', '-w', '1'] + fp_vcf = join(temp_dir, 'fp.vcf.gz') + fp_cmd += ['-Oz', '-o', fp_vcf] + fp_cmd += [somatic_vcf, truth_vcf] + call(fp_cmd) + return tp_vcf, fp_vcf + +def subset(vcf_in_path, vcf_out_path, bed_regions): + call(['bcftools', 'view', '-R', bed_regions, '-O', 'z', '-o', vcf_out_path, vcf_in_path]) + +def is_missing(x): + if x == '.': + return True + x = float(x) + return np.isnan(x) + +def to_str(x, missing_value): + if is_missing(x): + return str(missing_value) + else: + return str(x) + +def get_data(rec, features, n_samples, missing_value): + result = [[] for _ in range(n_samples)] + for feature in features: + value = get_value(rec, feature) + if type(value) == tuple: + assert len(value) == n_samples + result = [curr + [to_str(v, missing_value)] for curr, v in zip(result, value)] + else: + value_str = to_str(value, missing_value) + for d in result: + d.append(value_str) + return result + +def make_ranger_data(octopus_vcf_path, measures, is_tp, out, missing_value): + vcf = VariantFile(octopus_vcf_path) + n_samples = len(vcf.header.samples) + n_records = 0 + with open(out, 'a') as ranger_dat: + datwriter = csv.writer(ranger_dat, delimiter=' ') + for rec in vcf: + for row in get_data(rec, measures, n_samples, missing_value): + if is_tp: + row.append('1') + else: + row.append('0') + datwriter.writerow(row) + n_records += 1 + return n_records + +def concat(filenames, outpath): + with open(outpath, 'w') as outfile: + for fname in filenames: + with open(fname) as infile: + for line in infile: + outfile.write(line) + +def shuffle(fname): + lines = open(fname).readlines() + random.shuffle(lines) + open(fname, 'w').writelines(lines) + +def add_header(fname, header): + lines = open(fname).readlines() + with open(fname, 'w') as f: + f.write(header + '\n') + f.writelines(lines) + +def run_ranger_training(ranger, data_path, n_trees, min_node_size, threads, out): + call([ranger, '--file', data_path, '--depvarname', 'TP', '--probability', + '--ntree', str(n_trees), '--targetpartitionsize', str(min_node_size), + '--nthreads', str(threads), '--outprefix', out, '--write', '--verbose']) + +def main(options): + if not exists(options.out): + makedirs(options.out) + somatic_vcf_path = join(options.out, basename(options.variants).replace('.vcf', 'SOMATIC.tmp.vcf')) + filter_somatic(options.variants, somatic_vcf_path) + call(['tabix', somatic_vcf_path]) + tp_vcf_path, fp_vcf_path = classify_calls(somatic_vcf_path, options.truth, options.out, options.regions) + remove(somatic_vcf_path) + remove(somatic_vcf_path + ".tbi") + master_data_path = join(options.out, options.prefix + ".dat") + num_tps = make_ranger_data(tp_vcf_path, options.measures, True, master_data_path, options.missing_value) + num_fps = make_ranger_data(fp_vcf_path, options.measures, False, master_data_path, options.missing_value) + remove(tp_vcf_path) + remove(fp_vcf_path) + print("Number of TP examples: " + str(num_tps)) + print("Number of FP examples: " + str(num_fps)) + shuffle(master_data_path) + ranger_header = ' '.join(options.measures + ['TP']) + add_header(master_data_path, ranger_header) + ranger_out_prefix = join(options.out, options.prefix) + run_ranger_training(options.ranger, master_data_path, options.trees, options.min_node_size, options.threads, ranger_out_prefix) + remove(ranger_out_prefix + ".confusion") + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-V', '--variants', + type=str, + required=True, + help='Octopus calls with CSR annotations') + parser.add_argument('-T', '--regions', + type=str, + required=True, + help='BED files containing regions to use') + parser.add_argument('--truth', + type=str, + required=True, + help='Truth VCF file') + parser.add_argument('--measures', + type=str, + nargs='+', + required=True, + help='Measures to use for training') + parser.add_argument('--ranger', + type=str, + required=True, + help='Ranger binary') + parser.add_argument('--trees', + type=int, + default=300, + help='Number of trees to use in the random forest') + parser.add_argument('--min_node_size', + type=int, + default=20, + help='Node size to stop growing trees, implicitly limiting tree depth') + parser.add_argument('-o', '--out', + type=str, + help='Output directory') + parser.add_argument('--prefix', + type=str, + default='ranger_octopus', + help='Output files prefix') + parser.add_argument('-t', '--threads', + type=int, + default=1, + help='Number of threads for octopus') + parser.add_argument('--missing_value', + type=float, + default=-1, + help='Value for missing measures') + parsed, unparsed = parser.parse_known_args() + main(parsed) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ae35ae8a4..512d1c76a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,6 +25,8 @@ set(EXCEPTIONS_SOURCES exceptions/missing_index_error.cpp exceptions/unwritable_file_error.hpp exceptions/unwritable_file_error.cpp + exceptions/unimplemented_feature_error.hpp + exceptions/unimplemented_feature_error.cpp ) set(CONCEPTS_SOURCES @@ -49,6 +51,8 @@ set(BASICS_SOURCES basics/pedigree.cpp basics/trio.hpp basics/trio.cpp + basics/read_pileup.hpp + basics/read_pileup.cpp ) set(CONTAINERS_SOURCES @@ -94,6 +98,8 @@ set(IO_SOURCES io/read/read_reader_impl.hpp io/read/read_reader.hpp io/read/read_reader.cpp + io/read/read_writer.hpp + io/read/read_writer.cpp io/variant/htslib_bcf_facade.hpp io/variant/htslib_bcf_facade.cpp @@ -156,8 +162,8 @@ set(UTILS_SOURCES utils/timing.hpp utils/type_tricks.hpp utils/coverage_tracker.hpp - utils/read_size_estimator.hpp - utils/read_size_estimator.cpp + utils/input_reads_profiler.hpp + utils/input_reads_profiler.cpp utils/kmer_mapper.hpp utils/kmer_mapper.cpp utils/memory_footprint.hpp @@ -171,6 +177,10 @@ set(UTILS_SOURCES utils/parallel_transform.hpp utils/thread_pool.hpp utils/thread_pool.cpp + utils/concat.hpp + utils/select_top_k.hpp + utils/system_utils.hpp + utils/system_utils.cpp ) set(CORE_SOURCES @@ -188,6 +198,8 @@ set(CORE_SOURCES core/callers/population_caller.cpp core/callers/trio_caller.hpp core/callers/trio_caller.cpp + core/callers/polyclone_caller.hpp + core/callers/polyclone_caller.cpp core/types/calls/call_types.hpp core/types/calls/call.hpp @@ -221,6 +233,12 @@ set(CORE_SOURCES core/csr/facets/read_assignments.cpp core/csr/facets/reference_context.hpp core/csr/facets/reference_context.cpp + core/csr/facets/genotypes.hpp + core/csr/facets/genotypes.cpp + core/csr/facets/ploidies.hpp + core/csr/facets/ploidies.cpp + core/csr/facets/pedigree.hpp + core/csr/facets/pedigree.cpp core/csr/facets/facet_factory.hpp core/csr/facets/facet_factory.cpp @@ -244,6 +262,24 @@ set(CORE_SOURCES core/csr/filters/passing_filter.cpp core/csr/filters/training_filter_factory.hpp core/csr/filters/training_filter_factory.cpp + core/csr/filters/conditional_threshold_filter.hpp + core/csr/filters/conditional_threshold_filter.cpp + core/csr/filters/somatic_threshold_filter.hpp + core/csr/filters/somatic_threshold_filter.cpp + core/csr/filters/denovo_threshold_filter.hpp + core/csr/filters/denovo_threshold_filter.cpp + core/csr/filters/random_forest_filter.hpp + core/csr/filters/random_forest_filter.cpp + core/csr/filters/random_forest_filter_factory.hpp + core/csr/filters/random_forest_filter_factory.cpp + core/csr/filters/conditional_random_forest_filter.hpp + core/csr/filters/conditional_random_forest_filter.cpp + core/csr/filters/somatic_random_forest_filter.hpp + core/csr/filters/somatic_random_forest_filter.cpp + core/csr/filters/denovo_random_forest_filter.hpp + core/csr/filters/denovo_random_forest_filter.cpp + core/csr/filters/variant_filter_utils.hpp + core/csr/filters/variant_filter_utils.cpp core/csr/measures/measure.hpp core/csr/measures/measure.cpp @@ -253,8 +289,8 @@ set(CORE_SOURCES core/csr/measures/depth.cpp core/csr/measures/quality_by_depth.hpp core/csr/measures/quality_by_depth.cpp - core/csr/measures/max_genotype_quality.hpp - core/csr/measures/max_genotype_quality.cpp + core/csr/measures/genotype_quality.hpp + core/csr/measures/genotype_quality.cpp core/csr/measures/mapping_quality_zero_count.hpp core/csr/measures/mapping_quality_zero_count.cpp core/csr/measures/mean_mapping_quality.hpp @@ -280,18 +316,38 @@ set(CORE_SOURCES core/csr/measures/measures_fwd.hpp core/csr/measures/measure_factory.hpp core/csr/measures/measure_factory.cpp - core/csr/measures/realignments.hpp - core/csr/measures/realignments.cpp - core/csr/measures/unassigned_read_fraction.hpp - core/csr/measures/unassigned_read_fraction.cpp - + core/csr/measures/ambiguous_read_fraction.hpp + core/csr/measures/ambiguous_read_fraction.cpp + core/csr/measures/median_base_quality.hpp + core/csr/measures/median_base_quality.cpp + core/csr/measures/mismatch_count.hpp + core/csr/measures/mismatch_count.cpp + core/csr/measures/mismatch_fraction.hpp + core/csr/measures/mismatch_fraction.cpp + core/csr/measures/is_refcall.hpp + core/csr/measures/is_refcall.cpp + core/csr/measures/somatic_contamination.hpp + core/csr/measures/somatic_contamination.cpp + core/csr/measures/denovo_contamination.hpp + core/csr/measures/denovo_contamination.cpp + core/csr/measures/read_position_bias.hpp + core/csr/measures/read_position_bias.cpp + core/csr/measures/alt_allele_count.hpp + core/csr/measures/alt_allele_count.cpp + core/csr/measures/overlaps_tandem_repeat.hpp + core/csr/measures/overlaps_tandem_repeat.cpp + core/csr/measures/str_length.hpp + core/csr/measures/str_length.cpp + core/csr/measures/str_period.hpp + core/csr/measures/str_period.cpp + core/models/haplotype_likelihood_cache.hpp core/models/haplotype_likelihood_cache.cpp core/models/haplotype_likelihood_model.hpp core/models/haplotype_likelihood_model.cpp - - core/models/genotype/cnv_model.hpp - core/models/genotype/cnv_model.cpp + + core/models/genotype/subclone_model.hpp + core/models/genotype/subclone_model.cpp core/models/genotype/germline_likelihood_model.hpp core/models/genotype/germline_likelihood_model.cpp core/models/genotype/individual_model.hpp @@ -314,6 +370,8 @@ set(CORE_SOURCES core/models/genotype/uniform_population_prior_model.hpp core/models/genotype/coalescent_population_prior_model.hpp core/models/genotype/coalescent_population_prior_model.cpp + core/models/genotype/hardy_weinberg_model.hpp + core/models/genotype/hardy_weinberg_model.cpp core/models/pairhmm/pair_hmm.hpp core/models/pairhmm/pair_hmm.cpp @@ -341,7 +399,13 @@ set(CORE_SOURCES core/models/mutation/coalescent_model.cpp core/models/mutation/denovo_model.hpp core/models/mutation/denovo_model.cpp + core/models/mutation/indel_mutation_model.hpp + core/models/mutation/indel_mutation_model.cpp + + core/models/reference/individual_reference_likelihood_model.hpp + core/models/reference/individual_reference_likelihood_model.cpp + core/tools/coretools.hpp core/tools/haplotype_filter.hpp core/tools/haplotype_filter.cpp @@ -349,6 +413,8 @@ set(CORE_SOURCES core/tools/read_assigner.cpp core/tools/read_realigner.hpp core/tools/read_realigner.cpp + core/tools/bam_realigner.hpp + core/tools/bam_realigner.cpp core/tools/hapgen/genome_walker.hpp core/tools/hapgen/genome_walker.cpp @@ -356,6 +422,8 @@ set(CORE_SOURCES core/tools/hapgen/haplotype_generator.cpp core/tools/hapgen/haplotype_tree.hpp core/tools/hapgen/haplotype_tree.cpp + core/tools/hapgen/dense_variation_detector.hpp + core/tools/hapgen/dense_variation_detector.cpp core/tools/phaser/phaser.hpp core/tools/phaser/phaser.cpp @@ -374,6 +442,8 @@ set(CORE_SOURCES core/tools/vargen/vcf_extractor.cpp core/tools/vargen/variant_generator_builder.hpp core/tools/vargen/variant_generator_builder.cpp + core/tools/vargen/active_region_generator.hpp + core/tools/vargen/active_region_generator.cpp core/tools/vargen/utils/assembler.hpp core/tools/vargen/utils/assembler.cpp @@ -381,6 +451,8 @@ set(CORE_SOURCES core/tools/vargen/utils/global_aligner.cpp core/tools/vargen/utils/assembler_active_region_generator.hpp core/tools/vargen/utils/assembler_active_region_generator.cpp + core/tools/vargen/utils/misaligned_reads_detector.hpp + core/tools/vargen/utils/misaligned_reads_detector.cpp core/types/allele.hpp core/types/allele.cpp @@ -422,8 +494,9 @@ set(OCTOPUS_SOURCES set(INCLUDE_SOURCES ${octopus_SOURCE_DIR}/lib/bioio.hpp ${octopus_SOURCE_DIR}/lib/tandem/tandem.hpp - ${octopus_SOURCE_DIR}//lib/ksp/custom_dijkstra_call.hpp - ${octopus_SOURCE_DIR}//lib/ksp/yen_ksp.hpp + ${octopus_SOURCE_DIR}/lib/ksp/custom_dijkstra_call.hpp + ${octopus_SOURCE_DIR}/lib/ksp/yen_ksp.hpp + ${octopus_SOURCE_DIR}/lib/ranger/Forest.h ) set(REQUIRED_BOOST_LIBRARIES @@ -459,14 +532,13 @@ if (BUILD_TESTING) add_library(Octopus ${OCTOPUS_SOURCES} ${INCLUDE_SOURCES}) target_compile_features(Octopus PRIVATE cxx_thread_local) target_include_directories(Octopus PUBLIC ${octopus_SOURCE_DIR}/lib ${octopus_SOURCE_DIR}/src) - target_link_libraries(Octopus tandem) + target_link_libraries(Octopus tandem ranger) target_compile_definitions(Octopus PRIVATE -DBOOST_LOG_DYN_LINK) # Required for log find_package (Boost 1.65 REQUIRED COMPONENTS ${REQUIRED_BOOST_LIBRARIES} REQUIRED) if (Boost_FOUND) target_include_directories (Octopus PUBLIC ${Boost_INCLUDE_DIR}) target_link_libraries (Octopus ${Boost_LIBRARIES}) endif (Boost_FOUND) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/build/cmake/modules/") find_package (HTSlib 1.4 REQUIRED) if (HTSlib_FOUND) target_include_directories (Octopus PUBLIC ${HTSlib_INCLUDE_DIRS}) @@ -476,7 +548,7 @@ elseif (CMAKE_BUILD_TYPE MATCHES Debug) add_executable(octopus-debug main.cpp ${OCTOPUS_SOURCES} ${INCLUDE_SOURCES}) target_compile_features(octopus-debug PRIVATE cxx_thread_local) target_include_directories(octopus-debug PUBLIC ${octopus_SOURCE_DIR}/lib ${octopus_SOURCE_DIR}/src) - target_link_libraries(octopus-debug tandem) + target_link_libraries(octopus-debug tandem ranger) if (NOT BUILD_SHARED_LIBS) message(STATUS "Linking against boost static libraries") set(Boost_USE_STATIC_LIBS ON) @@ -489,7 +561,6 @@ elseif (CMAKE_BUILD_TYPE MATCHES Debug) target_include_directories (octopus-debug PUBLIC ${Boost_INCLUDE_DIR}) target_link_libraries (octopus-debug ${Boost_LIBRARIES}) endif (Boost_FOUND) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/build/cmake/modules/") find_package (HTSlib 1.4 REQUIRED) if (HTSlib_FOUND) target_include_directories (octopus-debug PUBLIC ${HTSlib_INCLUDE_DIRS}) @@ -505,7 +576,7 @@ elseif (CMAKE_BUILD_TYPE MATCHES RelWithDebInfo) target_compile_features(octopus-sanitize PRIVATE cxx_thread_local) target_compile_options(octopus-sanitize PRIVATE ${SanitizeFlags}) target_include_directories(octopus-sanitize PUBLIC ${octopus_SOURCE_DIR}/lib ${octopus_SOURCE_DIR}/src) - target_link_libraries(octopus-sanitize tandem ${SanitizeFlags}) + target_link_libraries(octopus-sanitize tandem ranger ${SanitizeFlags}) if (NOT BUILD_SHARED_LIBS) message(STATUS "Linking against boost static libraries") set(Boost_USE_STATIC_LIBS ON) @@ -518,7 +589,6 @@ elseif (CMAKE_BUILD_TYPE MATCHES RelWithDebInfo) target_include_directories (octopus-sanitize PUBLIC ${Boost_INCLUDE_DIR}) target_link_libraries (octopus-sanitize ${Boost_LIBRARIES} ${SanitizeFlags}) endif (Boost_FOUND) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/build/cmake/modules/") find_package (HTSlib 1.4 REQUIRED) if (HTSlib_FOUND) target_include_directories (octopus-sanitize PUBLIC ${HTSlib_INCLUDE_DIRS}) @@ -533,7 +603,7 @@ else() target_compile_features(octopus PRIVATE cxx_thread_local) target_compile_options(octopus PRIVATE -ffast-math -funroll-loops -march=native) target_include_directories(octopus PUBLIC ${octopus_SOURCE_DIR}/lib ${octopus_SOURCE_DIR}/src) - target_link_libraries(octopus tandem) + target_link_libraries(octopus tandem ranger) if (NOT BUILD_SHARED_LIBS) message(STATUS "Linking against boost static libraries") set(Boost_USE_STATIC_LIBS ON) @@ -546,7 +616,6 @@ else() target_include_directories (octopus PUBLIC ${Boost_INCLUDE_DIR}) target_link_libraries (octopus ${Boost_LIBRARIES}) endif (Boost_FOUND) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/build/cmake/modules/") find_package (HTSlib 1.4 REQUIRED) if (HTSlib_FOUND) target_include_directories (octopus PUBLIC ${HTSlib_INCLUDE_DIRS}) diff --git a/src/basics/aligned_read.cpp b/src/basics/aligned_read.cpp index e7360d5f9..afcea6173 100644 --- a/src/basics/aligned_read.cpp +++ b/src/basics/aligned_read.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "aligned_read.hpp" @@ -113,6 +113,14 @@ AlignedRead::Flags AlignedRead::flags() const noexcept return decompress(flags_); } +void AlignedRead::realign(GenomicRegion new_region, CigarString new_cigar) noexcept +{ + assert(sequence_size(new_cigar) == sequence_.size()); + assert(reference_size(new_cigar) == size(new_region)); + region_ = std::move(new_region); + cigar_ = std::move(new_cigar); +} + bool AlignedRead::is_marked_all_segments_in_read_aligned() const noexcept { return flags_[0]; @@ -158,7 +166,6 @@ bool AlignedRead::is_marked_supplementary_alignment() const noexcept AlignedRead::FlagBits AlignedRead::compress(const Flags& flags) const noexcept { FlagBits result {}; - result[0] = flags.all_segments_in_read_aligned; result[1] = flags.multiple_segment_template; result[2] = flags.unmapped; @@ -167,14 +174,14 @@ AlignedRead::FlagBits AlignedRead::compress(const Flags& flags) const noexcept result[5] = flags.qc_fail; result[6] = flags.duplicate; result[7] = flags.supplementary_alignment; - + result[8] = flags.first_template_segment; + result[9] = flags.last_template_segment; return result; } AlignedRead::Flags AlignedRead::decompress(const FlagBits& flags) const noexcept { - // Note: first_template_segment and last_template_segmenet are not currently used - return {flags[0], flags[1], flags[2], flags[3], flags[4], flags[5], flags[6], flags[7], false, false}; + return {flags[0], flags[1], flags[2], flags[3], flags[4], flags[5], flags[6], flags[7], flags[8], flags[9]}; } AlignedRead::Segment::FlagBits AlignedRead::Segment::compress(const Flags& flags) @@ -188,13 +195,11 @@ AlignedRead::Segment::FlagBits AlignedRead::Segment::compress(const Flags& flags std::size_t ReadHash::operator()(const octopus::AlignedRead &read) const { std::size_t result {}; - using boost::hash_combine; hash_combine(result, std::hash()(read.mapped_region())); hash_combine(result, std::hash()(read.cigar())); hash_combine(result, boost::hash_range(std::cbegin(read.base_qualities()), std::cend(read.base_qualities()))); hash_combine(result, read.mapping_quality()); - return result; } @@ -254,6 +259,16 @@ AlignedRead::NucleotideSequence::size_type sequence_size(const AlignedRead& read return sequence_size(contained_cigar_copy); } +bool is_forward_strand(const AlignedRead& read) noexcept +{ + return read.direction() == AlignedRead::Direction::forward; +} + +bool is_reverse_strand(const AlignedRead& read) noexcept +{ + return !is_forward_strand(read); +} + bool is_soft_clipped(const AlignedRead& read) noexcept { return is_soft_clipped(read.cigar()); @@ -304,19 +319,57 @@ AlignedRead copy(const AlignedRead& read, const GenomicRegion& region) if (contains(region, read)) return read; const auto copy_region = *overlapped_region(read, region); const auto reference_offset = static_cast(begin_distance(read, copy_region)); - const auto uncontained_cigar_copy = copy_reference(read.cigar(), 0, reference_offset); + auto uncontained_cigar_copy = copy_reference(read.cigar(), 0, reference_offset); auto contained_cigar_copy = copy_reference(read.cigar(), reference_offset, region_size(copy_region)); - const auto sequence_offset = sequence_size(uncontained_cigar_copy); - const auto sequence_length = sequence_size(contained_cigar_copy); - assert(sequence_offset + sequence_length <= sequence_size(read)); - const auto subsequence_begin_itr = next(cbegin(read.sequence()), sequence_offset); - const auto subsequence_end_itr = next(subsequence_begin_itr, sequence_length); + if (!uncontained_cigar_copy.empty() && !contained_cigar_copy.empty() + && uncontained_cigar_copy.back() == contained_cigar_copy.front() + && is_insertion(uncontained_cigar_copy.back())) { + uncontained_cigar_copy.pop_back(); + } + const auto copy_offset = sequence_size(uncontained_cigar_copy); + const auto copy_length = sequence_size(contained_cigar_copy); + assert(copy_offset + copy_length <= sequence_size(read)); + const auto subsequence_begin_itr = next(cbegin(read.sequence()), copy_offset); + const auto subsequence_end_itr = next(subsequence_begin_itr, copy_length); AlignedRead::NucleotideSequence sub_sequence {subsequence_begin_itr, subsequence_end_itr}; - const auto subqualities_begin_itr = next(cbegin(read.base_qualities()), sequence_offset); - const auto subqualities_end_itr = next(subqualities_begin_itr, sequence_length); + const auto subqualities_begin_itr = next(cbegin(read.base_qualities()), copy_offset); + const auto subqualities_end_itr = next(subqualities_begin_itr, copy_length); AlignedRead::BaseQualityVector sub_qualities {subqualities_begin_itr, subqualities_end_itr}; return AlignedRead {read.name(), copy_region, std::move(sub_sequence), std::move(sub_qualities), - std::move(contained_cigar_copy), read.mapping_quality(), read.flags()}; + std::move(contained_cigar_copy), read.mapping_quality(), read.flags(), read.read_group()}; +} + +template +T copy_helper(const T& sequence, const CigarString& cigar, const GenomicRegion& sequence_region, const GenomicRegion& request_region) +{ + if (!overlaps(sequence_region, request_region)) {}; + if (contains(request_region, sequence_region)) return sequence; + const auto copy_region = *overlapped_region(sequence_region, request_region); + const auto reference_offset = static_cast(begin_distance(sequence_region, copy_region)); + auto uncontained_cigar_copy = copy_reference(cigar, 0, reference_offset); + auto contained_cigar_copy = copy_reference(cigar, reference_offset, region_size(copy_region)); + if (!uncontained_cigar_copy.empty() && !contained_cigar_copy.empty() + && uncontained_cigar_copy.back() == contained_cigar_copy.front() + && is_insertion(uncontained_cigar_copy.back())) { + uncontained_cigar_copy.pop_back(); + } + const auto copy_offset = sequence_size(uncontained_cigar_copy); + const auto copy_length = sequence_size(contained_cigar_copy); + assert(copy_offset + copy_length <= sequence.size()); + using std::cbegin; using std::next; + const auto subsequence_begin_itr = next(cbegin(sequence), copy_offset); + const auto subsequence_end_itr = next(subsequence_begin_itr, copy_length); + return {subsequence_begin_itr, subsequence_end_itr}; +} + +AlignedRead::NucleotideSequence copy_sequence(const AlignedRead& read, const GenomicRegion& region) +{ + return copy_helper(read.sequence(), read.cigar(), read.mapped_region(), region); +} + +AlignedRead::BaseQualityVector copy_base_qualities(const AlignedRead& read, const GenomicRegion& region) +{ + return copy_helper(read.base_qualities(), read.cigar(), read.mapped_region(), region); } bool operator==(const AlignedRead& lhs, const AlignedRead& rhs) noexcept diff --git a/src/basics/aligned_read.hpp b/src/basics/aligned_read.hpp index f0462d3fa..d7c8ae965 100644 --- a/src/basics/aligned_read.hpp +++ b/src/basics/aligned_read.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef aligned_read_hpp @@ -81,15 +81,16 @@ class AlignedRead : public Comparable, public Mappable AlignedRead() = default; - template - AlignedRead(String_&& name, GenomicRegion_&& reference_region, Seq_&& sequence, Qualities_&& qualities, - CigarString_&& cigar, MappingQuality mapping_quality, const Flags& flags); + template + AlignedRead(String1_&& name, GenomicRegion_&& reference_region, Seq_&& sequence, Qualities_&& qualities, + CigarString_&& cigar, MappingQuality mapping_quality, const Flags& flags, String2_&& read_group); - template + template AlignedRead(String1_&& name, GenomicRegion_&& reference_region, Seq_&& sequence, Qualities_&& qualities, - CigarString_&& cigar, MappingQuality mapping_quality, Flags flags, - String2_&& next_segment_contig_name, MappingDomain::Position next_segment_begin, + CigarString_&& cigar, MappingQuality mapping_quality, Flags flags, String2_&& read_group, + String3_&& next_segment_contig_name, MappingDomain::Position next_segment_begin, MappingDomain::Size inferred_template_length, const Segment::Flags& next_segment_flags); @@ -114,6 +115,8 @@ class AlignedRead : public Comparable, public Mappable const Segment& next_segment() const; Flags flags() const noexcept; + void realign(GenomicRegion new_region, CigarString new_cigar) noexcept; + bool is_marked_all_segments_in_read_aligned() const noexcept; bool is_marked_multiple_segment_template() const noexcept; bool is_marked_unmapped() const noexcept; @@ -124,7 +127,7 @@ class AlignedRead : public Comparable, public Mappable bool is_marked_supplementary_alignment() const noexcept; private: - static constexpr std::size_t numFlags_ = 8; + static constexpr std::size_t numFlags_ = 10; using FlagBits = std::bitset; // should be ordered by sizeof @@ -162,34 +165,35 @@ struct AlignedRead::Flags bool last_template_segment; }; -template +template AlignedRead::AlignedRead(String_&& name, GenomicRegion_&& reference_region, Seq&& sequence, Qualities_&& qualities, - CigarString_&& cigar, MappingQuality mapping_quality, const Flags& flags) + CigarString_&& cigar, MappingQuality mapping_quality, const Flags& flags, String2_&& read_group) : region_ {std::forward(reference_region)} , name_ {std::forward(name)} , sequence_ {std::forward(sequence)} , base_qualities_ {std::forward(qualities)} , cigar_ {std::forward(cigar)} -, read_group_ {} +, read_group_ {std::forward(read_group)} , next_segment_ {} , flags_ {compress(flags)} , mapping_quality_ {mapping_quality} {} template + typename String2_, typename String3_> AlignedRead::AlignedRead(String1_&& name, GenomicRegion_&& reference_region, Seq&& sequence, Qualities_&& qualities, - CigarString_&& cigar, MappingQuality mapping_quality, Flags flags, - String2_&& next_segment_contig_name, MappingDomain::Position next_segment_begin, + CigarString_&& cigar, MappingQuality mapping_quality, Flags flags, String2_&& read_group, + String3_&& next_segment_contig_name, MappingDomain::Position next_segment_begin, MappingDomain::Size inferred_template_length, const Segment::Flags& next_segment_flags) : region_ {std::forward(reference_region)} , name_ {std::forward(name)} , sequence_ {std::forward(sequence)} , base_qualities_ {std::forward(qualities)} , cigar_ {std::forward(cigar)} -, read_group_ {} +, read_group_ {std::forward(read_group)} , next_segment_ { - Segment {std::forward(next_segment_contig_name), next_segment_begin, + Segment {std::forward(next_segment_contig_name), next_segment_begin, inferred_template_length, next_segment_flags} } , flags_ {compress(flags)} @@ -221,6 +225,9 @@ bool is_sequence_empty(const AlignedRead& read) noexcept; AlignedRead::NucleotideSequence::size_type sequence_size(const AlignedRead& read) noexcept; AlignedRead::NucleotideSequence::size_type sequence_size(const AlignedRead& read, const GenomicRegion& region); +bool is_forward_strand(const AlignedRead& read) noexcept; +bool is_reverse_strand(const AlignedRead& read) noexcept; + bool is_soft_clipped(const AlignedRead& read) noexcept; bool is_front_soft_clipped(const AlignedRead& read) noexcept; bool is_back_soft_clipped(const AlignedRead& read) noexcept; @@ -230,9 +237,10 @@ GenomicRegion clipped_mapped_region(const AlignedRead& read); // Returns the part of the read cigar string contained by the region CigarString copy_cigar(const AlignedRead& read, const GenomicRegion& region); - // Returns the part of the read (cigar, sequence, base_qualities) contained by the region AlignedRead copy(const AlignedRead& read, const GenomicRegion& region); +AlignedRead::NucleotideSequence copy_sequence(const AlignedRead& read, const GenomicRegion& region); +AlignedRead::BaseQualityVector copy_base_qualities(const AlignedRead& read, const GenomicRegion& region); bool operator==(const AlignedRead& lhs, const AlignedRead& rhs) noexcept; bool operator<(const AlignedRead& lhs, const AlignedRead& rhs) noexcept; diff --git a/src/basics/cigar_string.cpp b/src/basics/cigar_string.cpp index c689f42f3..b302108c8 100644 --- a/src/basics/cigar_string.cpp +++ b/src/basics/cigar_string.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "cigar_string.hpp" @@ -20,24 +20,24 @@ CigarOperation::CigarOperation(const Size size, const Flag flag) noexcept , flag_ {flag} {} -CigarOperation::Flag CigarOperation::flag() const noexcept +void CigarOperation::set_flag(Flag type) noexcept { - return flag_; + flag_ = type; } -CigarOperation::Size CigarOperation::size() const noexcept +void CigarOperation::set_size(Size size) noexcept { - return size_; + size_ = size; } -bool CigarOperation::advances_reference() const noexcept +CigarOperation::Flag CigarOperation::flag() const noexcept { - return !(flag_ == Flag::insertion || flag_ == Flag::hardClipped || flag_ == Flag::padding); + return flag_; } -bool CigarOperation::advances_sequence() const noexcept +CigarOperation::Size CigarOperation::size() const noexcept { - return !(flag_ == Flag::deletion || flag_ == Flag::hardClipped); + return size_; } // non-member methods @@ -64,10 +64,42 @@ bool is_valid(const CigarOperation& op) noexcept return is_valid(op.flag()) && op.size() > 0; } -bool is_match(const CigarOperation& op) noexcept +void increment_size(CigarOperation& op, CigarOperation::Size n) noexcept +{ + op.set_size(op.size() + n); +} + +void decrement_size(CigarOperation& op, CigarOperation::Size n) noexcept +{ + op.set_size(op.size() - n); +} + +bool advances_reference(CigarOperation::Flag flag) noexcept +{ + using Flag = CigarOperation::Flag; + return !(flag == Flag::insertion || flag == Flag::hardClipped || flag == Flag::padding); +} + +bool advances_reference(const CigarOperation& op) noexcept +{ + return advances_reference(op.flag()); +} + +bool advances_sequence(CigarOperation::Flag flag) noexcept +{ + using Flag = CigarOperation::Flag; + return !(flag == Flag::deletion || flag == Flag::hardClipped); +} + +bool advances_sequence(const CigarOperation& op) noexcept +{ + return advances_sequence(op.flag()); +} + +bool is_match(CigarOperation::Flag flag) noexcept { using Flag = CigarOperation::Flag; - switch (op.flag()) { + switch (flag) { case Flag::alignmentMatch: case Flag::sequenceMatch: case Flag::substitution: return true; @@ -75,16 +107,50 @@ bool is_match(const CigarOperation& op) noexcept } } +bool is_match(const CigarOperation& op) noexcept +{ + return is_match(op.flag()); +} + +bool is_insertion(CigarOperation::Flag flag) noexcept +{ + return flag == CigarOperation::Flag::insertion; +} + +bool is_insertion(const CigarOperation& op) noexcept +{ + return is_insertion(op.flag()); +} + +bool is_deletion(CigarOperation::Flag flag) noexcept +{ + return flag == CigarOperation::Flag::deletion; +} + +bool is_deletion(const CigarOperation& op) noexcept +{ + return is_deletion(op.flag()); +} + +bool is_indel(CigarOperation::Flag flag) noexcept +{ + return is_insertion(flag) || is_deletion(flag); +} + bool is_indel(const CigarOperation& op) noexcept +{ + return is_indel(op.flag()); +} + +bool is_clipping(CigarOperation::Flag flag) noexcept { using Flag = CigarOperation::Flag; - return op.flag() == Flag::insertion || op.flag() == Flag::deletion; + return flag == Flag::softClipped || flag == Flag::hardClipped; } bool is_clipping(const CigarOperation& op) noexcept { - using Flag = CigarOperation::Flag; - return op.flag() == Flag::softClipped || op.flag() == Flag::hardClipped; + return is_clipping(op.flag()); } // CigarString @@ -153,56 +219,108 @@ get_soft_clipped_sizes(const CigarString& cigar) noexcept // non-member functions -template -CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size, Predicate pred) +namespace { + +template +CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size, + Predicate1 offset_pred, Predicate2 size_pred) { CigarString result {}; result.reserve(cigar.size()); - auto op_it = std::cbegin(cigar); - const auto last_op = std::cend(cigar); - - while (op_it != last_op && (offset >= op_it->size() || !pred(*op_it))) { - if (pred(*op_it)) { - offset -= op_it->size(); + auto op_itr = std::cbegin(cigar); + const auto last_op_itr = std::cend(cigar); + while (op_itr != last_op_itr && offset > 0 && (offset >= op_itr->size() || !offset_pred(*op_itr))) { + if (offset_pred(*op_itr)) { + offset -= op_itr->size(); } - ++op_it; + ++op_itr; } - if (op_it != last_op) { - const auto remainder = op_it->size() - offset; + if (op_itr != last_op_itr && size_pred(*op_itr)) { + const auto remainder = op_itr->size() - offset; if (remainder >= size) { - result.emplace_back(size, op_it->flag()); + result.emplace_back(size, op_itr->flag()); return result; } - result.emplace_back(remainder, op_it->flag()); + result.emplace_back(remainder, op_itr->flag()); size -= remainder; - ++op_it; + ++op_itr; } - while (op_it != last_op && size > 0 && (size >= op_it->size() || !pred(*op_it))) { - result.emplace_back(*op_it); - if (pred(*op_it)) { - size -= op_it->size(); + while (op_itr != last_op_itr && size > 0 && (size >= op_itr->size() || !size_pred(*op_itr))) { + result.emplace_back(*op_itr); + if (size_pred(*op_itr)) { + size -= op_itr->size(); } - ++op_it; + ++op_itr; } - if (op_it != last_op && size > 0) { - result.emplace_back(size, op_it->flag()); + if (op_itr != last_op_itr && size > 0) { + result.emplace_back(size, op_itr->flag()); } return result; } -CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size) +template +CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size, + Predicate pred) +{ + return copy(cigar, offset, size, pred, pred); +} + +struct AdvancesReferencePred +{ + bool operator()(const CigarOperation& op) const noexcept + { + return advances_reference(op); + } +}; +struct AdvancesSequencePred +{ + bool operator()(const CigarOperation& op) const noexcept + { + return advances_sequence(op); + } +}; + +template +CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size, + Predicate offset_pred, const CigarStringCopyPolicy size_policy) +{ + using CopyPolicy = CigarStringCopyPolicy; + switch (size_policy) { + case CopyPolicy::reference: + return copy(cigar, offset, size, offset_pred, AdvancesReferencePred {}); + case CopyPolicy::sequence: + return copy(cigar, offset, size, offset_pred, AdvancesSequencePred {}); + case CopyPolicy::both: + default: + return copy(cigar, offset, size, offset_pred, [] (const auto& op) { return true; }); + } +} + +} // namespace + +CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size, + const CigarStringCopyPolicy offset_policy, const CigarStringCopyPolicy size_policy) { - return copy(cigar, offset, size, [](const auto& op) { return true; }); + using CopyPolicy = CigarStringCopyPolicy; + switch (offset_policy) { + case CopyPolicy::reference: + return copy(cigar, offset, size, AdvancesReferencePred {}, size_policy); + case CopyPolicy::sequence: + return copy(cigar, offset, size, AdvancesSequencePred {}, size_policy); + case CopyPolicy::both: + default: + return copy(cigar, offset, size, [] (const auto& op) { return true; }, size_policy); + } } CigarString copy_reference(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size) { - return copy(cigar, offset, size, [](const auto& op) { return op.advances_reference(); }); + return copy(cigar, offset, size, AdvancesReferencePred {}); } CigarString copy_sequence(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size) { - return copy(cigar, offset, size, [](const auto& op) { return op.advances_sequence(); }); + return copy(cigar, offset, size, AdvancesSequencePred {}); } std::vector decompose(const CigarString& cigar) @@ -220,10 +338,11 @@ CigarString collapse_matches(const CigarString& cigar) CigarString result {}; result.reserve(cigar.size()); for (auto match_end_itr = std::begin(cigar); match_end_itr != std::cend(cigar); ) { - const auto match_begin_itr = std::find_if(match_end_itr, std::end(cigar), is_match); + const auto f_is_match = [] (const CigarOperation& op) { return is_match(op); }; + const auto match_begin_itr = std::find_if(match_end_itr, std::end(cigar), f_is_match); result.insert(std::cend(result), match_end_itr, match_begin_itr); if (match_begin_itr == std::cend(cigar)) break; - match_end_itr = std::find_if_not(std::next(match_begin_itr), std::end(cigar), is_match); + match_end_itr = std::find_if_not(std::next(match_begin_itr), std::end(cigar), f_is_match); auto match_size = std::accumulate(match_begin_itr, match_end_itr, 0, [] (auto curr, const auto& op) { return curr + op.size(); }); result.emplace_back(match_size, CigarOperation::Flag::alignmentMatch); diff --git a/src/basics/cigar_string.hpp b/src/basics/cigar_string.hpp index 81714c0df..6fc877f12 100644 --- a/src/basics/cigar_string.hpp +++ b/src/basics/cigar_string.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef cigar_string_hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include @@ -47,21 +48,34 @@ class CigarOperation : public Comparable // Comparable so can co ~CigarOperation() = default; - Flag flag() const noexcept; + void set_flag(Flag type) noexcept; + void set_size(Size size) noexcept; + Flag flag() const noexcept; Size size() const noexcept; - bool advances_reference() const noexcept; - - bool advances_sequence() const noexcept; - private: Size size_; Flag flag_; }; +void increment_size(CigarOperation& op, CigarOperation::Size n = 1) noexcept; +void decrement_size(CigarOperation& op, CigarOperation::Size n = 1) noexcept; + +bool advances_reference(CigarOperation::Flag flag) noexcept; +bool advances_reference(const CigarOperation& op) noexcept; +bool advances_sequence(CigarOperation::Flag flag) noexcept; +bool advances_sequence(const CigarOperation& op) noexcept; + +bool is_match(CigarOperation::Flag flag) noexcept; bool is_match(const CigarOperation& op) noexcept; +bool is_insertion(CigarOperation::Flag flag) noexcept; +bool is_insertion(const CigarOperation& op) noexcept; +bool is_deletion(CigarOperation::Flag flag) noexcept; +bool is_deletion(const CigarOperation& op) noexcept; +bool is_indel(CigarOperation::Flag flag) noexcept; bool is_indel(const CigarOperation& op) noexcept; +bool is_clipping(CigarOperation::Flag flag) noexcept; bool is_clipping(const CigarOperation& op) noexcept; // CigarString @@ -77,9 +91,7 @@ bool is_valid(const CigarString& cigar) noexcept; bool is_minimal(const CigarString& cigar) noexcept; bool is_front_soft_clipped(const CigarString& cigar) noexcept; - bool is_back_soft_clipped(const CigarString& cigar) noexcept; - bool is_soft_clipped(const CigarString& cigar) noexcept; std::pair get_soft_clipped_sizes(const CigarString& cigar) noexcept; @@ -107,7 +119,7 @@ S reference_size(const CigarString& cigar) noexcept { return std::accumulate(std::cbegin(cigar), std::cend(cigar), S {0}, [] (const S curr, const CigarOperation& op) { - return curr + ((op.advances_reference()) ? op.size() : 0); + return curr + ((advances_reference(op)) ? op.size() : 0); }); } @@ -116,7 +128,7 @@ S sequence_size(const CigarString& cigar) noexcept { return std::accumulate(std::cbegin(cigar), std::cend(cigar), S {0}, [] (const S curr, const CigarOperation& op) { - return curr + ((op.advances_sequence()) ? op.size() : 0); + return curr + ((advances_sequence(op)) ? op.size() : 0); }); } @@ -131,10 +143,16 @@ CigarOperation get_operation_at_sequence_position(const CigarString& cigar, S po return *first; } -// Relative to both reference and sequence -CigarString copy(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size); -CigarString copy_reference(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size); -CigarString copy_sequence(const CigarString& cigar, CigarOperation::Size offset, CigarOperation::Size size); +enum class CigarStringCopyPolicy { reference, sequence, both }; + +CigarString copy(const CigarString& cigar, CigarOperation::Size offset, + CigarOperation::Size size = std::numeric_limits::max(), + CigarStringCopyPolicy offset_policy = CigarStringCopyPolicy::both, + CigarStringCopyPolicy size_policy = CigarStringCopyPolicy::both); +CigarString copy_reference(const CigarString& cigar, CigarOperation::Size offset, + CigarOperation::Size size = std::numeric_limits::max()); +CigarString copy_sequence(const CigarString& cigar, CigarOperation::Size offset, + CigarOperation::Size size = std::numeric_limits::max()); std::vector decompose(const CigarString& cigar); CigarString collapse_matches(const CigarString& cigar); @@ -155,41 +173,45 @@ struct CigarHash } // namespace octopus namespace std { - template <> struct hash + +template <> struct hash +{ + size_t operator()(const octopus::CigarOperation& op) const noexcept { - size_t operator()(const octopus::CigarOperation& op) const noexcept - { - return octopus::CigarHash()(op); - } - }; - - template <> struct hash + return octopus::CigarHash()(op); + } +}; + +template <> struct hash +{ + size_t operator()(const octopus::CigarString& cigar) const noexcept { - size_t operator()(const octopus::CigarString& cigar) const noexcept - { - return octopus::CigarHash()(cigar); - } - }; + return octopus::CigarHash()(cigar); + } +}; + } // namespace std namespace boost { - template <> - struct hash : std::unary_function + +template <> +struct hash : std::unary_function +{ + std::size_t operator()(const octopus::CigarOperation& op) const noexcept { - std::size_t operator()(const octopus::CigarOperation& op) const noexcept - { - return std::hash()(op); - } - }; - - template <> - struct hash : std::unary_function + return std::hash()(op); + } +}; + +template <> +struct hash : std::unary_function +{ + std::size_t operator()(const octopus::CigarString& cigar) const noexcept { - std::size_t operator()(const octopus::CigarString& cigar) const noexcept - { - return std::hash()(cigar); - } - }; + return std::hash()(cigar); + } +}; + } // namespace boost #endif diff --git a/src/basics/contig_region.hpp b/src/basics/contig_region.hpp index 9582b648b..32e662393 100644 --- a/src/basics/contig_region.hpp +++ b/src/basics/contig_region.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef contig_region_hpp diff --git a/src/basics/genomic_region.hpp b/src/basics/genomic_region.hpp index 37a38f71e..3ea10639e 100644 --- a/src/basics/genomic_region.hpp +++ b/src/basics/genomic_region.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef genomic_region_hpp diff --git a/src/basics/mappable_reference_wrapper.hpp b/src/basics/mappable_reference_wrapper.hpp index 8f67f2f73..0771c9966 100644 --- a/src/basics/mappable_reference_wrapper.hpp +++ b/src/basics/mappable_reference_wrapper.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_reference_wrapper_hpp diff --git a/src/basics/pedigree.cpp b/src/basics/pedigree.cpp index 5716e0ad3..736cd981a 100644 --- a/src/basics/pedigree.cpp +++ b/src/basics/pedigree.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "pedigree.hpp" diff --git a/src/basics/pedigree.hpp b/src/basics/pedigree.hpp index c3ce50f95..9a433060c 100644 --- a/src/basics/pedigree.hpp +++ b/src/basics/pedigree.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef pedigree_hpp @@ -21,7 +21,7 @@ class Pedigree public: struct Member { - enum class Sex { male, female }; + enum class Sex { male, female, hermaphroditic }; SampleName name; Sex sex; }; diff --git a/src/basics/phred.hpp b/src/basics/phred.hpp index 5789c106b..8d84b12b9 100644 --- a/src/basics/phred.hpp +++ b/src/basics/phred.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef phred_hpp diff --git a/src/basics/ploidy_map.cpp b/src/basics/ploidy_map.cpp index 72938fd6b..dfcba438f 100644 --- a/src/basics/ploidy_map.cpp +++ b/src/basics/ploidy_map.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "ploidy_map.hpp" @@ -57,4 +57,26 @@ std::vector get_ploidies(const std::vector& samples, const return result; } +unsigned get_min_ploidy(const std::vector& samples, const std::vector& contigs, const PloidyMap& ploidies) +{ + unsigned result = -1; + for (const auto& sample : samples) { + for (const auto& contig : contigs) { + result = std::min(result, ploidies.of(sample, contig)); + } + } + return result; +} + +unsigned get_max_ploidy(const std::vector& samples, const std::vector& contigs, const PloidyMap& ploidies) +{ + unsigned result {0}; + for (const auto& sample : samples) { + for (const auto& contig : contigs) { + result = std::max(result, ploidies.of(sample, contig)); + } + } + return result; +} + } // namespace octopus diff --git a/src/basics/ploidy_map.hpp b/src/basics/ploidy_map.hpp index 0d9b7829b..5108241b0 100644 --- a/src/basics/ploidy_map.hpp +++ b/src/basics/ploidy_map.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef ploidy_map_hpp @@ -40,6 +40,9 @@ class PloidyMap std::vector get_ploidies(const std::vector& samples, const ContigName& contig, const PloidyMap& ploidies); +unsigned get_min_ploidy(const std::vector& samples, const std::vector& contigs, const PloidyMap& ploidies); +unsigned get_max_ploidy(const std::vector& samples, const std::vector& contigs, const PloidyMap& ploidies); + } // namespace octopus #endif diff --git a/src/basics/read_pileup.cpp b/src/basics/read_pileup.cpp new file mode 100644 index 000000000..131300888 --- /dev/null +++ b/src/basics/read_pileup.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2016 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "read_pileup.hpp" + +#include +#include + +namespace octopus { + +ReadPileup::ReadPileup(ContigRegion::Position position) +: summaries_ {} +, region_ {position, position + 1} +{ + summaries_.emplace("$", ReadSummaries {}); +} + +const ContigRegion& ReadPileup::mapped_region() const noexcept +{ + return region_; +} + +unsigned ReadPileup::depth() const noexcept +{ + return std::accumulate(std::cbegin(summaries_), std::cend(summaries_), 0u, + [] (auto curr, const auto& p) { return curr + p.second.size(); }); +} + +unsigned ReadPileup::depth(const NucleotideSequence& sequence) const noexcept +{ + return this->summaries(sequence).size(); +} + +void ReadPileup::add(const AlignedRead& read) +{ + const GenomicRegion region {contig_name(read), region_}; + summaries_[copy_sequence(read, region)].push_back({copy_base_qualities(read, region), read.mapping_quality()}); +} + +const ReadPileup::ReadSummaries& ReadPileup::summaries(const NucleotideSequence& sequence) const +{ + const auto itr = summaries_.find(sequence); + if (itr != std::cend(summaries_)) { + return itr->second; + } else { + return summaries_.at("$"); + } +} + +unsigned ReadPileup::sum_base_qualities(const NucleotideSequence& sequence) const +{ + const auto& sequence_summaries = this->summaries(sequence); + return std::accumulate(std::cbegin(sequence_summaries), std::cend(sequence_summaries), 0u, + [] (auto curr, const ReadSummary& summary) { + return curr + std::accumulate(std::cbegin(summary.base_qualities), std::cend(summary.base_qualities), 0u); + }); +} + +namespace { + +auto overlap_range(std::vector& pileups, const AlignedRead& read) +{ + return overlap_range(std::begin(pileups), std::end(pileups), contig_region(read), BidirectionallySortedTag {}); +} + +} // namespace + +ReadPileups make_pileups(const ReadContainer& reads, const GenomicRegion& region) +{ + ReadPileups result {}; + result.reserve(size(region)); + for (auto position = region.begin(); position < region.end(); ++position) { + result.emplace_back(position); + } + for (const AlignedRead& read : overlap_range(reads, region)) { + for (ReadPileup& pileup : overlap_range(result, read)) { + pileup.add(read); + } + } + return result; +} + +} // namespace octopus diff --git a/src/basics/read_pileup.hpp b/src/basics/read_pileup.hpp new file mode 100644 index 000000000..723a44cf1 --- /dev/null +++ b/src/basics/read_pileup.hpp @@ -0,0 +1,65 @@ +// Copyright (c) 2016 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef read_pileup_hpp +#define read_pileup_hpp + +#include +#include +#include + +#include "config/common.hpp" +#include "concepts/mappable.hpp" +#include "contig_region.hpp" +#include "aligned_read.hpp" + +namespace octopus { + +class ReadPileup : public Mappable +{ +public: + using NucleotideSequence = AlignedRead::NucleotideSequence; + using BaseQuality = AlignedRead::BaseQuality; + using MappingQuality = AlignedRead::MappingQuality; + + struct ReadSummary + { + std::vector base_qualities; + MappingQuality mapping_quality; + }; + using ReadSummaries = std::vector; + + ReadPileup() = delete; + + ReadPileup(ContigRegion::Position position); + + ReadPileup(const ReadPileup&) = default; + ReadPileup& operator=(const ReadPileup&) = default; + ReadPileup(ReadPileup&&) = default; + ReadPileup& operator=(ReadPileup&&) = default; + + ~ReadPileup() = default; + + const ContigRegion& mapped_region() const noexcept; + + unsigned depth() const noexcept; + unsigned depth(const NucleotideSequence& sequence) const noexcept; + + void add(const AlignedRead& read); + + const ReadSummaries& summaries(const NucleotideSequence& sequence) const; + + unsigned sum_base_qualities(const NucleotideSequence& sequence) const; + +private: + std::map summaries_; + ContigRegion region_; +}; + +using ReadPileups = std::vector; + +ReadPileups make_pileups(const ReadContainer& reads, const GenomicRegion& region); + +} // namespace octopus + +#endif diff --git a/src/basics/trio.cpp b/src/basics/trio.cpp index 98b8582ef..e4dcf253e 100644 --- a/src/basics/trio.cpp +++ b/src/basics/trio.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "trio.hpp" diff --git a/src/basics/trio.hpp b/src/basics/trio.hpp index 57fc4162a..341543849 100644 --- a/src/basics/trio.hpp +++ b/src/basics/trio.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef trio_hpp diff --git a/src/concepts/comparable.hpp b/src/concepts/comparable.hpp index 8b9171d68..eb0956ecd 100644 --- a/src/concepts/comparable.hpp +++ b/src/concepts/comparable.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef comparable_hpp diff --git a/src/concepts/equitable.hpp b/src/concepts/equitable.hpp index 27a273271..7b1fdabbd 100644 --- a/src/concepts/equitable.hpp +++ b/src/concepts/equitable.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef equitable_hpp diff --git a/src/concepts/mappable.hpp b/src/concepts/mappable.hpp index 552ce47ac..17c10dab8 100644 --- a/src/concepts/mappable.hpp +++ b/src/concepts/mappable.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_hpp diff --git a/src/concepts/mappable_range.hpp b/src/concepts/mappable_range.hpp index 5bc36c7e1..27d8e5a77 100644 --- a/src/concepts/mappable_range.hpp +++ b/src/concepts/mappable_range.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_ranges_hpp diff --git a/src/config/cmake_config.h.in b/src/config/cmake_config.h.in index 4b4be10f1..491caf91a 100644 --- a/src/config/cmake_config.h.in +++ b/src/config/cmake_config.h.in @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef cmake_config_h diff --git a/src/config/common.cpp b/src/config/common.cpp index 79441f32c..98cc08ddf 100644 --- a/src/config/common.cpp +++ b/src/config/common.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "common.hpp" diff --git a/src/config/common.hpp b/src/config/common.hpp index 45a7ffe69..a4798c534 100644 --- a/src/config/common.hpp +++ b/src/config/common.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef common_hpp diff --git a/src/config/config.cpp b/src/config/config.cpp index 2a2cd4fdb..6c309a791 100644 --- a/src/config/config.cpp +++ b/src/config/config.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "config.hpp" @@ -7,7 +7,7 @@ namespace octopus { namespace config { -const VersionNumber Version {0, 3, 3, boost::optional {"alpha"}}; +const VersionNumber Version {0, 4, 0, boost::optional {"alpha"}}; std::ostream& operator<<(std::ostream& os, const VersionNumber& version) { @@ -23,7 +23,7 @@ const std::string BugReport {"https://github.com/luntergroup/octopus/issues"}; const std::vector Authors {"Daniel Cooke"}; -const std::string CopyrightNotice {"Copyright (c) 2017 University of Oxford"}; +const std::string CopyrightNotice {"Copyright (c) 2015-2018 University of Oxford"}; const unsigned CommandLineWidth {72}; diff --git a/src/config/config.hpp b/src/config/config.hpp index 4bde0c895..d594fab7a 100644 --- a/src/config/config.hpp +++ b/src/config/config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef config_hpp diff --git a/src/config/octopus_vcf.cpp b/src/config/octopus_vcf.cpp index dd49b3006..7cf6e5a85 100644 --- a/src/config/octopus_vcf.cpp +++ b/src/config/octopus_vcf.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "octopus_vcf.hpp" @@ -14,30 +14,20 @@ VcfHeader::Builder make_header_template() result.set_file_format(vcfspec::version); - result.add_info("AA", "1", "String", "Ancestral allele"); result.add_info("AC", "A", "Integer", "Allele count in genotypes, for each ALT allele, in the same order as listed"); - //result.add_info("AF", "A", "Float", "Allele Frequency, for each ALT allele, in the same order as listed"); result.add_info("AN", "1", "Integer", "Total number of alleles in called genotypes"); - result.add_info("BQ", "1", "Integer", "RMS base quality at this position"); result.add_info("DP", "1", "Integer", "Combined depth across samples"); - //result.add_info("END", "1", "Integer", "End position of the variant described in this record"); result.add_info("MQ", "1", "Integer", "RMS mapping quality"); result.add_info("MQ0", "1", "Integer", "Number of MAPQ == 0 reads covering this record"); result.add_info("NS", "1", "Integer", "Number of samples with data"); - result.add_info("SB", "1", "Float", "Strand bias at this position"); result.add_format("GT", "1", "String", "Genotype"); result.add_format("DP", "1", "Integer", "Read depth at this position for this sample"); - //result.add_format("FT", "1", "String", "Sample genotype filter indicating if this genotype was “called”"); - //result.add_format("GL", "G", "Float", "log10-scaled genotype likelihoods"); - //result.add_format("GLE", "1", "Integer", "Genotype likelihoods of heterogeneous ploidy"); - //result.add_format("PL", "G", "Integer", "Phred-scaled genotype likelihoods"); - //result.add_format("GP", "G", "Float", "Phred-scaled genotype posterior probabilities"); + result.add_format("FT", "1", "String", "Sample genotype filter indicating if this genotype was “called”"); result.add_format("GQ", "1", "Integer", "Conditional genotype quality (phred-scaled)"); result.add_format("PS", "1", "String", "Phase set"); result.add_format("PQ", "1", "Integer", "Phasing quality"); result.add_format("MQ", "1", "Integer", "RMS mapping quality"); - result.add_format("BQ", "1", "Integer", "RMS base quality at this position"); result.add_filter("PASS", "All filters passed"); @@ -46,6 +36,8 @@ VcfHeader::Builder make_header_template() static const std::unordered_map filter_descriptions { +{spec::filter::q3, "Variant quality is below 3"}, +{spec::filter::q5, "Variant quality is below 5"}, {spec::filter::q10, "Variant quality is below 10"}, {spec::filter::q20, "Variant quality is below 20"}, {spec::filter::lowQuality, "Variant quality is low"}, @@ -61,6 +53,13 @@ static const std::unordered_map filter_descriptions {spec::filter::highGCRegion, "The GC content of the region is too high"}, {spec::filter::lowGQ, "Sample genotype quality low"}, {spec::filter::highClippedReadFraction, "High fraction of clipped reads covering position"}, +{spec::filter::bq10, "Median base quality supporting variant is less than 10"}, +{spec::filter::lowBaseQuality, "Median base quality supporting variant is low"}, +{spec::filter::highMismatchFraction, "Count of reads containing mismatch to called allele is high"}, +{spec::filter::highMismatchFraction, "Fraction of reads containing mismatch to called allele is high"}, +{spec::filter::somaticContamination, "Somatic contamination detected in a called normal sample"}, +{spec::filter::deNovoContamination, "De novo allele detected in the offsprings parents"}, +{spec::filter::readPositionBias, "Position of variant in supporting reads is biased"} }; VcfHeader::Builder& add_filter(VcfHeader::Builder& builder, const std::string& key) diff --git a/src/config/octopus_vcf.hpp b/src/config/octopus_vcf.hpp index 31747bb2c..82b73150e 100644 --- a/src/config/octopus_vcf.hpp +++ b/src/config/octopus_vcf.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef octopus_vcf_hpp @@ -13,6 +13,12 @@ namespace octopus { namespace vcf { namespace spec { +namespace allele { + +VCF_SPEC_CONSTANT nonref {""}; + +} // namespace info + namespace info { VCF_SPEC_CONSTANT modelPosterior {"MP"}; @@ -23,6 +29,8 @@ VCF_SPEC_CONSTANT reversion {"REVERSION"}; namespace filter { +VCF_SPEC_CONSTANT q3 {"q3"}; +VCF_SPEC_CONSTANT q5 {"q5"}; VCF_SPEC_CONSTANT q10 {"q10"}; VCF_SPEC_CONSTANT q20 {"q20"}; VCF_SPEC_CONSTANT lowQuality {"LQ"}; @@ -38,6 +46,13 @@ VCF_SPEC_CONSTANT filteredReadFraction {"FRF"}; VCF_SPEC_CONSTANT highGCRegion {"GC"}; VCF_SPEC_CONSTANT lowGQ {"GQ"}; VCF_SPEC_CONSTANT highClippedReadFraction {"CRF"}; +VCF_SPEC_CONSTANT bq10 {"bq10"}; +VCF_SPEC_CONSTANT lowBaseQuality {"LBQ"}; +VCF_SPEC_CONSTANT highMismatchCount {"MC"}; +VCF_SPEC_CONSTANT highMismatchFraction {"MF"}; +VCF_SPEC_CONSTANT somaticContamination {"SC"}; +VCF_SPEC_CONSTANT deNovoContamination {"DC"}; +VCF_SPEC_CONSTANT readPositionBias {"RPB"}; } // namespace filter diff --git a/src/config/option_collation.cpp b/src/config/option_collation.cpp index 56aab3b26..42e3ff2e0 100644 --- a/src/config/option_collation.cpp +++ b/src/config/option_collation.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "option_collation.hpp" @@ -30,7 +30,6 @@ #include "basics/phred.hpp" #include "basics/genomic_region.hpp" #include "basics/aligned_read.hpp" -#include "basics/ploidy_map.hpp" #include "basics/trio.hpp" #include "basics/pedigree.hpp" #include "readpipe/read_pipe_fwd.hpp" @@ -49,6 +48,7 @@ #include "exceptions/missing_file_error.hpp" #include "core/csr/filters/threshold_filter_factory.hpp" #include "core/csr/filters/training_filter_factory.hpp" +#include "core/csr/filters/random_forest_filter_factory.hpp" namespace octopus { namespace options { @@ -288,11 +288,30 @@ ReferenceGenome make_reference(const OptionMap& options) { const fs::path input_path {options.at("reference").as()}; auto resolved_path = resolve_path(input_path, options); - const auto ref_cache_size = options.at("max-reference-cache-footprint").as().num_bytes(); + auto ref_cache_size = options.at("max-reference-cache-footprint").as(); + static constexpr MemoryFootprint min_non_zero_reference_cache_size {1'000}; // 1Kb + if (ref_cache_size.num_bytes() > 0 && ref_cache_size < min_non_zero_reference_cache_size) { + static bool warned {false}; + if (!warned) { + logging::WarningLogger warn_log {}; + stream(warn_log) << "Ignoring given reference cache size of " << ref_cache_size + << " as this size is too small. The maximum cache size will be set to zero"; + warned = true; + } + ref_cache_size = 0; + } + static constexpr MemoryFootprint min_warn_non_zero_reference_cache_size {1'000'000}; // 1Mb + if (ref_cache_size.num_bytes() > 0 && ref_cache_size < min_warn_non_zero_reference_cache_size) { + static bool warned {false}; + if (!warned) { + logging::WarningLogger warn_log {}; + stream(warn_log) << "The given reference cache size " << ref_cache_size + << " is very small and may not result in good performance."; + warned = true; + } + } try { - return octopus::make_reference(std::move(resolved_path), - ref_cache_size, - is_threading_allowed(options)); + return octopus::make_reference(std::move(resolved_path), ref_cache_size, is_threading_allowed(options)); } catch (MissingFileError& e) { e.set_location_specified("the command line option --reference"); throw; @@ -524,7 +543,7 @@ class MissingReadPathFile : public MissingFileError MissingReadPathFile(fs::path p) : MissingFileError {std::move(p), "read path"} {}; }; -void log_and_remove_duplicates(std::vector& paths, const std::string& type) +void remove_duplicates(std::vector& paths, const std::string& type, const bool log = true) { std::sort(std::begin(paths), std::end(paths)); const auto first_duplicate = std::adjacent_find(std::begin(paths), std::end(paths)); @@ -540,24 +559,26 @@ void log_and_remove_duplicates(std::vector& paths, const std::string& paths.erase(std::unique(first_duplicate, std::end(paths)), std::end(paths)); const auto num_unique_paths = paths.size(); const auto num_duplicate_paths = num_paths - num_unique_paths; - logging::WarningLogger warn_log {}; - auto warn_log_stream = stream(warn_log); - warn_log_stream << "Ignoring " << num_duplicate_paths << " duplicate " << type << " path"; - if (num_duplicate_paths > 1) { - warn_log_stream << 's'; - } - warn_log_stream << ": "; - std::for_each(std::cbegin(duplicates), std::prev(std::cend(duplicates)), [&] (const auto& path) { - warn_log_stream << path << ", "; - }); - warn_log_stream << duplicates.back(); - if (num_duplicate_paths > duplicates.size()) { - warn_log_stream << " (showing unique duplicates)"; + if (log) { + logging::WarningLogger warn_log {}; + auto warn_log_stream = stream(warn_log); + warn_log_stream << "Ignoring " << num_duplicate_paths << " duplicate " << type << " path"; + if (num_duplicate_paths > 1) { + warn_log_stream << 's'; + } + warn_log_stream << ": "; + std::for_each(std::cbegin(duplicates), std::prev(std::cend(duplicates)), [&] (const auto& path) { + warn_log_stream << path << ", "; + }); + warn_log_stream << duplicates.back(); + if (num_duplicate_paths > duplicates.size()) { + warn_log_stream << " (showing unique duplicates)"; + } } } } -std::vector get_read_paths(const OptionMap& options) +std::vector get_read_paths(const OptionMap& options, const bool log = true) { using namespace utils; std::vector result {}; @@ -575,7 +596,7 @@ std::vector get_read_paths(const OptionMap& options) throw e; } auto paths = get_resolved_paths_from_file(path_to_read_paths, options); - if (paths.empty()) { + if (log && paths.empty()) { logging::WarningLogger log {}; stream(log) << "The read path file you specified " << path_to_read_paths << " in the command line option '--reads-file' is empty"; @@ -583,10 +604,15 @@ std::vector get_read_paths(const OptionMap& options) append(std::move(paths), result); } } - log_and_remove_duplicates(result, "read"); + remove_duplicates(result, "read", log); return result; } +unsigned count_read_paths(const OptionMap& options) +{ + return get_read_paths(options, false).size(); +} + ReadManager make_read_manager(const OptionMap& options) { auto read_paths = get_read_paths(options); @@ -708,11 +734,11 @@ auto make_read_filterer(const OptionMap& options) if (!options.at("allow-supplementary-alignments").as()) { result.add(make_unique()); } - if (!options.at("consider-reads-with-unmapped-segments").as()) { + if (options.at("no-reads-with-unmapped-segments").as()) { result.add(make_unique()); result.add(make_unique()); } - if (!options.at("consider-reads-with-distant-segments").as()) { + if (options.at("no-reads-with-distant-segments").as()) { result.add(make_unique()); } if (options.at("no-adapter-contaminated-reads").as()) { @@ -750,21 +776,37 @@ ReadPipe make_read_pipe(ReadManager& read_manager, std::vector sampl } } -auto get_default_inclusion_predicate() +auto get_default_germline_inclusion_predicate() { return coretools::DefaultInclusionPredicate {}; } +bool is_cancer_calling(const OptionMap& options) +{ + return options.at("caller").as() == "cancer" || options.count("normal-sample") == 1; +} + +auto get_default_somatic_inclusion_predicate(boost::optional normal) +{ + if (normal) { + return coretools::DefaultSomaticInclusionPredicate {*normal}; + } else { + return coretools::DefaultSomaticInclusionPredicate {}; + } +} + auto get_default_inclusion_predicate(const OptionMap& options) noexcept { using namespace coretools; using InclusionPredicate = CigarScanner::Options::InclusionPredicate; - const auto caller = options.at("caller").as(); - if (caller == "cancer") { - // TODO: specialise for this case; we need to be careful about low frequency somatics. - return InclusionPredicate {get_default_inclusion_predicate()}; + if (is_cancer_calling(options)) { + boost::optional normal {}; + if (is_set("normal-sample", options)) { + normal = options.at("normal-sample").as(); + } + return InclusionPredicate {get_default_somatic_inclusion_predicate(normal)}; } else { - return InclusionPredicate {get_default_inclusion_predicate()}; + return InclusionPredicate {get_default_germline_inclusion_predicate()}; } } @@ -821,14 +863,6 @@ class ConflictingSourceVariantFile : public UserError {} }; -struct DefaultRepeatGenerator -{ - std::vector operator()(const ReferenceGenome& reference, GenomicRegion region) const - { - return find_repeat_regions(reference, region); - } -}; - auto get_max_expected_heterozygosity(const OptionMap& options) { const auto snp_heterozygosity = options.at("snp-heterozygosity").as(); @@ -862,7 +896,6 @@ auto make_variant_generator_builder(const OptionMap& options) } scanner_options.match = get_default_match_predicate(); scanner_options.use_clipped_coverage_tracking = true; - scanner_options.repeat_region_generator = DefaultRepeatGenerator {}; CigarScanner::Options::MisalignmentParameters misalign_params {}; misalign_params.max_expected_mutation_rate = get_max_expected_heterozygosity(options); misalign_params.snv_threshold = as_unsigned("min-base-quality", options); @@ -914,7 +947,7 @@ auto make_variant_generator_builder(const OptionMap& options) utils::append(std::move(file_sources_paths), source_paths); } } - log_and_remove_duplicates(source_paths, "source variant"); + remove_duplicates(source_paths, "source variant"); for (const auto& source_path : source_paths) { if (!fs::exists(source_path)) { throw MissingSourceVariantFile {source_path}; @@ -927,6 +960,7 @@ auto make_variant_generator_builder(const OptionMap& options) if (is_set("min-source-quality", options)) { vcf_options.min_quality = options.at("min-source-quality").as>().score(); } + vcf_options.extract_filtered = options.at("extract-filtered-source-candidates").as(); result.add_vcf_extractor(std::move(source_path), vcf_options); } } @@ -949,7 +983,11 @@ auto make_variant_generator_builder(const OptionMap& options) } result.add_vcf_extractor(std::move(resolved_regenotype_path)); } - + ActiveRegionGenerator::Options active_region_options {}; + if (is_set("assemble-all", options) && options.at("assemble-all").as()) { + active_region_options.assemble_all = true; + } + result.set_active_region_generator(std::move(active_region_options)); return result; } @@ -1053,6 +1091,9 @@ class MissingPloidyFile : public MissingFileError PloidyMap get_ploidy_map(const OptionMap& options) { + if (options.at("caller").as() == "polyclone") { + return PloidyMap {1}; + } std::vector flat_plodies {}; if (is_set("contig-ploidies-file", options)) { const fs::path input_path {options.at("contig-ploidies-file").as()}; @@ -1126,14 +1167,13 @@ auto get_max_haplotypes(const OptionMap& options) } } -auto get_max_expected_log_allele_count_per_base(const OptionMap& options) +auto get_dense_variation_detector(const OptionMap& options, const boost::optional& input_reads_profile) { const auto snp_heterozygosity = options.at("snp-heterozygosity").as(); const auto indel_heterozygosity = options.at("indel-heterozygosity").as(); const auto heterozygosity = snp_heterozygosity + indel_heterozygosity; - const auto snp_heterozygosity_stdev = options.at("snp-heterozygosity-stdev").as(); - const auto max_log_allele_count_per_base = heterozygosity + 8 * snp_heterozygosity_stdev; - return max_log_allele_count_per_base; + const auto heterozygosity_stdev = options.at("snp-heterozygosity-stdev").as(); + return coretools::DenseVariationDetector {heterozygosity, heterozygosity_stdev, input_reads_profile}; } auto get_max_indicator_join_distance() noexcept @@ -1146,7 +1186,7 @@ auto get_min_flank_pad() noexcept return 2 * (2 * HaplotypeLikelihoodModel{}.pad_requirement() - 1); } -auto make_haplotype_generator_builder(const OptionMap& options) +auto make_haplotype_generator_builder(const OptionMap& options, const boost::optional& input_reads_profile) { const auto lagging_policy = get_lagging_policy(options); const auto max_haplotypes = get_max_haplotypes(options); @@ -1157,18 +1197,18 @@ auto make_haplotype_generator_builder(const OptionMap& options) .set_target_limit(max_haplotypes).set_holdout_limit(holdout_limit).set_overflow_limit(overflow_limit) .set_lagging_policy(lagging_policy).set_max_holdout_depth(max_holdout_depth) .set_max_indicator_join_distance(get_max_indicator_join_distance()) - .set_max_expected_log_allele_count_per_base(get_max_expected_log_allele_count_per_base(options)) + .set_dense_variation_detector(get_dense_variation_detector(options, input_reads_profile)) .set_min_flank_pad(get_min_flank_pad()); } -boost::optional get_pedigree(const OptionMap& options) +boost::optional read_ped_file(const OptionMap& options) { - if (is_set("pedigree", options)) { - const auto ped_file = resolve_path(options.at("pedigree").as(), options); - return io::read_pedigree(ped_file); - } else { - return boost::none; - } + if (is_set("pedigree", options)) { + const auto ped_file = resolve_path(options.at("pedigree").as(), options); + return io::read_pedigree(ped_file); + } else { + return boost::none; + } } class BadTrioSampleSet : public UserError @@ -1265,9 +1305,6 @@ class BadTrioSamples : public UserError auto get_caller_type(const OptionMap& options, const std::vector& samples, const boost::optional& pedigree) { - // TODO: could think about getting rid of the 'caller' option and just - // deduce the caller type directly from the options. - // Will need to report an error if conflicting caller options are given anyway. auto result = options.at("caller").as(); if (result == "population" && samples.size() == 1) { result = "individual"; @@ -1282,6 +1319,33 @@ auto get_caller_type(const OptionMap& options, const std::vector& sa return result; } +class BadSampleCount : public UserError +{ + std::string do_where() const override + { + return "check_caller"; + } + + std::string do_why() const override + { + return "The number of samples is not accepted by the chosen caller"; + } + + std::string do_help() const override + { + return "Check the caller documentation for the required number of samples"; + } +}; + +void check_caller(const std::string& caller, const std::vector& samples, const OptionMap& options) +{ + if (caller == "polyclone") { + if (samples.size() != 1) { + throw BadSampleCount {}; + } + } +} + auto get_child_from_trio(std::vector trio, const Pedigree& pedigree) { if (is_parent_of(trio[0], trio[1], pedigree)) return trio[1]; @@ -1332,6 +1396,22 @@ Trio make_trio(std::vector samples, const OptionMap& options, }; } +boost::optional get_pedigree(const OptionMap& options, const std::vector& samples) +{ + auto result = read_ped_file(options); + if (!result) { + if (samples.size() == 3 && is_set("maternal-sample", options) && is_set("paternal-sample", options)) { + const auto trio = make_trio(samples, options, boost::none); + result = Pedigree {}; + using Sex = Pedigree::Member::Sex; + result->add_founder(Pedigree::Member {trio.mother(), Sex::female}); + result->add_founder(Pedigree::Member {trio.father(), Sex::male}); + result->add_descendant(Pedigree::Member {trio.child(), Sex::hermaphroditic}, trio.mother(), trio.father()); + } + } + return result; +} + class UnimplementedCaller : public ProgramError { std::string do_where() const override @@ -1405,18 +1485,20 @@ auto get_normal_contamination_risk(const OptionMap& options) } CallerFactory make_caller_factory(const ReferenceGenome& reference, ReadPipe& read_pipe, - const InputRegionMap& regions, const OptionMap& options) + const InputRegionMap& regions, const OptionMap& options, + const boost::optional input_reads_profile) { CallerBuilder vc_builder {reference, read_pipe, make_variant_generator_builder(options), - make_haplotype_generator_builder(options)}; - const auto pedigree = get_pedigree(options); + make_haplotype_generator_builder(options, input_reads_profile)}; + const auto pedigree = read_ped_file(options); const auto caller = get_caller_type(options, read_pipe.samples(), pedigree); + check_caller(caller, read_pipe.samples(), options); vc_builder.set_caller(caller); - if (caller == "population") { + if (caller == "population" || caller == "polyclone") { logging::WarningLogger log {}; - log << "The population calling model is currently under development and may not function as expected"; + stream(log) << "The " << caller << " calling model is an experimental feature and may not function as expected"; } if (is_set("refcall", options)) { @@ -1452,6 +1534,8 @@ CallerFactory make_caller_factory(const ReferenceGenome& reference, ReadPipe& re vc_builder.set_snp_heterozygosity(options.at("snp-heterozygosity").as()); vc_builder.set_indel_heterozygosity(options.at("indel-heterozygosity").as()); } + vc_builder.set_model_based_haplotype_dedup(options.at("dedup-haplotypes-with-prior-model").as()); + vc_builder.set_independent_genotype_prior_flag(options.at("use-independent-genotype-priors").as()); if (caller == "cancer") { if (is_set("normal-sample", options)) { vc_builder.set_normal_sample(options.at("normal-sample").as()); @@ -1460,6 +1544,7 @@ CallerFactory make_caller_factory(const ReferenceGenome& reference, ReadPipe& re log << "Tumour only calling requested. " "Please note this feature is still under development and results and runtimes may be poor"; } + vc_builder.set_max_somatic_haplotypes(as_unsigned("max-somatic-haplotypes", options)); vc_builder.set_somatic_snv_mutation_rate(options.at("somatic-snv-mutation-rate").as()); vc_builder.set_somatic_indel_mutation_rate(options.at("somatic-indel-mutation-rate").as()); vc_builder.set_min_expected_somatic_frequency(options.at("min-expected-somatic-frequency").as()); @@ -1473,10 +1558,13 @@ CallerFactory make_caller_factory(const ReferenceGenome& reference, ReadPipe& re vc_builder.set_snv_denovo_mutation_rate(options.at("snv-denovo-mutation-rate").as()); vc_builder.set_indel_denovo_mutation_rate(options.at("indel-denovo-mutation-rate").as()); vc_builder.set_min_denovo_posterior(options.at("min-denovo-posterior").as>()); + } else if (caller == "polyclone") { + vc_builder.set_max_clones(as_unsigned("max-clones", options)); } vc_builder.set_model_filtering(allow_model_filtering(options)); - if (caller == "cancer") { - vc_builder.set_max_joint_genotypes(as_unsigned("max-cancer-genotypes", options)); + vc_builder.set_max_genotypes(as_unsigned("max-genotypes", options)); + if (is_fast_mode(options)) { + vc_builder.set_max_joint_genotypes(10'000); } else { vc_builder.set_max_joint_genotypes(as_unsigned("max-joint-genotypes", options)); } @@ -1492,11 +1580,26 @@ bool is_call_filtering_requested(const OptionMap& options) noexcept return options.at("call-filtering").as(); } -std::string get_filter_expression(const OptionMap& options) +std::string get_germline_filter_expression(const OptionMap& options) { return options.at("filter-expression").as(); } +std::string get_somatic_filter_expression(const OptionMap& options) +{ + return options.at("somatic-filter-expression").as(); +} + +std::string get_denovo_filter_expression(const OptionMap& options) +{ + return options.at("denovo-filter-expression").as(); +} + +std::string get_refcall_filter_expression(const OptionMap& options) +{ + return options.at("refcall-filter-expression").as(); +} + bool is_csr_training(const OptionMap& options) { return options.count("csr-training") > 0; @@ -1513,19 +1616,104 @@ std::set get_training_measures(const OptionMap& options) return result; } -std::unique_ptr make_call_filter_factory(const ReferenceGenome& reference, - ReadPipe& read_pipe, - const OptionMap& options) +class MissingForestFile : public MissingFileError +{ + std::string do_where() const override + { + return "make_call_filter_factory"; + } +public: + MissingForestFile(fs::path p, std::string type) : MissingFileError {std::move(p), std::move(type)} {}; +}; + +auto get_caller_type(const OptionMap& options, const std::vector& samples) +{ + return get_caller_type(options, samples, get_pedigree(options, samples)); +} + +std::unique_ptr +make_call_filter_factory(const ReferenceGenome& reference, ReadPipe& read_pipe, const OptionMap& options, + boost::optional temp_directory) { if (is_call_filtering_requested(options)) { - if (is_csr_training(options)) { - return std::make_unique(get_training_measures(options)); + const auto caller = get_caller_type(options, read_pipe.samples()); + if (is_set("forest-file", options)) { + auto forest_file = resolve_path(options.at("forest-file").as(), options); + if (!fs::exists(forest_file)) { + throw MissingForestFile {forest_file, "forest-file"}; + } + if (!temp_directory) temp_directory = "/tmp"; + if (caller == "cancer") { + if (is_set("somatic-forest-file", options)) { + auto somatic_forest_file = resolve_path(options.at("somatic-forest-file").as(), options); + if (!fs::exists(somatic_forest_file)) { + throw MissingForestFile {somatic_forest_file, "somatic-forest-file"}; + } + return std::make_unique(forest_file, somatic_forest_file, *temp_directory); + } else if (options.at("somatics-only").as()) { + return std::make_unique(forest_file, *temp_directory, + RandomForestFilterFactory::ForestType::somatic); + } else { + logging::WarningLogger log {}; + log << "Both germline and somatic forests must be provided for random forest cancer variant filtering"; + return nullptr; + } + } else if (caller == "trio") { + if (options.at("denovos-only").as()) { + return std::make_unique(forest_file, *temp_directory, + RandomForestFilterFactory::ForestType::denovo); + } else { + return std::make_unique(forest_file, *temp_directory); + } + } else { + return std::make_unique(forest_file, *temp_directory); + } + } else if (is_set("somatic-forest-file", options)) { + if (options.at("somatics-only").as()) { + auto somatic_forest_file = resolve_path(options.at("somatic-forest-file").as(), options); + if (!fs::exists(somatic_forest_file)) { + throw MissingForestFile {somatic_forest_file, "somatic-forest-file"}; + } + return std::make_unique(somatic_forest_file, *temp_directory, + RandomForestFilterFactory::ForestType::somatic); + } else { + logging::WarningLogger log {}; + log << "Both germline and somatic forests must be provided for random forest cancer variant filtering"; + return nullptr; + } } else { - return std::make_unique(get_filter_expression(options)); + if (is_csr_training(options)) { + return std::make_unique(get_training_measures(options)); + } else { + auto germline_filter_expression = get_germline_filter_expression(options); + if (caller == "cancer") { + if (options.at("somatics-only").as()) { + return std::make_unique("", get_somatic_filter_expression(options), + "", get_refcall_filter_expression(options)); + } else { + return std::make_unique("", germline_filter_expression, + "", get_somatic_filter_expression(options), + "", get_refcall_filter_expression(options)); + } + } else if (caller == "trio") { + auto denovo_filter_expression = get_denovo_filter_expression(options); + if (options.at("denovos-only").as()) { + return std::make_unique("", denovo_filter_expression, + "", get_refcall_filter_expression(options), + true, ThresholdFilterFactory::Type::denovo); + } else { + return std::make_unique("", germline_filter_expression, + "", denovo_filter_expression, + "", get_refcall_filter_expression(options), + ThresholdFilterFactory::Type::denovo); + } + } else { + return std::make_unique(germline_filter_expression); + } + } } - } else { - return nullptr; } + return nullptr; } bool use_calling_read_pipe_for_call_filtering(const OptionMap& options) noexcept @@ -1549,14 +1737,10 @@ ReadPipe make_default_filter_read_pipe(ReadManager& read_manager, std::vector()); filterer.add(make_unique()); filterer.add(make_unique()); - filterer.add(make_unique()); - filterer.add(make_unique>()); - filterer.add(make_unique()); return ReadPipe {read_manager, std::move(transformer), std::move(filterer), boost::none, std::move(samples)}; } -ReadPipe make_call_filter_read_pipe(ReadManager& read_manager, std::vector samples, - const OptionMap& options) +ReadPipe make_call_filter_read_pipe(ReadManager& read_manager, std::vector samples, const OptionMap& options) { if (use_calling_read_pipe_for_call_filtering(options)) { return make_read_pipe(read_manager, std::move(samples), options); @@ -1620,11 +1804,36 @@ bool is_csr_training_mode(const OptionMap& options) boost::optional filter_request(const OptionMap& options) { - if (is_set("filter-vcf", options)) { + if (is_call_filtering_requested(options) && is_set("filter-vcf", options)) { return resolve_path(options.at("filter-vcf").as(), options); } return boost::none; } +boost::optional bamout_request(const OptionMap& options) +{ + if (is_set("bamout", options)) { + return resolve_path(options.at("bamout").as(), options); + } + return boost::none; +} + +unsigned max_open_read_files(const OptionMap& options) +{ + return 2 * std::min(as_unsigned("max-open-read-files", options), count_read_paths(options)); +} + +unsigned estimate_max_open_files(const OptionMap& options) +{ + unsigned result {0}; + result += max_open_read_files(options); + if (get_output_path(options)) result += 2; + result += is_debug_mode(options); + result += is_trace_mode(options); + result += is_call_filtering_requested(options); + result += is_legacy_vcf_requested(options); + return result; +} + } // namespace options } // namespace octopus diff --git a/src/config/option_collation.hpp b/src/config/option_collation.hpp index 1970d0fa8..05e3e693d 100644 --- a/src/config/option_collation.hpp +++ b/src/config/option_collation.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef option_collation_hpp @@ -13,13 +13,15 @@ #include "common.hpp" #include "option_parser.hpp" -#include "utils/memory_footprint.hpp" +#include "basics/ploidy_map.hpp" +#include "core/callers/caller_factory.hpp" +#include "core/csr/filters/variant_call_filter_factory.hpp" #include "io/reference/reference_genome.hpp" #include "io/read/read_manager.hpp" #include "io/variant/vcf_writer.hpp" #include "readpipe/read_pipe.hpp" -#include "core/callers/caller_factory.hpp" -#include "core/csr/filters/variant_call_filter_factory.hpp" +#include "utils/input_reads_profiler.hpp" +#include "utils/memory_footprint.hpp" namespace fs = boost::filesystem; @@ -53,21 +55,25 @@ ReadPipe make_read_pipe(ReadManager& read_manager, std::vector sampl bool call_sites_only(const OptionMap& options); +PloidyMap get_ploidy_map(const OptionMap& options); + +boost::optional get_pedigree(const OptionMap& options, const std::vector& samples); + CallerFactory make_caller_factory(const ReferenceGenome& reference, ReadPipe& read_pipe, - const InputRegionMap& regions, const OptionMap& options); + const InputRegionMap& regions, const OptionMap& options, + boost::optional input_reads_profile = boost::none); bool is_call_filtering_requested(const OptionMap& options) noexcept; -std::unique_ptr make_call_filter_factory(const ReferenceGenome& reference, - ReadPipe& read_pipe, - const OptionMap& options); +std::unique_ptr +make_call_filter_factory(const ReferenceGenome& reference, ReadPipe& read_pipe, const OptionMap& options, + boost::optional temp_directory = boost::none); bool use_calling_read_pipe_for_call_filtering(const OptionMap& options) noexcept; bool keep_unfiltered_calls(const OptionMap& options) noexcept; -ReadPipe make_call_filter_read_pipe(ReadManager& read_manager, std::vector samples, - const OptionMap& options); +ReadPipe make_call_filter_read_pipe(ReadManager& read_manager, std::vector samples, const OptionMap& options); boost::optional get_output_path(const OptionMap& options); @@ -79,6 +85,10 @@ bool is_csr_training_mode(const OptionMap& options); boost::optional filter_request(const OptionMap& options); +boost::optional bamout_request(const OptionMap& options); + +unsigned estimate_max_open_files(const OptionMap& options); + } // namespace options } // namespace octopus diff --git a/src/config/option_parser.cpp b/src/config/option_parser.cpp index f349a698b..dbd556c43 100644 --- a/src/config/option_parser.cpp +++ b/src/config/option_parser.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "option_parser.hpp" @@ -171,6 +171,10 @@ OptionMap parse_options(const int argc, const char** argv) po::value(), "VCF file specifying calls to regenotype, only sites in this files will appear in the" " final output") + + ("bamout", + po::value(), + "Output a realigned BAM file") ; po::options_description transforms("Read transformations"); @@ -258,13 +262,13 @@ OptionMap parse_options(const int argc, const char** argv) po::bool_switch()->default_value(false), "Allows reads marked as supplementary alignments") - ("consider-reads-with-unmapped-segments", + ("no-reads-with-unmapped-segments", po::bool_switch()->default_value(false), - "Allows reads with unmapped template segmenets to be used for calling") + "Filter reads with unmapped template segmenets to be used for calling") - ("consider-reads-with-distant-segments", + ("no-reads-with-distant-segments", po::bool_switch()->default_value(false), - "Allows reads with template segmenets that are on different contigs") + "Filter reads with template segmenets that are on different contigs") ("no-adapter-contaminated-reads", po::bool_switch()->default_value(false), @@ -304,6 +308,10 @@ OptionMap parse_options(const int argc, const char** argv) ("min-source-quality", po::value>()->implicit_value(Phred {2.0}), "Only variants with quality above this value are considered for candidate generation") + + ("extract-filtered-source-candidates", + po::value()->default_value(false), + "Extract variants from source VCF records that have been filtered") ("min-base-quality", po::value()->default_value(20), @@ -339,7 +347,7 @@ OptionMap parse_options(const int argc, const char** argv) po::value()->default_value(200), "The maximum number of bases allowed to overlap assembly regions") - ("force-assemble", + ("assemble-all", po::bool_switch()->default_value(false), "Forces all regions to be assembled") @@ -385,9 +393,17 @@ OptionMap parse_options(const int argc, const char** argv) ("extension-level", po::value()->default_value(ExtensionLevel::normal), "Level of haplotype extension. Possible values are: conservative, normal, optimistic, aggressive") + + ("haplotype-extension-threshold,e", + po::value>()->default_value(Phred {100.0}, "100"), + "Haplotypes with posterior probability less than this can be filtered before extension") + + ("dedup-haplotypes-with-prior-model", + po::value()->default_value(true), + "Remove duplicate haplotypes using mutation prior model") ; - po::options_description caller("Caller (general)"); + po::options_description caller("Calling (general)"); caller.add_options() ("caller,C", po::value()->default_value("population"), @@ -425,41 +441,71 @@ OptionMap parse_options(const int argc, const char** argv) ("snp-heterozygosity,z", po::value()->default_value(0.001, "0.001"), - "SNP heterozygosity for the given samples") + "Germline SNP heterozygosity for the given samples") ("snp-heterozygosity-stdev", po::value()->default_value(0.01, "0.01"), - "Standard deviation of the SNP heterozygosity used for the given samples") + "Standard deviation of the germline SNP heterozygosity used for the given samples") ("indel-heterozygosity,y", po::value()->default_value(0.0001, "0.0001"), - "Indel heterozygosity for the given samples") + "Germline indel heterozygosity for the given samples") ("use-uniform-genotype-priors", po::bool_switch()->default_value(false), "Use a uniform prior model when calculating genotype posteriors") + ("max-genotypes", + po::value()->default_value(5000), + "The maximum number of genotypes to evaluate") + + ("max-joint-genotypes", + po::value()->default_value(1000000), + "The maximum number of joint genotype vectors to consider when computing joint" + " genotype posterior probabilities") + + ("use-independent-genotype-priors", + po::bool_switch()->default_value(false), + "Use independent genotype priors for joint calling") + ("model-posterior", po::value(), "Calculate model posteriors for every call") + + ("inactive-flank-scoring", + po::value()->default_value(true), + "Disables additional calculation to adjust alignment score when there are inactive" + " candidates in haplotype flanking regions") + + ("model-mapping-quality", + po::value()->default_value(true), + "Include the read mapping quality in the haplotype likelihood calculation") + + ("sequence-error-model", + po::value()->default_value("HiSeq"), + "The sequencer error model to use (HiSeq or xTen)") ; - po::options_description cancer("Caller (cancer)"); + po::options_description cancer("Calling (cancer)"); cancer.add_options() ("normal-sample,N", po::value(), "Normal sample - all other samples are considered tumour") + ("max-somatic-haplotypes", + po::value()->default_value(2), + "Maximum number of somatic haplotypes that may be considered") + ("somatic-snv-mutation-rate", - po::value()->default_value(2e-05, "1e-05"), + po::value()->default_value(1e-04, "0.0001"), "Expected SNV somatic mutation rate, per megabase pair, for this sample") ("somatic-indel-mutation-rate", - po::value()->default_value(5e-06, "1e-05"), + po::value()->default_value(1e-06, "0.000001"), "Expected INDEL somatic mutation rate, per megabase pair, for this sample") ("min-expected-somatic-frequency", - po::value()->default_value(0.05, "0.05"), + po::value()->default_value(0.03, "0.03"), "Minimum expected somatic allele frequency in the sample") ("min-credible-somatic-frequency", @@ -467,27 +513,23 @@ OptionMap parse_options(const int argc, const char** argv) "Minimum credible somatic allele frequency that will be reported") ("credible-mass", - po::value()->default_value(0.99, "0.99"), + po::value()->default_value(0.9, "0.9"), "Mass of the posterior density to use for evaluating allele frequencies") ("min-somatic-posterior", po::value>()->default_value(Phred {0.5}), "Minimum posterior probability (phred scale) to emit a somatic mutation call") - ("max-cancer-genotypes", - po::value()->default_value(20000), - "The maximum number of cancer genotype vectors to evaluate") - ("normal-contamination-risk", po::value()->default_value(NormalContaminationRisk::low), "The risk the normal sample has contamination from the tumour") - + ("somatics-only", po::bool_switch()->default_value(false), - "Only emit somatic variant calls") + "Only emit SOMATIC mutations") ; - po::options_description trio("Caller (trio)"); + po::options_description trio("Calling (trio)"); trio.add_options() ("maternal-sample,M", po::value(), @@ -498,20 +540,27 @@ OptionMap parse_options(const int argc, const char** argv) "Paternal sample") ("snv-denovo-mutation-rate", - po::value()->default_value(1e-9, "1e-9"), + po::value()->default_value(1.3e-8, "1.3e-8"), "SNV de novo mutation rate, per base per generation") ("indel-denovo-mutation-rate", - po::value()->default_value(1e-10, "1e-10"), + po::value()->default_value(1e-9, "1e-9"), "INDEL de novo mutation rate, per base per generation") ("min-denovo-posterior", - po::value>()->default_value(Phred {0.5}), + po::value>()->default_value(Phred {3}), "Minimum posterior probability (phred scale) to emit a de novo mutation call") ("denovos-only", po::bool_switch()->default_value(false), - "Only emit de novo variant calls") + "Only emit DENOVO mutations") + ; + + po::options_description polyclone("Calling (polyclone)"); + polyclone.add_options() + ("max-clones", + po::value()->default_value(3), + "Maximum number of unique clones to consider") ; po::options_description phasing("Phasing"); @@ -522,49 +571,32 @@ OptionMap parse_options(const int argc, const char** argv) " of runtime speed. Possible values are: minimal, conservative, moderate, normal, aggressive") ("min-phase-score", - po::value>()->default_value(Phred {20.0}), + po::value>()->default_value(Phred {10.0}), "Minimum phase score (phred scale) required to report sites as phased") - - ("use-unconditional-phase-score", - po::bool_switch()->default_value(false), - "Computes unconditional phase scores rather than conditioning on called genotypes") - ; - - po::options_description advanced("Advanced calling algorithm"); - advanced.add_options() - ("haplotype-extension-threshold,e", - po::value>()->default_value(Phred {100.0}, "100"), - "Haplotypes with posterior probability less than this can be filtered before extension") - - ("inactive-flank-scoring", - po::value()->default_value(true), - "Disables additional calculation to adjust alignment score when there are inactive" - " candidates in haplotype flanking regions") - - ("model-mapping-quality", - po::value()->default_value(true), - "Include the read mapping quality in the haplotype likelihood calculation") - - ("max-joint-genotypes", - po::value()->default_value(1000000), - "The maximum number of joint genotype vectors to consider when computing joint" - " genotype posterior probabilities") - - ("sequence-error-model", - po::value()->default_value("HiSeq"), - "The sequencer error model to use (HiSeq or xTen)") ; - po::options_description call_filtering("Callset filtering"); + po::options_description call_filtering("CSR filtering"); call_filtering.add_options() ("call-filtering,f", po::value()->default_value(true), "Enable all variant call filtering") ("filter-expression", - po::value()->default_value("QUAL < 10 | MQ < 10 | MP < 20 | AF < 0.05 | SB > 0.98 | MQD > 0.9"), + po::value()->default_value("QUAL < 10 | MQ < 10 | MP < 10 | AF < 0.05 | SB > 0.98 | BQ < 15 | RPB > 0.99"), "Boolean expression to use to filter variant calls") + ("somatic-filter-expression", + po::value()->default_value("QUAL < 2 | GQ < 20 | MQ < 30 | SB > 0.9 | BQ < 20 | DP < 3 | MF > 0.2 | SC > 1 | FRF > 0.5"), + "Boolean expression to use to filter somatic variant calls") + + ("denovo-filter-expression", + po::value()->default_value("QUAL < 10 | GQ < 20 | MQ < 30 | SB > 0.9 | BQ < 20 | DP < 3 | DC > 1 | MF > 0.2 | FRF > 0.5"), + "Boolean expression to use to filter somatic variant calls") + + ("refcall-filter-expression", + po::value()->default_value("QUAL < 2 | GQ < 20 | MQ < 10 | DP < 5 | MF > 0.2"), + "Boolean expression to use to filter homozygous reference calls") + ("use-calling-reads-for-filtering", po::value()->default_value(false), "Use the original reads used for variant calling for filtering") @@ -580,12 +612,20 @@ OptionMap parse_options(const int argc, const char** argv) ("filter-vcf", po::value(), "Filter the given Octopus VCF without calling") + + ("forest-file", + po::value(), + "Trained Ranger random forest file") + + ("somatic-forest-file", + po::value(), + "Trained Ranger random forest file for somatic variants") ; po::options_description all("octopus options"); all.add(general).add(backend).add(input).add(transforms).add(filters) .add(variant_generation).add(haplotype_generation).add(caller) - .add(advanced).add(cancer).add(trio).add(phasing).add(call_filtering); + .add(cancer).add(trio).add(polyclone).add(phasing).add(call_filtering); OptionMap vm_init; po::store(run(po::command_line_parser(argc, argv).options(general).allow_unregistered()), vm_init); @@ -593,14 +633,38 @@ OptionMap parse_options(const int argc, const char** argv) if (vm_init.count("help") == 1) { po::store(run(po::command_line_parser(argc, argv).options(caller).allow_unregistered()), vm_init); if (vm_init.count("caller") == 1) { - const auto caller = vm_init.at("caller").as(); + const auto selected_caller = vm_init.at("caller").as(); validate_caller(vm_init); - if (caller == "individual") { - std::cout << all << std::endl; - } else if (caller == "population") { - std::cout << all << std::endl; - } else if (caller == "cancer") { - std::cout << all << std::endl; + if (selected_caller == "individual") { + po::options_description individual_options("octopus individual calling options"); + individual_options.add(general).add(backend).add(input).add(transforms).add(filters) + .add(variant_generation).add(haplotype_generation).add(caller) + .add(phasing).add(call_filtering); + std::cout << individual_options << std::endl; + } else if (selected_caller == "trio") { + po::options_description trio_options("octopus trio calling options"); + trio_options.add(general).add(backend).add(input).add(transforms).add(filters) + .add(variant_generation).add(haplotype_generation).add(caller).add(trio) + .add(phasing).add(call_filtering); + std::cout << trio_options << std::endl; + } else if (selected_caller == "population") { + po::options_description population_options("octopus population calling options"); + population_options.add(general).add(backend).add(input).add(transforms).add(filters) + .add(variant_generation).add(haplotype_generation).add(caller) + .add(phasing).add(call_filtering); + std::cout << population_options << std::endl; + } else if (selected_caller == "cancer") { + po::options_description cancer_options("octopus cancer calling options"); + cancer_options.add(general).add(backend).add(input).add(transforms).add(filters) + .add(variant_generation).add(haplotype_generation).add(caller).add(cancer) + .add(phasing).add(call_filtering); + std::cout << cancer_options << std::endl; + } else if (selected_caller == "polyclone") { + po::options_description polyclone_options("octopus polyclone calling options"); + polyclone_options.add(general).add(backend).add(input).add(transforms).add(filters) + .add(variant_generation).add(haplotype_generation).add(caller).add(polyclone) + .add(phasing).add(call_filtering); + std::cout << polyclone_options << std::endl; } else { std::cout << all << std::endl; } @@ -894,12 +958,11 @@ void validate_caller(const OptionMap& vm) { if (vm.count("caller") == 1) { const auto caller = vm.at("caller").as(); - static const std::array validCallers { - "individual", "population", "cancer", "trio" + static const std::array validCallers { + "individual", "population", "cancer", "trio", "polyclone" }; if (std::find(std::cbegin(validCallers), std::cend(validCallers), caller) == std::cend(validCallers)) { - throw po::validation_error {po::validation_error::kind_t::invalid_option_value, caller, - "caller"}; + throw po::validation_error {po::validation_error::kind_t::invalid_option_value, caller, "caller"}; } } } @@ -940,7 +1003,7 @@ void validate(const OptionMap& vm) "max-open-read-files", "downsample-above", "downsample-target", "max-region-to-assemble", "fallback-kmer-gap", "organism-ploidy", "max-haplotypes", "haplotype-holdout-threshold", "haplotype-overflow", - "max-joint-genotypes" + "max-genotypes", "max-joint-genotypes", "max-somatic-haplotypes", "max-clones" }; const std::vector probability_options { "snp-heterozygosity", "snp-heterozygosity-stdev", "indel-heterozygosity", diff --git a/src/config/option_parser.hpp b/src/config/option_parser.hpp index d0ab5d2f9..f18e33b7f 100644 --- a/src/config/option_parser.hpp +++ b/src/config/option_parser.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef option_parser_hpp diff --git a/src/containers/mappable_flat_multi_set.hpp b/src/containers/mappable_flat_multi_set.hpp index e8447e52b..27d411de3 100644 --- a/src/containers/mappable_flat_multi_set.hpp +++ b/src/containers/mappable_flat_multi_set.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_set_hpp diff --git a/src/containers/mappable_flat_set.hpp b/src/containers/mappable_flat_set.hpp index ec450adef..63d3c5602 100644 --- a/src/containers/mappable_flat_set.hpp +++ b/src/containers/mappable_flat_set.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_flat_set_hpp @@ -597,15 +597,17 @@ MappableFlatSet::erase(const_iterator first, const_iter return result; } -namespace { - template - BidirIt binary_find(BidirIt first, BidirIt last, const T& value) - { - const auto it = std::lower_bound(first, last, value); - return (it != last && *it == value) ? it : last; - } +namespace detail { + +template +BidirIt binary_find(BidirIt first, BidirIt last, const T& value) +{ + const auto it = std::lower_bound(first, last, value); + return (it != last && *it == value) ? it : last; } +} // namespace detail + template template typename MappableFlatSet::size_type @@ -626,7 +628,7 @@ MappableFlatSet::erase_all(BidirIt first, const BidirIt auto last_element = std::end(elements_); while (first != last) { - const auto it = binary_find(first_contained, last_contained, *first); + const auto it = detail::binary_find(first_contained, last_contained, *first); if (it != last_contained) { const auto p = std::mismatch(std::next(it), last_contained, std::next(first), last); const auto n = std::distance(p.first, last_contained); diff --git a/src/containers/mappable_map.hpp b/src/containers/mappable_map.hpp index 3bceee166..1866db1d9 100644 --- a/src/containers/mappable_map.hpp +++ b/src/containers/mappable_map.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_map_hpp diff --git a/src/containers/matrix_map.hpp b/src/containers/matrix_map.hpp index ba5f4817a..1496a9dd7 100644 --- a/src/containers/matrix_map.hpp +++ b/src/containers/matrix_map.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef matrix_map_hpp @@ -495,6 +495,7 @@ typename KeyEqual2 = std::equal_to ZipIterator cbegin() const { return begin_; } ZipIterator cend() const { return end_; } + bool empty() const noexcept { return key2_indices_.get().empty(); } IndexSizeType size() const noexcept { return key2_indices_.get().size(); } // T& operator[](const Key2& key) diff --git a/src/containers/probability_matrix.hpp b/src/containers/probability_matrix.hpp index e7186aef6..8b10b801e 100644 --- a/src/containers/probability_matrix.hpp +++ b/src/containers/probability_matrix.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef probability_matrix_hpp diff --git a/src/core/callers/caller.cpp b/src/core/callers/caller.cpp index eccf6534a..097451a96 100644 --- a/src/core/callers/caller.cpp +++ b/src/core/callers/caller.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "caller.hpp" @@ -25,6 +25,9 @@ #include "timers.hpp" +#include "core/tools/read_assigner.hpp" +#include "core/tools/read_realigner.hpp" + namespace octopus { // public methods @@ -62,6 +65,16 @@ Caller::CallTypeSet Caller::call_types() const return do_call_types(); } +unsigned Caller::min_callable_ploidy() const +{ + return do_min_callable_ploidy(); +} + +unsigned Caller::max_callable_ploidy() const +{ + return do_max_callable_ploidy(); +} + namespace debug { template @@ -152,7 +165,7 @@ std::deque Caller::call(const GenomicRegion& call_region, ProgressMet } if (!candidate_generator_.requires_reads()) { // as we didn't fetch them earlier - reads = read_pipe_.get().fetch_reads(extract_regions(candidates)); + reads = read_pipe_.get().fetch_reads(call_region); } pause(init_timer); auto calls = call_variants(call_region, candidates, reads, progress_meter); @@ -163,12 +176,19 @@ std::deque Caller::call(const GenomicRegion& call_region, ProgressMet if (debug_log_) stream(*debug_log_) << "Converting " << calls.size() << " calls made in " << call_region << " to VCF"; return convert_to_vcf(std::move(calls), record_factory, call_region); } - + std::vector Caller::regenotype(const std::vector& variants, ProgressMeter& progress_meter) const { return {}; // TODO } +auto assign_and_realign(const std::vector& reads, const Genotype& genotype) +{ + auto result = compute_haplotype_support(genotype, reads, {AssignmentConfig::AmbiguousAction::first}); + for (auto& p : result) realign_to_reference(p.second, p.first); + return result; +} + // private methods namespace debug { @@ -263,10 +283,17 @@ std::deque Caller::call_variants(const GenomicRegion& call_region, const MappableFlatSet& candidates, const ReadMap& reads, ProgressMeter& progress_meter) const { - auto haplotype_generator = make_haplotype_generator(candidates, reads); + std::deque result {}; auto haplotype_likelihoods = make_haplotype_likelihood_cache(); + if (candidates.empty()) { + if (refcalls_requested()) { + utils::append(call_reference(call_region, reads), result); + } + progress_meter.log_completed(call_region); + return result; + } + auto haplotype_generator = make_haplotype_generator(candidates, reads); GeneratorStatus status; - std::deque result {}; std::vector haplotypes {}, next_haplotypes {}; GenomicRegion active_region; boost::optional next_active_region {}, prev_called_region {}; @@ -276,6 +303,11 @@ Caller::call_variants(const GenomicRegion& call_region, const MappableFlatSet& haplotypes) const { - return unique_least_complex(haplotypes, Haplotype {haplotype_region(haplotypes), reference_.get()}); + return octopus::remove_duplicates(haplotypes, Haplotype {haplotype_region(haplotypes), reference_.get()}); } Caller::GeneratorStatus @@ -622,25 +654,27 @@ void Caller::call_variants(const GenomicRegion& active_region, const auto uncalled_region = get_uncalled_region(active_region, passed_region, completed_region); auto active_candidates = extract_callable_variants(candidates, uncalled_region, prev_called_region, next_active_region, backtrack_region); - std::vector called_regions; + std::vector calls {}; if (!active_candidates.empty()) { if (debug_log_) stream(*debug_log_) << "Calling variants in region " << uncalled_region; resume(calling_timer); - auto variant_calls = wrap(call_variants(active_candidates, latents)); + calls = wrap(call_variants(active_candidates, latents)); pause(calling_timer); - if (!variant_calls.empty()) { - set_model_posteriors(variant_calls, latents, haplotypes, haplotype_likelihoods); - called_regions = extract_covered_regions(variant_calls); - set_phasing(variant_calls, latents, haplotypes, call_region); - utils::append(std::move(variant_calls), result); + if (!calls.empty()) { + set_model_posteriors(calls, latents, haplotypes, haplotype_likelihoods); + set_phasing(calls, latents, haplotypes, call_region); } } - prev_called_region = uncalled_region; if (refcalls_requested()) { - auto alleles = generate_candidate_reference_alleles(uncalled_region, active_candidates, called_regions); - auto reference_calls = wrap(call_reference(alleles, latents, reads)); - utils::append(std::move(reference_calls), result); + const auto refcall_region = right_overhang_region(uncalled_region, completed_region); + const auto pileups = make_pileups(reads, latents, refcall_region); + auto alleles = generate_reference_alleles(refcall_region, active_candidates, calls); + auto reference_calls = wrap(call_reference(alleles, latents, pileups)); + const auto itr = utils::append(std::move(reference_calls), calls); + std::inplace_merge(std::begin(calls), itr, std::end(calls)); } + utils::append(std::move(calls), result); + prev_called_region = uncalled_region; completed_region = encompassing_region(completed_region, passed_region); } } @@ -708,14 +742,14 @@ void set_phase(const SampleName& sample, const Phaser::PhaseSet::PhaseRegion& ph } } -void set_phasing(std::vector& calls, const Phaser::PhaseSet& phase_set, - const GenomicRegion& calling_region) +void set_phasing(std::vector& calls, const Phaser::PhaseSet& phase_set, const GenomicRegion& calling_region) { if (!calls.empty()) { const auto call_regions = extract_regions(calls); for (auto& call : calls) { const auto& call_region = mapped_region(call); for (const auto& p : phase_set.phase_regions) { + const SampleName& sample {p.first}; const auto phase = find_phase_region(p.second, call_region); if (phase && overlaps(calling_region, phase->get().region)) { if (begins_before(phase->get().region, calling_region)) { @@ -726,10 +760,10 @@ void set_phasing(std::vector& calls, const Phaser::PhaseSet& phase_ expand_lhs(phase->get().region, begin_distance(output_calls.front(), phase->get().region)), phase->get().score }; - set_phase(p.first, clipped_phase, call_regions, call); + set_phase(sample, clipped_phase, call_regions, call); } } else { - set_phase(p.first, *phase, call_regions, call); + set_phase(sample, *phase, call_regions, call); } } } @@ -958,150 +992,148 @@ bool Caller::done_calling(const GenomicRegion& region) const noexcept return is_empty(region); } -std::vector -Caller::generate_callable_alleles(const GenomicRegion& region, const std::vector& candidates) const -{ - using std::begin; using std::end; using std::make_move_iterator; using std::back_inserter; - auto overlapped_candidates = copy_overlapped(candidates, region); - if (is_empty(region) && overlapped_candidates.empty()) return {}; - if (overlapped_candidates.empty()) { - switch (parameters_.refcall_type) { - case RefCallType::positional: - return make_positional_reference_alleles(region, reference_); - case RefCallType::blocked: - return std::vector {make_reference_allele(region, reference_)}; - default: - return {}; - } - } - auto variant_alleles = decompose(overlapped_candidates); - if (parameters_.refcall_type == RefCallType::none) return variant_alleles; - auto covered_regions = extract_covered_regions(overlapped_candidates); - auto uncovered_regions = extract_intervening_regions(covered_regions, region); - std::vector result {}; - if (parameters_.refcall_type == Caller::RefCallType::blocked) { - auto reference_alleles = make_reference_alleles(uncovered_regions, reference_); - result.reserve(reference_alleles.size() + variant_alleles.size()); - std::merge(make_move_iterator(begin(reference_alleles)), - make_move_iterator(end(reference_alleles)), - make_move_iterator(begin(variant_alleles)), - make_move_iterator(end(variant_alleles)), - back_inserter(result)); - } else { - result.reserve(variant_alleles.size() + sum_region_sizes(uncovered_regions)); - auto uncovered_itr = begin(uncovered_regions); - auto uncovered_end = end(uncovered_regions); - for (auto&& variant_allele : variant_alleles) { - if (uncovered_itr != uncovered_end && begins_before(*uncovered_itr, variant_allele)) { - auto alleles = make_positional_reference_alleles(*uncovered_itr, reference_); - result.insert(end(result), - make_move_iterator(begin(alleles)), - make_move_iterator(end(alleles))); - std::advance(uncovered_itr, 1); - } - result.push_back(std::move(variant_allele)); - } - if (uncovered_itr != uncovered_end) { - auto alleles = make_positional_reference_alleles(*uncovered_itr, reference_); - result.insert(end(result), - make_move_iterator(begin(alleles)), - make_move_iterator(end(alleles))); - } - } +std::vector Caller::call_reference(const GenomicRegion& region, const ReadMap& reads) const +{ + const auto active_reads = copy_overlapped(reads, region); + const auto active_reads_region = encompassing_region(active_reads); + const auto haplotype_region = expand(active_reads_region, HaplotypeLikelihoodModel{}.pad_requirement()); + const std::vector haplotypes {{haplotype_region, reference_}}; + auto haplotype_likelihoods = make_haplotype_likelihood_cache(); + haplotype_likelihoods.populate(active_reads, haplotypes); + const auto latents = infer_latents(haplotypes, haplotype_likelihoods); + const auto pileups = make_pileups(active_reads, *latents, region); + const auto alleles = generate_reference_alleles(region); + return wrap(call_reference(alleles, *latents, pileups)); +} + +namespace { + +template +auto extract_unique_regions(const Container& mappables) +{ + auto result = extract_regions(mappables); + result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); return result; } -template -ForwardIt find_next(ForwardIt first, ForwardIt last, const Variant& candidate) +template +auto set_difference(std::vector&& first, const std::vector& second) { - return std::find_if_not(first, last, - [&] (const Allele& allele) { - return is_same_region(allele, candidate); - }); + std::vector result {}; + result.reserve(first.size()); + std::set_difference(std::make_move_iterator(std::begin(first)), + std::make_move_iterator(std::end(first)), + std::cbegin(second), std::cend(second), + std::back_inserter(result)); + return result; } -void append_allele(std::vector& alleles, const Allele& allele, - const Caller::RefCallType refcall_type) +auto extract_uncalled_candidate_regions(const std::vector& candidates, + const std::vector& calls) { - if (refcall_type == Caller::RefCallType::blocked && !alleles.empty() - && are_adjacent(alleles.back(), allele)) { - alleles.back() = Allele {encompassing_region(alleles.back(), allele), - alleles.back().sequence() + allele.sequence()}; - } else { - alleles.push_back(allele); + auto uncalled_regions = set_difference(extract_unique_regions(candidates), extract_unique_regions(calls)); + return extract_covered_regions(std::move(uncalled_regions)); +} + +template +auto merge(std::vector&& first, std::vector&& second) +{ + std::vector result {}; + result.reserve(first.size() + second.size()); + using std::make_move_iterator; using std::begin; using std::end; + std::merge(make_move_iterator(begin(first)), make_move_iterator(end(first)), + make_move_iterator(begin(second)), make_move_iterator(end(second)), + std::back_inserter(result)); + return result; +} + +auto extract_uncalled_reference_regions(const GenomicRegion& region, + const std::vector& candidates, + const std::vector& calls) +{ + auto uncalled_candidate_regions = extract_uncalled_candidate_regions(candidates, calls); + auto noncandidate_regions = extract_intervening_regions(extract_covered_regions(candidates), region); + auto result = merge(std::move(uncalled_candidate_regions), std::move(noncandidate_regions)); + result.erase(std::remove_if(std::begin(result), std::end(result), + [] (const auto& region) { return is_empty(region); }), + std::end(result)); + return result; +} + +auto make_positional_reference_alleles(const std::vector& regions, const ReferenceGenome& reference) +{ + std::vector result {}; + result.reserve(sum_region_sizes(regions)); + for (const auto& region : regions) { + utils::append(make_positional_reference_alleles(region, reference), result); } + return result; } -// TODO: we should catch the case where an insertion has been called and push the refcall -// block up a position, otherwise the returned reference allele (block) will never be called. +} // namespace + std::vector -Caller::generate_candidate_reference_alleles(const GenomicRegion& region, - const std::vector& candidates, - const std::vector& called_regions) const -{ - using std::cbegin; using std::cend; - auto callable_alleles = generate_callable_alleles(region, candidates); - if (callable_alleles.empty() || parameters_.refcall_type == RefCallType::none) return {}; - if (candidates.empty()) return callable_alleles; - auto allele_itr = cbegin(callable_alleles); - auto allele_end_itr = cend(callable_alleles); - auto called_itr = cbegin(called_regions); - auto called_end_itr = cend(called_regions); - auto candidate_itr = cbegin(candidates); - auto candidate_end_itr = cend(candidates); - std::vector result {}; - result.reserve(callable_alleles.size()); - while (allele_itr != allele_end_itr) { - if (candidate_itr == candidate_end_itr) { - append_allele(result, *allele_itr, parameters_.refcall_type); - std::copy(std::next(allele_itr), allele_end_itr, std::back_inserter(result)); - break; - } - if (called_itr == called_end_itr) { - append_allele(result, *allele_itr, parameters_.refcall_type); - if (begins_before(*allele_itr, *candidate_itr)) { - ++allele_itr; - } else { - allele_itr = find_next(allele_itr, allele_end_itr, *candidate_itr); - ++candidate_itr; - } - } else { - if (is_same_region(*called_itr, *candidate_itr)) { // called candidate - while (is_before(*allele_itr, *called_itr)) { - append_allele(result, *allele_itr, parameters_.refcall_type); - ++allele_itr; - } - allele_itr = find_next(allele_itr, allele_end_itr, *candidate_itr); - ++candidate_itr; - ++called_itr; - } else if (begins_before(*called_itr, *candidate_itr)) { // parsimonised called candidate - if (!overlaps(*allele_itr, *called_itr)) { - append_allele(result, *allele_itr, parameters_.refcall_type); - ++allele_itr; - } else { - if (begins_before(*allele_itr, *called_itr)) { // when variant has been left padded - append_allele(result, copy(*allele_itr, left_overhang_region(*allele_itr, *called_itr)), - parameters_.refcall_type); - } - // skip contained alleles and candidates as they include called variants - allele_itr = cend(contained_range(allele_itr, allele_end_itr, *called_itr)).base(); - candidate_itr = cend(contained_range(candidate_itr, candidate_end_itr, *called_itr)).base(); - ++called_itr; - } - } else { - append_allele(result, *allele_itr, parameters_.refcall_type); - if (begins_before(*allele_itr, *candidate_itr)) { - ++allele_itr; - } else { - allele_itr = find_next(allele_itr, allele_end_itr, *candidate_itr); - ++candidate_itr; - } +Caller::generate_reference_alleles(const GenomicRegion& region, + const std::vector& candidates, + const std::vector& calls) const +{ + auto refcall_regions = extract_uncalled_reference_regions(region, candidates, calls); + if (parameters_.refcall_type == RefCallType::positional) { + return make_positional_reference_alleles(std::move(refcall_regions), reference_); + } else { + return make_reference_alleles(std::move(refcall_regions), reference_); + } +} + +std::vector Caller::generate_reference_alleles(const GenomicRegion& region) const +{ + return generate_reference_alleles(region, {}, {}); +} + +namespace { + +auto overlap_range(std::vector& pileups, const AlignedRead& read) +{ + return overlap_range(std::begin(pileups), std::end(pileups), contig_region(read), BidirectionallySortedTag {}); +} + +} // namespace + +auto make_pileups(const std::vector& reads, const Genotype& genotype, const GenomicRegion& region) +{ + const auto realignments = assign_and_realign(reads, genotype); + ReadPileups result {}; + result.reserve(size(region)); + for (auto position = region.begin(); position < region.end(); ++position) { + result.emplace_back(position); + } + for (const auto& p : realignments) { + for (const auto& read : p.second) { + for (ReadPileup& pileup : overlap_range(result, read)) { + pileup.add(read); } } } return result; } +auto make_pileups(const ReadContainer& reads, const Genotype& genotype, const GenomicRegion& region) +{ + const std::vector copy {std::cbegin(reads), std::cend(reads)}; + return make_pileups(copy, genotype, region); +} + +Caller::ReadPileupMap Caller::make_pileups(const ReadMap& reads, const Latents& latents, const GenomicRegion& region) const +{ + ReadPileupMap result {}; + result.reserve(samples_.size()); + for (const auto& sample : samples_) { + const auto called_genotype = call_genotype(latents, sample); + result.emplace(sample, octopus::make_pileups(reads.at(sample), called_genotype, region)); + } + return result; +} + namespace debug { template diff --git a/src/core/callers/caller.hpp b/src/core/callers/caller.hpp index 11420ce98..646494c29 100644 --- a/src/core/callers/caller.hpp +++ b/src/core/callers/caller.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef caller_hpp @@ -28,6 +28,7 @@ #include "logging/logging.hpp" #include "io/variant/vcf_record.hpp" #include "core/tools/vcf_record_factory.hpp" +#include "basics/read_pileup.hpp" namespace octopus { @@ -63,6 +64,9 @@ class Caller CallTypeSet call_types() const; + unsigned min_callable_ploidy() const; + unsigned max_callable_ploidy() const; + std::deque call(const GenomicRegion& call_region, ProgressMeter& progress_meter) const; std::vector regenotype(const std::vector& variants, ProgressMeter& progress_meter) const; @@ -122,14 +126,20 @@ class Caller HaplotypeLikelihoodModel likelihood_model_; Phaser phaser_; Parameters parameters_; - + // virtual methods virtual std::string do_name() const = 0; virtual CallTypeSet do_call_types() const = 0; - + virtual unsigned do_min_callable_ploidy() const { return 1; } + virtual unsigned do_max_callable_ploidy() const { return max_callable_ploidy(); }; + +protected: virtual std::size_t do_remove_duplicates(std::vector& haplotypes) const; + using ReadPileupMap = std::unordered_map; + +private: virtual std::unique_ptr infer_latents(const std::vector& haplotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods) const = 0; @@ -146,7 +156,7 @@ class Caller virtual std::vector> call_reference(const std::vector& alleles, const Latents& latents, - const ReadMap& reads) const = 0; + const ReadPileupMap& pileups) const = 0; // helper methods @@ -200,11 +210,13 @@ class Caller void set_phasing(std::vector& calls, const Latents& latents, const std::vector& haplotypes, const GenomicRegion& call_region) const; bool done_calling(const GenomicRegion& region) const noexcept; + std::vector call_reference(const GenomicRegion& region, const ReadMap& reads) const; std::vector - generate_callable_alleles(const GenomicRegion& region, const std::vector& candidates) const; - std::vector - generate_candidate_reference_alleles(const GenomicRegion& region, const std::vector& candidates, - const std::vector& called_regions) const; + generate_reference_alleles(const GenomicRegion& region, + const std::vector& candidates, + const std::vector& calls) const; + std::vector generate_reference_alleles(const GenomicRegion& region) const; + ReadPileupMap make_pileups(const ReadMap& reads, const Latents& latents, const GenomicRegion& region) const; }; } // namespace octopus diff --git a/src/core/callers/caller_builder.cpp b/src/core/callers/caller_builder.cpp index 102a4997e..352c43c42 100644 --- a/src/core/callers/caller_builder.cpp +++ b/src/core/callers/caller_builder.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "caller_builder.hpp" @@ -7,6 +7,7 @@ #include "individual_caller.hpp" #include "population_caller.hpp" #include "trio_caller.hpp" +#include "polyclone_caller.hpp" namespace octopus { @@ -150,6 +151,12 @@ CallerBuilder& CallerBuilder::set_indel_heterozygosity(double heterozygosity) no return *this; } +CallerBuilder& CallerBuilder::set_max_genotypes(unsigned max) noexcept +{ + params_.max_genotypes = max; + return *this; +} + CallerBuilder& CallerBuilder::set_max_joint_genotypes(unsigned max) noexcept { params_.max_joint_genotypes = max; @@ -162,6 +169,18 @@ CallerBuilder& CallerBuilder::set_likelihood_model(HaplotypeLikelihoodModel mode return *this; } +CallerBuilder& CallerBuilder::set_model_based_haplotype_dedup(bool use) noexcept +{ + params_.deduplicate_haplotypes_with_caller_model = use; + return *this; +} + +CallerBuilder& CallerBuilder::set_independent_genotype_prior_flag(bool use_independent) noexcept +{ + params_.use_independent_genotype_priors = use_independent; + return *this; +} + // cancer CallerBuilder& CallerBuilder::set_normal_sample(SampleName normal_sample) @@ -170,6 +189,12 @@ CallerBuilder& CallerBuilder::set_normal_sample(SampleName normal_sample) return *this; } +CallerBuilder& CallerBuilder::set_max_somatic_haplotypes(unsigned n) noexcept +{ + params_.max_somatic_haplotypes = n; + return *this; +} + CallerBuilder& CallerBuilder::set_somatic_snv_mutation_rate(double rate) noexcept { params_.somatic_snv_mutation_rate = rate; @@ -236,6 +261,12 @@ CallerBuilder& CallerBuilder::set_indel_denovo_mutation_rate(double rate) noexce return *this; } +CallerBuilder& CallerBuilder::set_max_clones(unsigned n) noexcept +{ + params_.max_clones = n; + return *this; +} + std::unique_ptr CallerBuilder::build(const ContigName& contig) const { if (factory_.count(caller_) == 0) { @@ -314,7 +345,8 @@ CallerBuilder::CallerFactoryMap CallerBuilder::generate_factory() const params_.ploidies.of(samples.front(), *requested_contig_), make_individual_prior_model(params_.snp_heterozygosity, params_.indel_heterozygosity), params_.min_variant_posterior, - params_.min_refcall_posterior + params_.min_refcall_posterior, + params_.deduplicate_haplotypes_with_caller_model }); }}, {"population", [this, &samples] () { @@ -326,6 +358,8 @@ CallerBuilder::CallerFactoryMap CallerBuilder::generate_factory() const get_ploidies(samples, *requested_contig_, params_.ploidies), make_population_prior_model(params_.snp_heterozygosity, params_.indel_heterozygosity), params_.max_joint_genotypes, + params_.use_independent_genotype_priors, + params_.deduplicate_haplotypes_with_caller_model }); }}, {"cancer", [this, &samples] () { @@ -342,8 +376,10 @@ CallerBuilder::CallerFactoryMap CallerBuilder::generate_factory() const params_.min_expected_somatic_frequency, params_.credible_mass, params_.min_credible_somatic_frequency, - params_.max_joint_genotypes, - params_.normal_contamination_risk + params_.max_genotypes, + params_.max_somatic_haplotypes, + params_.normal_contamination_risk, + params_.deduplicate_haplotypes_with_caller_model }); }}, {"trio", [this] () { @@ -359,8 +395,21 @@ CallerBuilder::CallerFactoryMap CallerBuilder::generate_factory() const params_.min_variant_posterior, params_.min_denovo_posterior, params_.min_refcall_posterior, - params_.max_joint_genotypes + params_.max_joint_genotypes, + params_.deduplicate_haplotypes_with_caller_model }); + }}, + {"polyclone", [this] () { + return std::make_unique(make_components(), + params_.general, + PolycloneCaller::Parameters { + make_individual_prior_model(params_.snp_heterozygosity, params_.indel_heterozygosity), + params_.min_variant_posterior, + params_.min_refcall_posterior, + params_.deduplicate_haplotypes_with_caller_model, + params_.max_clones, + params_.max_genotypes + }); }} }; } diff --git a/src/core/callers/caller_builder.hpp b/src/core/callers/caller_builder.hpp index 065d28820..965aa6b2b 100644 --- a/src/core/callers/caller_builder.hpp +++ b/src/core/callers/caller_builder.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef caller_builder_hpp @@ -57,11 +57,15 @@ class CallerBuilder CallerBuilder& set_min_phase_score(Phred score) noexcept; CallerBuilder& set_snp_heterozygosity(double heterozygosity) noexcept; CallerBuilder& set_indel_heterozygosity(double heterozygosity) noexcept; + CallerBuilder& set_max_genotypes(unsigned max) noexcept; CallerBuilder& set_max_joint_genotypes(unsigned max) noexcept; CallerBuilder& set_likelihood_model(HaplotypeLikelihoodModel model) noexcept; + CallerBuilder& set_model_based_haplotype_dedup(bool use) noexcept; + CallerBuilder& set_independent_genotype_prior_flag(bool use_independent) noexcept; // cancer CallerBuilder& set_normal_sample(SampleName normal_sample); + CallerBuilder& set_max_somatic_haplotypes(unsigned n) noexcept; CallerBuilder& set_somatic_snv_mutation_rate(double rate) noexcept; CallerBuilder& set_somatic_indel_mutation_rate(double rate) noexcept; CallerBuilder& set_min_expected_somatic_frequency(double frequency) noexcept; @@ -76,6 +80,9 @@ class CallerBuilder CallerBuilder& set_snv_denovo_mutation_rate(double rate) noexcept; CallerBuilder& set_indel_denovo_mutation_rate(double rate) noexcept; + // prokaryote + CallerBuilder& set_max_clones(unsigned n) noexcept; + // pedigree CallerBuilder& set_pedigree(Pedigree pedigree); @@ -100,10 +107,13 @@ class CallerBuilder Phred min_variant_posterior, min_refcall_posterior; boost::optional snp_heterozygosity, indel_heterozygosity; Phred min_phase_score; - unsigned max_joint_genotypes; + unsigned max_genotypes, max_joint_genotypes; + bool deduplicate_haplotypes_with_caller_model; + bool use_independent_genotype_priors; // cancer boost::optional normal_sample; + unsigned max_somatic_haplotypes; double somatic_snv_mutation_rate, somatic_indel_mutation_rate; double min_expected_somatic_frequency; double credible_mass; @@ -117,6 +127,9 @@ class CallerBuilder Phred min_denovo_posterior; boost::optional snv_denovo_mutation_rate, indel_denovo_mutation_rate; + // prokaryote + unsigned max_clones; + // pedigree boost::optional pedigree; }; diff --git a/src/core/callers/caller_factory.cpp b/src/core/callers/caller_factory.cpp index 44d75e9e5..7f2264b59 100644 --- a/src/core/callers/caller_factory.cpp +++ b/src/core/callers/caller_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "caller_factory.hpp" diff --git a/src/core/callers/caller_factory.hpp b/src/core/callers/caller_factory.hpp index 6f5fa926e..297e898e2 100644 --- a/src/core/callers/caller_factory.hpp +++ b/src/core/callers/caller_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef caller_factory_hpp diff --git a/src/core/callers/cancer_caller.cpp b/src/core/callers/cancer_caller.cpp index 661614f74..c2216d676 100644 --- a/src/core/callers/cancer_caller.cpp +++ b/src/core/callers/cancer_caller.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "cancer_caller.hpp" @@ -68,6 +68,8 @@ CancerCaller::CancerCaller(Caller::Components&& components, } } +// private methods + std::string CancerCaller::do_name() const { return "cancer"; @@ -81,7 +83,15 @@ CancerCaller::CallTypeSet CancerCaller::do_call_types() const }; } -// private methods +unsigned CancerCaller::do_min_callable_ploidy() const +{ + return parameters_.ploidy; +} + +unsigned CancerCaller::do_max_callable_ploidy() const +{ + return parameters_.ploidy + parameters_.max_somatic_haplotypes; +} bool CancerCaller::has_normal_sample() const noexcept { @@ -93,21 +103,36 @@ const SampleName& CancerCaller::normal_sample() const return *parameters_.normal_sample; } +std::size_t CancerCaller::do_remove_duplicates(std::vector& haplotypes) const +{ + if (parameters_.deduplicate_haplotypes_with_germline_model) { + if (haplotypes.size() < 2) return 0; + CoalescentModel::Parameters model_params {}; + if (parameters_.germline_prior_model_params) model_params = *parameters_.germline_prior_model_params; + Haplotype reference {mapped_region(haplotypes.front()), reference_.get()}; + CoalescentModel model {std::move(reference), model_params, haplotypes.size(), CoalescentModel::CachingStrategy::none}; + const CoalescentProbabilityGreater cmp {std::move(model)}; + return octopus::remove_duplicates(haplotypes, cmp); + } else { + return Caller::do_remove_duplicates(haplotypes); + } +} + std::unique_ptr CancerCaller::infer_latents(const std::vector& haplotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { // Store any intermediate results in Latents for reuse, so the order of model evaluation matters! - auto result = std::make_unique(haplotypes, samples_, get_model_priors()); + auto result = std::make_unique(haplotypes, samples_); + set_model_priors(*result); + if (has_normal_sample()) result->normal_sample_ = std::cref(normal_sample()); generate_germline_genotypes(*result, haplotypes); if (debug_log_) stream(*debug_log_) << "There are " << result->germline_genotypes_.size() << " candidate germline genotypes"; evaluate_germline_model(*result, haplotype_likelihoods); evaluate_cnv_model(*result, haplotype_likelihoods); - generate_cancer_genotypes(*result, haplotype_likelihoods); - if (debug_log_) stream(*debug_log_) << "There are " << result->cancer_genotypes_.size() << " candidate cancer genotypes"; - if (has_normal_sample()) result->normal_sample_ = std::cref(normal_sample()); - evaluate_tumour_model(*result, haplotype_likelihoods); + fit_tumour_model(*result, haplotype_likelihoods); evaluate_noise_model(*result, haplotype_likelihoods); + set_model_posteriors(*result); return result; } @@ -120,10 +145,49 @@ CancerCaller::calculate_model_posterior(const std::vector& haplotypes dynamic_cast(latents)); } +void CancerCaller::fit_tumour_model(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const +{ + model::TumourModel::InferredLatents prev_tumour_latents; + std::vector> prev_cancer_genotypes; + boost::optional> prev_cancer_genotype_indices; + for (unsigned somatic_ploidy {1}; somatic_ploidy <= parameters_.max_somatic_haplotypes; ++somatic_ploidy) { + if (debug_log_) stream(*debug_log_) << "Fitting tumour model with somatic ploidy " << somatic_ploidy; + latents.somatic_ploidy_ = somatic_ploidy; + generate_cancer_genotypes(latents, haplotype_likelihoods); + if (debug_log_) stream(*debug_log_) << "There are " << latents.cancer_genotypes_.size() << " candidate cancer genotypes"; + evaluate_tumour_model(latents, haplotype_likelihoods); + if (somatic_ploidy > 1) { + if (latents.tumour_model_inferences_.approx_log_evidence <= prev_tumour_latents.approx_log_evidence) { + break; + } + } else if (latents.tumour_model_inferences_.approx_log_evidence < std::max(latents.germline_model_inferences_.log_evidence, + latents.cnv_model_inferences_.approx_log_evidence)) { + break; + } + if (latents.haplotypes_.get().size() <= somatic_ploidy + 1) break; + if (somatic_ploidy < parameters_.max_somatic_haplotypes) { + // save previous state + prev_tumour_latents = std::move(latents.tumour_model_inferences_); + prev_cancer_genotypes = std::move(latents.cancer_genotypes_); + prev_cancer_genotype_indices = std::move(latents.cancer_genotype_indices_); + } + } + if (latents.somatic_ploidy_ > 1) { + if (latents.tumour_model_inferences_.approx_log_evidence <= prev_tumour_latents.approx_log_evidence) { + // load previous state + --latents.somatic_ploidy_; + latents.tumour_model_inferences_ = std::move(prev_tumour_latents); + latents.cancer_genotypes_ = std::move(prev_cancer_genotypes); + latents.cancer_genotype_indices_ = std::move(prev_cancer_genotype_indices); + } + } + if (debug_log_) stream(*debug_log_) << "Using tumour model with somatic ploidy " << latents.somatic_ploidy_; +} + static double calculate_model_posterior(const double normal_germline_model_log_evidence, const double normal_dummy_model_log_evidence) { - constexpr double normalModelPrior {0.999}; + constexpr double normalModelPrior {0.99}; constexpr double dummyModelPrior {1.0 - normalModelPrior}; const auto normal_model_ljp = std::log(normalModelPrior) + normal_germline_model_log_evidence; const auto dummy_model_ljp = std::log(dummyModelPrior) + normal_dummy_model_log_evidence; @@ -135,7 +199,7 @@ static double calculate_model_posterior(const double germline_model_log_evidence const double dummy_model_log_evidence, const double noise_model_log_evidence) { - constexpr double normalModelPrior {0.999}; + constexpr double normalModelPrior {0.99}; constexpr double dummyModelPrior {1.0 - normalModelPrior}; const auto normal_model_ljp = std::log(normalModelPrior) + germline_model_log_evidence; const auto dummy_model_ljp = std::log(dummyModelPrior) + dummy_model_log_evidence; @@ -144,6 +208,19 @@ static double calculate_model_posterior(const double germline_model_log_evidence return std::exp(normal_model_ljp - norm); } +namespace { + +auto demote_each(const std::vector>& genotypes) +{ + std::vector> result {}; + result.reserve(genotypes.size()); + std::transform(std::cbegin(genotypes), std::cend(genotypes), std::back_inserter(result), + [] (const auto& genotype) { return demote(genotype); }); + return result; +} + +} // namespace + boost::optional CancerCaller::calculate_model_posterior(const std::vector& haplotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods, @@ -159,14 +236,16 @@ CancerCaller::calculate_model_posterior(const std::vector& haplotypes } else { normal_inferences = germline_model.evaluate(latents.germline_genotypes_, haplotype_likelihoods); } - const auto dummy_genotypes = generate_all_genotypes(haplotypes, parameters_.ploidy + 1); + const auto dummy_genotypes = demote_each(latents.cancer_genotypes_); const auto dummy_inferences = germline_model.evaluate(dummy_genotypes, haplotype_likelihoods); - auto noise_model_priors = get_normal_noise_model_priors(germline_model.prior_model()); - const CNVModel noise_model {{normal_sample()}, std::move(noise_model_priors)}; - auto noise_inferences = noise_model.evaluate(latents.germline_genotypes_, haplotype_likelihoods); - return octopus::calculate_model_posterior(normal_inferences.log_evidence, - dummy_inferences.log_evidence, - noise_inferences.approx_log_evidence); + if (latents.noise_model_inferences_) { + return octopus::calculate_model_posterior(normal_inferences.log_evidence, + dummy_inferences.log_evidence, + latents.noise_model_inferences_->approx_log_evidence); + } else { + return octopus::calculate_model_posterior(normal_inferences.log_evidence, + dummy_inferences.log_evidence); + } } else { // TODO return boost::none; @@ -188,7 +267,7 @@ void CancerCaller::generate_germline_genotypes(Latents& latents, const std::vect if (haplotypes.size() < 4) { latents.germline_genotypes_ = generate_all_genotypes(haplotypes, parameters_.ploidy); } else { - std::vector> germline_genotype_indices {}; + std::vector germline_genotype_indices {}; latents.germline_genotypes_ = generate_all_genotypes(haplotypes, parameters_.ploidy, germline_genotype_indices); latents.germline_genotype_indices_ = std::move(germline_genotype_indices); } @@ -207,37 +286,37 @@ auto zip_cref(const std::vector& genotypes, const std::vector return result; } -template -auto extract_greatest_probability_genotypes(const std::vector& genotypes, - const std::vector& probabilities, - const std::size_t n, - const boost::optional min_include_probability = boost::none, - const boost::optional max_exclude_probability = boost::none) +template +auto extract_greatest_probability_values(const std::vector& values, + const std::vector& probabilities, + const std::size_t n, + const boost::optional min_include_probability = boost::none, + const boost::optional max_exclude_probability = boost::none) { - assert(genotypes.size() == probabilities.size()); - if (genotypes.size() <= n) return genotypes; - auto genotype_probabilities = zip_cref(genotypes, probabilities); - auto last_include_itr = std::next(std::begin(genotype_probabilities), n); + assert(values.size() == probabilities.size()); + if (values.size() <= n) return values; + auto value_probabilities = zip_cref(values, probabilities); + auto last_include_itr = std::next(std::begin(value_probabilities), n); const auto probability_greater = [] (const auto& lhs, const auto& rhs) noexcept { return lhs.second > rhs.second; }; - std::partial_sort(std::begin(genotype_probabilities), last_include_itr, std::end(genotype_probabilities), probability_greater); + std::partial_sort(std::begin(value_probabilities), last_include_itr, std::end(value_probabilities), probability_greater); if (min_include_probability) { - last_include_itr = std::upper_bound(std::begin(genotype_probabilities), last_include_itr, *min_include_probability, + last_include_itr = std::upper_bound(std::begin(value_probabilities), last_include_itr, *min_include_probability, [] (auto lhs, const auto& rhs) noexcept { return lhs > rhs.second; }); - if (last_include_itr == std::begin(genotype_probabilities)) ++last_include_itr; + if (last_include_itr == std::begin(value_probabilities)) ++last_include_itr; } if (max_exclude_probability) { - last_include_itr = std::partition(last_include_itr, std::end(genotype_probabilities), + last_include_itr = std::partition(last_include_itr, std::end(value_probabilities), [&] (const auto& p) noexcept { return p.second > *max_exclude_probability; }); } - std::vector result {}; - result.reserve(std::distance(std::begin(genotype_probabilities), last_include_itr)); - std::transform(std::begin(genotype_probabilities), last_include_itr, std::back_inserter(result), + std::vector result {}; + result.reserve(std::distance(std::begin(value_probabilities), last_include_itr)); + std::transform(std::begin(value_probabilities), last_include_itr, std::back_inserter(result), [] (const auto& p) { return p.first.get(); }); return result; } auto extract_greatest_probability_genotypes(const std::vector>& genotypes, - const std::vector>& genotype_indices, + const std::vector& genotype_indices, const std::vector& probabilities, const std::size_t n, const boost::optional min_include_probability = boost::none, @@ -245,15 +324,15 @@ auto extract_greatest_probability_genotypes(const std::vector>; - using GenotypeIndexReference = std::reference_wrapper>; + using GenotypeIndexReference = std::reference_wrapper; std::vector> zipped {}; zipped.reserve(genotypes.size()); std::transform(std::cbegin(genotypes), std::cend(genotypes), std::cbegin(genotype_indices), std::back_inserter(zipped), [] (const auto& g, const auto& g_idx) { return std::make_pair(std::cref(g), std::cref(g_idx)); }); - auto tmp = extract_greatest_probability_genotypes(zipped, probabilities, n, min_include_probability, max_exclude_probability); + auto tmp = extract_greatest_probability_values(zipped, probabilities, n, min_include_probability, max_exclude_probability); std::vector> result_genotypes {}; result_genotypes.reserve(tmp.size()); - std::vector> result_indices {}; + std::vector result_indices {}; result_indices.reserve(tmp.size()); for (const auto& p : tmp) { result_genotypes.push_back(p.first.get()); @@ -270,7 +349,8 @@ void CancerCaller::generate_cancer_genotypes(Latents& latents, const HaplotypeLi const auto num_haplotypes = latents.haplotypes_.get().size(); const auto num_germline_genotypes = germline_genotypes.size(); const auto max_possible_cancer_genotypes = num_haplotypes * num_germline_genotypes; - if (max_possible_cancer_genotypes <= parameters_.max_genotypes) { + const auto max_allowed_cancer_genotypes = std::max(parameters_.max_genotypes, num_germline_genotypes); + if (max_possible_cancer_genotypes <= max_allowed_cancer_genotypes) { generate_cancer_genotypes(latents, latents.germline_genotypes_); } else if (has_normal_sample()) { if (has_high_normal_contamination_risk(latents)) { @@ -283,6 +363,13 @@ void CancerCaller::generate_cancer_genotypes(Latents& latents, const HaplotypeLi } } +auto calculate_max_germline_genotype_bases(const unsigned max_genotypes, const unsigned num_haplotypes, + const unsigned somatic_ploidy) +{ + const auto num_somatic_genotypes = num_genotypes(num_haplotypes, somatic_ploidy); + return max_genotypes / num_somatic_genotypes; +} + void CancerCaller::generate_cancer_genotypes_with_clean_normal(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { const auto& germline_genotypes = latents.germline_genotypes_; @@ -296,47 +383,134 @@ void CancerCaller::generate_cancer_genotypes_with_clean_normal(Latents& latents, latents.normal_germline_inferences_ = latents.germline_model_->evaluate(germline_genotypes, haplotype_likelihoods); } const auto& germline_normal_posteriors = latents.normal_germline_inferences_->posteriors.genotype_probabilities; - const auto max_germline_genotype_bases = parameters_.max_genotypes / latents.haplotypes_.get().size(); + const auto max_allowed_cancer_genotypes = std::max(parameters_.max_genotypes, germline_genotypes.size()); + const auto max_germline_genotype_bases = calculate_max_germline_genotype_bases(max_allowed_cancer_genotypes, latents.haplotypes_.get().size(), latents.somatic_ploidy_); if (latents.germline_genotype_indices_) { std::vector> germline_bases; - std::vector> germline_bases_indices; + std::vector germline_bases_indices; std::tie(germline_bases, germline_bases_indices) = extract_greatest_probability_genotypes(germline_genotypes, *latents.germline_genotype_indices_, germline_normal_posteriors, max_germline_genotype_bases, 1e-100, 1e-2); - std::vector, unsigned>> cancer_genotype_indices {}; + std::vector cancer_genotype_indices {}; latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_bases, germline_bases_indices, - latents.haplotypes_, cancer_genotype_indices); + latents.haplotypes_, cancer_genotype_indices, + latents.somatic_ploidy_); latents.cancer_genotype_indices_ = std::move(cancer_genotype_indices); } else { - auto germline_bases = extract_greatest_probability_genotypes(germline_genotypes, germline_normal_posteriors, - max_germline_genotype_bases, 1e-100, 1e-2); - latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_bases, latents.haplotypes_); + auto germline_bases = extract_greatest_probability_values(germline_genotypes, germline_normal_posteriors, + max_germline_genotype_bases, 1e-100, 1e-2); + latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_bases, latents.haplotypes_, latents.somatic_ploidy_); } } void CancerCaller::generate_cancer_genotypes_with_contaminated_normal(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { // TODO - generate_cancer_genotypes(latents, latents.germline_genotypes_); + generate_cancer_genotypes_with_clean_normal(latents, haplotype_likelihoods); } +namespace { + +struct GenotypeReferenceEqual +{ + using GenotypeReference = std::reference_wrapper>; + std::size_t operator()(const GenotypeReference& lhs, const GenotypeReference& rhs) const + { + return lhs.get() == rhs.get(); + } +}; + +template +BidirIt binary_find(BidirIt first, BidirIt last, const T& value, Compare cmp) +{ + const auto itr = std::lower_bound(first, last, value, std::move(cmp)); + return (itr != last && *itr == value) ? itr : last; +} + +} // namespace + void CancerCaller::generate_cancer_genotypes_with_no_normal(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { - // TODO - generate_cancer_genotypes(latents, latents.germline_genotypes_); + const auto& haplotypes = latents.haplotypes_.get(); + const auto& germline_genotypes = latents.germline_genotypes_; + const auto& germline_genotype_posteriors = latents.germline_model_inferences_.posteriors.genotype_probabilities; + std::vector germline_model_haplotype_posteriors(haplotypes.size()); + if (latents.germline_genotype_indices_) { + GenotypeIndex buffer {}; + for (std::size_t g {0}; g < germline_genotypes.size(); ++g) { + const auto& g_indices = (*latents.germline_genotype_indices_)[g]; + for (auto idx : g_indices) { + if (std::find(std::cbegin(buffer), std::cend(buffer), idx) == std::cend(buffer)) { + germline_model_haplotype_posteriors[idx] += germline_genotype_posteriors[g]; + } + } + buffer.clear(); + } + } else { + std::unordered_map tmp {}; + tmp.reserve(haplotypes.size()); + for (std::size_t g {0}; g < germline_genotypes.size(); ++g) { + for (const auto& haplotype : germline_genotypes[g].copy_unique_ref()) { + tmp[haplotype] += germline_genotype_posteriors[g]; + } + std::transform(std::cbegin(haplotypes), std::cend(haplotypes), std::begin(germline_model_haplotype_posteriors), + [&tmp] (const auto& haplotype) { return tmp.at(haplotype); }); + } + } + const auto max_allowed_cancer_genotypes = std::max(parameters_.max_genotypes, germline_genotypes.size()); + const auto max_germline_genotype_bases = calculate_max_germline_genotype_bases(max_allowed_cancer_genotypes, latents.haplotypes_.get().size(), latents.somatic_ploidy_); + const auto max_germline_haplotype_bases = max_num_elements(max_germline_genotype_bases, parameters_.ploidy); + const auto top_haplotypes = extract_greatest_probability_values(haplotypes, germline_model_haplotype_posteriors, + max_germline_haplotype_bases); + auto germline_bases = generate_all_genotypes(top_haplotypes, parameters_.ploidy); + if (latents.germline_genotype_indices_) { + std::vector germline_bases_indices; + germline_bases_indices.reserve(germline_bases.size()); + if (std::is_sorted(std::cbegin(germline_genotypes), std::cend(germline_genotypes), GenotypeLess {})) { + std::sort(std::begin(germline_bases), std::end(germline_bases), GenotypeLess {}); + auto genotype_itr = std::cbegin(germline_genotypes); + for (const auto& genotype : germline_bases) { + const auto match_itr = binary_find(genotype_itr, std::cend(germline_genotypes), genotype, GenotypeLess {}); + assert(match_itr != std::cend(germline_genotypes)); + const auto idx = std::distance(std::cbegin(germline_genotypes), match_itr); + germline_bases_indices.push_back((*latents.germline_genotype_indices_)[idx]); + genotype_itr = std::next(match_itr); + } + } else { + using GenotypeReference = std::reference_wrapper>; + using GenotypeReferenceIndexMap = std::unordered_map, GenotypeReferenceEqual>; + GenotypeReferenceIndexMap genotype_indices {}; + genotype_indices.reserve(germline_genotypes.size()); + for (std::size_t i {0}; i < germline_genotypes.size(); ++i) { + genotype_indices.emplace(std::cref(germline_genotypes[i]), i); + } + for (const auto& genotype : germline_bases) { + germline_bases_indices.push_back((*latents.germline_genotype_indices_)[genotype_indices.at(genotype)]); + } + } + std::vector cancer_genotype_indices {}; + latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_bases, germline_bases_indices, + latents.haplotypes_, cancer_genotype_indices, + latents.somatic_ploidy_); + latents.cancer_genotype_indices_ = std::move(cancer_genotype_indices); + } else { + latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_bases, haplotypes, latents.somatic_ploidy_); + } } void CancerCaller::generate_cancer_genotypes(Latents& latents, const std::vector>& germline_genotypes) const { if (latents.germline_genotype_indices_) { - std::vector, unsigned>> cancer_genotype_indices {}; + std::vector cancer_genotype_indices {}; latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_genotypes, *latents.germline_genotype_indices_, - latents.haplotypes_, cancer_genotype_indices); + latents.haplotypes_, cancer_genotype_indices, + latents.somatic_ploidy_); latents.cancer_genotype_indices_ = std::move(cancer_genotype_indices); } else { - latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_genotypes, latents.haplotypes_); + latents.cancer_genotypes_ = generate_all_cancer_genotypes(germline_genotypes, latents.haplotypes_, latents.somatic_ploidy_); } } @@ -379,9 +553,10 @@ void CancerCaller::evaluate_tumour_model(Latents& latents, const HaplotypeLikeli assert(latents.germline_prior_model_ && !latents.cancer_genotypes_.empty()); SomaticMutationModel mutation_model {parameters_.somatic_mutation_model_params}; latents.cancer_genotype_prior_model_ = CancerGenotypePriorModel {*latents.germline_prior_model_, std::move(mutation_model)}; - auto somatic_model_priors = get_somatic_model_priors(*latents.cancer_genotype_prior_model_); + auto somatic_model_priors = get_somatic_model_priors(*latents.cancer_genotype_prior_model_, latents.somatic_ploidy_); const TumourModel somatic_model {samples_, somatic_model_priors}; if (latents.cancer_genotype_indices_) { + assert(latents.cancer_genotype_prior_model_->germline_model().is_primed()); latents.cancer_genotype_prior_model_->mutation_model().prime(latents.haplotypes_); latents.tumour_model_inferences_ = somatic_model.evaluate(latents.cancer_genotypes_, *latents.cancer_genotype_indices_, haplotype_likelihoods); @@ -393,20 +568,57 @@ void CancerCaller::evaluate_tumour_model(Latents& latents, const HaplotypeLikeli auto get_high_posterior_genotypes(const std::vector>& genotypes, const model::TumourModel::InferredLatents& latents) { - return extract_greatest_probability_genotypes(genotypes, latents.posteriors.genotype_probabilities, 10, 1e-3); + return extract_greatest_probability_values(genotypes, latents.posteriors.genotype_probabilities, 10, 1e-3); } void CancerCaller::evaluate_noise_model(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { - if (has_normal_sample()) { + if (has_normal_sample() && !has_high_normal_contamination_risk(latents)) { + if (!latents.normal_germline_inferences_) { + assert(latents.germline_model_); + haplotype_likelihoods.prime(normal_sample()); + if (latents.germline_genotype_indices_) { + latents.normal_germline_inferences_ = latents.germline_model_->evaluate(latents.germline_genotypes_, + *latents.germline_genotype_indices_, + haplotype_likelihoods); + } else { + latents.normal_germline_inferences_ = latents.germline_model_->evaluate(latents.germline_genotypes_, + haplotype_likelihoods); + } + } assert(latents.cancer_genotype_prior_model_); - auto noise_model_priors = get_noise_model_priors(*latents.cancer_genotype_prior_model_); - const TumourModel noise_model {samples_, noise_model_priors}; + auto noise_model_priors = get_noise_model_priors(*latents.cancer_genotype_prior_model_, latents.somatic_ploidy_); + const TumourModel noise_model {{*parameters_.normal_sample}, noise_model_priors}; auto noise_genotypes = get_high_posterior_genotypes(latents.cancer_genotypes_, latents.tumour_model_inferences_); latents.noise_model_inferences_ = noise_model.evaluate(noise_genotypes, haplotype_likelihoods); } } +void CancerCaller::set_model_priors(Latents& latents) const +{ + latents.model_priors_ = {.09, 0.01, 0.9}; +} + +void CancerCaller::set_model_posteriors(Latents& latents) const +{ + const auto& germline_inferences = latents.germline_model_inferences_; + const auto& cnv_inferences = latents.cnv_model_inferences_; + const auto& somatic_inferences = latents.tumour_model_inferences_; + const auto& model_priors = latents.model_priors_; + if (debug_log_) { + stream(*debug_log_) << "Germline model evidence: " << germline_inferences.log_evidence; + stream(*debug_log_) << "CNV model evidence: " << cnv_inferences.approx_log_evidence; + stream(*debug_log_) << "Somatic model evidence: " << somatic_inferences.approx_log_evidence; + } + const auto germline_model_jlp = std::log(model_priors.germline) + germline_inferences.log_evidence; + const auto cnv_model_jlp = std::log(model_priors.cnv) + cnv_inferences.approx_log_evidence; + const auto somatic_model_jlp = std::log(model_priors.somatic) + somatic_inferences.approx_log_evidence; + const auto norm = maths::log_sum_exp(germline_model_jlp, cnv_model_jlp, somatic_model_jlp); + latents.model_posteriors_.germline = std::exp(germline_model_jlp - norm); + latents.model_posteriors_.cnv = std::exp(cnv_model_jlp - norm); + latents.model_posteriors_.somatic = std::exp(somatic_model_jlp - norm); +} + CancerCaller::CNVModel::Priors CancerCaller::get_cnv_model_priors(const GenotypePriorModel& prior_model) const { @@ -425,34 +637,40 @@ CancerCaller::get_cnv_model_priors(const GenotypePriorModel& prior_model) const return Priors {prior_model, std::move(cnv_alphas)}; } +auto make_dirichlet_alphas(unsigned n_germline, double germline, unsigned n_somatic, double somatic) +{ + model::TumourModel::Priors::GenotypeMixturesDirichletAlphas result(n_germline + n_somatic); + std::fill_n(std::begin(result), n_germline, germline); + std::fill_n(std::rbegin(result), n_somatic, somatic); + return result; +} + CancerCaller::TumourModel::Priors -CancerCaller::get_somatic_model_priors(const CancerGenotypePriorModel& prior_model) const +CancerCaller::get_somatic_model_priors(const CancerGenotypePriorModel& prior_model, const unsigned somatic_ploidy) const { using Priors = TumourModel::Priors; Priors::GenotypeMixturesDirichletAlphaMap alphas {}; alphas.reserve(samples_.size()); for (const auto& sample : samples_) { if (has_normal_sample() && sample == normal_sample()) { - Priors::GenotypeMixturesDirichletAlphas sample_alphas(parameters_.ploidy + 1, parameters_.somatic_normal_germline_alpha); - sample_alphas.back() = parameters_.somatic_normal_somatic_alpha; - alphas.emplace(sample, std::move(sample_alphas)); + alphas.emplace(sample, make_dirichlet_alphas(parameters_.ploidy, parameters_.somatic_normal_germline_alpha, + somatic_ploidy, parameters_.somatic_normal_somatic_alpha)); } else { - Priors::GenotypeMixturesDirichletAlphas sample_alphas(parameters_.ploidy + 1, parameters_.somatic_tumour_germline_alpha); - sample_alphas.back() = parameters_.somatic_tumour_somatic_alpha; - alphas.emplace(sample, std::move(sample_alphas)); + alphas.emplace(sample, make_dirichlet_alphas(parameters_.ploidy, parameters_.somatic_tumour_germline_alpha, + somatic_ploidy, parameters_.somatic_tumour_somatic_alpha)); } } return Priors {prior_model, std::move(alphas)}; } CancerCaller::TumourModel::Priors -CancerCaller::get_noise_model_priors(const CancerGenotypePriorModel& prior_model) const +CancerCaller::get_noise_model_priors(const CancerGenotypePriorModel& prior_model, const unsigned somatic_ploidy) const { // The noise model is intended to capture noise that may also be present in the normal sample, // hence all samples have the same prior alphas. using Priors = TumourModel::Priors; - Priors::GenotypeMixturesDirichletAlphas noise_alphas(parameters_.ploidy + 1, parameters_.somatic_tumour_germline_alpha); - noise_alphas.back() = parameters_.somatic_tumour_somatic_alpha; + auto noise_alphas = make_dirichlet_alphas(parameters_.ploidy, parameters_.somatic_normal_germline_alpha, + somatic_ploidy, parameters_.somatic_tumour_somatic_alpha); Priors::GenotypeMixturesDirichletAlphaMap alphas {}; alphas.reserve(samples_.size()); for (const auto& sample : samples_) { @@ -485,6 +703,66 @@ namespace { using VariantReference = std::reference_wrapper; using VariantPosteriorVector = std::vector>>; +auto compute_marginal_credible_interval(const model::TumourModel::Priors::GenotypeMixturesDirichletAlphas& alphas, + const std::size_t k, const double mass) +{ + const auto a0 = std::accumulate(std::cbegin(alphas), std::cend(alphas), 0.0); + return maths::beta_hdi(alphas[k], a0 - alphas[k], mass); +} + +auto compute_marginal_credible_intervals(const model::TumourModel::Priors::GenotypeMixturesDirichletAlphas& alphas, + const double mass) +{ + const auto a0 = std::accumulate(std::cbegin(alphas), std::cend(alphas), 0.0); + std::vector> result {}; + result.reserve(alphas.size()); + for (const auto& alpha : alphas) { + result.push_back(maths::beta_hdi(alpha, a0 - alpha, mass)); + } + return result; +} + +using CredibleRegionMap = std::unordered_map>>; + +auto compute_marginal_credible_intervals(const model::TumourModel::Priors::GenotypeMixturesDirichletAlphaMap& alphas, + const double mass) +{ + CredibleRegionMap result {}; + result.reserve(alphas.size()); + for (const auto& p : alphas) { + result.emplace(p.first, compute_marginal_credible_intervals(p.second, mass)); + } + return result; +} + +auto compute_credible_somatic_mass(const model::TumourModel::Priors::GenotypeMixturesDirichletAlphas& alphas, + const unsigned somatic_ploidy, const double min_credible_somatic_frequency) +{ + if (somatic_ploidy == 1) { + return maths::dirichlet_marginal_sf(alphas, alphas.size() - 1, min_credible_somatic_frequency); + } else { + double inv_result {1.0}; + for (unsigned i {1}; i <= somatic_ploidy; ++i) { + inv_result *= maths::dirichlet_marginal_cdf(alphas, alphas.size() - i, min_credible_somatic_frequency); + } + return 1.0 - inv_result; + } +} + +auto compute_credible_somatic_mass(const model::TumourModel::Priors::GenotypeMixturesDirichletAlphaMap& alphas, + const unsigned somatic_ploidy, const double min_credible_somatic_frequency) +{ + double inv_result {1.0}; + for (const auto& p : alphas) { + if (somatic_ploidy == 1) { + inv_result *= maths::dirichlet_marginal_cdf(p.second, p.second.size() - 1, min_credible_somatic_frequency); + } else { + inv_result *= 1.0 - compute_credible_somatic_mass(p.second, somatic_ploidy, min_credible_somatic_frequency); + } + } + return 1.0 - inv_result; +} + struct GermlineVariantCall : Mappable { GermlineVariantCall() = delete; @@ -536,8 +814,7 @@ struct GermlineGenotypeCall , posterior {posterior} {} - Genotype genotype; - boost::optional somatic; + Genotype genotype, somatic; Phred posterior; }; @@ -551,7 +828,7 @@ struct CancerGenotypeCall CancerGenotype genotype; Phred posterior; - std::unordered_map>> credible_regions; + CredibleRegionMap credible_regions; }; using CancerGenotypeCalls = std::vector; @@ -603,10 +880,8 @@ auto call_candidates(const VariantPosteriorVector& candidate_posteriors, calls.reserve(candidate_posteriors.size()); std::vector uncalled {}; for (const auto& p : candidate_posteriors) { - if (p.second >= min_posterior) { - if (contains_alt(genotype_call, p.first)) { - calls.emplace_back(p.first, p.second); - } + if (p.second >= min_posterior && contains_alt(genotype_call, p.first)) { + calls.emplace_back(p.first, p.second); } else { uncalled.emplace_back(p.first); } @@ -616,6 +891,11 @@ auto call_candidates(const VariantPosteriorVector& candidate_posteriors, // somatic variant posterior +bool is_somatic(const Allele& allele, const CancerGenotype& genotype) +{ + return contains(genotype.somatic(), allele) && !contains(genotype.germline(), allele); +} + auto compute_somatic_variant_posteriors(const std::vector& candidates, const std::vector>& cancer_genotypes, const std::vector& cancer_genotype_posteriors, @@ -624,14 +904,12 @@ auto compute_somatic_variant_posteriors(const std::vector& can { VariantPosteriorVector result {}; result.reserve(candidates.size()); - for (const auto& candidate : candidates) { const auto& allele = candidate.get().alt_allele(); const auto p = std::inner_product(std::cbegin(cancer_genotypes), std::cend(cancer_genotypes), std::cbegin(cancer_genotype_posteriors), 0.0, std::plus<> {}, [&allele] (const auto& genotype, auto posterior) { - if (genotype.somatic_element().contains(allele) - && !contains(genotype.germline_genotype(), allele)) { + if (is_somatic(allele, genotype)) { return posterior; } else { return 0.0; @@ -640,7 +918,6 @@ auto compute_somatic_variant_posteriors(const std::vector& can const auto complement = std::min(somatic_model_posterior * p * somatic_posterior, 1.0); result.emplace_back(candidate, probability_to_phred(1.0 - complement)); } - return result; } @@ -657,44 +934,11 @@ auto call_somatic_variants(const VariantPosteriorVector& somatic_variant_posteri return result; } -template -auto compute_marginal_credible_interval(const T& alphas, const double mass) -{ - const auto a0 = std::accumulate(std::cbegin(alphas), std::cend(alphas), 0.0); - std::vector> result {}; - result.reserve(alphas.size()); - for (const auto& alpha : alphas) { - result.push_back(maths::beta_hdi(alpha, a0 - alpha, mass)); - } - return result; -} - -using CredibleRegionMap = std::unordered_map>>; - -template -auto compute_marginal_credible_intervals(const M& alphas, const double mass) -{ - CredibleRegionMap result {}; - result.reserve(alphas.size()); - for (const auto& p : alphas) { - result.emplace(p.first, compute_marginal_credible_interval(p.second, mass)); - } - return result; -} - -template -auto compute_somatic_mass(const T& alphas, const double c = 0.05) -{ - const auto a0 = std::accumulate(std::cbegin(alphas), std::cend(alphas), 0.0); - return maths::beta_cdf_complement(alphas.back(), a0 - alphas.back(), c); -} - -template auto call_somatic_genotypes(const CancerGenotype& called_genotype, const std::vector& called_somatic_regions, const std::vector>& genotypes, const std::vector& genotype_posteriors, - const T& credible_regions) + const CredibleRegionMap& credible_regions) { CancerGenotypeCalls result {}; result.reserve(called_somatic_regions.size()); @@ -713,29 +957,33 @@ auto call_somatic_genotypes(const CancerGenotype& called_genotype, // output -octopus::VariantCall::GenotypeCall convert(GermlineGenotypeCall call) +octopus::VariantCall::GenotypeCall demote(GermlineGenotypeCall call) { return octopus::VariantCall::GenotypeCall {std::move(call.genotype), call.posterior}; } std::unique_ptr transform_germline_call(GermlineVariantCall&& variant_call, GermlineGenotypeCall&& genotype_call, - const std::vector& samples, - const std::vector& somatic_samples) + const std::vector& samples, const std::vector& somatic_samples) { std::vector> genotypes {}; for (const auto& sample : samples) { if (std::find(std::cbegin(somatic_samples), std::cend(somatic_samples), sample) == std::cend(somatic_samples)) { - genotypes.emplace_back(sample, convert(genotype_call)); + genotypes.emplace_back(sample, demote(genotype_call)); } else { auto copy = genotype_call; - copy.genotype.emplace(*copy.somatic); - genotypes.emplace_back(sample, convert(std::move(copy))); + for (auto allele : genotype_call.somatic) copy.genotype.emplace(allele); + genotypes.emplace_back(sample, demote(std::move(copy))); } } - return std::make_unique(variant_call.variant.get(), - std::move(genotypes), - variant_call.posterior); + return std::make_unique(variant_call.variant.get(), std::move(genotypes), variant_call.posterior); +} + +template +auto find_index(const Container& values, const T& value) +{ + const auto itr = std::find(std::cbegin(values), std::cend(values), value); + return itr != std::cend(values) ? std::distance(std::cbegin(values), itr) : -1; } auto transform_somatic_calls(SomaticVariantCalls&& somatic_calls, CancerGenotypeCalls&& genotype_calls, @@ -743,26 +991,24 @@ auto transform_somatic_calls(SomaticVariantCalls&& somatic_calls, CancerGenotype { std::vector> result {}; result.reserve(somatic_calls.size()); - std::transform(std::make_move_iterator(std::begin(somatic_calls)), - std::make_move_iterator(std::end(somatic_calls)), - std::make_move_iterator(std::begin(genotype_calls)), - std::back_inserter(result), + std::transform(std::make_move_iterator(std::begin(somatic_calls)), std::make_move_iterator(std::end(somatic_calls)), + std::make_move_iterator(std::begin(genotype_calls)), std::back_inserter(result), [&somatic_samples] (auto&& variant_call, auto&& genotype_call) -> std::unique_ptr { std::unordered_map credible_regions {}; + const auto germline_ploidy = genotype_call.genotype.germline_ploidy(); for (const auto& p : genotype_call.credible_regions) { SomaticCall::GenotypeCredibleRegions sample_credible_regions {}; - sample_credible_regions.germline.reserve(p.second.size() - 1); + sample_credible_regions.germline.reserve(germline_ploidy); std::copy(std::cbegin(p.second), std::prev(std::cend(p.second)), std::back_inserter(sample_credible_regions.germline)); if (std::find(std::cbegin(somatic_samples), std::cend(somatic_samples), p.first) != std::cend(somatic_samples)) { - sample_credible_regions.somatic = p.second.back(); + auto somatic_idx = find_index(genotype_call.genotype.somatic(), variant_call.variant.get().alt_allele()); + sample_credible_regions.somatic = p.second[germline_ploidy + somatic_idx]; } credible_regions.emplace(p.first, std::move(sample_credible_regions)); } - return std::make_unique(variant_call.variant.get(), - std::move(genotype_call.genotype), - genotype_call.posterior, - std::move(credible_regions), + return std::make_unique(variant_call.variant.get(), std::move(genotype_call.genotype), + genotype_call.posterior, std::move(credible_regions), variant_call.posterior); }); return result; @@ -770,51 +1016,37 @@ auto transform_somatic_calls(SomaticVariantCalls&& somatic_calls, CancerGenotype } // namespace +namespace debug { + +template +void print_variants(S&& stream, const std::vector& variants) +{ + for (const auto& v : variants) stream << v.variant << " " << v.posterior << '\n'; +} + +} // namespace debug + std::vector> CancerCaller::call_variants(const std::vector& candidates, const Latents& latents) const { // TODO: refactor this into smaller methods! - const auto model_posteriors = calculate_model_posteriors(latents); - if (debug_log_) { - stream(*debug_log_) << "Germline model posterior: " << model_posteriors.germline; - stream(*debug_log_) << "CNV model posterior: " << model_posteriors.cnv; - stream(*debug_log_) << "Somatic model posterior: " << model_posteriors.somatic; - } - const auto sample_somatic_inv_posteriors = calculate_probability_samples_not_somatic(latents); - const auto somatic_posterior = calculate_somatic_probability(sample_somatic_inv_posteriors, model_posteriors); - const auto germline_genotype_posteriors = calculate_germline_genotype_posteriors(latents, model_posteriors); + const auto& model_posteriors = latents.model_posteriors_; + log(model_posteriors); + const auto somatic_posterior = calculate_somatic_probability(latents); + const auto germline_genotype_posteriors = calculate_germline_genotype_posteriors(latents); const auto& cancer_genotype_posteriors = latents.tumour_model_inferences_.posteriors.genotype_probabilities; + log(latents.germline_genotypes_, germline_genotype_posteriors, latents.cnv_model_inferences_, + latents.cancer_genotypes_, latents.tumour_model_inferences_); + const auto germline_candidate_posteriors = compute_candidate_posteriors(candidates, germline_genotype_posteriors); boost::optional> called_germline_genotype {}; boost::optional> called_cancer_genotype {}; - if (debug_log_) { - auto map_germline = find_map_genotype(germline_genotype_posteriors); - auto germline_log = stream(*debug_log_); - germline_log << "MAP germline genotype: "; - debug::print_variant_alleles(germline_log, map_germline->first); - const auto& cnv_model_genotype_posteriors = latents.cnv_model_inferences_.posteriors.genotype_probabilities; - auto cnv_posteriors = zip_cref(latents.germline_genotypes_, cnv_model_genotype_posteriors); - auto map_cnv = find_map_genotype(cnv_posteriors); - auto cnv_log = stream(*debug_log_); - cnv_log << "MAP CNV genotype: "; - debug::print_variant_alleles(cnv_log, map_cnv->first); - auto somatic_log = stream(*debug_log_); - auto cancer_posteriors = zip_cref(latents.cancer_genotypes_, cancer_genotype_posteriors); - auto map_somatic = find_map_genotype(cancer_posteriors); - called_cancer_genotype = map_somatic->first.get(); - somatic_log << "MAP cancer genotype: "; - debug::print_variant_alleles(somatic_log, *called_cancer_genotype); - somatic_log << ' ' << map_somatic->second; - } - const auto germline_candidate_posteriors = compute_candidate_posteriors(candidates, germline_genotype_posteriors); if (model_posteriors.somatic > model_posteriors.germline && somatic_posterior >= parameters_.min_somatic_posterior) { - if (debug_log_) { - *debug_log_ << "Using cancer genotype for germline genotype call"; - } + if (debug_log_) *debug_log_ << "Using cancer genotype for germline genotype call"; if (!called_cancer_genotype) { auto cancer_posteriors = zip_cref(latents.cancer_genotypes_, cancer_genotype_posteriors); called_cancer_genotype = find_map_genotype(cancer_posteriors)->first; } - called_germline_genotype = called_cancer_genotype->germline_genotype(); + called_germline_genotype = called_cancer_genotype->germline(); } else { called_germline_genotype = find_map_genotype(germline_genotype_posteriors)->first; } @@ -824,7 +1056,7 @@ CancerCaller::call_variants(const std::vector& candidates, const Latent *called_germline_genotype, parameters_.min_variant_posterior); std::vector> result {}; - boost::optional called_somatic_haplotype {}; + Genotype called_somatic_genotype {}; std::vector somatic_samples {}; if (somatic_posterior >= parameters_.min_somatic_posterior) { auto somatic_allele_posteriors = compute_somatic_variant_posteriors(uncalled_germline_candidates, @@ -836,7 +1068,7 @@ CancerCaller::call_variants(const std::vector& candidates, const Latent auto cancer_posteriors = zip_cref(latents.cancer_genotypes_, cancer_genotype_posteriors); called_cancer_genotype = find_map_genotype(cancer_posteriors)->first.get(); } - if (called_cancer_genotype->germline_genotype() == called_germline_genotype) { + if (called_cancer_genotype->germline() == called_germline_genotype) { auto somatic_variant_calls = call_somatic_variants(somatic_allele_posteriors, *called_cancer_genotype, parameters_.min_somatic_posterior); const auto& somatic_alphas = latents.tumour_model_inferences_.posteriors.alphas; @@ -851,28 +1083,34 @@ CancerCaller::call_variants(const std::vector& candidates, const Latent somatic_samples.push_back(p.first); } } - if (has_normal_sample() && latents.noise_model_inferences_) { - // Does the normal sample contain the called somatic variant? - const auto& noisy_alphas = latents.noise_model_inferences_->posteriors.alphas.at(normal_sample()); - const auto noise_credible_region = compute_marginal_credible_interval(noisy_alphas, parameters_.credible_mass).back(); - const auto somatic_mass = compute_somatic_mass(noisy_alphas, parameters_.min_expected_somatic_frequency); - if (noise_credible_region.first >= parameters_.min_credible_somatic_frequency || somatic_mass > 0.5) { - somatic_samples.clear(); + if (latents.noise_model_inferences_ && latents.normal_germline_inferences_) { + const auto noise_model_evidence = latents.noise_model_inferences_->approx_log_evidence; + const auto germline_model_evidence = latents.normal_germline_inferences_->log_evidence; + if (noise_model_evidence > germline_model_evidence) { + // Does the normal sample contain the called somatic variant? + const auto& noisy_alphas = latents.noise_model_inferences_->posteriors.alphas.at(normal_sample()); + const auto noise_mass = compute_credible_somatic_mass(noisy_alphas, latents.somatic_ploidy_, parameters_.min_expected_somatic_frequency); + if (noise_mass > 2 * parameters_.min_credible_somatic_frequency) { + somatic_samples.clear(); + } } } if (somatic_samples.empty()) { somatic_variant_calls.clear(); somatic_variant_calls.shrink_to_fit(); } else { - called_somatic_haplotype = called_cancer_genotype->somatic_element(); + called_somatic_genotype = called_cancer_genotype->somatic(); } } + if (debug_log_) { + *debug_log_ << "Called somatic variants:"; + debug::print_variants(stream(*debug_log_), somatic_variant_calls); + } const auto called_somatic_regions = extract_regions(somatic_variant_calls); auto cancer_genotype_calls = call_somatic_genotypes(*called_cancer_genotype, called_somatic_regions, latents.cancer_genotypes_, cancer_genotype_posteriors, credible_regions); - result = transform_somatic_calls(std::move(somatic_variant_calls), std::move(cancer_genotype_calls), - somatic_samples); + result = transform_somatic_calls(std::move(somatic_variant_calls), std::move(cancer_genotype_calls), somatic_samples); } else if (debug_log_) { stream(*debug_log_) << "Conflict between called germline genotype and called cancer genotype. Not calling somatics"; } @@ -882,20 +1120,24 @@ CancerCaller::call_variants(const std::vector& candidates, const Latent GermlineGenotypeCalls germline_genotype_calls {}; germline_genotype_calls.reserve(called_germline_regions.size()); for (const auto& region : called_germline_regions) { - auto genotype_chunk = copy(called_germline_genotype, region); + auto genotype_chunk = copy(*called_germline_genotype, region); const auto inv_posterior = std::accumulate(std::cbegin(germline_genotype_posteriors), std::cend(germline_genotype_posteriors), 0.0, [&called_germline_genotype] (const double curr, const auto& p) { return curr + (contains(p.first, *called_germline_genotype) ? 0.0 : p.second); }); - if (called_somatic_haplotype) { + if (called_somatic_genotype.ploidy() > 0) { germline_genotype_calls.emplace_back(std::move(genotype_chunk), - copy(*called_somatic_haplotype, region), + copy(called_somatic_genotype, region), probability_to_phred(inv_posterior)); } else { germline_genotype_calls.emplace_back(std::move(genotype_chunk), probability_to_phred(inv_posterior)); } } + if (debug_log_) { + *debug_log_ << "Called germline variants:"; + debug::print_variants(stream(*debug_log_), germline_variant_calls); + } result.reserve(result.size() + germline_variant_calls.size()); const auto itr = std::end(result); std::transform(std::make_move_iterator(std::begin(germline_variant_calls)), @@ -911,36 +1153,12 @@ CancerCaller::call_variants(const std::vector& candidates, const Latent return result; } -CancerCaller::ModelPriors CancerCaller::get_model_priors() const -{ - const auto s = parameters_.germline_weight + parameters_.cnv_weight + parameters_.somatic_weight; - return {parameters_.germline_weight / s, parameters_.cnv_weight / s, parameters_.somatic_weight / s}; -} - -CancerCaller::ModelPosteriors -CancerCaller::calculate_model_posteriors(const Latents& latents) const -{ - const auto& germline_inferences = latents.germline_model_inferences_; - const auto& cnv_inferences = latents.cnv_model_inferences_; - const auto& somatic_inferences = latents.tumour_model_inferences_; - const auto& model_priors = latents.model_priors_; - const auto germline_model_jlp = std::log(model_priors.germline) + germline_inferences.log_evidence; - const auto cnv_model_jlp = std::log(model_priors.cnv) + cnv_inferences.approx_log_evidence; - const auto somatic_model_jlp = std::log(model_priors.somatic) + somatic_inferences.approx_log_evidence; - const auto norm = maths::log_sum_exp(germline_model_jlp, cnv_model_jlp, somatic_model_jlp); - auto germline_model_posterior = std::exp(germline_model_jlp - norm); - auto cnv_model_posterior = std::exp(cnv_model_jlp - norm); - auto somatic_model_posterior = std::exp(somatic_model_jlp - norm); - return {germline_model_posterior, cnv_model_posterior, somatic_model_posterior}; -} - CancerCaller::GermlineGenotypeProbabilityMap -CancerCaller::calculate_germline_genotype_posteriors(const Latents& latents, - const ModelPosteriors& model_posteriors) const +CancerCaller::calculate_germline_genotype_posteriors(const Latents& latents) const { + const auto& model_posteriors = latents.model_posteriors_; const auto& germline_genotypes = latents.germline_genotypes_; GermlineGenotypeProbabilityMap result {germline_genotypes.size()}; - std::transform(std::cbegin(germline_genotypes), std::cend(germline_genotypes), std::cbegin(latents.germline_model_inferences_.posteriors.genotype_probabilities), std::inserter(result, std::begin(result)), @@ -954,39 +1172,21 @@ CancerCaller::calculate_germline_genotype_posteriors(const Latents& latents, const auto& cancer_genotypes = latents.cancer_genotypes_; const auto& tumour_posteriors = latents.tumour_model_inferences_.posteriors.genotype_probabilities; for (std::size_t i {0}; i < cancer_genotypes.size(); ++i) { - result[cancer_genotypes[i].germline_genotype()] += model_posteriors.somatic * tumour_posteriors[i]; + result[cancer_genotypes[i].germline()] += model_posteriors.somatic * tumour_posteriors[i]; } - - return result; -} - -CancerCaller::ProbabilityVector -CancerCaller::calculate_probability_samples_not_somatic(const Latents& inferences) const -{ - std::vector result(samples_.size()); - const auto& posterior_alphas = inferences.tumour_model_inferences_.posteriors.alphas; - std::transform(std::cbegin(posterior_alphas), std::cend(posterior_alphas), - std::begin(result), [this] (const auto& p) { - const auto a0 = std::accumulate(std::cbegin(p.second), std::prev(std::cend(p.second)), 0.0); - return maths::beta_cdf(p.second.back(), a0, parameters_.min_expected_somatic_frequency); - }); return result; } -Phred CancerCaller::calculate_somatic_probability(const ProbabilityVector& sample_somatic_posteriors, - const ModelPosteriors& model_posteriors) const +Phred CancerCaller::calculate_somatic_probability(const CancerCaller::Latents& latents) const { - auto result = 1.0 - std::accumulate(std::cbegin(sample_somatic_posteriors), - std::cend(sample_somatic_posteriors), - 1.0, std::multiplies<> {}); - result *= model_posteriors.somatic; - return probability_to_phred(1 - result); + auto conditional_somatic_mass = compute_credible_somatic_mass(latents.tumour_model_inferences_.posteriors.alphas, + latents.somatic_ploidy_, + parameters_.min_expected_somatic_frequency); + return probability_to_phred(1.0 - latents.model_posteriors_.somatic * conditional_somatic_mass); } std::vector> -CancerCaller::call_reference(const std::vector& alleles, - const Caller::Latents& latents, - const ReadMap& reads) const +CancerCaller::call_reference(const std::vector& alleles, const Caller::Latents& latents, const ReadPileupMap& pileups) const { return {}; } @@ -1006,11 +1206,9 @@ std::unique_ptr CancerCaller::make_germline_prior_model(cons // CancerCaller::Latents CancerCaller::Latents::Latents(const std::vector& haplotypes, - const std::vector& samples, - CancerCaller::ModelPriors model_priors) + const std::vector& samples) : haplotypes_ {haplotypes} , samples_ {samples} -, model_priors_ {model_priors} {} std::shared_ptr @@ -1051,47 +1249,74 @@ auto zip(const T&... containers) -> boost::iterator_range().copy_unique_ref()) { - result.at(haplotype) += p.get<1>(); + result.at(haplotype) += model_posteriors_.germline * p.get<1>(); } } // Contribution from CNV model - Latents::HaplotypeProbabilityMap cnv_result {haplotypes_.get().size()}; - for (const auto& haplotype : haplotypes_.get()) { - cnv_result.emplace(haplotype, 0.0); - } for (const auto& p : zip(germline_genotypes_, cnv_model_inferences_.posteriors.genotype_probabilities)) { for (const auto& haplotype : p.get<0>().copy_unique_ref()) { - cnv_result.at(haplotype) += p.get<1>(); + result.at(haplotype) += model_posteriors_.cnv * p.get<1>(); } } + const auto conditional_somatic_prob = compute_credible_somatic_mass(tumour_model_inferences_.posteriors.alphas, somatic_ploidy_, 0.1); // Contribution from tumour model - Latents::HaplotypeProbabilityMap somatic_result {haplotypes_.get().size()}; - for (const auto& haplotype : haplotypes_.get()) { - somatic_result.emplace(haplotype, 0.0); - } for (const auto& p : zip(cancer_genotypes_, tumour_model_inferences_.posteriors.genotype_probabilities)) { - for (const auto& haplotype : p.get<0>().germline_genotype().copy_unique_ref()) { - somatic_result.at(haplotype) += p.get<1>(); + for (const auto& haplotype : p.get<0>().germline().copy_unique_ref()) { + result.at(haplotype) += model_posteriors_.somatic * p.get<1>(); + } + for (const auto& haplotype : p.get<0>().somatic().copy_unique_ref()) { + result.at(haplotype) += model_posteriors_.somatic * conditional_somatic_prob * p.get<1>(); } - somatic_result.at(p.get<0>().somatic_element()) += p.get<1>(); - } - for (auto& p : result) { - p.second *= model_priors_.germline; - p.second += model_priors_.cnv * cnv_result.at(p.first); - p.second += model_priors_.somatic * somatic_result.at(p.first); } haplotype_posteriors_ = std::make_shared(std::move(result)); } +// logging + +void CancerCaller::log(const ModelPosteriors& model_posteriors) const +{ + if (debug_log_) { + stream(*debug_log_) << "Germline model posterior: " << model_posteriors.germline; + stream(*debug_log_) << "CNV model posterior: " << model_posteriors.cnv; + stream(*debug_log_) << "Somatic model posterior: " << model_posteriors.somatic; + } +} + +void CancerCaller::log(const GenotypeVector& germline_genotypes, + const GermlineGenotypeProbabilityMap& germline_genotype_posteriors, + const CNVModel::InferredLatents& cnv_inferences, + const CancerGenotypeVector& cancer_genotypes, + const TumourModel::InferredLatents& tumour_inferences) const +{ + if (debug_log_) { + auto map_germline = find_map_genotype(germline_genotype_posteriors); + auto germline_log = stream(*debug_log_); + germline_log << "MAP germline genotype: "; + debug::print_variant_alleles(germline_log, map_germline->first); + auto cnv_posteriors = zip_cref(germline_genotypes, cnv_inferences.posteriors.genotype_probabilities); + auto map_cnv = find_map_genotype(cnv_posteriors); + auto cnv_log = stream(*debug_log_); + cnv_log << "MAP CNV genotype: "; + debug::print_variant_alleles(cnv_log, map_cnv->first); + auto somatic_log = stream(*debug_log_); + auto cancer_posteriors = zip_cref(cancer_genotypes, tumour_inferences.posteriors.genotype_probabilities); + auto map_somatic = find_map_genotype(cancer_posteriors); + auto map_cancer_genotype = map_somatic->first.get(); + somatic_log << "MAP cancer genotype: "; + debug::print_variant_alleles(somatic_log, map_cancer_genotype); + somatic_log << ' ' << map_somatic->second; + } +} + } // namespace octopus diff --git a/src/core/callers/cancer_caller.hpp b/src/core/callers/cancer_caller.hpp index 206a1ec06..b6e789bd0 100644 --- a/src/core/callers/cancer_caller.hpp +++ b/src/core/callers/cancer_caller.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef cancer_caller_hpp @@ -20,7 +20,7 @@ #include "core/models/mutation/coalescent_model.hpp" #include "core/models/mutation/somatic_mutation_model.hpp" #include "core/models/genotype/individual_model.hpp" -#include "core/models/genotype/cnv_model.hpp" +#include "core/models/genotype/subclone_model.hpp" #include "core/models/genotype/tumour_model.hpp" #include "basics/phred.hpp" #include "caller.hpp" @@ -46,12 +46,13 @@ class CancerCaller : public Caller boost::optional germline_prior_model_params; SomaticMutationModel::Parameters somatic_mutation_model_params; double min_expected_somatic_frequency, credible_mass, min_credible_somatic_frequency; - unsigned max_genotypes = 20000; + std::size_t max_genotypes = 20000; + unsigned max_somatic_haplotypes = 1; NormalContaminationRisk normal_contamination_risk = NormalContaminationRisk::low; - double cnv_normal_alpha = 10.0, cnv_tumour_alpha = 0.75; - double somatic_normal_germline_alpha = 10.0, somatic_normal_somatic_alpha = 0.08; - double somatic_tumour_germline_alpha = 1.0, somatic_tumour_somatic_alpha = 0.8; - double germline_weight = 70, cnv_weight = 3, somatic_weight = 2; + bool deduplicate_haplotypes_with_germline_model = true; + double cnv_normal_alpha = 50.0, cnv_tumour_alpha = 0.5; + double somatic_normal_germline_alpha = 50.0, somatic_normal_somatic_alpha = 0.05; + double somatic_tumour_germline_alpha = 1.5, somatic_tumour_somatic_alpha = 1.0; }; CancerCaller() = delete; @@ -69,7 +70,7 @@ class CancerCaller : public Caller private: using GermlineModel = model::IndividualModel; - using CNVModel = model::CNVModel; + using CNVModel = model::SubcloneModel; using TumourModel = model::TumourModel; class Latents; @@ -89,6 +90,10 @@ class CancerCaller : public Caller std::string do_name() const override; CallTypeSet do_call_types() const override; + unsigned do_min_callable_ploidy() const override; + unsigned do_max_callable_ploidy() const override; + + std::size_t do_remove_duplicates(std::vector& haplotypes) const override; std::unique_ptr infer_latents(const std::vector& haplotypes, @@ -112,7 +117,7 @@ class CancerCaller : public Caller std::vector> call_reference(const std::vector& alleles, const Caller::Latents& latents, - const ReadMap& reads) const override; + const ReadPileupMap& pileups) const override; bool has_normal_sample() const noexcept; const SampleName& normal_sample() const; @@ -136,19 +141,27 @@ class CancerCaller : public Caller void evaluate_tumour_model(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; void evaluate_noise_model(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; + void set_model_priors(Latents& latents) const; + void set_model_posteriors(Latents& latents) const; + + void fit_tumour_model(Latents& latents, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; + std::unique_ptr make_germline_prior_model(const std::vector& haplotypes) const; CNVModel::Priors get_cnv_model_priors(const GenotypePriorModel& prior_model) const; - TumourModel::Priors get_somatic_model_priors(const CancerGenotypePriorModel& prior_model) const; - TumourModel::Priors get_noise_model_priors(const CancerGenotypePriorModel& prior_model) const; + TumourModel::Priors get_somatic_model_priors(const CancerGenotypePriorModel& prior_model, unsigned somatic_ploidy) const; + TumourModel::Priors get_noise_model_priors(const CancerGenotypePriorModel& prior_model, unsigned somatic_ploidy) const; CNVModel::Priors get_normal_noise_model_priors(const GenotypePriorModel& prior_model) const; - ModelPriors get_model_priors() const; - ModelPosteriors calculate_model_posteriors(const Latents& latents) const; - - GermlineGenotypeProbabilityMap - calculate_germline_genotype_posteriors(const Latents& latents, const ModelPosteriors& model_posteriors) const; - ProbabilityVector calculate_probability_samples_not_somatic(const Latents& inferences) const; - Phred calculate_somatic_probability(const ProbabilityVector& sample_somatic_posteriors, - const ModelPosteriors& model_posteriors) const; + + GermlineGenotypeProbabilityMap calculate_germline_genotype_posteriors(const Latents& latents) const; + Phred calculate_somatic_probability(const Latents& latents) const; + + // logging + void log(const ModelPosteriors& model_posteriors) const; + void log(const GenotypeVector& germline_genotypes, + const GermlineGenotypeProbabilityMap& germline_genotype_posteriors, + const CNVModel::InferredLatents& cnv_inferences, + const CancerGenotypeVector& cancer_genotypes, + const TumourModel::InferredLatents& tumour_inferences) const; }; class CancerCaller::Latents : public Caller::Latents @@ -160,8 +173,7 @@ class CancerCaller::Latents : public Caller::Latents Latents() = delete; Latents(const std::vector& haplotypes, - const std::vector& samples, - CancerCaller::ModelPriors model_priors); + const std::vector& samples); std::shared_ptr haplotype_posteriors() const override; std::shared_ptr genotype_posteriors() const override; @@ -169,9 +181,10 @@ class CancerCaller::Latents : public Caller::Latents private: std::reference_wrapper> haplotypes_; std::vector> germline_genotypes_; + unsigned somatic_ploidy_ = 1; std::vector> cancer_genotypes_; boost::optional>> germline_genotype_indices_ = boost::none; - boost::optional, unsigned>>> cancer_genotype_indices_ = boost::none; + boost::optional> cancer_genotype_indices_ = boost::none; std::reference_wrapper> samples_; boost::optional> normal_sample_ = boost::none; @@ -185,6 +198,7 @@ class CancerCaller::Latents : public Caller::Latents TumourModel::InferredLatents tumour_model_inferences_; boost::optional noise_model_inferences_ = boost::none; boost::optional normal_germline_inferences_ = boost::none; + CancerCaller::ModelPosteriors model_posteriors_; mutable std::shared_ptr haplotype_posteriors_ = nullptr; mutable std::shared_ptr genotype_posteriors_ = nullptr; diff --git a/src/core/callers/individual_caller.cpp b/src/core/callers/individual_caller.cpp index 2b6583859..ad5d11c7e 100644 --- a/src/core/callers/individual_caller.cpp +++ b/src/core/callers/individual_caller.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "individual_caller.hpp" @@ -53,6 +53,26 @@ IndividualCaller::CallTypeSet IndividualCaller::do_call_types() const return {std::type_index(typeid(GermlineVariantCall))}; } +unsigned IndividualCaller::do_min_callable_ploidy() const +{ + return parameters_.ploidy; +} + +std::size_t IndividualCaller::do_remove_duplicates(std::vector& haplotypes) const +{ + if (parameters_.deduplicate_haplotypes_with_germline_model) { + if (haplotypes.size() < 2) return 0; + CoalescentModel::Parameters model_params {}; + if (parameters_.prior_model_params) model_params = *parameters_.prior_model_params; + Haplotype reference {mapped_region(haplotypes.front()), reference_.get()}; + CoalescentModel model {std::move(reference), model_params, haplotypes.size(), CoalescentModel::CachingStrategy::none}; + const CoalescentProbabilityGreater cmp {std::move(model)}; + return octopus::remove_duplicates(haplotypes, cmp); + } else { + return Caller::do_remove_duplicates(haplotypes); + } +} + // IndividualCaller::Latents public methods IndividualCaller::Latents::Latents(const SampleName& sample, @@ -150,9 +170,8 @@ IndividualCaller::calculate_model_posterior(const std::vector& haplot namespace { -using GM = model::IndividualModel; using GenotypeProbabilityMap = ProbabilityMatrix>::InnerMap; -using VariantReference = std::reference_wrapper; +using VariantReference = std::reference_wrapper; using VariantPosteriorVector = std::vector>>; struct VariantCall : Mappable @@ -233,13 +252,10 @@ VariantCalls call_candidates(const VariantPosteriorVector& candidate_posteriors, { VariantCalls result {}; result.reserve(candidate_posteriors.size()); - - std::copy_if(std::cbegin(candidate_posteriors), std::cend(candidate_posteriors), - std::back_inserter(result), + std::copy_if(std::cbegin(candidate_posteriors), std::cend(candidate_posteriors), std::back_inserter(result), [&genotype_call, min_posterior] (const auto& p) { return p.second >= min_posterior && contains_alt(genotype_call, p.first); }); - return result; } @@ -311,13 +327,9 @@ octopus::VariantCall::GenotypeCall convert(GenotypeCall&& call) std::unique_ptr transform_call(const SampleName& sample, VariantCall&& variant_call, GenotypeCall&& genotype_call) { - std::vector> tmp { - std::make_pair(sample, convert(std::move(genotype_call))) - }; - std::unique_ptr result { - std::make_unique(variant_call.variant.get(), std::move(tmp), - variant_call.posterior) - }; + std::vector> tmp {std::make_pair(sample, convert(std::move(genotype_call)))}; + std::unique_ptr result {std::make_unique(variant_call.variant.get(), std::move(tmp), + variant_call.posterior)}; return result; } @@ -326,10 +338,8 @@ auto transform_calls(const SampleName& sample, VariantCalls&& variant_calls, { std::vector> result {}; result.reserve(variant_calls.size()); - std::transform(std::make_move_iterator(std::begin(variant_calls)), - std::make_move_iterator(std::end(variant_calls)), - std::make_move_iterator(std::begin(genotype_calls)), - std::back_inserter(result), + std::transform(std::make_move_iterator(std::begin(variant_calls)), std::make_move_iterator(std::end(variant_calls)), + std::make_move_iterator(std::begin(genotype_calls)), std::back_inserter(result), [&sample] (VariantCall&& variant_call, GenotypeCall&& genotype_call) { return transform_call(sample, std::move(variant_call), std::move(genotype_call)); }); @@ -379,81 +389,102 @@ namespace { // reference genotype calling -struct RefCall : public Mappable +struct RefCall { - RefCall() = default; - - template - RefCall(A&& reference_allele, double posterior) - : reference_allele {std::forward(reference_allele)} - , posterior {posterior} - {} - - const GenomicRegion& mapped_region() const noexcept { return reference_allele.mapped_region(); } - Allele reference_allele; - double posterior; + Phred posterior; }; -using RefCalls = std::vector; - -// double marginalise_reference_genotype(const Allele& reference_allele, -// const GenotypeProbabilityMap& sample_genotype_posteriors) -// { -// double result {0}; -// -// for (const auto& genotype_posterior : sample_genotype_posteriors) { -// if (is_homozygous(genotype_posterior.first, reference_allele)) { -// result += genotype_posterior.second; -// } -// } -// -// return result; -// } - -// RefCalls call_reference(const GenotypeProbabilityMap& genotype_posteriors, -// const std::vector& reference_alleles, -// const ReadMap::mapped_type& reads, const double min_call_posterior) -// { -// RefCalls result {}; -// -// if (reference_alleles.empty()) return result; -// -// result.reserve(reference_alleles.size()); -// -// for (const auto& reference_allele : reference_alleles) { -// double posterior {0}; -// -// if (has_coverage(reads, mapped_region(reference_allele))) { -// posterior = marginalise_reference_genotype(reference_allele, -// genotype_posteriors); -// } -// -// if (posterior >= min_call_posterior) { -// result.emplace_back(reference_allele, posterior); -// } -// } -// -// result.shrink_to_fit(); -// -// return result; -// } +const GenomicRegion& mapped_region(const GenotypeProbabilityMap& genotype_posteriors) +{ + return mapped_region(std::cbegin(genotype_posteriors)->first); +} + +bool has_variation(const Allele& allele, const GenotypeProbabilityMap& genotype_posteriors) +{ + if (!genotype_posteriors.empty() && !contains(mapped_region(genotype_posteriors), allele)) { + return false; + } + return std::any_of(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), + [&allele] (const auto& p) { + return !is_homozygous(p.first, allele); + }); +} + +auto marginalise_homozygous(const Allele& allele, const GenotypeProbabilityMap& genotype_posteriors) +{ + auto p = std::accumulate(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), 0.0, + [&] (const auto curr, const auto& p) { + return curr + (is_homozygous(p.first, allele) ? 0.0 : p.second); + }); + return probability_to_phred(p); +} + +auto mean_depth(const ReadPileups& pileups, const GenomicRegion& region) +{ + const auto overlapped = overlap_range(pileups, region.contig_region()); + return maths::mean(overlapped, [] (const auto& pileup) { return pileup.depth(); }); +} + +auto compute_homozygous_posterior(const Allele& allele, + const GenotypeProbabilityMap& genotype_posteriors, + const ReadPileups& pileups) +{ + if (has_variation(allele, genotype_posteriors)) { + return marginalise_homozygous(allele, genotype_posteriors); + } else { + const auto coverage = mean_depth(pileups, mapped_region(allele)); + return Phred {2 * static_cast(coverage)}; + } +} + +auto call_reference(const std::vector& reference_alleles, + const GenotypeProbabilityMap& genotype_posteriors, + const ReadPileups& pileups, + const Phred min_call_posterior) +{ + std::vector result {}; + for (const auto& allele : reference_alleles) { + const auto posterior = compute_homozygous_posterior(allele, genotype_posteriors, pileups); + if (posterior >= min_call_posterior) { + result.push_back({allele, posterior}); + } + } + return result; +} + +auto transform_calls(std::vector&& calls, const SampleName& sample, const unsigned ploidy) +{ + std::vector> result {}; + result.reserve(calls.size()); + std::transform(std::make_move_iterator(std::begin(calls)), std::make_move_iterator(std::end(calls)), + std::back_inserter(result), + [&] (auto&& call) { + std::map genotype {{sample, {ploidy, call.posterior}}}; + return std::make_unique(std::move(call.reference_allele), call.posterior, std::move(genotype)); + }); + return result; +} } // namespace std::vector> IndividualCaller::call_reference(const std::vector& alleles, const Caller::Latents& latents, - const ReadMap& reads) const + const ReadPileupMap& pileups) const { - return call_reference(alleles, dynamic_cast(latents), reads); + return call_reference(alleles, dynamic_cast(latents), pileups); } std::vector> -IndividualCaller::call_reference(const std::vector& alleles, const Latents& latents, - const ReadMap& reads) const +IndividualCaller::call_reference(const std::vector& alleles, + const Latents& latents, + const ReadPileupMap& pileups) const { - return {}; + const auto& genotype_posteriors = (*latents.genotype_posteriors_)[sample()]; + auto calls = octopus::call_reference(alleles, genotype_posteriors, pileups.at(sample()), + parameters_.min_refcall_posterior); + return transform_calls(std::move(calls), sample(), parameters_.ploidy); } const SampleName& IndividualCaller::sample() const noexcept diff --git a/src/core/callers/individual_caller.hpp b/src/core/callers/individual_caller.hpp index 2d876d6a8..8ff218600 100644 --- a/src/core/callers/individual_caller.hpp +++ b/src/core/callers/individual_caller.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef individual_caller_hpp @@ -37,6 +37,7 @@ class IndividualCaller : public Caller unsigned ploidy; boost::optional prior_model_params; Phred min_variant_posterior, min_refcall_posterior; + bool deduplicate_haplotypes_with_germline_model = false; }; IndividualCaller() = delete; @@ -59,6 +60,9 @@ class IndividualCaller : public Caller std::string do_name() const override; CallTypeSet do_call_types() const override; + unsigned do_min_callable_ploidy() const override; + + std::size_t do_remove_duplicates(std::vector& haplotypes) const override; std::unique_ptr infer_latents(const std::vector& haplotypes, @@ -82,11 +86,11 @@ class IndividualCaller : public Caller std::vector> call_reference(const std::vector& alleles, const Caller::Latents& latents, - const ReadMap& reads) const override; + const ReadPileupMap& pileups) const override; std::vector> call_reference(const std::vector& alleles, const Latents& latents, - const ReadMap& reads) const; + const ReadPileupMap& pileups) const; const SampleName& sample() const noexcept; diff --git a/src/core/callers/polyclone_caller.cpp b/src/core/callers/polyclone_caller.cpp new file mode 100644 index 000000000..77a779f4c --- /dev/null +++ b/src/core/callers/polyclone_caller.cpp @@ -0,0 +1,597 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "polyclone_caller.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "basics/genomic_region.hpp" +#include "containers/probability_matrix.hpp" +#include "core/types/allele.hpp" +#include "core/types/variant.hpp" +#include "core/types/calls/germline_variant_call.hpp" +#include "core/types/calls/reference_call.hpp" +#include "core/models/genotype/uniform_genotype_prior_model.hpp" +#include "core/models/genotype/coalescent_genotype_prior_model.hpp" +#include "utils/mappable_algorithms.hpp" +#include "utils/read_stats.hpp" +#include "utils/concat.hpp" +#include "logging/logging.hpp" + +namespace octopus { + +PolycloneCaller::PolycloneCaller(Caller::Components&& components, + Caller::Parameters general_parameters, + Parameters specific_parameters) +: Caller {std::move(components), std::move(general_parameters)} +, parameters_ {std::move(specific_parameters)} +{ + if (parameters_.max_clones < 1) { + throw std::logic_error {"PolycloneCaller: max_clones must be > 1"}; + } +} + +std::string PolycloneCaller::do_name() const +{ + return "polyclone"; +} + +PolycloneCaller::CallTypeSet PolycloneCaller::do_call_types() const +{ + return {std::type_index(typeid(GermlineVariantCall))}; +} + +unsigned PolycloneCaller::do_min_callable_ploidy() const +{ + return 1; +} + +unsigned PolycloneCaller::do_max_callable_ploidy() const +{ + return parameters_.max_clones; +} + +std::size_t PolycloneCaller::do_remove_duplicates(std::vector& haplotypes) const +{ + if (parameters_.deduplicate_haplotypes_with_germline_model) { + if (haplotypes.size() < 2) return 0; + CoalescentModel::Parameters model_params {}; + if (parameters_.prior_model_params) model_params = *parameters_.prior_model_params; + Haplotype reference {mapped_region(haplotypes.front()), reference_.get()}; + CoalescentModel model {std::move(reference), model_params, haplotypes.size(), CoalescentModel::CachingStrategy::none}; + const CoalescentProbabilityGreater cmp {std::move(model)}; + return octopus::remove_duplicates(haplotypes, cmp); + } else { + return Caller::do_remove_duplicates(haplotypes); + } +} + +// PolycloneCaller::Latents public methods + +PolycloneCaller::Latents::Latents(std::vector> haploid_genotypes, std::vector> polyploid_genotypes, + HaploidModelInferences haploid_model_inferences, SubloneModelInferences subclone_model_inferences, + const SampleName& sample, const std::function& clonality_prior) +: haploid_genotypes_ {std::move(haploid_genotypes)} +, polyploid_genotypes_ {std::move(polyploid_genotypes)} +, haploid_model_inferences_ {std::move(haploid_model_inferences)} +, subclone_model_inferences_ {std::move(subclone_model_inferences)} +, model_posteriors_ {} +, sample_ {sample} +{ + if (!polyploid_genotypes_.empty()) { + const auto haploid_model_prior = std::log(clonality_prior(1)); + const auto called_subclonality = polyploid_genotypes_.front().ploidy(); + const auto subclone_model_prior = std::log(clonality_prior(called_subclonality)); + const auto haploid_model_jp = haploid_model_prior + haploid_model_inferences_.log_evidence; + const auto subclone_model_jp = subclone_model_prior + subclone_model_inferences_.approx_log_evidence; + const auto norm = maths::log_sum_exp({haploid_model_jp, subclone_model_jp}); + model_posteriors_.clonal = std::exp(haploid_model_jp - norm); + model_posteriors_.subclonal = std::exp(subclone_model_jp - norm); + } else { + model_posteriors_.clonal = 1.0; + model_posteriors_.subclonal = 0.0; + } +} + +std::shared_ptr +PolycloneCaller::Latents::haplotype_posteriors() const noexcept +{ + if (haplotype_posteriors_ == nullptr) { + haplotype_posteriors_ = std::make_shared(); + for (const auto& p : (*(this->genotype_posteriors()))[sample_]) { + for (const auto& haplotype : p.first.copy_unique_ref()) { + (*haplotype_posteriors_)[haplotype] += p.second; + } + } + } + return haplotype_posteriors_; +} + +std::shared_ptr +PolycloneCaller::Latents::genotype_posteriors() const noexcept +{ + if (genotype_posteriors_ == nullptr) { + const auto genotypes = concat(haploid_genotypes_, polyploid_genotypes_); + auto posteriors = concat(haploid_model_inferences_.posteriors.genotype_probabilities, + subclone_model_inferences_.posteriors.genotype_probabilities); + std::for_each(std::begin(posteriors), std::next(std::begin(posteriors), haploid_genotypes_.size()), + [this] (auto& p) { p *= model_posteriors_.clonal; }); + std::for_each(std::next(std::begin(posteriors), haploid_genotypes_.size()), std::end(posteriors), + [this] (auto& p) { p *= model_posteriors_.subclonal; }); + genotype_posteriors_ = std::make_shared(std::make_move_iterator(std::begin(genotypes)), + std::make_move_iterator(std::end(genotypes))); + insert_sample(sample_, posteriors, *genotype_posteriors_); + } + return genotype_posteriors_; +} + +// PolycloneCaller::Latents private methods + +namespace { + +auto make_sublone_model_mixture_prior_map(const SampleName& sample, const unsigned num_clones, const double alpha = 0.5) +{ + model::SubcloneModel::Priors::GenotypeMixturesDirichletAlphaMap result {}; + model::SubcloneModel::Priors::GenotypeMixturesDirichletAlphas alphas(num_clones, alpha); + result.emplace(sample, std::move(alphas)); + return result; +} + +template +T nth_greatest_value(std::vector values, const std::size_t n) +{ + auto nth_itr = std::next(std::begin(values), n); + std::nth_element(std::begin(values), nth_itr, std::end(values), std::greater<> {}); + return *nth_itr; +} + +template +void erase_indices(std::vector& v, const std::vector& indices) +{ + assert(std::is_sorted(std::cbegin(indices), std::cend(indices))); + std::for_each(std::crbegin(indices), std::crend(indices), [&v] (auto idx) { v.erase(std::next(std::cbegin(v), idx)); }); +} + +void reduce(std::vector>& genotypes, const GenotypePriorModel& genotype_prior_model, + const HaplotypeLikelihoodCache& haplotype_likelihoods, const std::size_t n) +{ + if (genotypes.size() <= n) return; + const model::IndividualModel approx_model {genotype_prior_model}; + const auto approx_posteriors = approx_model.evaluate(genotypes, haplotype_likelihoods).posteriors.genotype_probabilities; + const auto min_posterior = nth_greatest_value(approx_posteriors, n + 1); + std::size_t idx {0}; + genotypes.erase(std::remove_if(std::begin(genotypes), std::end(genotypes), + [&] (const auto& genotype) { return approx_posteriors[idx++] <= min_posterior; }), + std::end(genotypes)); +} + +void fit_sublone_model(const std::vector& haplotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods, + const GenotypePriorModel& genotype_prior_model, const SampleName& sample, const unsigned max_clones, + const double haploid_model_evidence, const std::function& clonality_prior, + const std::size_t max_genotypes, std::vector>& polyploid_genotypes, + model::SubcloneModel::InferredLatents& sublonal_inferences, + boost::optional& debug_log) +{ + const auto haploid_prior = std::log(clonality_prior(1)); + for (unsigned num_clones {2}; num_clones <= max_clones; ++num_clones) { + const auto clonal_model_prior = clonality_prior(num_clones); + if (clonal_model_prior == 0.0) break; + auto genotypes = generate_all_full_rank_genotypes(haplotypes, num_clones); + reduce(genotypes, genotype_prior_model, haplotype_likelihoods, max_genotypes); + if (debug_log) stream(*debug_log) << "Generated " << genotypes.size() << " genotypes with clonality " << num_clones; + if (genotypes.empty()) break; + model::SubcloneModel::Priors subclonal_model_priors {genotype_prior_model, make_sublone_model_mixture_prior_map(sample, num_clones)}; + model::SubcloneModel subclonal_model {{sample}, subclonal_model_priors}; + auto inferences = subclonal_model.evaluate(genotypes, haplotype_likelihoods); + if (debug_log) stream(*debug_log) << "Evidence for model with clonality " << num_clones << " is " << inferences.approx_log_evidence; + if (num_clones == 2) { + polyploid_genotypes = std::move(genotypes); + sublonal_inferences = std::move(inferences); + if ((std::log(clonal_model_prior) + sublonal_inferences.approx_log_evidence) + < (haploid_prior + haploid_model_evidence)) break; + } else { + if ((std::log(clonal_model_prior) + inferences.approx_log_evidence) + <= (std::log(clonality_prior(num_clones - 1)) + sublonal_inferences.approx_log_evidence)) break; + polyploid_genotypes = std::move(genotypes); + sublonal_inferences = std::move(inferences); + } + } +} + +} // namespace + +std::unique_ptr +PolycloneCaller::infer_latents(const std::vector& haplotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods) const +{ + auto haploid_genotypes = generate_all_genotypes(haplotypes, 1); + if (debug_log_) stream(*debug_log_) << "There are " << haploid_genotypes.size() << " candidate haploid genotypes"; + const auto genotype_prior_model = make_prior_model(haplotypes); + const model::IndividualModel haploid_model {*genotype_prior_model, debug_log_}; + haplotype_likelihoods.prime(sample()); + auto haploid_inferences = haploid_model.evaluate(haploid_genotypes, haplotype_likelihoods); + if (debug_log_) stream(*debug_log_) << "Evidence for haploid model is " << haploid_inferences.log_evidence; + std::vector> polyploid_genotypes; model::SubcloneModel::InferredLatents sublonal_inferences; + fit_sublone_model(haplotypes, haplotype_likelihoods, *genotype_prior_model, sample(), parameters_.max_clones, + haploid_inferences.log_evidence, parameters_.clonality_prior, parameters_.max_genotypes, polyploid_genotypes, + sublonal_inferences, debug_log_); + if (debug_log_) stream(*debug_log_) << "There are " << polyploid_genotypes.size() << " candidate polyploid genotypes"; + using std::move; + return std::make_unique(move(haploid_genotypes), move(polyploid_genotypes), + move(haploid_inferences), move(sublonal_inferences), + sample(), parameters_.clonality_prior); +} + +boost::optional +PolycloneCaller::calculate_model_posterior(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods, + const Caller::Latents& latents) const +{ + return calculate_model_posterior(haplotypes, haplotype_likelihoods, dynamic_cast(latents)); +} + +boost::optional +PolycloneCaller::calculate_model_posterior(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods, + const Latents& latents) const +{ + return boost::none; +} + +std::vector> +PolycloneCaller::call_variants(const std::vector& candidates, const Caller::Latents& latents) const +{ + return call_variants(candidates, dynamic_cast(latents)); +} + +namespace { + +using GenotypeProbabilityMap = ProbabilityMatrix>::InnerMap; +using VariantReference = std::reference_wrapper; +using VariantPosteriorVector = std::vector>>; + +struct VariantCall : Mappable +{ + VariantCall() = delete; + VariantCall(const std::pair>& p) + : variant {p.first} + , posterior {p.second} + {} + VariantCall(const Variant& variant, Phred posterior) + : variant {variant} + , posterior {posterior} + {} + + const GenomicRegion& mapped_region() const noexcept + { + return octopus::mapped_region(variant.get()); + } + + VariantReference variant; + Phred posterior; + bool is_dummy_filtered = false; +}; + +using VariantCalls = std::vector; + +struct GenotypeCall +{ + template GenotypeCall(T&& genotype, Phred posterior) + : genotype {std::forward(genotype)} + , posterior {posterior} + {} + + Genotype genotype; + Phred posterior; +}; + +using GenotypeCalls = std::vector; + +// allele posterior calculations + +auto compute_posterior(const Allele& allele, const GenotypeProbabilityMap& genotype_posteriors) +{ + auto p = std::accumulate(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), + 0.0, [&allele] (const auto curr, const auto& p) { + return curr + (contains(p.first, allele) ? 0.0 : p.second); + }); + return probability_to_phred(p); +} + +auto compute_candidate_posteriors(const std::vector& candidates, + const GenotypeProbabilityMap& genotype_posteriors) +{ + VariantPosteriorVector result {}; + result.reserve(candidates.size()); + for (const auto& candidate : candidates) { + result.emplace_back(candidate, compute_posterior(candidate.alt_allele(), genotype_posteriors)); + } + return result; +} + +// variant calling + +bool has_callable(const VariantPosteriorVector& variant_posteriors, const Phred min_posterior) noexcept +{ + return std::any_of(std::cbegin(variant_posteriors), std::cend(variant_posteriors), + [=] (const auto& p) noexcept { return p.second >= min_posterior; }); +} + +bool contains_alt(const Genotype& genotype_call, const VariantReference& candidate) +{ + return includes(genotype_call, candidate.get().alt_allele()); +} + +VariantCalls call_candidates(const VariantPosteriorVector& candidate_posteriors, + const Genotype& genotype_call, + const Phred min_posterior) +{ + VariantCalls result {}; + result.reserve(candidate_posteriors.size()); + std::copy_if(std::cbegin(candidate_posteriors), std::cend(candidate_posteriors), std::back_inserter(result), + [&genotype_call, min_posterior] (const auto& p) { + return p.second >= min_posterior && contains_alt(genotype_call, p.first); + }); + return result; +} + +// variant genotype calling + +template +PairIterator find_map(PairIterator first, PairIterator last) +{ + return std::max_element(first, last, [] (const auto& lhs, const auto& rhs) { return lhs.second < rhs.second; }); +} + +template +bool is_homozygous_reference(const Genotype& g) +{ + return is_reference(g[0]) && g.is_homozygous(); +} + +auto call_genotype(const GenotypeProbabilityMap& genotype_posteriors, const bool ignore_hom_ref = false) +{ + const auto map_itr = find_map(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors)); + assert(map_itr != std::cend(genotype_posteriors)); + if (!ignore_hom_ref || !is_homozygous_reference(map_itr->first)) { + return map_itr->first; + } else { + const auto lhs_map_itr = find_map(std::cbegin(genotype_posteriors), map_itr); + const auto rhs_map_itr = find_map(std::next(map_itr), std::cend(genotype_posteriors)); + if (lhs_map_itr != map_itr) { + if (rhs_map_itr != std::cend(genotype_posteriors)) { + return lhs_map_itr->second < rhs_map_itr->second ? rhs_map_itr->first : lhs_map_itr->first; + } else { + return lhs_map_itr->first; + } + } else { + return rhs_map_itr->first; + } + } +} + +auto compute_posterior(const Genotype& genotype, const GenotypeProbabilityMap& genotype_posteriors) +{ + auto p = std::accumulate(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), 0.0, + [&genotype] (const double curr, const auto& p) { + return curr + (contains(p.first, genotype) ? 0.0 : p.second); + }); + return probability_to_phred(p); +} + +GenotypeCalls call_genotypes(const Genotype& genotype_call, + const GenotypeProbabilityMap& genotype_posteriors, + const std::vector& variant_regions) +{ + GenotypeCalls result {}; + result.reserve(variant_regions.size()); + for (const auto& region : variant_regions) { + auto genotype_chunk = copy(genotype_call, region); + const auto posterior = compute_posterior(genotype_chunk, genotype_posteriors); + result.emplace_back(std::move(genotype_chunk), posterior); + } + return result; +} + +// output + +octopus::VariantCall::GenotypeCall convert(GenotypeCall&& call) +{ + return octopus::VariantCall::GenotypeCall {std::move(call.genotype), call.posterior}; +} + +std::unique_ptr +transform_call(const SampleName& sample, VariantCall&& variant_call, GenotypeCall&& genotype_call) +{ + std::vector> tmp {std::make_pair(sample, convert(std::move(genotype_call)))}; + std::unique_ptr result {std::make_unique(variant_call.variant.get(), std::move(tmp), + variant_call.posterior)}; + return result; +} + +auto transform_calls(const SampleName& sample, VariantCalls&& variant_calls, GenotypeCalls&& genotype_calls) +{ + std::vector> result {}; + result.reserve(variant_calls.size()); + std::transform(std::make_move_iterator(std::begin(variant_calls)), std::make_move_iterator(std::end(variant_calls)), + std::make_move_iterator(std::begin(genotype_calls)), std::back_inserter(result), + [&sample] (VariantCall&& variant_call, GenotypeCall&& genotype_call) { + return transform_call(sample, std::move(variant_call), std::move(genotype_call)); + }); + return result; +} + +} // namespace + +namespace debug { namespace { + +void log(const GenotypeProbabilityMap& genotype_posteriors, + boost::optional& debug_log, + boost::optional& trace_log); + +void log(const VariantPosteriorVector& candidate_posteriors, + boost::optional& debug_log, + boost::optional& trace_log, + Phred min_posterior); + +void log(const Genotype& called_genotype, + boost::optional& debug_log); + +} // namespace +} // namespace debug + +std::vector> +PolycloneCaller::call_variants(const std::vector& candidates, const Latents& latents) const +{ + log(latents); + const auto& genotype_posteriors = (*latents.genotype_posteriors())[sample()]; + debug::log(genotype_posteriors, debug_log_, trace_log_); + const auto candidate_posteriors = compute_candidate_posteriors(candidates, genotype_posteriors); + debug::log(candidate_posteriors, debug_log_, trace_log_, parameters_.min_variant_posterior); + const bool force_call_non_ref {has_callable(candidate_posteriors, parameters_.min_variant_posterior)}; + const auto genotype_call = octopus::call_genotype(genotype_posteriors, force_call_non_ref); + auto variant_calls = call_candidates(candidate_posteriors, genotype_call, parameters_.min_variant_posterior); + const auto called_regions = extract_regions(variant_calls); + auto genotype_calls = call_genotypes(genotype_call, genotype_posteriors, called_regions); + return transform_calls(sample(), std::move(variant_calls), std::move(genotype_calls)); +} + +std::vector> +PolycloneCaller::call_reference(const std::vector& alleles, const Caller::Latents& latents, const ReadPileupMap& pileup) const +{ + return call_reference(alleles, dynamic_cast(latents), pileup); +} + +std::vector> +PolycloneCaller::call_reference(const std::vector& alleles, const Latents& latents, const ReadPileupMap& pileup) const +{ + return {}; +} + +const SampleName& PolycloneCaller::sample() const noexcept +{ + return samples_.front(); +} + +std::unique_ptr PolycloneCaller::make_prior_model(const std::vector& haplotypes) const +{ + if (parameters_.prior_model_params) { + return std::make_unique(CoalescentModel { + Haplotype {mapped_region(haplotypes.front()), reference_}, + *parameters_.prior_model_params, haplotypes.size(), CoalescentModel::CachingStrategy::address + }); + } else { + return std::make_unique(); + } +} + +void PolycloneCaller::log(const Latents& latents) const +{ + if (debug_log_) { + stream(*debug_log_) << "Clonal model posterior is " << latents.model_posteriors_.clonal + << " and subclonal model posterior is " << latents.model_posteriors_.subclonal; + if (latents.model_posteriors_.subclonal > latents.model_posteriors_.clonal) { + stream(*debug_log_) << "Detected subclonality is " << latents.polyploid_genotypes_.front().ploidy(); + } + } +} + +namespace debug { namespace { + +template +void print_genotype_posteriors(S&& stream, + const GenotypeProbabilityMap& genotype_posteriors, + const std::size_t n) +{ + const auto m = std::min(n, genotype_posteriors.size()); + if (m == genotype_posteriors.size()) { + stream << "Printing all genotype posteriors " << '\n'; + } else { + stream << "Printing top " << m << " genotype posteriors " << '\n'; + } + using GenotypeReference = std::reference_wrapper>; + std::vector> v {}; + v.reserve(genotype_posteriors.size()); + std::copy(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), std::back_inserter(v)); + const auto mth = std::next(std::begin(v), m); + std::partial_sort(std::begin(v), mth, std::end(v), + [] (const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); + std::for_each(std::begin(v), mth, + [&] (const auto& p) { + print_variant_alleles(stream, p.first); + stream << " " << p.second << '\n'; + }); +} + +void print_genotype_posteriors(const GenotypeProbabilityMap& genotype_posteriors, + const std::size_t n) +{ + print_genotype_posteriors(std::cout, genotype_posteriors, n); +} + +template +void print_candidate_posteriors(S&& stream, const VariantPosteriorVector& candidate_posteriors, + const std::size_t n) +{ + const auto m = std::min(n, candidate_posteriors.size()); + if (m == candidate_posteriors.size()) { + stream << "Printing all candidate variant posteriors " << '\n'; + } else { + stream << "Printing top " << m << " candidate variant posteriors " << '\n'; + } + std::vector>> v {}; + v.reserve(candidate_posteriors.size()); + std::copy(std::cbegin(candidate_posteriors), std::cend(candidate_posteriors), std::back_inserter(v)); + const auto mth = std::next(std::begin(v), m); + std::partial_sort(std::begin(v), mth, std::end(v), + [] (const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); + std::for_each(std::begin(v), mth, + [&] (const auto& p) { + stream << p.first.get() << " " << p.second.probability_true() << '\n'; + }); +} + +void print_candidate_posteriors(const VariantPosteriorVector& candidate_posteriors, + const std::size_t n) +{ + print_candidate_posteriors(std::cout, candidate_posteriors, n); +} + +void log(const GenotypeProbabilityMap& genotype_posteriors, + boost::optional& debug_log, + boost::optional& trace_log) +{ + if (trace_log) { + print_genotype_posteriors(stream(*trace_log), genotype_posteriors, -1); + } + if (debug_log) { + print_genotype_posteriors(stream(*debug_log), genotype_posteriors, 5); + } +} + +void log(const VariantPosteriorVector& candidate_posteriors, + boost::optional& debug_log, + boost::optional& trace_log, + Phred min_posterior) +{ + if (trace_log) { + print_candidate_posteriors(stream(*trace_log), candidate_posteriors, -1); + } + if (debug_log) { + const auto n = std::count_if(std::cbegin(candidate_posteriors), std::cend(candidate_posteriors), + [=] (const auto& p) { return p.second >= min_posterior; }); + print_candidate_posteriors(stream(*debug_log), candidate_posteriors, std::max(n, decltype(n) {5})); + } +} + +} // namespace +} // namespace debug + +} // namespace octopus diff --git a/src/core/callers/polyclone_caller.hpp b/src/core/callers/polyclone_caller.hpp new file mode 100644 index 000000000..fbacb99e1 --- /dev/null +++ b/src/core/callers/polyclone_caller.hpp @@ -0,0 +1,147 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef polyclone_caller_hpp +#define polyclone_caller_hpp + +#include +#include +#include +#include + +#include + +#include "config/common.hpp" +#include "basics/phred.hpp" +#include "core/types/haplotype.hpp" +#include "core/types/genotype.hpp" +#include "core/models/mutation/coalescent_model.hpp" +#include "core/models/genotype/genotype_prior_model.hpp" +#include "core/models/genotype/individual_model.hpp" +#include "core/models/genotype/subclone_model.hpp" +#include "utils/maths.hpp" +#include "caller.hpp" + +namespace octopus { + +class GenomicRegion; +class ReadPipe; +class Variant; +class HaplotypeLikelihoodCache; +class VariantCall; + +class PolycloneCaller : public Caller +{ +public: + using Caller::CallTypeSet; + + struct Parameters + { + boost::optional prior_model_params; + Phred min_variant_posterior, min_refcall_posterior; + bool deduplicate_haplotypes_with_germline_model = false; + unsigned max_clones = 3, max_genotypes = 10'000; + std::function clonality_prior = [] (unsigned clonality) { return maths::geometric_pdf(clonality, 0.5); }; + }; + + PolycloneCaller() = delete; + + PolycloneCaller(Caller::Components&& components, + Caller::Parameters general_parameters, + Parameters specific_parameters); + + PolycloneCaller(const PolycloneCaller&) = delete; + PolycloneCaller& operator=(const PolycloneCaller&) = delete; + PolycloneCaller(PolycloneCaller&&) = delete; + PolycloneCaller& operator=(PolycloneCaller&&) = delete; + + ~PolycloneCaller() = default; + +private: + class Latents; + friend Latents; + + Parameters parameters_; + + struct ModelProbabilities + { + double clonal, subclonal; + }; + + std::string do_name() const override; + CallTypeSet do_call_types() const override; + unsigned do_min_callable_ploidy() const override; + unsigned do_max_callable_ploidy() const override; + + std::size_t do_remove_duplicates(std::vector& haplotypes) const override; + + std::unique_ptr + infer_latents(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const override; + + boost::optional + calculate_model_posterior(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods, + const Caller::Latents& latents) const override; + + boost::optional + calculate_model_posterior(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods, + const Latents& latents) const; + + std::vector> + call_variants(const std::vector& candidates, const Caller::Latents& latents) const override; + + std::vector> + call_variants(const std::vector& candidates, const Latents& latents) const; + + std::vector> + call_reference(const std::vector& alleles, const Caller::Latents& latents, + const ReadPileupMap& pileup) const override; + + std::vector> + call_reference(const std::vector& alleles, const Latents& latents, + const ReadPileupMap& pileup) const; + + const SampleName& sample() const noexcept; + + std::unique_ptr make_prior_model(const std::vector& haplotypes) const; + + // debug + void log(const Latents& latents) const; +}; + +class PolycloneCaller::Latents : public Caller::Latents +{ +public: + using HaploidModelInferences = model::IndividualModel::InferredLatents; + using SubloneModelInferences = model::SubcloneModel::InferredLatents; + + using Caller::Latents::HaplotypeProbabilityMap; + using Caller::Latents::GenotypeProbabilityMap; + + Latents() = delete; + + Latents(std::vector> haploid_genotypes, std::vector> polyploid_genotypes, + HaploidModelInferences haploid_model_inferences, SubloneModelInferences subclone_model_inferences, + const SampleName& sample, const std::function& clonality_prior); + + std::shared_ptr haplotype_posteriors() const noexcept override; + std::shared_ptr genotype_posteriors() const noexcept override; + +private: + std::vector> haploid_genotypes_, polyploid_genotypes_; + HaploidModelInferences haploid_model_inferences_; + SubloneModelInferences subclone_model_inferences_; + PolycloneCaller::ModelProbabilities model_posteriors_; + SampleName sample_; + + mutable std::shared_ptr genotype_posteriors_; + mutable std::shared_ptr haplotype_posteriors_; + + friend PolycloneCaller; +}; + +} // namespace octopus + +#endif diff --git a/src/core/callers/population_caller.cpp b/src/core/callers/population_caller.cpp index 64af86c24..e2022bf1a 100644 --- a/src/core/callers/population_caller.cpp +++ b/src/core/callers/population_caller.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "population_caller.hpp" @@ -60,6 +60,31 @@ PopulationCaller::CallTypeSet PopulationCaller::do_call_types() const return {std::type_index(typeid(GermlineVariantCall))}; } +unsigned PopulationCaller::do_min_callable_ploidy() const +{ + return *std::min_element(std::cbegin(parameters_.ploidies), std::cend(parameters_.ploidies)); +} + +unsigned PopulationCaller::do_max_callable_ploidy() const +{ + return *std::max_element(std::cbegin(parameters_.ploidies), std::cend(parameters_.ploidies)); +} + +std::size_t PopulationCaller::do_remove_duplicates(std::vector& haplotypes) const +{ + if (parameters_.deduplicate_haplotypes_with_germline_model) { + if (haplotypes.size() < 2) return 0; + CoalescentModel::Parameters model_params {}; + if (parameters_.prior_model_params) model_params = *parameters_.prior_model_params; + Haplotype reference {mapped_region(haplotypes.front()), reference_.get()}; + CoalescentModel model {std::move(reference), model_params, haplotypes.size(), CoalescentModel::CachingStrategy::none}; + const CoalescentProbabilityGreater cmp {std::move(model)}; + return octopus::remove_duplicates(haplotypes, cmp); + } else { + return Caller::do_remove_duplicates(haplotypes); + } +} + // IndividualCaller::Latents public methods namespace { @@ -101,19 +126,6 @@ using GM = model::PopulationModel; using GenotypeMarginalPosteriorVector = std::vector; using GenotypeMarginalPosteriorMatrix = std::vector; -auto calculate_genotype_marginal_posteriors(const GM::Latents& posteriors, - const std::size_t num_genotypes, - const std::size_t num_samples) -{ - GenotypeMarginalPosteriorMatrix result {num_samples, GenotypeMarginalPosteriorVector(num_genotypes, 0.0)}; - for (std::size_t i {0}; i < posteriors.genotype_combinations.size(); ++i) { - for (std::size_t s {0}; s < num_samples; ++s) { - result[s][posteriors.genotype_combinations[i][s]] += posteriors.joint_genotype_probabilities[i]; - } - } - return result; -} - auto calculate_haplotype_posteriors(const std::vector& haplotypes, const std::vector>& genotypes, const GenotypeMarginalPosteriorMatrix& genotype_posteriors, @@ -178,12 +190,13 @@ PopulationCaller::Latents::Latents(const std::vector& samples, ModelInferences&& inferences) : model_latents_ {std::move(inferences)} { - auto genotype_marginal_posteriors = calculate_genotype_marginal_posteriors(model_latents_.posteriors, genotypes.size(), samples.size()); auto inverse_genotypes = make_inverse_genotype_table(haplotypes, genotypes); - haplotype_posteriors_ = std::make_shared(calculate_haplotype_posteriors(haplotypes, genotypes, genotype_marginal_posteriors, inverse_genotypes)); + haplotype_posteriors_ = std::make_shared(calculate_haplotype_posteriors(haplotypes, genotypes, + model_latents_.posteriors.marginal_genotype_probabilities, + inverse_genotypes)); GenotypeProbabilityMap genotype_posteriors {std::begin(genotypes), std::end(genotypes)}; for (std::size_t s {0}; s < samples.size(); ++s) { - insert_sample(samples[s], genotype_marginal_posteriors[s], genotype_posteriors); + insert_sample(samples[s], model_latents_.posteriors.marginal_genotype_probabilities[s], genotype_posteriors); } genotype_posteriors_ = std::make_shared(std::move(genotype_posteriors)); genotypes_.emplace(genotypes.front().ploidy(), std::move(genotypes)); @@ -240,20 +253,10 @@ std::unique_ptr PopulationCaller::infer_latents(const std::vector& haplotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { - //const auto prior_model = make_prior_model(haplotypes); - //const model::PopulationModel model {*prior_model, {parameters_.max_genotypes_per_sample}, debug_log_}; - const auto prior_model = make_independent_prior_model(haplotypes); - const model::IndependentPopulationModel model {*prior_model, debug_log_}; - if (parameters_.ploidies.size() == 1) { - auto genotypes = generate_all_genotypes(haplotypes, parameters_.ploidies.front()); - if (debug_log_) stream(*debug_log_) << "There are " << genotypes.size() << " candidate genotypes"; - auto inferences = model.evaluate(samples_, genotypes, haplotype_likelihoods); - return std::make_unique(samples_, haplotypes, std::move(genotypes), std::move(inferences)); + if (use_independence_model()) { + return infer_latents_with_independence_model(haplotypes, haplotype_likelihoods); } else { - auto unique_genotypes = generate_unique_genotypes(haplotypes, parameters_.ploidies); - auto sample_genotypes = assign_samples_to_genotypes(parameters_.ploidies, unique_genotypes); - auto inferences = model.evaluate(samples_, sample_genotypes, haplotype_likelihoods); - return std::make_unique(samples_, haplotypes, std::move(unique_genotypes), std::move(inferences)); + return infer_latents_with_joint_model(haplotypes, haplotype_likelihoods); } } @@ -334,7 +337,7 @@ using GenotypeCalls = std::vector>; // allele posterior calculations -using AlleleBools = std::deque; // using std::deque because std::vector is evil +using AlleleBools = std::deque; // using std::deque because std::vector is evil using GenotypePropertyBools = std::vector; auto marginalise(const GenotypeProbabilityMap& genotype_posteriors, @@ -428,21 +431,21 @@ auto compute_posteriors(const std::vector& samples, // haplotype genotype calling -auto call_genotypes(const GM::Latents& latents, const std::vector>& genotypes) -{ - const auto itr = std::max_element(std::cbegin(latents.joint_genotype_probabilities), std::cend(latents.joint_genotype_probabilities)); - const auto& called_indices = latents.genotype_combinations[std::distance(std::cbegin(latents.joint_genotype_probabilities), itr)]; - std::vector> result {}; - result.reserve(called_indices.size()); - std::transform(std::cbegin(called_indices), std::cend(called_indices), std::back_inserter(result), - [&] (auto idx) { return genotypes[idx]; }); - return result; -} - -auto call_genotypes(const GM::Latents& latents, const std::unordered_map>>& genotypes) -{ - return call_genotypes(latents, std::cbegin(genotypes)->second); -} +//auto call_genotypes(const GM::Latents& latents, const std::vector>& genotypes) +//{ +// const auto itr = std::max_element(std::cbegin(latents.joint_genotype_probabilities), std::cend(latents.joint_genotype_probabilities)); +// const auto& called_indices = latents.genotype_combinations[std::distance(std::cbegin(latents.joint_genotype_probabilities), itr)]; +// std::vector> result {}; +// result.reserve(called_indices.size()); +// std::transform(std::cbegin(called_indices), std::cend(called_indices), std::back_inserter(result), +// [&] (auto idx) { return genotypes[idx]; }); +// return result; +//} +// +//auto call_genotypes(const GM::Latents& latents, const std::unordered_map>>& genotypes) +//{ +// return call_genotypes(latents, std::cbegin(genotypes)->second); +//} auto call_genotype(const PopulationGenotypeProbabilityMap::InnerMap& genotype_posteriors) { @@ -624,10 +627,8 @@ auto transform_calls(const std::vector& samples, { std::vector> result {}; result.reserve(variant_calls.size()); - std::transform(std::make_move_iterator(std::begin(variant_calls)), - std::make_move_iterator(std::end(variant_calls)), - std::make_move_iterator(std::begin(genotype_calls)), - std::back_inserter(result), + std::transform(std::make_move_iterator(std::begin(variant_calls)), std::make_move_iterator(std::end(variant_calls)), + std::make_move_iterator(std::begin(genotype_calls)), std::back_inserter(result), [&samples] (auto&& variant_call, auto&& genotype_call) { return transform_call(samples, std::move(variant_call), std::move(genotype_call)); }); @@ -728,16 +729,61 @@ using RefCalls = std::vector; std::vector> PopulationCaller::call_reference(const std::vector& alleles, - const Caller::Latents& latents, - const ReadMap& reads) const + const Caller::Latents& latents, + const ReadPileupMap& pileups) const { return {}; } -std::unique_ptr PopulationCaller::make_prior_model(const std::vector& haplotypes) const +bool PopulationCaller::use_independence_model() const noexcept +{ + return parameters_.use_independent_genotype_priors || !parameters_.prior_model_params; +} + +std::unique_ptr +PopulationCaller::infer_latents_with_joint_model(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const +{ + const auto prior_model = make_joint_prior_model(haplotypes); + prior_model->prime(haplotypes); + const model::PopulationModel model {*prior_model, {parameters_.max_joint_genotypes}, debug_log_}; + if (parameters_.ploidies.size() == 1) { + std::vector genotype_indices; + auto genotypes = generate_all_genotypes(haplotypes, parameters_.ploidies.front(), genotype_indices); + if (debug_log_) stream(*debug_log_) << "There are " << genotypes.size() << " candidate genotypes"; + auto inferences = model.evaluate(samples_, genotypes, genotype_indices, haplotypes, haplotype_likelihoods); + return std::make_unique(samples_, haplotypes, std::move(genotypes), std::move(inferences)); + } else { + auto unique_genotypes = generate_unique_genotypes(haplotypes, parameters_.ploidies); + auto sample_genotypes = assign_samples_to_genotypes(parameters_.ploidies, unique_genotypes); + auto inferences = model.evaluate(samples_, sample_genotypes, haplotype_likelihoods); + return std::make_unique(samples_, haplotypes, std::move(unique_genotypes), std::move(inferences)); + } +} + +std::unique_ptr +PopulationCaller::infer_latents_with_independence_model(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const +{ + const auto prior_model = make_independent_prior_model(haplotypes); + const model::IndependentPopulationModel model {*prior_model, debug_log_}; + if (parameters_.ploidies.size() == 1) { + auto genotypes = generate_all_genotypes(haplotypes, parameters_.ploidies.front()); + if (debug_log_) stream(*debug_log_) << "There are " << genotypes.size() << " candidate genotypes"; + auto inferences = model.evaluate(samples_, genotypes, haplotype_likelihoods); + return std::make_unique(samples_, haplotypes, std::move(genotypes), std::move(inferences)); + } else { + auto unique_genotypes = generate_unique_genotypes(haplotypes, parameters_.ploidies); + auto sample_genotypes = assign_samples_to_genotypes(parameters_.ploidies, unique_genotypes); + auto inferences = model.evaluate(samples_, sample_genotypes, haplotype_likelihoods); + return std::make_unique(samples_, haplotypes, std::move(unique_genotypes), std::move(inferences)); + } +} + +std::unique_ptr PopulationCaller::make_joint_prior_model(const std::vector& haplotypes) const { if (parameters_.prior_model_params) { - return std::make_unique(CoalescentModel { + return std::make_unique(CoalescentModel{ Haplotype {mapped_region(haplotypes.front()), reference_}, *parameters_.prior_model_params }); diff --git a/src/core/callers/population_caller.hpp b/src/core/callers/population_caller.hpp index 7ebc0790d..112964e7a 100644 --- a/src/core/callers/population_caller.hpp +++ b/src/core/callers/population_caller.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef population_caller_hpp @@ -37,7 +37,9 @@ class PopulationCaller : public Caller Phred min_variant_posterior, min_refcall_posterior; std::vector ploidies; boost::optional prior_model_params; - unsigned max_genotypes_per_sample; + std::size_t max_joint_genotypes; + bool use_independent_genotype_priors = false; + bool deduplicate_haplotypes_with_germline_model = true; }; PopulationCaller() = delete; @@ -60,6 +62,10 @@ class PopulationCaller : public Caller std::string do_name() const override; CallTypeSet do_call_types() const override; + unsigned do_min_callable_ploidy() const override; + unsigned do_max_callable_ploidy() const override; + + std::size_t do_remove_duplicates(std::vector& haplotypes) const override; std::unique_ptr infer_latents(const std::vector& haplotypes, @@ -73,9 +79,16 @@ class PopulationCaller : public Caller std::vector> call_reference(const std::vector& alleles, const Caller::Latents& latents, - const ReadMap& reads) const override; + const ReadPileupMap& pileups) const override; - std::unique_ptr make_prior_model(const std::vector& haplotypes) const; + bool use_independence_model() const noexcept; + std::unique_ptr + infer_latents_with_joint_model(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const; + std::unique_ptr + infer_latents_with_independence_model(const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const; + std::unique_ptr make_joint_prior_model(const std::vector& haplotypes) const; std::unique_ptr make_independent_prior_model(const std::vector& haplotypes) const; }; diff --git a/src/core/callers/trio_caller.cpp b/src/core/callers/trio_caller.cpp index de26fe25e..12dfe4b8a 100644 --- a/src/core/callers/trio_caller.cpp +++ b/src/core/callers/trio_caller.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "trio_caller.hpp" @@ -26,19 +26,34 @@ #include "utils/map_utils.hpp" #include "utils/mappable_algorithms.hpp" #include "utils/maths.hpp" - -#include "timers.hpp" +#include "exceptions/unimplemented_feature_error.hpp" namespace octopus { +class BadPloidy : public UnimplementedFeatureError +{ + std::string do_help() const override + { + return "Use the population caller and/or submit a feature request"; + } +public: + BadPloidy(unsigned max_ploidy) + : UnimplementedFeatureError {"trio calling with ploidies greater than " + std::to_string(max_ploidy), "TrioCaller"} + {} +}; + TrioCaller::TrioCaller(Caller::Components&& components, Caller::Parameters general_parameters, Parameters specific_parameters) : Caller {std::move(components), std::move(general_parameters)} , parameters_ {std::move(specific_parameters)} { - if (parameters_.maternal_ploidy == 0) { - throw std::logic_error {"IndividualCaller: ploidy must be > 0"}; + if (parameters_.maternal_ploidy == 0 || parameters_.paternal_ploidy == 0 || parameters_.child_ploidy == 0) { + throw std::logic_error {"TrioCaller: ploidy must be > 0"}; + } + const auto max_ploidy = model::TrioModel::max_ploidy(); + if (parameters_.maternal_ploidy > max_ploidy || parameters_.paternal_ploidy > max_ploidy || parameters_.child_ploidy > max_ploidy) { + throw BadPloidy {max_ploidy}; } } @@ -54,6 +69,31 @@ Caller::CallTypeSet TrioCaller::do_call_types() const std::type_index(typeid(DenovoReferenceReversionCall))}; } +unsigned TrioCaller::do_min_callable_ploidy() const +{ + return std::max({parameters_.maternal_ploidy, parameters_.paternal_ploidy, parameters_.child_ploidy}); +} + +unsigned TrioCaller::do_max_callable_ploidy() const +{ + return std::max({parameters_.maternal_ploidy, parameters_.paternal_ploidy, parameters_.child_ploidy}); +} + +std::size_t TrioCaller::do_remove_duplicates(std::vector& haplotypes) const +{ + if (parameters_.deduplicate_haplotypes_with_germline_model) { + if (haplotypes.size() < 2) return 0; + CoalescentModel::Parameters model_params {}; + if (parameters_.germline_prior_model_params) model_params = *parameters_.germline_prior_model_params; + Haplotype reference {mapped_region(haplotypes.front()), reference_.get()}; + CoalescentModel model {std::move(reference), model_params, haplotypes.size(), CoalescentModel::CachingStrategy::none}; + const CoalescentProbabilityGreater cmp {std::move(model)}; + return octopus::remove_duplicates(haplotypes, cmp); + } else { + return Caller::do_remove_duplicates(haplotypes); + } +} + // TrioCaller::Latents TrioCaller::Latents::Latents(const std::vector& haplotypes, @@ -295,16 +335,20 @@ TrioCaller::calculate_model_posterior(const std::vector& haplotypes, const Latents& latents) const { const auto max_ploidy = std::max({parameters_.maternal_ploidy, parameters_.paternal_ploidy, parameters_.child_ploidy}); - std::vector> genotype_indices {}; - const auto genotypes = generate_all_genotypes(haplotypes, max_ploidy + 1, genotype_indices); - const auto germline_prior_model = make_prior_model(haplotypes); - DeNovoModel denovo_model {parameters_.denovo_model_params}; - germline_prior_model->prime(haplotypes); - denovo_model.prime(haplotypes); - const model::TrioModel model {parameters_.trio, *germline_prior_model, denovo_model, - TrioModel::Options {parameters_.max_joint_genotypes}}; - const auto inferences = model.evaluate(genotypes, genotype_indices, haplotype_likelihoods); - return octopus::calculate_model_posterior(latents.model_latents.log_evidence, inferences.log_evidence); + if (max_ploidy + 1 <= model::TrioModel::max_ploidy()) { + std::vector> genotype_indices {}; + const auto genotypes = generate_all_genotypes(haplotypes, max_ploidy + 1, genotype_indices); + const auto germline_prior_model = make_prior_model(haplotypes); + DeNovoModel denovo_model {parameters_.denovo_model_params}; + germline_prior_model->prime(haplotypes); + denovo_model.prime(haplotypes); + const model::TrioModel model {parameters_.trio, *germline_prior_model, denovo_model, + TrioModel::Options {parameters_.max_joint_genotypes}}; + const auto inferences = model.evaluate(genotypes, genotype_indices, haplotype_likelihoods); + return octopus::calculate_model_posterior(latents.model_latents.log_evidence, inferences.log_evidence); + } else { + return boost::none; + } } std::vector> @@ -854,14 +898,14 @@ TrioCaller::call_variants(const std::vector& candidates, const Latents& std::vector> TrioCaller::call_reference(const std::vector& alleles, const Caller::Latents& latents, - const ReadMap& reads) const + const ReadPileupMap& pileups) const { - return call_reference(alleles, dynamic_cast(latents), reads); + return call_reference(alleles, dynamic_cast(latents), pileups); } std::vector> TrioCaller::call_reference(const std::vector& alleles, const Latents& latents, - const ReadMap& reads) const + const ReadPileupMap& pileups) const { return {}; } diff --git a/src/core/callers/trio_caller.hpp b/src/core/callers/trio_caller.hpp index d40c2845d..7b42a6574 100644 --- a/src/core/callers/trio_caller.hpp +++ b/src/core/callers/trio_caller.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef trio_caller_hpp @@ -34,6 +34,7 @@ class TrioCaller : public Caller DeNovoModel::Parameters denovo_model_params; Phred min_variant_posterior, min_denovo_posterior, min_refcall_posterior; unsigned max_joint_genotypes; + bool deduplicate_haplotypes_with_germline_model = true; }; TrioCaller() = delete; @@ -56,6 +57,10 @@ class TrioCaller : public Caller std::string do_name() const override; CallTypeSet do_call_types() const override; + unsigned do_min_callable_ploidy() const override; + unsigned do_max_callable_ploidy() const override; + + std::size_t do_remove_duplicates(std::vector& haplotypes) const override; std::unique_ptr infer_latents(const std::vector& haplotypes, @@ -79,11 +84,11 @@ class TrioCaller : public Caller std::vector> call_reference(const std::vector& alleles, const Caller::Latents& latents, - const ReadMap& reads) const override; + const ReadPileupMap& pileups) const override; std::vector> call_reference(const std::vector& alleles, const Latents& latents, - const ReadMap& reads) const; + const ReadPileupMap& pileups) const; std::unique_ptr make_prior_model(const std::vector& haplotypes) const; }; diff --git a/src/core/calling_components.cpp b/src/core/calling_components.cpp index 5e90b3815..29ab7767a 100644 --- a/src/core/calling_components.cpp +++ b/src/core/calling_components.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "calling_components.hpp" @@ -12,7 +12,6 @@ #include "config/config.hpp" #include "config/option_collation.hpp" -#include "utils/read_size_estimator.hpp" #include "utils/map_utils.hpp" #include "logging/logging.hpp" #include "exceptions/user_error.hpp" @@ -159,11 +158,26 @@ boost::optional GenomeCallingComponents::filter_r return components_.filter_request_; } +boost::optional GenomeCallingComponents::bamout() const +{ + return components_.bamout_; +} + bool GenomeCallingComponents::sites_only() const noexcept { return components_.sites_only; } +const PloidyMap& GenomeCallingComponents::ploidies() const noexcept +{ + return components_.ploidies; +} + +boost::optional GenomeCallingComponents::pedigree() const +{ + return components_.pedigree; +} + namespace { std::vector @@ -325,21 +339,6 @@ void drop_unused_samples(std::vector& calling_samples, ReadManager& rm.drop_samples(unused_samples); } -auto estimate_read_size(const std::vector& samples, - const InputRegionMap& input_regions, - ReadManager& read_manager) -{ - auto result = estimate_mean_read_size(samples, input_regions, read_manager); - if (!result) { - result = default_read_size_estimate(); - logging::WarningLogger log{}; - log << "Could not estimate read size from data, resorting to default"; - } - auto debug_log = logging::get_debug_log(); - if (debug_log) stream(*debug_log) << "Estimated read size is " << *result << " bytes"; - return *result; -} - bool is_multithreaded_run(const options::OptionMap& options) noexcept { const auto num_threads = options::get_num_threads(options); @@ -371,14 +370,36 @@ boost::optional get_temp_directory(const options::OptionMap& options) } } -std::size_t calculate_max_num_reads(const std::size_t max_buffer_bytes, - const std::vector& samples, - const InputRegionMap& input_regions, - ReadManager& read_manager) +auto estimate_read_size(const boost::optional& profile) noexcept { - if (samples.empty()) return 0; - static constexpr std::size_t minBufferBytes{1'000'000}; - return std::max(max_buffer_bytes, minBufferBytes) / estimate_read_size(samples, input_regions, read_manager); + double result; + if (profile) { + result = profile->mean_read_bytes + profile->read_bytes_stdev; + } else { + result = default_read_size_estimate(); + logging::WarningLogger log{}; + log << "Could not estimate read size from data, resorting to default"; + } + auto debug_log = logging::get_debug_log(); + if (debug_log) stream(*debug_log) << "Estimated read size is " << result << " bytes"; + return result; +} + +std::size_t calculate_max_num_reads(MemoryFootprint max_buffer_size, const boost::optional& profile) noexcept +{ + static constexpr MemoryFootprint min_buffer_size {50'000'000}; // 50Mb + if (max_buffer_size < min_buffer_size) { + static bool warned {false}; + if (!warned) { + logging::WarningLogger warn_log {}; + stream(warn_log) << "Ignoring given maximum read buffer size of " << max_buffer_size + << " as this size is too small. Setting maximum to " + << min_buffer_size << " instead."; + warned = true; + } + max_buffer_size = min_buffer_size; + } + return max_buffer_size.num_bytes() / estimate_read_size(profile); } auto add_identifier(const fs::path& base, const std::string& identifier) @@ -454,19 +475,23 @@ GenomeCallingComponents::Components::Components(ReferenceGenome&& reference, Rea , samples {extract_samples(options, this->read_manager)} , regions {get_search_regions(options, this->reference, this->read_manager)} , contigs {get_contigs(this->regions, this->reference, options::get_contig_output_order(options))} +, temp_directory {get_temp_directory(options)} +, reads_profile_ {profile_reads(this->samples, this->regions, this->read_manager)} , read_pipe {options::make_read_pipe(this->read_manager, this->samples, options)} -, caller_factory {options::make_caller_factory(this->reference, this->read_pipe, this->regions, options)} -, call_filter_factory {options::make_call_filter_factory(this->reference, this->read_pipe, options)} +, caller_factory {options::make_caller_factory(this->reference, this->read_pipe, this->regions, options, this->reads_profile_)} +, call_filter_factory {options::make_call_filter_factory(this->reference, this->read_pipe, options, this->temp_directory)} , filter_read_pipe {} , output {std::move(output)} , num_threads {options::get_num_threads(options)} , read_buffer_size {} -, temp_directory {get_temp_directory(options)} , progress_meter {regions} +, ploidies {options::get_ploidy_map(options)} +, pedigree {options::get_pedigree(options, samples)} , sites_only {options::call_sites_only(options)} , filtered_output {} , legacy {} , filter_request_ {} +, bamout_ {options::bamout_request(options)} { drop_unused_samples(this->samples, this->read_manager); setup_progress_meter(options); @@ -495,8 +520,7 @@ void GenomeCallingComponents::Components::setup_progress_meter(const options::Op void GenomeCallingComponents::Components::set_read_buffer_size(const options::OptionMap& options) { if (!samples.empty() && !regions.empty() && read_manager.good()) { - read_buffer_size = calculate_max_num_reads(options::get_target_read_buffer_size(options).num_bytes(), - samples, regions, read_manager); + read_buffer_size = calculate_max_num_reads(options::get_target_read_buffer_size(options), reads_profile_); } } diff --git a/src/core/calling_components.hpp b/src/core/calling_components.hpp index 8a2004fc2..62aa2901f 100644 --- a/src/core/calling_components.hpp +++ b/src/core/calling_components.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef calling_components_hpp @@ -15,12 +15,15 @@ #include "config/common.hpp" #include "config/option_parser.hpp" #include "basics/genomic_region.hpp" +#include "basics/ploidy_map.hpp" +#include "basics/pedigree.hpp" #include "io/reference/reference_genome.hpp" #include "io/read/read_manager.hpp" #include "io/variant/vcf_writer.hpp" #include "readpipe/read_pipe_fwd.hpp" #include "core/callers/caller_factory.hpp" #include "core/csr/filters/variant_call_filter_factory.hpp" +#include "utils/input_reads_profiler.hpp" #include "logging/progress_meter.hpp" namespace octopus { @@ -63,8 +66,11 @@ class GenomeCallingComponents const ReadPipe& filter_read_pipe() const noexcept; ProgressMeter& progress_meter() noexcept; bool sites_only() const noexcept; + const PloidyMap& ploidies() const noexcept; + boost::optional pedigree() const; boost::optional legacy() const; boost::optional filter_request() const; + boost::optional bamout() const; private: struct Components @@ -86,6 +92,8 @@ class GenomeCallingComponents std::vector samples; InputRegionMap regions; std::vector contigs; + boost::optional temp_directory; + boost::optional reads_profile_; ReadPipe read_pipe; CallerFactory caller_factory; std::unique_ptr call_filter_factory; @@ -93,12 +101,14 @@ class GenomeCallingComponents VcfWriter output; boost::optional num_threads; std::size_t read_buffer_size; - boost::optional temp_directory; ProgressMeter progress_meter; + PloidyMap ploidies; + boost::optional pedigree; bool sites_only; boost::optional filtered_output; boost::optional legacy; boost::optional filter_request_; + boost::optional bamout_; void setup_progress_meter(const options::OptionMap& options); void set_read_buffer_size(const options::OptionMap& options); diff --git a/src/core/csr/facets/facet.cpp b/src/core/csr/facets/facet.cpp index 696416ebb..dbd1f4e10 100644 --- a/src/core/csr/facets/facet.cpp +++ b/src/core/csr/facets/facet.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "facet.hpp" diff --git a/src/core/csr/facets/facet.hpp b/src/core/csr/facets/facet.hpp index 7395bfa66..ff3e51356 100644 --- a/src/core/csr/facets/facet.hpp +++ b/src/core/csr/facets/facet.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef facet_hpp @@ -15,17 +15,32 @@ #include "config/common.hpp" #include "core/types/haplotype.hpp" #include "core/tools/read_assigner.hpp" +#include "basics/ploidy_map.hpp" +#include "basics/pedigree.hpp" namespace octopus { namespace csr { class Facet : public Equitable { public: + using GenotypeMap = std::unordered_map>>; + using SampleSupportMap = std::unordered_map; + using LocalPloidyMap = std::unordered_map; + + struct SupportMaps + { + SampleSupportMap support; + ReadMap ambiguous; + }; + using ResultType = boost::variant, - std::reference_wrapper>, + std::reference_wrapper, std::reference_wrapper, std::reference_wrapper>, - std::reference_wrapper + std::reference_wrapper, + std::reference_wrapper, + std::reference_wrapper, + std::reference_wrapper >; Facet() = default; diff --git a/src/core/csr/facets/facet_factory.cpp b/src/core/csr/facets/facet_factory.cpp index 88d663e24..a7203b86b 100644 --- a/src/core/csr/facets/facet_factory.cpp +++ b/src/core/csr/facets/facet_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "facet_factory.hpp" @@ -16,20 +16,56 @@ #include "read_assignments.hpp" #include "reference_context.hpp" #include "samples.hpp" +#include "genotypes.hpp" +#include "ploidies.hpp" +#include "pedigree.hpp" namespace octopus { namespace csr { -FacetFactory::FacetFactory(const ReferenceGenome& reference, BufferedReadPipe read_pipe) -: reference_ {reference} +FacetFactory::FacetFactory(VcfHeader input_header) +: input_header_ {std::move(input_header)} +, samples_ {input_header_.samples()} +, reference_ {} +, read_pipe_ {} +, ploidies_ {} +, pedigree_ {} +, facet_makers_ {} +{ + setup_facet_makers(); +} + +FacetFactory::FacetFactory(VcfHeader input_header, const ReferenceGenome& reference, BufferedReadPipe read_pipe, PloidyMap ploidies) +: input_header_ {std::move(input_header)} +, samples_ {input_header_.samples()} +, reference_ {reference} , read_pipe_ {std::move(read_pipe)} +, ploidies_ {std::move(ploidies)} +, pedigree_ {} +, facet_makers_ {} +{ + setup_facet_makers(); +} + +FacetFactory::FacetFactory(VcfHeader input_header, const ReferenceGenome& reference, BufferedReadPipe read_pipe, PloidyMap ploidies, + octopus::Pedigree pedigree) +: input_header_ {std::move(input_header)} +, samples_ {input_header_.samples()} +, reference_ {reference} +, read_pipe_ {std::move(read_pipe)} +, ploidies_ {std::move(ploidies)} +, pedigree_ {std::move(pedigree)} , facet_makers_ {} { setup_facet_makers(); } FacetFactory::FacetFactory(FacetFactory&& other) -: reference_ {std::move(other.reference_)} +: input_header_ {std::move(other.input_header_)} +, samples_ {std::move(other.samples_)} +, reference_ {std::move(other.reference_)} , read_pipe_ {std::move(other.read_pipe_)} +, ploidies_ {std::move(other.ploidies_)} +, pedigree_ {std::move(other.pedigree_)} , facet_makers_ {} { setup_facet_makers(); @@ -38,8 +74,12 @@ FacetFactory::FacetFactory(FacetFactory&& other) FacetFactory& FacetFactory::operator=(FacetFactory&& other) { using std::swap; + swap(input_header_, other.input_header_); + swap(samples_, other.samples_); swap(reference_, other.reference_); swap(read_pipe_, other.read_pipe_); + swap(ploidies_, other.ploidies_); + swap(pedigree_, other.pedigree_); setup_facet_makers(); return *this; } @@ -59,6 +99,7 @@ class UnknownFacet : public ProgramError FacetWrapper FacetFactory::make(const std::string& name, const CallBlock& block) const { + check_requirements(name); const auto block_data = make_block_data({name}, block); return make(name, block_data); } @@ -66,6 +107,7 @@ FacetWrapper FacetFactory::make(const std::string& name, const CallBlock& block) FacetFactory::FacetBlock FacetFactory::make(const std::vector& names, const CallBlock& block) const { if (names.empty()) return {}; + check_requirements(names); const auto block_data = make_block_data(names, block); return make(names, block_data); } @@ -78,6 +120,17 @@ decltype(auto) name() noexcept return Facet().name(); } +bool requires_reference(const std::string& facet) noexcept +{ + const static std::array read_facets{name(), name()}; + return std::find(std::cbegin(read_facets), std::cend(read_facets), facet) != std::cend(read_facets); +} + +bool requires_reference(const std::vector& facets) noexcept +{ + return std::any_of(std::cbegin(facets), std::cend(facets), [](const auto& facet) { return requires_reference(facet); }); +} + bool requires_reads(const std::string& facet) noexcept { const static std::array read_facets{name(), name()}; @@ -97,17 +150,56 @@ bool requires_genotypes(const std::string& facet) noexcept bool requires_genotypes(const std::vector& facets) noexcept { - return std::any_of(std::cbegin(facets), std::cend(facets), - [](const auto& facet) { return requires_genotypes(facet); }); + return std::any_of(std::cbegin(facets), std::cend(facets), [](const auto& facet) { return requires_genotypes(facet); }); +} + +bool requires_ploidies(const std::string& facet) noexcept +{ + const static std::array genotype_facets{name()}; + return std::find(std::cbegin(genotype_facets), std::cend(genotype_facets), facet) != std::cend(genotype_facets); +} + +bool requires_ploidies(const std::vector& facets) noexcept +{ + return std::any_of(std::cbegin(facets), std::cend(facets), [](const auto& facet) { return requires_ploidies(facet); }); +} + +bool requires_pedigree(const std::string& facet) noexcept +{ + const static std::array genotype_facets{name()}; + return std::find(std::cbegin(genotype_facets), std::cend(genotype_facets), facet) != std::cend(genotype_facets); +} + +bool requires_pedigree(const std::vector& facets) noexcept +{ + return std::any_of(std::cbegin(facets), std::cend(facets), [](const auto& facet) { return requires_pedigree(facet); }); } } // namespace -std::vector FacetFactory::make(const std::vector& names, - const std::vector& blocks, - ThreadPool& workers) const +class BadFacetFactoryRequest : public ProgramError +{ + std::string facet_; + std::string do_where() const override { return "FacetFactory"; } + std::string do_why() const override + { + return "Could not make facet " + facet_ + " due to an unmet requirement"; + } + std::string do_help() const override + { + return "submit an error report"; + } +public: + BadFacetFactoryRequest(std::string facet) : facet_ {std::move(facet)} {} +}; + +std::vector +FacetFactory::make(const std::vector& names, + const std::vector& blocks, + ThreadPool& workers) const { if (blocks.empty()) return {}; + check_requirements(names); std::vector result {}; result.reserve(blocks.size()); if (blocks.size() > 1 && !workers.empty()) { @@ -115,19 +207,18 @@ std::vector FacetFactory::make(const std::vectorfetch_reads(*data.region); } } - futures.push_back(workers.push([this, &names, data {std::move(data)}, &block, &samples, fetch_genotypes] () mutable { + futures.push_back(workers.push([this, &names, data {std::move(data)}, &block, fetch_genotypes] () mutable { if (fetch_genotypes) { - data.genotypes = extract_genotypes(block, samples, reference_); + data.genotypes = extract_genotypes(block, samples_, *reference_); } return this->make(names, data); })); @@ -156,21 +247,57 @@ void FacetFactory::setup_facet_makers() facet_makers_[name()] = [this] (const BlockData& block) -> FacetWrapper { assert(block.reads && block.genotypes); - return {std::make_unique(reference_, *block.genotypes, *block.reads)}; + return {std::make_unique(*reference_, *block.genotypes, *block.reads)}; }; facet_makers_[name()] = [this] (const BlockData& block) -> FacetWrapper { if (block.region) { constexpr GenomicRegion::Size context_size {50}; - return {std::make_unique(reference_, expand(*block.region, context_size))}; + return {std::make_unique(*reference_, expand(*block.region, context_size))}; } else { return {nullptr}; } }; facet_makers_[name()] = [this] (const BlockData& block) -> FacetWrapper { - return {std::make_unique(read_pipe_.source().samples())}; + return {std::make_unique(this->samples_)}; + }; + facet_makers_[name()] = [] (const BlockData& block) -> FacetWrapper + { + assert(block.genotypes); + return {std::make_unique(*block.genotypes)}; + }; + facet_makers_[name()] = [this] (const BlockData& block) -> FacetWrapper + { + return {std::make_unique(*ploidies_, *block.region, input_header_.samples())}; }; + facet_makers_[name()] = [this] (const BlockData& block) -> FacetWrapper + { + return {std::make_unique(*pedigree_)}; + }; +} + +void FacetFactory::check_requirements(const std::string& name) const +{ + if (!read_pipe_ && requires_reads(name)) { + throw BadFacetFactoryRequest {name}; + } + if (!reference_ && requires_reference(name)) { + throw BadFacetFactoryRequest {name}; + } + if (!ploidies_ && requires_ploidies(name)) { + throw BadFacetFactoryRequest {name}; + } + if (!pedigree_ && requires_pedigree(name)) { + throw BadFacetFactoryRequest {name}; + } +} + +void FacetFactory::check_requirements(const std::vector& names) const +{ + for (const auto& name : names) { + check_requirements(name); + } } FacetWrapper FacetFactory::make(const std::string& name, const BlockData& block) const @@ -199,10 +326,10 @@ FacetFactory::BlockData FacetFactory::make_block_data(const std::vectorfetch_reads(*result.region); } if (requires_genotypes(names)) { - result.genotypes = extract_genotypes(block, read_pipe_.source().samples(), reference_); + result.genotypes = extract_genotypes(block, samples_, *reference_); } } return result; diff --git a/src/core/csr/facets/facet_factory.hpp b/src/core/csr/facets/facet_factory.hpp index 28a5c0a40..1de9c0f84 100644 --- a/src/core/csr/facets/facet_factory.hpp +++ b/src/core/csr/facets/facet_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef facet_factory_hpp @@ -12,6 +12,9 @@ #include #include "config/common.hpp" +#include "basics/pedigree.hpp" +#include "basics/ploidy_map.hpp" +#include "io/variant/vcf_header.hpp" #include "io/variant/vcf_record.hpp" #include "io/reference/reference_genome.hpp" #include "readpipe/buffered_read_pipe.hpp" @@ -29,7 +32,10 @@ class FacetFactory FacetFactory() = delete; - FacetFactory(const ReferenceGenome& reference, BufferedReadPipe read_pipe); + FacetFactory(VcfHeader input_header); + FacetFactory(VcfHeader input_header, const ReferenceGenome& reference, BufferedReadPipe read_pipe, PloidyMap ploidies); + FacetFactory(VcfHeader input_header, const ReferenceGenome& reference, BufferedReadPipe read_pipe, PloidyMap ploidies, + octopus::Pedigree pedigree); FacetFactory(const FacetFactory&) = delete; FacetFactory& operator=(const FacetFactory&) = delete; @@ -40,8 +46,7 @@ class FacetFactory FacetWrapper make(const std::string& name, const CallBlock& block) const; FacetBlock make(const std::vector& names, const CallBlock& block) const; - std::vector make(const std::vector& names, const std::vector& blocks, - ThreadPool& workers) const; + std::vector make(const std::vector& names, const std::vector& blocks, ThreadPool& workers) const; private: struct BlockData @@ -51,12 +56,18 @@ class FacetFactory boost::optional genotypes; }; - std::reference_wrapper reference_; - BufferedReadPipe read_pipe_; + VcfHeader input_header_; + std::vector samples_; + boost::optional> reference_; + boost::optional read_pipe_; + boost::optional ploidies_; + boost::optional pedigree_; std::unordered_map> facet_makers_; void setup_facet_makers(); + void check_requirements(const std::string& name) const; + void check_requirements(const std::vector& names) const; FacetWrapper make(const std::string& name, const BlockData& block) const; FacetBlock make(const std::vector& names, const BlockData& block) const; BlockData make_block_data(const std::vector& names, const CallBlock& block) const; diff --git a/src/core/csr/facets/genotypes.cpp b/src/core/csr/facets/genotypes.cpp new file mode 100644 index 000000000..045caee51 --- /dev/null +++ b/src/core/csr/facets/genotypes.cpp @@ -0,0 +1,20 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "genotypes.hpp" + +#include + +namespace octopus { namespace csr { + +const std::string Genotypes::name_ {"Genotypes"}; + +Genotypes::Genotypes(GenotypeMap genotypes) : genotypes_ {std::move(genotypes)} {} + +Facet::ResultType Genotypes::do_get() const +{ + return std::cref(genotypes_); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/facets/genotypes.hpp b/src/core/csr/facets/genotypes.hpp new file mode 100644 index 000000000..4878a9158 --- /dev/null +++ b/src/core/csr/facets/genotypes.hpp @@ -0,0 +1,34 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef genotypes_hpp +#define genotypes_hpp + +#include +#include + +#include "facet.hpp" + +namespace octopus { namespace csr { + +class Genotypes : public Facet +{ +public: + using ResultType = std::reference_wrapper; + + Genotypes() = default; + Genotypes(GenotypeMap genotypes); + +private: + static const std::string name_; + + GenotypeMap genotypes_; + + const std::string& do_name() const noexcept override { return name_; } + Facet::ResultType do_get() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/facets/overlapping_reads.cpp b/src/core/csr/facets/overlapping_reads.cpp index fd045ba67..05240a02f 100644 --- a/src/core/csr/facets/overlapping_reads.cpp +++ b/src/core/csr/facets/overlapping_reads.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "overlapping_reads.hpp" diff --git a/src/core/csr/facets/overlapping_reads.hpp b/src/core/csr/facets/overlapping_reads.hpp index 0e56fd628..043effa1e 100644 --- a/src/core/csr/facets/overlapping_reads.hpp +++ b/src/core/csr/facets/overlapping_reads.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef overlapping_reads_hpp diff --git a/src/core/csr/facets/pedigree.cpp b/src/core/csr/facets/pedigree.cpp new file mode 100644 index 000000000..3b7e15c95 --- /dev/null +++ b/src/core/csr/facets/pedigree.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "pedigree.hpp" + +#include + +namespace octopus { namespace csr { + +const std::string Pedigree::name_ {"Pedigree"}; + +Pedigree::Pedigree(octopus::Pedigree pedigree) +: pedigree_ {std::move(pedigree)} +{} + +Facet::ResultType Pedigree::do_get() const +{ + return std::cref(pedigree_); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/facets/pedigree.hpp b/src/core/csr/facets/pedigree.hpp new file mode 100644 index 000000000..b02c2e647 --- /dev/null +++ b/src/core/csr/facets/pedigree.hpp @@ -0,0 +1,36 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef csr_pedigree_hpp +#define csr_pedigree_hpp + +#include +#include +#include + +#include "basics/pedigree.hpp" +#include "facet.hpp" + +namespace octopus { namespace csr { + +class Pedigree : public Facet +{ +public: + using ResultType = std::reference_wrapper; + + Pedigree() = default; + Pedigree(octopus::Pedigree pedigree); + +private: + static const std::string name_; + + octopus::Pedigree pedigree_; + + const std::string& do_name() const noexcept override { return name_; } + Facet::ResultType do_get() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/facets/ploidies.cpp b/src/core/csr/facets/ploidies.cpp new file mode 100644 index 000000000..92cf0fa12 --- /dev/null +++ b/src/core/csr/facets/ploidies.cpp @@ -0,0 +1,29 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "ploidies.hpp" + +#include +#include +#include + +namespace octopus { namespace csr { + +const std::string Ploidies::name_ {"Ploidies"}; + +Ploidies::Ploidies(const PloidyMap& ploidies, const GenomicRegion& region, const std::vector& samples) +{ + const auto local_ploidies = get_ploidies(samples, region.contig_name(), ploidies); + ploidies_.reserve(samples.size()); + std::transform(std::cbegin(samples), std::cend(samples), std::cbegin(local_ploidies), + std::inserter(ploidies_, std::begin(ploidies_)), + [] (const auto& sample, auto ploidy) { return std::make_pair(sample, ploidy); }); +} + +Facet::ResultType Ploidies::do_get() const +{ + return std::cref(ploidies_); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/facets/ploidies.hpp b/src/core/csr/facets/ploidies.hpp new file mode 100644 index 000000000..68ce1ed18 --- /dev/null +++ b/src/core/csr/facets/ploidies.hpp @@ -0,0 +1,38 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef ploidies_hpp +#define ploidies_hpp + +#include +#include +#include + +#include "config/common.hpp" +#include "basics/genomic_region.hpp" +#include "basics/ploidy_map.hpp" +#include "facet.hpp" + +namespace octopus { namespace csr { + +class Ploidies : public Facet +{ +public: + using ResultType = std::reference_wrapper; + + Ploidies() = default; + Ploidies(const PloidyMap& ploidies, const GenomicRegion& region, const std::vector& samples); + +private: + static const std::string name_; + + LocalPloidyMap ploidies_; + + const std::string& do_name() const noexcept override { return name_; } + Facet::ResultType do_get() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/facets/read_assignments.cpp b/src/core/csr/facets/read_assignments.cpp index 34721b3ac..6752f27b1 100644 --- a/src/core/csr/facets/read_assignments.cpp +++ b/src/core/csr/facets/read_assignments.cpp @@ -1,8 +1,10 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_assignments.hpp" +#include "core/tools/read_realigner.hpp" + namespace octopus { namespace csr { const std::string ReadAssignments::name_ {"ReadAssignments"}; @@ -16,9 +18,9 @@ auto copy_overlapped_to_vector(const ReadContainer& reads, const Mappable& mappa return std::vector {std::cbegin(overlapped), std::cend(overlapped)}; } -bool is_homozygous_nonreference(const Genotype& genotype) +void move_insert(std::deque& reads, const SampleName& sample, ReadMap& result) { - return genotype.is_homozygous() && !is_reference(genotype[0]); + result[sample].insert(std::make_move_iterator(std::begin(reads)), std::make_move_iterator(std::end(reads))); } } // namespace @@ -26,31 +28,42 @@ bool is_homozygous_nonreference(const Genotype& genotype) ReadAssignments::ReadAssignments(const ReferenceGenome& reference, const GenotypeMap& genotypes, const ReadMap& reads) : result_ {} { - result_.reserve(genotypes.size()); + AssignmentConfig assigner_config {}; + assigner_config.ambiguous_action = AssignmentConfig::AmbiguousAction::random; + const auto num_samples = genotypes.size(); + result_.support.reserve(num_samples); + result_.ambiguous.reserve(num_samples); for (const auto& p : genotypes) { const auto& sample = p.first; - const auto& genotypes = p.second; - result_[sample].reserve(genotypes.size()); - for (const auto& genotype : genotypes) { + const auto& sample_genotypes = p.second; + result_.support[sample].reserve(sample_genotypes.size()); + for (const auto& genotype : sample_genotypes) { auto local_reads = copy_overlapped_to_vector(reads.at(sample), genotype); for (const auto& haplotype : genotype) { // So every called haplotype appears in support map, even if no read support - result_[sample][haplotype] = {}; + result_.support[sample][haplotype] = {}; } if (!local_reads.empty()) { HaplotypeSupportMap genotype_support {}; - if (!is_homozygous_nonreference(genotype)) { - genotype_support = compute_haplotype_support(genotype, local_reads); + std::deque unassigned {}; + if (!genotype.is_homozygous()) { + genotype_support = compute_haplotype_support(genotype, local_reads, unassigned, assigner_config); } else { - auto augmented_genotype = genotype; - Haplotype ref {mapped_region(genotype), reference}; - result_[sample][ref] = {}; - augmented_genotype.emplace(std::move(ref)); - genotype_support = compute_haplotype_support(augmented_genotype, local_reads); + if (is_reference(genotype[0])) { + genotype_support[genotype[0]] = std::move(local_reads); + } else { + auto augmented_genotype = genotype; + Haplotype ref {mapped_region(genotype), reference}; + result_.support[sample][ref] = {}; + augmented_genotype.emplace(std::move(ref)); + genotype_support = compute_haplotype_support(augmented_genotype, local_reads, unassigned, assigner_config); + } } for (auto& s : genotype_support) { - result_[sample][s.first] = std::move(s.second); + safe_realign_to_reference(s.second, s.first); + result_.support[sample][s.first] = std::move(s.second); } + move_insert(unassigned, sample, result_.ambiguous); } } } diff --git a/src/core/csr/facets/read_assignments.hpp b/src/core/csr/facets/read_assignments.hpp index 707a52082..88354e379 100644 --- a/src/core/csr/facets/read_assignments.hpp +++ b/src/core/csr/facets/read_assignments.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_assignments_hpp @@ -22,9 +22,7 @@ namespace octopus { namespace csr { class ReadAssignments : public Facet { public: - using GenotypeMap = std::unordered_map>>; - using SampleSupportMap = std::unordered_map; - using ResultType = std::reference_wrapper; + using ResultType = std::reference_wrapper; ReadAssignments() = default; @@ -33,7 +31,7 @@ class ReadAssignments : public Facet private: static const std::string name_; - SampleSupportMap result_; + SupportMaps result_; const std::string& do_name() const noexcept override { return name_; } Facet::ResultType do_get() const override; diff --git a/src/core/csr/facets/reference_context.cpp b/src/core/csr/facets/reference_context.cpp index 63de9266f..a830b8156 100644 --- a/src/core/csr/facets/reference_context.cpp +++ b/src/core/csr/facets/reference_context.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "reference_context.hpp" diff --git a/src/core/csr/facets/reference_context.hpp b/src/core/csr/facets/reference_context.hpp index f2747312e..6f2d29bb0 100644 --- a/src/core/csr/facets/reference_context.hpp +++ b/src/core/csr/facets/reference_context.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef reference_context_hpp diff --git a/src/core/csr/facets/samples.cpp b/src/core/csr/facets/samples.cpp index 925fc1270..126f5345e 100644 --- a/src/core/csr/facets/samples.cpp +++ b/src/core/csr/facets/samples.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "samples.hpp" diff --git a/src/core/csr/facets/samples.hpp b/src/core/csr/facets/samples.hpp index 335ad0258..647a097bd 100644 --- a/src/core/csr/facets/samples.hpp +++ b/src/core/csr/facets/samples.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef samples_hpp diff --git a/src/core/csr/filters/conditional_random_forest_filter.cpp b/src/core/csr/filters/conditional_random_forest_filter.cpp new file mode 100644 index 000000000..1c34de1a9 --- /dev/null +++ b/src/core/csr/filters/conditional_random_forest_filter.cpp @@ -0,0 +1,326 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "conditional_random_forest_filter.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ranger/ForestProbability.h" + +#include "basics/phred.hpp" +#include "utils/concat.hpp" +#include "exceptions/missing_file_error.hpp" +#include "exceptions/program_error.hpp" +#include "exceptions/malformed_file_error.hpp" + +namespace octopus { namespace csr { + + +namespace { + +class MissingForestFile : public MissingFileError +{ + std::string do_where() const override { return "ConditionalRandomForestFilter"; } +public: + MissingForestFile(boost::filesystem::path p) : MissingFileError {std::move(p), ".forest"} {}; +}; + +void check_all_exists(const std::vector& forests) +{ + for (const auto& forest : forests) { + if (!boost::filesystem::exists(forest)) { + throw MissingForestFile {forest}; + } + } +} + +} // namespace + +ConditionalRandomForestFilter::ConditionalRandomForestFilter(FacetFactory facet_factory, + std::vector measures, + std::vector chooser_measures, + std::function)> chooser, + std::vector ranger_forests, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory, + boost::optional progress) +: DoublePassVariantCallFilter {std::move(facet_factory), concat(std::move(measures), chooser_measures), + std::move(output_config), threading, progress} +, forest_paths_ {std::move(ranger_forests)} +, temp_dir_ {std::move(temp_directory)} +, chooser_ {std::move(chooser)} +, num_chooser_measures_ {chooser_measures.size()} +, num_records_ {0} +, data_buffer_ {} +{ + check_all_exists(forest_paths_); + forests_.reserve(forest_paths_.size()); + std::generate_n(std::back_inserter(forests_), forest_paths_.size(), + [] () { return std::make_unique(); }); +} + +const std::string ConditionalRandomForestFilter::call_qual_name_ = "RFQUAL"; + +boost::optional ConditionalRandomForestFilter::call_quality_name() const +{ + return call_qual_name_; +} + +void ConditionalRandomForestFilter::annotate(VcfHeader::Builder& header) const +{ + header.add_info(call_qual_name_, "1", "Float", "Empirical quality score from random forest classifier"); + header.add_filter("RF", "Random Forest filtered"); +} + +std::int8_t ConditionalRandomForestFilter::choose_forest(const MeasureVector& measures) const +{ + const MeasureVector chooser_measures(std::prev(std::cend(measures), num_chooser_measures_), std::cend(measures)); + return chooser_(chooser_measures); +} + +template +static void write_line(const std::vector& data, std::ostream& out) +{ + std::copy(std::cbegin(data), std::prev(std::cend(data)), std::ostream_iterator {out, " "}); + out << data.back() << '\n'; +} + +void ConditionalRandomForestFilter::prepare_for_registration(const SampleList& samples) const +{ + std::vector data_header {}; + data_header.reserve(measures_.size() - num_chooser_measures_); + std::transform(std::cbegin(measures_), std::prev(std::cend(measures_), num_chooser_measures_), std::back_inserter(data_header), + [] (const auto& measure) { return measure.name(); }); + data_header.push_back("TP"); + const auto num_forests = forest_paths_.size(); + data_.resize(num_forests); + for (std::size_t forest_idx {0}; forest_idx < num_forests; ++forest_idx) { + data_[forest_idx].reserve(samples.size()); + for (const auto& sample : samples) { + auto data_path = temp_dir_; + Path fname {"octopus_ranger_temp_forest_data_" + std::to_string(forest_idx) + "_" + sample + ".dat"}; + data_path /= fname; + data_[forest_idx].emplace_back(data_path.string(), data_path); + write_line(data_header, data_[forest_idx].back().handle); + } + } + data_buffer_.resize(num_forests); + for (auto& buffer : data_buffer_) buffer.resize(samples.size()); + choices_.resize(samples.size()); +} + +namespace { + +struct MeasureDoubleVisitor : boost::static_visitor<> +{ + double result; + template void operator()(const T& value) + { + result = boost::lexical_cast(value); + } + template void operator()(const boost::optional& value) + { + if (value) { + (*this)(*value); + } else { + result = -1; + } + } + template void operator()(const std::vector& values) + { + throw std::runtime_error {"Vector cast not supported"}; + } + void operator()(boost::any value) + { + throw std::runtime_error {"Any cast not supported"}; + } +}; + +auto cast_to_double(const Measure::ResultType& value) +{ + MeasureDoubleVisitor vis {}; + boost::apply_visitor(vis, value); + return vis.result; +} + +class NanMeasure : public ProgramError +{ + std::string do_where() const override { return "ConditionalRandomForestFilter"; } + std::string do_why() const override { return "detected a nan measure"; } + std::string do_help() const override { return "submit an error report"; } +}; + +void check_nan(const std::vector& values) +{ + if (std::any_of(std::cbegin(values), std::cend(values), [] (auto v) { return std::isnan(v); })) { + throw NanMeasure {}; + } +} + +void skip_lines(std::istream& in, int n = 1) +{ + for (; n > 0; --n) in.ignore(std::numeric_limits::max(), '\n'); +} + +} // namespace + +void ConditionalRandomForestFilter::record(const std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const +{ + assert(!measures.empty()); + const auto forest_idx = choose_forest(measures); + const auto num_forests = static_cast( data_buffer_.size()); + if (forest_idx >= 0 && forest_idx < num_forests) { + auto& buffer = data_buffer_[forest_idx][sample_idx]; + std::transform(std::cbegin(measures), std::prev(std::cend(measures), num_chooser_measures_), + std::back_inserter(buffer), cast_to_double); + buffer.push_back(0); // dummy TP value + check_nan(buffer); + write_line(buffer, data_[forest_idx][sample_idx].handle); + buffer.clear(); + } else { + hard_filtered_record_indices_.push_back(call_idx); + } + if (call_idx >= num_records_) ++num_records_; + choices_[sample_idx].push_back(forest_idx); +} + +void ConditionalRandomForestFilter::close_data_files() const +{ + for (auto& forest : data_) { + for (auto& sample : forest) { + sample.handle.close(); + } + } +} + +namespace { + +bool read_header(std::ifstream& prediction_file) +{ + skip_lines(prediction_file); + std::string order; + std::getline(prediction_file, order); + skip_lines(prediction_file); + return order.front() == '1'; +} + +static double get_prob_false(std::string& prediction_line, const bool tp_first) +{ + using std::cbegin; using std::cend; + if (tp_first) { + prediction_line.erase(cbegin(prediction_line), std::next(std::find(cbegin(prediction_line), cend(prediction_line), ' '))); + prediction_line.erase(std::find(cbegin(prediction_line), cend(prediction_line), ' '), cend(prediction_line)); + } else { + prediction_line.erase(std::find(cbegin(prediction_line), cend(prediction_line), ' '), cend(prediction_line)); + } + return boost::lexical_cast(prediction_line); +} + +} // namespace + +class MalformedForestFile : public MalformedFileError +{ + std::string do_where() const override { return "ConditionalRandomForestFilter"; } + std::string do_help() const override + { + return "make sure the forest was trained with the same measures and in the same order as the prediction measures"; + } +public: + MalformedForestFile(boost::filesystem::path file) : MalformedFileError {std::move(file)} {} +}; + +void ConditionalRandomForestFilter::prepare_for_classification(boost::optional& log) const +{ + close_data_files(); + const Path ranger_prefix {temp_dir_ / "octopus_ranger_temp"}; + const Path ranger_prediction_fname {ranger_prefix.string() + ".prediction"}; + data_buffer_.resize(1); + auto& predictions = data_buffer_[0]; + predictions.resize(num_records_); + const auto num_samples = choices_.size(); + for (std::size_t forest_idx {0}; forest_idx < forest_paths_.size(); ++forest_idx) { + for (std::size_t sample_idx {0}; sample_idx < num_samples; ++sample_idx) { + auto forest_choice_itr = std::find(std::cbegin(choices_[sample_idx]), std::cend(choices_[sample_idx]), forest_idx); + if (forest_choice_itr != std::cend(choices_[sample_idx])) { + const auto& file = data_[forest_idx][sample_idx]; + std::vector tmp {}, cat_vars {}; + auto& forest = forests_[forest_idx]; + try { + forest->initCpp("TP", ranger::MemoryMode::MEM_DOUBLE, file.path.string(), 0, ranger_prefix.string(), + 1000, nullptr, 12, 1, forest_paths_[forest_idx].string(), ranger::ImportanceMode::IMP_GINI, 1, "", + tmp, "", true, cat_vars, false, ranger::SplitRule::LOGRANK, "", false, 1.0, + ranger::DEFAULT_ALPHA, ranger::DEFAULT_MINPROP, false, + ranger::PredictionType::RESPONSE, ranger::DEFAULT_NUM_RANDOM_SPLITS); + } catch (const std::runtime_error& e) { + throw MalformedForestFile {forest_paths_[forest_idx]}; + } + forest->run(false); + forest->writePredictionFile(); + std::ifstream prediction_file {ranger_prediction_fname.string()}; + const auto tp_first = read_header(prediction_file); + std::string line; + while (std::getline(prediction_file, line)) { + if (!line.empty()) { + const auto record_idx = std::distance(std::cbegin(choices_[sample_idx]), forest_choice_itr); + predictions[record_idx].push_back(get_prob_false(line, tp_first)); + assert(forest_choice_itr != std::cend(choices_[sample_idx])); + forest_choice_itr = std::find(std::next(forest_choice_itr), std::cend(choices_[sample_idx]), forest_idx); + } + } + boost::filesystem::remove(file.path); + } + } + } + boost::filesystem::remove(ranger_prediction_fname); + data_.clear(); + data_.shrink_to_fit(); + choices_.clear(); + choices_.shrink_to_fit(); + if (!hard_filtered_record_indices_.empty()) { + hard_filtered_.resize(num_records_, false); + for (auto idx : hard_filtered_record_indices_) { + hard_filtered_[idx] = true; + } + hard_filtered_record_indices_.clear(); + hard_filtered_record_indices_.shrink_to_fit(); + } +} + +std::size_t ConditionalRandomForestFilter::get_forest_choice(std::size_t call_idx, std::size_t sample_idx) const +{ + return choices_.empty() ? 0 : choices_[sample_idx][call_idx]; +} + +VariantCallFilter::Classification ConditionalRandomForestFilter::classify(const std::size_t call_idx, std::size_t sample_idx) const +{ + Classification result {}; + if (hard_filtered_.empty() || !hard_filtered_[call_idx]) { + const auto& predictions = data_buffer_[0]; + assert(call_idx < predictions.size() && sample_idx < predictions[call_idx].size()); + const auto prob_false = predictions[call_idx][sample_idx]; + if (prob_false < 0.5) { + result.category = Classification::Category::unfiltered; + } else { + result.category = Classification::Category::soft_filtered; + result.reasons.assign({"RF"}); + } + result.quality = probability_to_phred(std::max(prob_false, 1e-10)); + } else { + result.category = Classification::Category::hard_filtered; + } + return result; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/conditional_random_forest_filter.hpp b/src/core/csr/filters/conditional_random_forest_filter.hpp new file mode 100644 index 000000000..063c21e1c --- /dev/null +++ b/src/core/csr/filters/conditional_random_forest_filter.hpp @@ -0,0 +1,85 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef conditional_random_forest_filter_hpp +#define conditional_random_forest_filter_hpp + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "ranger/Forest.h" + +#include "double_pass_variant_call_filter.hpp" + +namespace octopus { namespace csr { + +class ConditionalRandomForestFilter : public DoublePassVariantCallFilter +{ +public: + using Path = boost::filesystem::path; + + ConditionalRandomForestFilter() = delete; + + ConditionalRandomForestFilter(FacetFactory facet_factory, + std::vector measures, + std::vector chooser_measures, + std::function)> chooser, + std::vector ranger_forests, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory = "/tmp", + boost::optional progress = boost::none); + + ConditionalRandomForestFilter(const ConditionalRandomForestFilter&) = delete; + ConditionalRandomForestFilter& operator=(const ConditionalRandomForestFilter&) = delete; + ConditionalRandomForestFilter(ConditionalRandomForestFilter&&) = default; + ConditionalRandomForestFilter& operator=(ConditionalRandomForestFilter&&) = default; + + virtual ~ConditionalRandomForestFilter() override = default; + +private: + struct File + { + std::ofstream handle; + Path path; + template + File(F&& handle, P&& path) : handle {std::forward(handle)}, path {std::forward

(path)} {}; + }; + + std::vector forest_paths_; + Path temp_dir_; + std::vector> forests_; + std::function)> chooser_; + std::size_t num_chooser_measures_; + + mutable std::vector> data_; + mutable std::size_t num_records_; + mutable std::vector>> data_buffer_; + mutable std::vector> choices_; + mutable std::deque hard_filtered_record_indices_; + mutable std::vector hard_filtered_; + + const static std::string call_qual_name_; + + boost::optional call_quality_name() const override; + void annotate(VcfHeader::Builder& header) const override; + std::int8_t choose_forest(const MeasureVector& measures) const; + void prepare_for_registration(const SampleList& samples) const override; + void record(std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const override; + void close_data_files() const; + void prepare_for_classification(boost::optional& log) const override; + std::size_t get_forest_choice(std::size_t call_idx, std::size_t sample_idx) const; + Classification classify(std::size_t call_idx, std::size_t sample_idx) const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/conditional_threshold_filter.cpp b/src/core/csr/filters/conditional_threshold_filter.cpp new file mode 100644 index 000000000..6b40ed501 --- /dev/null +++ b/src/core/csr/filters/conditional_threshold_filter.cpp @@ -0,0 +1,121 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "conditional_threshold_filter.hpp" + +#include +#include +#include + +#include "utils/append.hpp" + +namespace octopus { namespace csr { + +namespace { + +auto concat(std::vector conditions) +{ + ThresholdVariantCallFilter::ConditionVectorPair result {}; + for (auto& p : conditions) { + utils::append(p.hard, result.hard); + utils::append(p.soft, result.soft); + } + return result; +} + +bool are_all_unique(std::vector keys) +{ + std::sort(std::begin(keys), std::end(keys)); + return std::adjacent_find(std::cbegin(keys), std::cend(keys)) == std::cend(keys); +} + +} // namespace + +ConditionalThresholdVariantCallFilter::ConditionalThresholdVariantCallFilter(FacetFactory facet_factory, + std::vector conditions, + std::vector chooser_measures, + std::function)> chooser, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress) +: ThresholdVariantCallFilter {std::move(facet_factory), concat(conditions), output_config, threading, progress, chooser_measures} +, hard_ranges_ {} +, soft_ranges_ {} +, chooser_ {std::move(chooser)} +, unique_filter_keys_ {} +, num_chooser_measures_ {chooser_measures.size()} +{ + measures_.shrink_to_fit(); + hard_ranges_.reserve(conditions.size()); + soft_ranges_.reserve(conditions.size()); + unique_filter_keys_.reserve(conditions.size()); + std::size_t i {0}, j {0}, k {0}; + for (const auto& p : conditions) { + hard_ranges_.push_back({i, i + p.hard.size(), j}); + i += p.hard.size(); j += p.hard.size(); + soft_ranges_.push_back({i, i + p.soft.size(), k}); + i += p.soft.size(); + std::vector filter_keys {std::next(std::cbegin(vcf_filter_keys_), k), + std::next(std::cbegin(vcf_filter_keys_), k + p.soft.size())}; + unique_filter_keys_.push_back(are_all_unique(filter_keys)); + k += p.soft.size(); + } +} + +bool ConditionalThresholdVariantCallFilter::passes_all_hard_filters(const MeasureVector& measures) const +{ + return passes_all_hard_filters(measures, hard_ranges_[choose_filter(measures)]); +} + +bool ConditionalThresholdVariantCallFilter::passes_all_soft_filters(const MeasureVector& measures) const +{ + return passes_all_soft_filters(measures, soft_ranges_[choose_filter(measures)]); +} + +std::vector ConditionalThresholdVariantCallFilter::get_failing_vcf_filter_keys(const MeasureVector& measures) const +{ + const auto filter_idx = choose_filter(measures); + auto result = get_failing_vcf_filter_keys(measures, soft_ranges_[filter_idx]); + if (!unique_filter_keys_[filter_idx]) { + std::sort(std::begin(result), std::end(result)); + result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); + } + return result; +} + +std::size_t ConditionalThresholdVariantCallFilter::choose_filter(const MeasureVector& measures) const +{ + const MeasureVector chooser_measures(std::prev(std::cend(measures), num_chooser_measures_), std::cend(measures)); + return chooser_(chooser_measures); +} + +bool ConditionalThresholdVariantCallFilter::passes_all_hard_filters(const MeasureVector& measures, const MeasureIndexRange range) const +{ + using std::cbegin; using std::next; + return passes_all_filters(next(cbegin(measures), range.measure_begin), next(cbegin(measures), range.measure_end), + next(cbegin(hard_thresholds_), range.threshold_begin)); +} + +bool ConditionalThresholdVariantCallFilter::passes_all_soft_filters(const MeasureVector& measures, const MeasureIndexRange range) const +{ + using std::cbegin; using std::next; + return passes_all_filters(next(cbegin(measures), range.measure_begin), next(cbegin(measures), range.measure_end), + next(cbegin(soft_thresholds_), range.threshold_begin)); +} + +std::vector +ConditionalThresholdVariantCallFilter::get_failing_vcf_filter_keys(const MeasureVector& measures, const MeasureIndexRange range) const +{ + std::vector result {}; + const auto num_conditions = range.measure_end - range.measure_begin; + result.reserve(num_conditions); + for (std::size_t i {0}; i < num_conditions; ++i) { + if (!soft_thresholds_[range.threshold_begin + i](measures[range.measure_begin + i])) { + result.push_back(vcf_filter_keys_[range.threshold_begin + i]); + } + } + return result; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/conditional_threshold_filter.hpp b/src/core/csr/filters/conditional_threshold_filter.hpp new file mode 100644 index 000000000..b8392670f --- /dev/null +++ b/src/core/csr/filters/conditional_threshold_filter.hpp @@ -0,0 +1,65 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef conditional_threshold_filter_hpp +#define conditional_threshold_filter_hpp + +#include +#include +#include +#include + +#include + +#include "threshold_filter.hpp" +#include "logging/progress_meter.hpp" +#include "../facets/facet_factory.hpp" +#include "../measures/measure.hpp" + +namespace octopus { namespace csr { + +class ConditionalThresholdVariantCallFilter : public ThresholdVariantCallFilter +{ +public: + ConditionalThresholdVariantCallFilter() = delete; + + ConditionalThresholdVariantCallFilter(FacetFactory facet_factory, + std::vector conditions, + std::vector chooser_measures, + std::function)> chooser, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress = boost::none); + + ConditionalThresholdVariantCallFilter(const ConditionalThresholdVariantCallFilter&) = delete; + ConditionalThresholdVariantCallFilter& operator=(const ConditionalThresholdVariantCallFilter&) = delete; + ConditionalThresholdVariantCallFilter(ConditionalThresholdVariantCallFilter&&) = default; + ConditionalThresholdVariantCallFilter& operator=(ConditionalThresholdVariantCallFilter&&) = default; + + virtual ~ConditionalThresholdVariantCallFilter() override = default; + +private: + struct MeasureIndexRange + { + std::size_t measure_begin, measure_end, threshold_begin; + }; + + std::vector hard_ranges_, soft_ranges_; + std::function)> chooser_; + std::vector unique_filter_keys_; + std::size_t num_chooser_measures_; + + virtual bool passes_all_hard_filters(const MeasureVector& measures) const override; + virtual bool passes_all_soft_filters(const MeasureVector& measures) const override; + virtual std::vector get_failing_vcf_filter_keys(const MeasureVector& measures) const override; + + std::size_t choose_filter(const MeasureVector& measures) const; + bool passes_all_hard_filters(const MeasureVector& measures, MeasureIndexRange range) const; + bool passes_all_soft_filters(const MeasureVector& measures, MeasureIndexRange range) const; + std::vector get_failing_vcf_filter_keys(const MeasureVector& measures, MeasureIndexRange range) const; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/denovo_random_forest_filter.cpp b/src/core/csr/filters/denovo_random_forest_filter.cpp new file mode 100644 index 000000000..13627ebf3 --- /dev/null +++ b/src/core/csr/filters/denovo_random_forest_filter.cpp @@ -0,0 +1,59 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "denovo_random_forest_filter.hpp" + +#include + +#include "../measures/is_denovo.hpp" +#include "../measures/is_refcall.hpp" + +namespace octopus { namespace csr { + +DeNovoRandomForestVariantCallFilter::DeNovoRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path germline_forest, Path denovo_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory, + boost::optional progress) +: ConditionalRandomForestFilter { + std::move(facet_factory), + std::move(measures), + {make_wrapped_measure(true)}, + [] (const MeasureVector& measures) -> std::int8_t { return !boost::get(measures.front()); }, + {std::move(germline_forest), std::move(denovo_forest)}, + std::move(output_config), + std::move(threading), + std::move(temp_directory), + progress +} {} + +DeNovoRandomForestVariantCallFilter::DeNovoRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path denovo_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory, + boost::optional progress) +: ConditionalRandomForestFilter { + std::move(facet_factory), + std::move(measures), + {make_wrapped_measure(false)}, + [] (const MeasureVector& measures) -> std::int8_t { return !boost::get(measures.front()); }, + {std::move(denovo_forest)}, + std::move(output_config), + std::move(threading), + std::move(temp_directory), + progress +} {} + +bool DeNovoRandomForestVariantCallFilter::is_soft_filtered(const ClassificationList& sample_classifications, + const MeasureVector& measures) const +{ + return std::any_of(std::cbegin(sample_classifications), std::cend(sample_classifications), + [] (const auto& c) { return c.category != Classification::Category::unfiltered; }); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/denovo_random_forest_filter.hpp b/src/core/csr/filters/denovo_random_forest_filter.hpp new file mode 100644 index 000000000..bc9f1ba5d --- /dev/null +++ b/src/core/csr/filters/denovo_random_forest_filter.hpp @@ -0,0 +1,56 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef denovo_random_forest_filter_hpp +#define denovo_random_forest_filter_hpp + +#include +#include + +#include +#include + +#include "threshold_filter.hpp" +#include "conditional_random_forest_filter.hpp" +#include "logging/progress_meter.hpp" +#include "../facets/facet_factory.hpp" +#include "../measures/measure.hpp" + +namespace octopus { namespace csr { + +class DeNovoRandomForestVariantCallFilter : public ConditionalRandomForestFilter +{ +public: + DeNovoRandomForestVariantCallFilter() = delete; + + DeNovoRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path germline_forest, Path denovo_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory = "/tmp", + boost::optional progress = boost::none); + // De novo only + DeNovoRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path denovo_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory = "/tmp", + boost::optional progress = boost::none); + + DeNovoRandomForestVariantCallFilter(const DeNovoRandomForestVariantCallFilter&) = delete; + DeNovoRandomForestVariantCallFilter& operator=(const DeNovoRandomForestVariantCallFilter&) = delete; + DeNovoRandomForestVariantCallFilter(DeNovoRandomForestVariantCallFilter&&) = default; + DeNovoRandomForestVariantCallFilter& operator=(DeNovoRandomForestVariantCallFilter&&) = default; + + virtual ~DeNovoRandomForestVariantCallFilter() override = default; + +private: + virtual bool is_soft_filtered(const ClassificationList& sample_classifications, const MeasureVector& measures) const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/denovo_threshold_filter.cpp b/src/core/csr/filters/denovo_threshold_filter.cpp new file mode 100644 index 000000000..69ecedd51 --- /dev/null +++ b/src/core/csr/filters/denovo_threshold_filter.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "denovo_threshold_filter.hpp" + +#include +#include +#include +#include + +#include + +#include "../measures/is_denovo.hpp" +#include "../measures/is_refcall.hpp" + +namespace octopus { namespace csr { + +DeNovoThresholdVariantCallFilter::DeNovoThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair germline, + ConditionVectorPair denovo, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress) +: ConditionalThresholdVariantCallFilter { + std::move(facet_factory), + {std::move(germline), std::move(denovo), std::move(reference)}, + {make_wrapped_measure(true), make_wrapped_measure(true)}, + [] (const MeasureVector& measures) -> std::size_t { + assert(measures.size() == 2); + if (boost::get(measures.front())) { + return 1; // DENOVO sample + } else if (boost::get(measures.back())) { + return 2; // REFCALL sample + } else { + return 0; // germline variant sample + }}, + output_config, threading, progress +} {} + +DeNovoThresholdVariantCallFilter::DeNovoThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair denovo, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress) +: ConditionalThresholdVariantCallFilter { + std::move(facet_factory), + {{{{make_wrapped_measure(false), make_wrapped_threshold>(false)}}, {}}, std::move(denovo), std::move(reference)}, + {make_wrapped_measure(false), make_wrapped_measure(true), make_wrapped_measure(true)}, + [] (const MeasureVector& measures) -> std::size_t { + assert(measures.size() == 3); + if (!boost::get(measures[0])) { + return 0; // Not DENOVO call + } else if (boost::get(measures[1])) { + return 1; // DENOVO sample + } else if (boost::get(measures[2])) { + return 2; // DENOVO call REFCALL parent + } else { + return 1; // DENOVO call non-REFCALL parent + }}, + output_config, threading, progress +} {} + +bool DeNovoThresholdVariantCallFilter::is_soft_filtered(const ClassificationList& sample_classifications, + const MeasureVector& measures) const +{ + return std::any_of(std::cbegin(sample_classifications), std::cend(sample_classifications), + [] (const auto& c) { return c.category != Classification::Category::unfiltered; }); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/denovo_threshold_filter.hpp b/src/core/csr/filters/denovo_threshold_filter.hpp new file mode 100644 index 000000000..7650b19b7 --- /dev/null +++ b/src/core/csr/filters/denovo_threshold_filter.hpp @@ -0,0 +1,55 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef denovo_threshold_filter_hpp +#define denovo_threshold_filter_hpp + +#include +#include + +#include + +#include "threshold_filter.hpp" +#include "conditional_threshold_filter.hpp" +#include "logging/progress_meter.hpp" +#include "../facets/facet_factory.hpp" +#include "../measures/measure.hpp" + +namespace octopus { namespace csr { + +class DeNovoThresholdVariantCallFilter : public ConditionalThresholdVariantCallFilter +{ +public: + DeNovoThresholdVariantCallFilter() = delete; + + DeNovoThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair germline, + ConditionVectorPair denovo, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress = boost::none); + + // Hard filter germline + DeNovoThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair denovo, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress = boost::none); + + DeNovoThresholdVariantCallFilter(const DeNovoThresholdVariantCallFilter&) = delete; + DeNovoThresholdVariantCallFilter& operator=(const DeNovoThresholdVariantCallFilter&) = delete; + DeNovoThresholdVariantCallFilter(DeNovoThresholdVariantCallFilter&&) = default; + DeNovoThresholdVariantCallFilter& operator=(DeNovoThresholdVariantCallFilter&&) = default; + + virtual ~DeNovoThresholdVariantCallFilter() override = default; + +private: + virtual bool is_soft_filtered(const ClassificationList& sample_classifications, const MeasureVector& measures) const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/double_pass_variant_call_filter.cpp b/src/core/csr/filters/double_pass_variant_call_filter.cpp index 6a6fd5683..59bf1a9c7 100644 --- a/src/core/csr/filters/double_pass_variant_call_filter.cpp +++ b/src/core/csr/filters/double_pass_variant_call_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "double_pass_variant_call_filter.hpp" @@ -8,8 +8,11 @@ #include #include +#include + #include "io/variant/vcf_reader.hpp" #include "io/variant/vcf_writer.hpp" +#include "utils/append.hpp" namespace octopus { namespace csr { @@ -29,48 +32,75 @@ void DoublePassVariantCallFilter::filter(const VcfReader& source, VcfWriter& des assert(dest.is_header_written()); make_registration_pass(source, samples); prepare_for_classification(info_log_); - make_filter_pass(source, dest); + make_filter_pass(source, samples, dest); } -void DoublePassVariantCallFilter::log_registration_pass_start(Log& log) const +void DoublePassVariantCallFilter::log_registration_pass(Log& log) const { log << "CSR: Starting registration pass"; } void DoublePassVariantCallFilter::make_registration_pass(const VcfReader& source, const SampleList& samples) const { - if (info_log_) log_registration_pass_start(*info_log_); + if (info_log_) log_registration_pass(*info_log_); + prepare_for_registration(samples); if (progress_) progress_->start(); - if (can_measure_single_call()) { + std::size_t record_idx {0}; + if (can_measure_multiple_blocks()) { + for (auto p = source.iterate(); p.first != p.second;) { + const auto blocks = read_next_blocks(p.first, p.second, samples); + record(blocks, record_idx, samples); + for (const auto& block : blocks) record_idx += block.size(); + } + } else if (can_measure_single_call()) { auto p = source.iterate(); - std::size_t idx {0}; - std::for_each(std::move(p.first), std::move(p.second), [&] (const VcfRecord& call) { record(call, idx++); }); + std::for_each(std::move(p.first), std::move(p.second), [&] (const VcfRecord& call) { record(call, record_idx++, samples); }); } else { - std::size_t idx {0}; for (auto p = source.iterate(); p.first != p.second;) { - const auto calls = read_next_block(p.first, p.second, samples); - record(calls, idx); - idx += calls.size(); + const auto block = read_next_block(p.first, p.second, samples); + record(block, record_idx, samples); + record_idx += block.size(); } } if (progress_) progress_->stop(); } -void DoublePassVariantCallFilter::record(const VcfRecord& call, const std::size_t idx) const +void DoublePassVariantCallFilter::record(const VcfRecord& call, const std::size_t record_idx, const SampleList& samples) const +{ + record(call, measure(call), record_idx, samples); +} + +void DoublePassVariantCallFilter::record(const CallBlock& block, const std::size_t record_idx, const SampleList& samples) const { - record(idx, measure(call)); + record(block, measure(block), record_idx, samples); +} + +void DoublePassVariantCallFilter::record(const std::vector& blocks, std::size_t record_idx, const SampleList& samples) const +{ + const auto measures = measure(blocks); + assert(measures.size() == blocks.size()); + for (auto tup : boost::combine(blocks, measures)) { + const auto& block = tup.get<0>(); + record(block, tup.get<1>(), record_idx, samples); + record_idx += block.size(); + } +} + +void DoublePassVariantCallFilter::record(const VcfRecord& call, const MeasureVector& measures, + const std::size_t record_idx, const SampleList& samples) const +{ + for (std::size_t sample_idx {0}; sample_idx < samples.size(); ++sample_idx) { + this->record(record_idx, sample_idx, get_sample_values(measures, measures_, sample_idx)); + } log_progress(mapped_region(call)); } -void DoublePassVariantCallFilter::record(const std::vector& calls, std::size_t first_idx) const +void DoublePassVariantCallFilter::record(const CallBlock& block, const MeasureBlock& measures, + std::size_t record_idx, const SampleList& samples) const { - if (!calls.empty()) { - const auto measures = measure(calls); - assert(measures.size() == calls.size()); - for (const auto& m : measures) { - record(first_idx++, m); - } - log_progress(encompassing_region(calls)); + assert(measures.size() == block.size()); + for (auto tup : boost::combine(block, measures)) { + record(tup.get<0>(), tup.get<1>(), record_idx++, samples); } } @@ -79,7 +109,7 @@ void DoublePassVariantCallFilter::log_filter_pass_start(Log& log) const log << "CSR: Starting filtering pass"; } -void DoublePassVariantCallFilter::make_filter_pass(const VcfReader& source, VcfWriter& dest) const +void DoublePassVariantCallFilter::make_filter_pass(const VcfReader& source, const SampleList& samples, VcfWriter& dest) const { if (info_log_) log_filter_pass_start(*info_log_); if (progress_) { @@ -89,13 +119,26 @@ void DoublePassVariantCallFilter::make_filter_pass(const VcfReader& source, VcfW } auto p = source.iterate(); std::size_t idx {0}; - std::for_each(std::move(p.first), std::move(p.second), [&] (const VcfRecord& call) { filter(call, idx++, dest); }); + std::for_each(std::move(p.first), std::move(p.second), [&] (const VcfRecord& call) { filter(call, idx++, samples, dest); }); if (progress_) progress_->stop(); } -void DoublePassVariantCallFilter::filter(const VcfRecord& call, const std::size_t idx, VcfWriter& dest) const +std::vector +DoublePassVariantCallFilter::classify(std::size_t call_idx, const SampleList& samples) const +{ + std::vector result(samples.size()); + for (std::size_t sample_idx {0}; sample_idx < samples.size(); ++sample_idx) { + result[sample_idx] = this->classify(call_idx, sample_idx); + } + return result; +} + +void DoublePassVariantCallFilter::filter(const VcfRecord& call, const std::size_t call_idx, const SampleList& samples, + VcfWriter& dest) const { - write(call, classify(idx), dest); + const auto sample_classifications = classify(call_idx, samples); + const auto call_classification = merge(sample_classifications); + write(call, call_classification, samples, sample_classifications, dest); log_progress(mapped_region(call)); } diff --git a/src/core/csr/filters/double_pass_variant_call_filter.hpp b/src/core/csr/filters/double_pass_variant_call_filter.hpp index 719443b7f..cdfea56f7 100644 --- a/src/core/csr/filters/double_pass_variant_call_filter.hpp +++ b/src/core/csr/filters/double_pass_variant_call_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef double_variant_call_filter_hpp @@ -42,19 +42,24 @@ class DoublePassVariantCallFilter : public VariantCallFilter mutable boost::optional progress_; mutable boost::optional current_contig_; - virtual void log_registration_pass_start(Log& log) const; - virtual void record(std::size_t call_idx, MeasureVector measures) const = 0; + virtual void log_registration_pass(Log& log) const; + virtual void prepare_for_registration(const SampleList& samples) const {}; + virtual void record(std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const = 0; virtual void prepare_for_classification(boost::optional& log) const = 0; virtual void log_filter_pass_start(Log& log) const; - virtual Classification classify(std::size_t call_idx) const = 0; + virtual Classification classify(std::size_t call_idx, std::size_t sample_idx) const = 0; void filter(const VcfReader& source, VcfWriter& dest, const SampleList& samples) const override; void make_registration_pass(const VcfReader& source, const SampleList& samples) const; - void record(const VcfRecord& call, std::size_t idx) const; - void record(const std::vector& calls, std::size_t first_idx) const; - void make_filter_pass(const VcfReader& source, VcfWriter& dest) const; - void filter(const VcfRecord& call, std::size_t idx, VcfWriter& dest) const; + void record(const VcfRecord& call, std::size_t record_idx, const SampleList& samples) const; + void record(const CallBlock& block, std::size_t record_idx, const SampleList& samples) const; + void record(const std::vector& blocks, std::size_t record_idx, const SampleList& samples) const; + void record(const VcfRecord& call, const MeasureVector& measures, std::size_t record_idx, const SampleList& samples) const; + void record(const CallBlock& block, const MeasureBlock& measures, std::size_t record_idx, const SampleList& samples) const; + void make_filter_pass(const VcfReader& source, const SampleList& samples, VcfWriter& dest) const; + std::vector classify(std::size_t call_idx, const SampleList& samples) const; + void filter(const VcfRecord& call, std::size_t idx, const SampleList& samples, VcfWriter& dest) const; void log_progress(const GenomicRegion& region) const; }; diff --git a/src/core/csr/filters/passing_filter.cpp b/src/core/csr/filters/passing_filter.cpp index fa7e2ffec..98be25acc 100644 --- a/src/core/csr/filters/passing_filter.cpp +++ b/src/core/csr/filters/passing_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "passing_filter.hpp" diff --git a/src/core/csr/filters/passing_filter.hpp b/src/core/csr/filters/passing_filter.hpp index 22270186e..487f58367 100644 --- a/src/core/csr/filters/passing_filter.hpp +++ b/src/core/csr/filters/passing_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef passing_filter_hpp diff --git a/src/core/csr/filters/random_forest_filter.cpp b/src/core/csr/filters/random_forest_filter.cpp new file mode 100644 index 000000000..6f3ed9072 --- /dev/null +++ b/src/core/csr/filters/random_forest_filter.cpp @@ -0,0 +1,203 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "random_forest_filter.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "ranger/ForestProbability.h" + +#include "basics/phred.hpp" + +namespace octopus { namespace csr { + +RandomForestFilter::RandomForestFilter(FacetFactory facet_factory, + std::vector measures, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path ranger_forest, Path temp_directory, + boost::optional progress) +: DoublePassVariantCallFilter {std::move(facet_factory), std::move(measures), std::move(output_config), threading, progress} +, forest_ {std::make_unique()} +, ranger_forest_ {std::move(ranger_forest)} +, temp_dir_ {std::move(temp_directory)} +, num_records_ {0} +, data_buffer_ {} +{} + +const std::string RandomForestFilter::call_qual_name_ = "RFQUAL"; + +boost::optional RandomForestFilter::call_quality_name() const +{ + return call_qual_name_; +} + +void RandomForestFilter::annotate(VcfHeader::Builder& header) const +{ + header.add_info(call_qual_name_, "1", "Float", "Empirical quality score from random forest classifier"); + header.add_filter("RF", "Random Forest filtered"); +} + +namespace { + +template +static void write_line(const std::vector& data, std::ostream& out) +{ + std::copy(std::cbegin(data), std::prev(std::cend(data)), std::ostream_iterator {out, " "}); + out << data.back() << '\n'; +} + +} // namespace + +void RandomForestFilter::prepare_for_registration(const SampleList& samples) const +{ + std::vector data_header {}; + data_header.reserve(measures_.size()); + for (const auto& measure : measures_) { + data_header.push_back(measure.name()); + } + data_header.push_back("TP"); + data_.reserve(samples.size()); + for (const auto& sample : samples) { + auto data_path = temp_dir_; + Path fname {"octopus_ranger_temp_forest_data_" + sample + ".dat"}; + data_path /= fname; + data_.emplace_back(data_path.string(), data_path); + write_line(data_header, data_.back().handle); + } + data_buffer_.resize(samples.size()); +} + +namespace { + +struct MeasureDoubleVisitor : boost::static_visitor<> +{ + double result; + template void operator()(const T& value) + { + result = boost::lexical_cast(value); + } + template void operator()(const boost::optional& value) + { + if (value) { + (*this)(*value); + } else { + result = -1; + } + } + template void operator()(const std::vector& values) + { + throw std::runtime_error {"Vector cast not supported"}; + } + void operator()(boost::any value) + { + throw std::runtime_error {"Any cast not supported"}; + } +}; + +auto cast_to_double(const Measure::ResultType& value) +{ + MeasureDoubleVisitor vis {}; + boost::apply_visitor(vis, value); + return vis.result; +} + +void skip_lines(std::istream& in, int n = 1) +{ + for (; n > 0; --n) in.ignore(std::numeric_limits::max(), '\n'); +} + +} // namespace + +void RandomForestFilter::record(const std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const +{ + assert(!measures.empty()); + std::transform(std::cbegin(measures), std::cend(measures), std::back_inserter(data_buffer_[sample_idx]), cast_to_double); + data_buffer_[sample_idx].push_back(0); // dummy TP value + write_line(data_buffer_[sample_idx], data_[sample_idx].handle); + data_buffer_[sample_idx].clear(); + if (call_idx >= num_records_) ++num_records_; +} + +namespace { + +bool read_header(std::ifstream& prediction_file) +{ + skip_lines(prediction_file); + std::string order; + std::getline(prediction_file, order); + skip_lines(prediction_file); + return order.front() == '1'; +} + +static double get_prob_false(std::string& prediction_line, const bool tp_first) +{ + using std::cbegin; using std::cend; + if (tp_first) { + prediction_line.erase(cbegin(prediction_line), std::next(std::find(cbegin(prediction_line), cend(prediction_line), ' '))); + prediction_line.erase(std::find(cbegin(prediction_line), cend(prediction_line), ' '), cend(prediction_line)); + } else { + prediction_line.erase(std::find(cbegin(prediction_line), cend(prediction_line), ' '), cend(prediction_line)); + } + return boost::lexical_cast(prediction_line); +} + +} // namespace + +void RandomForestFilter::prepare_for_classification(boost::optional& log) const +{ + const Path ranger_prefix {temp_dir_ / "octopus_ranger_temp"}; + const Path ranger_prediction_fname {ranger_prefix.string() + ".prediction"}; + data_buffer_.resize(num_records_); + for (auto& file : data_) { + file.handle.close(); + std::vector tmp {}, cat_vars {}; + forest_->initCpp("TP", ranger::MemoryMode::MEM_DOUBLE, file.path.string(), 0, ranger_prefix.string(), + 1000, nullptr, 12, 1, ranger_forest_.string(), ranger::ImportanceMode::IMP_GINI, 1, "", + tmp, "", true, cat_vars, false, ranger::SplitRule::LOGRANK, "", false, 1.0, + ranger::DEFAULT_ALPHA, ranger::DEFAULT_MINPROP, false, + ranger::PredictionType::RESPONSE, ranger::DEFAULT_NUM_RANDOM_SPLITS); + forest_->run(false); + forest_->writePredictionFile(); + std::ifstream prediction_file {ranger_prediction_fname.string()}; + const auto tp_first = read_header(prediction_file); + std::string line; + std::size_t record_idx {0}; + while (std::getline(prediction_file, line)) { + if (!line.empty()) { + data_buffer_[record_idx++].push_back(get_prob_false(line, tp_first)); + } + } + boost::filesystem::remove(file.path); + } + boost::filesystem::remove(ranger_prediction_fname); + data_.clear(); + data_.shrink_to_fit(); +} + +VariantCallFilter::Classification RandomForestFilter::classify(const std::size_t call_idx, std::size_t sample_idx) const +{ + assert(call_idx < data_buffer_.size() && sample_idx < data_buffer_[call_idx].size()); + const auto prob_false = data_buffer_[call_idx][sample_idx]; + Classification result {}; + if (prob_false < 0.5) { + result.category = Classification::Category::unfiltered; + } else { + result.category = Classification::Category::soft_filtered; + result.reasons.assign({"RF"}); + } + result.quality = probability_to_phred(std::max(prob_false, 1e-10)); + return result; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/random_forest_filter.hpp b/src/core/csr/filters/random_forest_filter.hpp new file mode 100644 index 000000000..700abb143 --- /dev/null +++ b/src/core/csr/filters/random_forest_filter.hpp @@ -0,0 +1,72 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef random_forest_filter_hpp +#define random_forest_filter_hpp + +#include +#include +#include +#include + +#include +#include + +#include "ranger/Forest.h" + +#include "double_pass_variant_call_filter.hpp" + +namespace octopus { namespace csr { + +class RandomForestFilter : public DoublePassVariantCallFilter +{ +public: + using Path = boost::filesystem::path; + + RandomForestFilter() = delete; + + RandomForestFilter(FacetFactory facet_factory, + std::vector measures, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path ranger_forest, + Path temp_directory = "/tmp", + boost::optional progress = boost::none); + + RandomForestFilter(const RandomForestFilter&) = delete; + RandomForestFilter& operator=(const RandomForestFilter&) = delete; + RandomForestFilter(RandomForestFilter&&) = default; + RandomForestFilter& operator=(RandomForestFilter&&) = default; + + virtual ~RandomForestFilter() override = default; + +private: + struct File + { + std::ofstream handle; + Path path; + template + File(F&& handle, P&& path) : handle {std::forward(handle)}, path {std::forward

(path)} {}; + }; + + std::unique_ptr forest_; + Path ranger_forest_, temp_dir_; + + mutable std::vector data_; + mutable std::size_t num_records_; + mutable std::vector> data_buffer_; + + const static std::string call_qual_name_; + + boost::optional call_quality_name() const override; + void annotate(VcfHeader::Builder& header) const override; + void prepare_for_registration(const SampleList& samples) const override; + void record(std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const override; + void prepare_for_classification(boost::optional& log) const override; + Classification classify(std::size_t call_idx, std::size_t sample_idx) const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/random_forest_filter_factory.cpp b/src/core/csr/filters/random_forest_filter_factory.cpp new file mode 100644 index 000000000..456080df9 --- /dev/null +++ b/src/core/csr/filters/random_forest_filter_factory.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "random_forest_filter_factory.hpp" + +#include "utils/string_utils.hpp" +#include "../measures/measure_factory.hpp" +#include "somatic_random_forest_filter.hpp" +#include "denovo_random_forest_filter.hpp" + +namespace octopus { namespace csr { + +namespace { + +std::vector parse_measures(const std::vector& measure_names) +{ + std::vector result{}; + result.reserve(measure_names.size()); + std::transform(std::cbegin(measure_names), std::cend(measure_names), std::back_inserter(result), make_measure); + return result; +} + +} // namespace + +static const auto default_measure_names = utils::split("AC AF ARF BQ CRF DP FRF GC GQ MC MF MP MQ MQ0 MQD QD QUAL REFCALL RPB SB SC SOMATIC STR_LENGTH STR_PERIOD", ' '); + +RandomForestFilterFactory::RandomForestFilterFactory(Path ranger_forest, Path temp_directory, ForestType type) +: ranger_forests_ {std::move(ranger_forest)} +, forest_types_ {type} +, temp_directory_ {std::move(temp_directory)} +{ + measures_ = parse_measures(default_measure_names); +} + +RandomForestFilterFactory::RandomForestFilterFactory(Path germline_ranger_forest, Path somatic_ranger_forest, Path temp_directory) +: ranger_forests_ {std::move(germline_ranger_forest), std::move(somatic_ranger_forest)} +, forest_types_ {ForestType::germline, ForestType::somatic} +, temp_directory_ {std::move(temp_directory)} +{ + measures_ = parse_measures(default_measure_names); +} + +std::unique_ptr RandomForestFilterFactory::do_clone() const +{ + return std::make_unique(*this); +} + +std::unique_ptr +RandomForestFilterFactory::do_make(FacetFactory facet_factory, + VariantCallFilter::OutputOptions output_config, + boost::optional progress, + VariantCallFilter::ConcurrencyPolicy threading) const +{ + if (ranger_forests_.size() == 1) { + assert(forest_types_.size() == 1); + switch (forest_types_.front()) { + case ForestType::somatic: + return std::make_unique(std::move(facet_factory), measures_, ranger_forests_[0], + output_config, threading, temp_directory_, progress); + case ForestType::denovo: + return std::make_unique(std::move(facet_factory), measures_, ranger_forests_[0], + output_config, threading, temp_directory_, progress); + case ForestType::germline: + default: + return std::make_unique(std::move(facet_factory), measures_, output_config, threading, + ranger_forests_[0], temp_directory_, progress); + } + } else { + assert(ranger_forests_.size() == 2); + return std::make_unique(std::move(facet_factory), measures_, + ranger_forests_[0], ranger_forests_[1], + output_config, threading, temp_directory_, progress); + } +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/random_forest_filter_factory.hpp b/src/core/csr/filters/random_forest_filter_factory.hpp new file mode 100644 index 000000000..02f716023 --- /dev/null +++ b/src/core/csr/filters/random_forest_filter_factory.hpp @@ -0,0 +1,61 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef random_forest_filter_factory_hpp +#define random_forest_filter_factory_hpp + +#include +#include +#include + +#include +#include + +#include "logging/progress_meter.hpp" +#include "../measures/measure.hpp" +#include "variant_call_filter_factory.hpp" +#include "variant_call_filter.hpp" +#include "random_forest_filter.hpp" + +namespace octopus { namespace csr { + +class FacetFactory; + +class RandomForestFilterFactory : public VariantCallFilterFactory +{ +public: + using Path = RandomForestFilter::Path; + enum class ForestType { germline, somatic, denovo }; + + RandomForestFilterFactory() = default; + + RandomForestFilterFactory(Path ranger_forest, Path temp_directory, ForestType type = ForestType::germline); + RandomForestFilterFactory(Path germline_ranger_forest, Path somatic_ranger_forest, Path temp_directory); + + RandomForestFilterFactory(const RandomForestFilterFactory&) = default; + RandomForestFilterFactory& operator=(const RandomForestFilterFactory&) = default; + RandomForestFilterFactory(RandomForestFilterFactory&&) = default; + RandomForestFilterFactory& operator=(RandomForestFilterFactory&&) = default; + + ~RandomForestFilterFactory() = default; + +private: + std::vector measures_; + std::vector ranger_forests_; + std::vector forest_types_; + Path temp_directory_; + + std::unique_ptr do_clone() const override; + std::unique_ptr do_make(FacetFactory facet_factory, + VariantCallFilter::OutputOptions output_config, + boost::optional progress, + VariantCallFilter::ConcurrencyPolicy threading) const override; +}; + +} // namespace csr + +using csr::RandomForestFilterFactory; + +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/single_pass_variant_call_filter.cpp b/src/core/csr/filters/single_pass_variant_call_filter.cpp index de933bccf..e977c7dc6 100644 --- a/src/core/csr/filters/single_pass_variant_call_filter.cpp +++ b/src/core/csr/filters/single_pass_variant_call_filter.cpp @@ -1,8 +1,9 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "single_pass_variant_call_filter.hpp" +#include #include #include #include @@ -31,59 +32,71 @@ void SinglePassVariantCallFilter::filter(const VcfReader& source, VcfWriter& des if (progress_) progress_->start(); if (can_measure_multiple_blocks()) { for (auto p = source.iterate(); p.first != p.second;) { - filter(read_next_blocks(p.first, p.second, samples), dest); + filter(read_next_blocks(p.first, p.second, samples), dest, samples); } } else if (can_measure_single_call()) { auto p = source.iterate(); - std::for_each(std::move(p.first), std::move(p.second), [&] (const VcfRecord& call) { filter(call, dest); }); + std::for_each(std::move(p.first), std::move(p.second), [&] (const VcfRecord& call) { filter(call, dest, samples); }); } else { for (auto p = source.iterate(); p.first != p.second;) { - filter(read_next_block(p.first, p.second, samples), dest); + filter(read_next_block(p.first, p.second, samples), dest, samples); } } if (progress_) progress_->stop(); } -void SinglePassVariantCallFilter::filter(const VcfRecord& call, VcfWriter& dest) const +void SinglePassVariantCallFilter::filter(const VcfRecord& call, VcfWriter& dest, const SampleList& samples) const { - filter(call, measure(call), dest); + filter(call, measure(call), dest, samples); } -void SinglePassVariantCallFilter::filter(const CallBlock& block, VcfWriter& dest) const +void SinglePassVariantCallFilter::filter(const CallBlock& block, VcfWriter& dest, const SampleList& samples) const { - filter(block, measure(block), dest); + filter(block, measure(block), dest, samples); } -void SinglePassVariantCallFilter::filter(const std::vector& blocks, VcfWriter& dest) const +void SinglePassVariantCallFilter::filter(const std::vector& blocks, VcfWriter& dest, const SampleList& samples) const { const auto measures = measure(blocks); assert(measures.size() == blocks.size()); for (auto tup : boost::combine(blocks, measures)) { - filter(tup.get<0>(), tup.get<1>(), dest); + filter(tup.get<0>(), tup.get<1>(), dest, samples); } } -void SinglePassVariantCallFilter::filter(const CallBlock& block, const MeasureBlock& measures, VcfWriter& dest) const +void SinglePassVariantCallFilter::filter(const CallBlock& block, const MeasureBlock& measures, VcfWriter& dest, const SampleList& samples) const { assert(measures.size() == block.size()); for (auto tup : boost::combine(block, measures)) { - filter(tup.get<0>(), tup.get<1>(), dest); + filter(tup.get<0>(), tup.get<1>(), dest, samples); } } -void SinglePassVariantCallFilter::filter(const VcfRecord& call, const MeasureVector& measures, VcfWriter& dest) const +void SinglePassVariantCallFilter::filter(const VcfRecord& call, const MeasureVector& measures, VcfWriter& dest, const SampleList& samples) const { + const auto sample_classifications = classify(measures, samples); + const auto call_classification = merge(sample_classifications, measures); if (annotate_measures_) { auto annotation_builder = VcfRecord::Builder {call}; annotate(annotation_builder, measures); const auto annotated_call = annotation_builder.build_once(); - write(annotated_call, classify(measures), dest); + write(annotated_call, call_classification, samples, sample_classifications, dest); } else { - write(call, classify(measures), dest); + write(call, call_classification, samples, sample_classifications, dest); } log_progress(mapped_region(call)); } +VariantCallFilter::ClassificationList +SinglePassVariantCallFilter::classify(const MeasureVector& call_measures, const SampleList& samples) const +{ + ClassificationList result(samples.size()); + for (std::size_t sample_idx {0}; sample_idx < samples.size(); ++sample_idx) { + result[sample_idx] = this->classify(get_sample_values(call_measures, measures_, sample_idx)); + } + return result; +} + static auto expand_lhs_to_zero(const GenomicRegion& region) { return GenomicRegion {region.contig_name(), 0, region.end()}; diff --git a/src/core/csr/filters/single_pass_variant_call_filter.hpp b/src/core/csr/filters/single_pass_variant_call_filter.hpp index 81e2474a7..d1a69890c 100644 --- a/src/core/csr/filters/single_pass_variant_call_filter.hpp +++ b/src/core/csr/filters/single_pass_variant_call_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef single_pass_variant_call_filter_hpp @@ -43,12 +43,12 @@ class SinglePassVariantCallFilter : public VariantCallFilter virtual Classification classify(const MeasureVector& call_measures) const = 0; void filter(const VcfReader& source, VcfWriter& dest, const SampleList& samples) const override; - - void filter(const VcfRecord& call, VcfWriter& dest) const; - void filter(const CallBlock& block, VcfWriter& dest) const; - void filter(const std::vector& blocks, VcfWriter& dest) const; - void filter(const CallBlock& block, const MeasureBlock & measures, VcfWriter& dest) const; - void filter(const VcfRecord& call, const MeasureVector& measures, VcfWriter& dest) const; + void filter(const VcfRecord& call, VcfWriter& dest, const SampleList& samples) const; + void filter(const CallBlock& block, VcfWriter& dest, const SampleList& samples) const; + void filter(const std::vector& blocks, VcfWriter& dest, const SampleList& samples) const; + void filter(const CallBlock& block, const MeasureBlock & measures, VcfWriter& dest, const SampleList& samples) const; + void filter(const VcfRecord& call, const MeasureVector& measures, VcfWriter& dest, const SampleList& samples) const; + ClassificationList classify(const MeasureVector& call_measures, const SampleList& samples) const; void log_progress(const GenomicRegion& region) const; }; diff --git a/src/core/csr/filters/somatic_random_forest_filter.cpp b/src/core/csr/filters/somatic_random_forest_filter.cpp new file mode 100644 index 000000000..71170a4ed --- /dev/null +++ b/src/core/csr/filters/somatic_random_forest_filter.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "somatic_random_forest_filter.hpp" + +#include + +#include "../measures/is_somatic.hpp" +#include "../measures/is_refcall.hpp" + +namespace octopus { namespace csr { + +SomaticRandomForestVariantCallFilter::SomaticRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path germline_forest, Path somatic_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory, + boost::optional progress) +: ConditionalRandomForestFilter { + std::move(facet_factory), + std::move(measures), + {make_wrapped_measure(true), make_wrapped_measure(true)}, + [] (const MeasureVector& measures) -> std::int8_t { + assert(measures.size() == 2); + if (boost::get(measures.front())) { + return 1; + } else if (boost::get(measures.back())) { + return 1; + } else { + return 0; + }}, + {std::move(germline_forest), std::move(somatic_forest)}, + std::move(output_config), + std::move(threading), + std::move(temp_directory), + progress +} {} + +SomaticRandomForestVariantCallFilter::SomaticRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path somatic_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory, + boost::optional progress) +: ConditionalRandomForestFilter { + std::move(facet_factory), + std::move(measures), + {make_wrapped_measure(false)}, + [] (const MeasureVector& measures) -> std::int8_t { + assert(measures.size() == 1); + return !boost::get(measures.front()); + }, + {std::move(somatic_forest)}, + std::move(output_config), + std::move(threading), + std::move(temp_directory), + progress +} {} + +bool SomaticRandomForestVariantCallFilter::is_soft_filtered(const ClassificationList& sample_classifications, + const MeasureVector& measures) const +{ + return std::any_of(std::cbegin(sample_classifications), std::cend(sample_classifications), + [] (const auto& c) { return c.category != Classification::Category::unfiltered; }); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/somatic_random_forest_filter.hpp b/src/core/csr/filters/somatic_random_forest_filter.hpp new file mode 100644 index 000000000..9ab4692a6 --- /dev/null +++ b/src/core/csr/filters/somatic_random_forest_filter.hpp @@ -0,0 +1,56 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef somatic_random_forest_filter_hpp +#define somatic_random_forest_filter_hpp + +#include +#include + +#include +#include + +#include "threshold_filter.hpp" +#include "conditional_random_forest_filter.hpp" +#include "logging/progress_meter.hpp" +#include "../facets/facet_factory.hpp" +#include "../measures/measure.hpp" + +namespace octopus { namespace csr { + +class SomaticRandomForestVariantCallFilter : public ConditionalRandomForestFilter +{ +public: + SomaticRandomForestVariantCallFilter() = delete; + + SomaticRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path germline_forest, Path somatic_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory = "/tmp", + boost::optional progress = boost::none); + // Somatics only + SomaticRandomForestVariantCallFilter(FacetFactory facet_factory, + std::vector measures, + Path somatic_forest, + OutputOptions output_config, + ConcurrencyPolicy threading, + Path temp_directory = "/tmp", + boost::optional progress = boost::none); + + SomaticRandomForestVariantCallFilter(const SomaticRandomForestVariantCallFilter&) = delete; + SomaticRandomForestVariantCallFilter& operator=(const SomaticRandomForestVariantCallFilter&) = delete; + SomaticRandomForestVariantCallFilter(SomaticRandomForestVariantCallFilter&&) = default; + SomaticRandomForestVariantCallFilter& operator=(SomaticRandomForestVariantCallFilter&&) = default; + + virtual ~SomaticRandomForestVariantCallFilter() override = default; + +private: + virtual bool is_soft_filtered(const ClassificationList& sample_classifications, const MeasureVector& measures) const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/somatic_threshold_filter.cpp b/src/core/csr/filters/somatic_threshold_filter.cpp new file mode 100644 index 000000000..37b868df2 --- /dev/null +++ b/src/core/csr/filters/somatic_threshold_filter.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "somatic_threshold_filter.hpp" + +#include +#include +#include +#include + +#include + +#include "../measures/is_somatic.hpp" +#include "../measures/is_refcall.hpp" + +namespace octopus { namespace csr { + +SomaticThresholdVariantCallFilter::SomaticThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair germline, + ConditionVectorPair somatic, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress) +: ConditionalThresholdVariantCallFilter { + std::move(facet_factory), + {std::move(germline), std::move(somatic), std::move(reference)}, + {make_wrapped_measure(true), make_wrapped_measure(true)}, + [] (const MeasureVector& measures) -> std::size_t { + assert(measures.size() == 2); + if (boost::get(measures.front())) { + return 1; + } else if (boost::get(measures.back())) { + return 2; + } else { + return 0; + }}, + output_config, threading, progress +} {} + +SomaticThresholdVariantCallFilter::SomaticThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair somatic, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress) +: ConditionalThresholdVariantCallFilter { + std::move(facet_factory), + {{{{make_wrapped_measure(false), make_wrapped_threshold>(false)}}, {}}, std::move(somatic), std::move(reference)}, + {make_wrapped_measure(true), make_wrapped_measure(true)}, + [] (const MeasureVector& measures) -> std::size_t { + assert(measures.size() == 2); + if (boost::get(measures.front())) { + return 1; + } else if (boost::get(measures.back())) { + return 2; + } else { + return 0; + }}, + output_config, threading, progress +} {} + +bool SomaticThresholdVariantCallFilter::is_soft_filtered(const ClassificationList& sample_classifications, + const MeasureVector& measures) const +{ + return std::any_of(std::cbegin(sample_classifications), std::cend(sample_classifications), + [] (const auto& c) { return c.category != Classification::Category::unfiltered; }); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/somatic_threshold_filter.hpp b/src/core/csr/filters/somatic_threshold_filter.hpp new file mode 100644 index 000000000..92775e9af --- /dev/null +++ b/src/core/csr/filters/somatic_threshold_filter.hpp @@ -0,0 +1,54 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef somatic_threshold_filter_hpp +#define somatic_threshold_filter_hpp + +#include +#include + +#include + +#include "threshold_filter.hpp" +#include "conditional_threshold_filter.hpp" +#include "logging/progress_meter.hpp" +#include "../facets/facet_factory.hpp" +#include "../measures/measure.hpp" + +namespace octopus { namespace csr { + +class SomaticThresholdVariantCallFilter : public ConditionalThresholdVariantCallFilter +{ +public: + SomaticThresholdVariantCallFilter() = delete; + + SomaticThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair germline, + ConditionVectorPair somatic, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress = boost::none); + // Hard filter germline + SomaticThresholdVariantCallFilter(FacetFactory facet_factory, + ConditionVectorPair somatic, + ConditionVectorPair reference, + OutputOptions output_config, + ConcurrencyPolicy threading, + boost::optional progress = boost::none); + + SomaticThresholdVariantCallFilter(const SomaticThresholdVariantCallFilter&) = delete; + SomaticThresholdVariantCallFilter& operator=(const SomaticThresholdVariantCallFilter&) = delete; + SomaticThresholdVariantCallFilter(SomaticThresholdVariantCallFilter&&) = default; + SomaticThresholdVariantCallFilter& operator=(SomaticThresholdVariantCallFilter&&) = default; + + virtual ~SomaticThresholdVariantCallFilter() override = default; + +private: + virtual bool is_soft_filtered(const ClassificationList& sample_classifications, const MeasureVector& measures) const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/filters/threshold_filter.cpp b/src/core/csr/filters/threshold_filter.cpp index a7303672e..97d03ba39 100644 --- a/src/core/csr/filters/threshold_filter.cpp +++ b/src/core/csr/filters/threshold_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "threshold_filter.hpp" @@ -11,10 +11,13 @@ #include "io/variant/vcf_header.hpp" #include "utils/append.hpp" +#include "utils/concat.hpp" #include "config/octopus_vcf.hpp" namespace octopus { namespace csr { +namespace { + auto extract_measures(std::vector& conditions) { std::vector result {}; @@ -59,34 +62,35 @@ bool are_all_unique(std::vector keys) return std::adjacent_find(std::cbegin(keys), std::cend(keys)) == std::cend(keys); } +} // namespace + ThresholdVariantCallFilter::ThresholdVariantCallFilter(FacetFactory facet_factory, - std::vector hard_conditions, - std::vector soft_conditions, + ConditionVectorPair conditions, OutputOptions output_config, ConcurrencyPolicy threading, - boost::optional progress) -: SinglePassVariantCallFilter {std::move(facet_factory), extract_measures(hard_conditions, soft_conditions), + boost::optional progress, + std::vector other_measures) +: SinglePassVariantCallFilter {std::move(facet_factory), + concat(extract_measures(conditions.hard, conditions.soft), std::move(other_measures)), output_config, threading, progress} -, hard_thresholds_ {extract_thresholds(hard_conditions)} -, soft_thresholds_ {extract_thresholds(soft_conditions)} -, vcf_filter_keys_ {extract_vcf_filter_keys(soft_conditions)} +, hard_thresholds_ {extract_thresholds(conditions.hard)} +, soft_thresholds_ {extract_thresholds(conditions.soft)} +, vcf_filter_keys_ {extract_vcf_filter_keys(conditions.soft)} , all_unique_filter_keys_ {are_all_unique(vcf_filter_keys_)} {} +bool ThresholdVariantCallFilter::passes_all_filters(MeasureIterator first_measure, MeasureIterator last_measure, + ThresholdIterator first_threshold) const +{ + return std::inner_product(first_measure, last_measure, first_threshold, true, std::multiplies<> {}, + [] (const auto& measure, const auto& threshold) -> bool { return threshold(measure); }); +} + void ThresholdVariantCallFilter::annotate(VcfHeader::Builder& header) const { for (const auto& key : vcf_filter_keys_) { octopus::vcf::add_filter(header, key); } - for (const auto& name : measure_names_) { - if (name != "QUAL") { - if (name == "DP" || name == "MQ0") { - header.add_info(name, "1", "Integer", "CSR measure"); - } else { - header.add_info(name, "1", "Float", "CSR measure"); - } - } - } } VariantCallFilter::Classification ThresholdVariantCallFilter::classify(const MeasureVector& measures) const @@ -104,23 +108,21 @@ VariantCallFilter::Classification ThresholdVariantCallFilter::classify(const Mea bool ThresholdVariantCallFilter::passes_all_hard_filters(const MeasureVector& measures) const { - return std::inner_product(std::cbegin(measures), std::next(std::cbegin(measures), hard_thresholds_.size()), - std::cbegin(hard_thresholds_), true, std::multiplies<> {}, - [] (const auto& measure, const auto& threshold) -> bool { return threshold(measure); }); + return passes_all_filters(std::cbegin(measures), std::next(std::cbegin(measures), hard_thresholds_.size()), + std::cbegin(hard_thresholds_)); } bool ThresholdVariantCallFilter::passes_all_soft_filters(const MeasureVector& measures) const { - return std::inner_product(std::next(std::cbegin(measures), hard_thresholds_.size()), std::cend(measures), - std::cbegin(soft_thresholds_), true, std::multiplies<> {}, - [] (const auto& measure, const auto& threshold) -> bool { return threshold(measure); }); + return passes_all_filters(std::next(std::cbegin(measures), hard_thresholds_.size()), std::cend(measures), + std::cbegin(soft_thresholds_)); } std::vector ThresholdVariantCallFilter::get_failing_vcf_filter_keys(const MeasureVector& measures) const { std::vector result {}; result.reserve(soft_thresholds_.size()); - for (std::size_t i {0}; i < measures.size(); ++i) { + for (std::size_t i {0}; i < soft_thresholds_.size(); ++i) { if (!soft_thresholds_[i](measures[i + hard_thresholds_.size()])) { result.push_back(vcf_filter_keys_[i]); } diff --git a/src/core/csr/filters/threshold_filter.hpp b/src/core/csr/filters/threshold_filter.hpp index 3b8457c9a..bc03850f2 100644 --- a/src/core/csr/filters/threshold_filter.hpp +++ b/src/core/csr/filters/threshold_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef threshold_filter_hpp @@ -52,14 +52,19 @@ class ThresholdVariantCallFilter : public SinglePassVariantCallFilter std::string vcf_filter_key = "."; }; + struct ConditionVectorPair + { + std::vector hard, soft; + }; + ThresholdVariantCallFilter() = delete; ThresholdVariantCallFilter(FacetFactory facet_factory, - std::vector hard_conditions, - std::vector soft_conditions, + ConditionVectorPair conditions, OutputOptions output_config, ConcurrencyPolicy threading, - boost::optional progress = boost::none); + boost::optional progress = boost::none, + std::vector other_measures = {}); ThresholdVariantCallFilter(const ThresholdVariantCallFilter&) = delete; ThresholdVariantCallFilter& operator=(const ThresholdVariantCallFilter&) = delete; @@ -68,17 +73,25 @@ class ThresholdVariantCallFilter : public SinglePassVariantCallFilter virtual ~ThresholdVariantCallFilter() override = default; -private: - std::vector hard_thresholds_, soft_thresholds_; +protected: + using ThresholdVector = std::vector; + using ThresholdIterator = ThresholdVector::const_iterator; + using MeasureIterator = MeasureVector::const_iterator; + + ThresholdVector hard_thresholds_, soft_thresholds_; std::vector vcf_filter_keys_; bool all_unique_filter_keys_; + bool passes_all_filters(MeasureIterator first_measure, MeasureIterator last_measure, + ThresholdIterator first_threshold) const; + +private: virtual void annotate(VcfHeader::Builder& header) const override; virtual Classification classify(const MeasureVector& measures) const override; - bool passes_all_hard_filters(const MeasureVector& measures) const; - bool passes_all_soft_filters(const MeasureVector& measures) const; - std::vector get_failing_vcf_filter_keys(const MeasureVector& measures) const; + virtual bool passes_all_hard_filters(const MeasureVector& measures) const; + virtual bool passes_all_soft_filters(const MeasureVector& measures) const; + virtual std::vector get_failing_vcf_filter_keys(const MeasureVector& measures) const; }; template @@ -103,7 +116,7 @@ struct UnaryThreshold { explicit UnaryVisitor(T target, Cmp cmp) : target {target}, cmp {cmp} {} template - bool operator()(T_ value) const noexcept { return cmp(target, value); } + bool operator()(T_ value) const noexcept { return !cmp(value, target); } template bool operator()(boost::optional value) const noexcept { @@ -113,6 +126,11 @@ struct UnaryThreshold { return true; } + template + bool operator()(const std::vector& values) const noexcept + { + return std::all_of(std::cbegin(values), std::cend(values), [this] (const auto& value) { return (*this)(value); }); + } T target; Cmp cmp; }; diff --git a/src/core/csr/filters/threshold_filter_factory.cpp b/src/core/csr/filters/threshold_filter_factory.cpp index 20edbcf4d..4fa28b0d5 100644 --- a/src/core/csr/filters/threshold_filter_factory.cpp +++ b/src/core/csr/filters/threshold_filter_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "threshold_filter_factory.hpp" @@ -17,6 +17,8 @@ #include "config/octopus_vcf.hpp" #include "exceptions/user_error.hpp" #include "utils/maths.hpp" +#include "somatic_threshold_filter.hpp" +#include "denovo_threshold_filter.hpp" namespace octopus { namespace csr { @@ -39,6 +41,9 @@ auto make_threshold(const std::string& comparator, const T target) if (comparator == "==") { return make_wrapped_threshold>(target); } + if (comparator == "!=") { + return make_wrapped_threshold>(target); + } if (comparator == "<") { return make_wrapped_threshold>(target); } @@ -67,21 +72,32 @@ void init(MeasureToFilterKeyMap& filter_names) filter_names[name()] = lowModelPosterior; filter_names[name()] = lowQuality; filter_names[name()] = lowQualityByDepth; - filter_names[name()] = lowGQ; + filter_names[name()] = lowGQ; filter_names[name()] = strandBias; filter_names[name()] = filteredReadFraction; filter_names[name()] = highGCRegion; filter_names[name()] = highClippedReadFraction; + filter_names[name()] = lowBaseQuality; + filter_names[name()] = highMismatchCount; + filter_names[name()] = highMismatchFraction; + filter_names[name()] = somaticContamination; + filter_names[name()] = deNovoContamination; + filter_names[name()] = readPositionBias; } auto get_vcf_filter_name(const MeasureWrapper& measure, const std::string& comparator, const double threshold_target) { using namespace octopus::vcf::spec::filter; // First look for special names - if (measure.name() == Quality().name()) { + if (measure.name() == name()) { + if (maths::almost_equal(threshold_target, 3.0)) return std::string {q3}; + if (maths::almost_equal(threshold_target, 5.0)) return std::string {q5}; if (maths::almost_equal(threshold_target, 10.0)) return std::string {q10}; if (maths::almost_equal(threshold_target, 20.0)) return std::string {q20}; } + if (measure.name() == name()) { + if (maths::almost_equal(threshold_target, 10.0)) return std::string {bq10}; + } static MeasureToFilterKeyMap default_filter_names {}; if (default_filter_names.empty()) { init(default_filter_names); @@ -104,7 +120,11 @@ auto make_condition(const std::string& measure_name, const std::string& comparat auto make_condition(const std::string& measure, const std::string& comparator, const std::string& threshold_target) { try { - return make_condition(measure, comparator, boost::lexical_cast(threshold_target)); + if (threshold_target.find('.') == std::string::npos) { + return make_condition(measure, comparator, boost::lexical_cast(threshold_target)); + } else { + return make_condition(measure, comparator, boost::lexical_cast(threshold_target)); + } } catch (const boost::bad_lexical_cast&) { throw BadVariantFilterCondition {}; } @@ -119,7 +139,7 @@ auto parse_conditions(std::string expression) boost::split(conditions, expression, boost::is_any_of("|")); for (const auto& condition : conditions) { std::vector tokens {}; - boost::split(tokens, condition, boost::is_any_of("<,>,<=,=>,==")); + boost::split(tokens, condition, boost::is_any_of("<,>,<=,=>,==,!=")); if (tokens.size() == 2) { const auto comparitor_pos = tokens.front().size(); const auto comparitor_length = condition.size() - comparitor_pos - tokens.back().size(); @@ -138,8 +158,29 @@ ThresholdFilterFactory::ThresholdFilterFactory(std::string soft_expression) {} ThresholdFilterFactory::ThresholdFilterFactory(std::string hard_expression, std::string soft_expression) -: hard_conditions_ {parse_conditions(std::move(hard_expression))} -, soft_conditions_ {parse_conditions(std::move(soft_expression))} +: germline_ {parse_conditions(std::move(hard_expression)), parse_conditions(std::move(soft_expression))} +, somatic_ {} +, reference_ {} +{} + +ThresholdFilterFactory::ThresholdFilterFactory(std::string germline_hard_expression, std::string germline_soft_expression, + std::string somatic_hard_expression, std::string somatic_soft_expression, + std::string refcall_hard_expression, std::string refcall_soft_expression, + Type type) +: germline_ {parse_conditions(std::move(germline_hard_expression)), parse_conditions(std::move(germline_soft_expression))} +, somatic_ {parse_conditions(std::move(somatic_hard_expression)), parse_conditions(std::move(somatic_soft_expression))} +, reference_ {parse_conditions(std::move(refcall_hard_expression)), parse_conditions(std::move(refcall_soft_expression))} +, type_ {type} +{} + +ThresholdFilterFactory::ThresholdFilterFactory(std::string somatic_hard_expression, std::string somatic_soft_expression, + std::string refcall_hard_expression, std::string refcall_soft_expression, + bool somatics_only, Type type) +: germline_ {} +, somatic_ {parse_conditions(std::move(somatic_hard_expression)), parse_conditions(std::move(somatic_soft_expression))} +, reference_ {parse_conditions(std::move(refcall_hard_expression)), parse_conditions(std::move(refcall_soft_expression))} +, type_ {type} +, hard_filter_germline_ {somatics_only} {} std::unique_ptr ThresholdFilterFactory::do_clone() const @@ -152,8 +193,33 @@ std::unique_ptr ThresholdFilterFactory::do_make(FacetFactory boost::optional progress, VariantCallFilter::ConcurrencyPolicy threading) const { - return std::make_unique(std::move(facet_factory), hard_conditions_, soft_conditions_, - output_config, threading, progress); + if (somatic_.hard.empty() && somatic_.soft.empty()) { + return std::make_unique(std::move(facet_factory), + germline_, + output_config, threading, progress); + } else { + if (type_ == Type::somatic) { + if (hard_filter_germline_) { + return std::make_unique(std::move(facet_factory), + somatic_, reference_, + output_config, threading, progress); + } else { + return std::make_unique(std::move(facet_factory), + germline_, somatic_, reference_, + output_config, threading, progress); + } + } else { + if (hard_filter_germline_) { + return std::make_unique(std::move(facet_factory), + somatic_, reference_, + output_config, threading, progress); + } else { + return std::make_unique(std::move(facet_factory), + germline_, somatic_, reference_, + output_config, threading, progress); + } + } + } } } // namespace csr diff --git a/src/core/csr/filters/threshold_filter_factory.hpp b/src/core/csr/filters/threshold_filter_factory.hpp index be2d73ef5..ea5d1591a 100644 --- a/src/core/csr/filters/threshold_filter_factory.hpp +++ b/src/core/csr/filters/threshold_filter_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef threshold_filter_factory_hpp @@ -22,10 +22,19 @@ class FacetFactory; class ThresholdFilterFactory : public VariantCallFilterFactory { public: + enum class Type { somatic, denovo }; + ThresholdFilterFactory() = default; ThresholdFilterFactory(std::string soft_expression); ThresholdFilterFactory(std::string hard_expression, std::string soft_expression); + ThresholdFilterFactory(std::string germline_hard_expression, std::string germline_soft_expression, + std::string somatic_hard_expression, std::string somatic_soft_expression, + std::string refcall_hard_expression, std::string refcall_soft_expression, + Type type = Type::somatic); + ThresholdFilterFactory(std::string somatic_hard_expression, std::string somatic_soft_expression, + std::string refcall_hard_expression, std::string refcall_soft_expression, + bool somatics_only = true, Type type = Type::somatic); ThresholdFilterFactory(const ThresholdFilterFactory&) = default; ThresholdFilterFactory& operator=(const ThresholdFilterFactory&) = default; @@ -36,8 +45,11 @@ class ThresholdFilterFactory : public VariantCallFilterFactory private: using Condition = ThresholdVariantCallFilter::Condition; + using ConditionVectorPair = ThresholdVariantCallFilter::ConditionVectorPair; - std::vector hard_conditions_, soft_conditions_; + ConditionVectorPair germline_, somatic_, reference_; + Type type_; + bool hard_filter_germline_; std::unique_ptr do_clone() const override; std::unique_ptr do_make(FacetFactory facet_factory, diff --git a/src/core/csr/filters/training_filter_factory.cpp b/src/core/csr/filters/training_filter_factory.cpp index 42869f3c0..9da49cfe5 100644 --- a/src/core/csr/filters/training_filter_factory.cpp +++ b/src/core/csr/filters/training_filter_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "training_filter_factory.hpp" diff --git a/src/core/csr/filters/training_filter_factory.hpp b/src/core/csr/filters/training_filter_factory.hpp index e8e5a3da6..bb15cbd35 100644 --- a/src/core/csr/filters/training_filter_factory.hpp +++ b/src/core/csr/filters/training_filter_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef training_filter_factory_hpp diff --git a/src/core/csr/filters/unsupervised_clustering_filter.cpp b/src/core/csr/filters/unsupervised_clustering_filter.cpp index 2fe8976bf..5db8db5ad 100644 --- a/src/core/csr/filters/unsupervised_clustering_filter.cpp +++ b/src/core/csr/filters/unsupervised_clustering_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "unsupervised_clustering_filter.hpp" @@ -23,7 +23,7 @@ void UnsupervisedClusteringFilter::annotate(VcfHeader::Builder& header) const // TODO } -void UnsupervisedClusteringFilter::record(const std::size_t call_idx, MeasureVector measures) const +void UnsupervisedClusteringFilter::record(const std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const { if (data_.size() == call_idx) { data_.push_back(std::move(measures)); @@ -48,7 +48,7 @@ void UnsupervisedClusteringFilter::prepare_for_classification(boost::optional classifications_; void annotate(VcfHeader::Builder& header) const override; - void record(std::size_t call_idx, MeasureVector measures) const override; + void record(std::size_t call_idx, std::size_t sample_idx, MeasureVector measures) const override; void prepare_for_classification(boost::optional& log) const override; - Classification classify(std::size_t call_idx) const override; + Classification classify(std::size_t call_idx, std::size_t sample_idx) const override; bool all_missing(const MeasureVector& measures) const noexcept; void remove_missing_features() const; diff --git a/src/core/csr/filters/unsupervised_clustering_filter_factory.cpp b/src/core/csr/filters/unsupervised_clustering_filter_factory.cpp index 710059894..36e922fd9 100644 --- a/src/core/csr/filters/unsupervised_clustering_filter_factory.cpp +++ b/src/core/csr/filters/unsupervised_clustering_filter_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "unsupervised_clustering_filter_factory.hpp" @@ -10,14 +10,18 @@ namespace octopus { namespace csr { +namespace { + std::vector parse_measures(const std::set& measure_names) { - std::vector result {}; + std::vector result{}; result.reserve(measure_names.size()); std::transform(std::cbegin(measure_names), std::cend(measure_names), std::back_inserter(result), make_measure); return result; } +} // namespace + UnsupervisedClusteringFilterFactory::UnsupervisedClusteringFilterFactory(const std::set& measure_names) : measures_ {parse_measures(measure_names)} {} diff --git a/src/core/csr/filters/unsupervised_clustering_filter_factory.hpp b/src/core/csr/filters/unsupervised_clustering_filter_factory.hpp index e6742ed13..5f5b01bfc 100644 --- a/src/core/csr/filters/unsupervised_clustering_filter_factory.hpp +++ b/src/core/csr/filters/unsupervised_clustering_filter_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef unsupervised_clustering_filter_factory_hpp diff --git a/src/core/csr/filters/variant_call_filter.cpp b/src/core/csr/filters/variant_call_filter.cpp index 3cc84cd46..f576cce08 100644 --- a/src/core/csr/filters/variant_call_filter.cpp +++ b/src/core/csr/filters/variant_call_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "variant_call_filter.hpp" @@ -6,8 +6,10 @@ #include #include #include -#include +#include +#include #include +#include #include #include @@ -26,6 +28,7 @@ #include "utils/append.hpp" #include "utils/parallel_transform.hpp" #include "io/variant/vcf_writer.hpp" +#include "io/variant/vcf_spec.hpp" namespace octopus { namespace csr { @@ -65,13 +68,24 @@ VariantCallFilter::VariantCallFilter(FacetFactory facet_factory, std::vector measures, OutputOptions output_config, ConcurrencyPolicy threading) -: debug_log_ {logging::get_debug_log()} +: measures_ {std::move(measures)} +, debug_log_ {logging::get_debug_log()} , facet_factory_ {std::move(facet_factory)} -, facet_names_ {get_facets(measures)} -, measures_ {std::move(measures)} +, facet_names_ {get_facets(measures_)} , output_config_ {output_config} +, duplicate_measures_ {} , workers_ {get_pool_size(threading)} -{} +{ + std::unordered_map measure_counts {}; + measure_counts.reserve(measures_.size()); + for (const auto& m : measures_) { + ++measure_counts[m]; + if (measure_counts[m] == 2) { + duplicate_measures_.push_back(m); + } + } + duplicate_measures_.shrink_to_fit(); +} void VariantCallFilter::filter(const VcfReader& source, VcfWriter& dest) const { @@ -84,6 +98,57 @@ void VariantCallFilter::filter(const VcfReader& source, VcfWriter& dest) const // protected methods +namespace { + +template +bool all_equal(const Range& values, BinaryPredicate pred) +{ + const auto not_pred = [&](const auto& lhs, const auto& rhs) { return !pred(lhs, rhs); }; + return std::adjacent_find(std::cbegin(values), std::cend(values), not_pred) == std::cend(values); +} + +} // namespace + +VariantCallFilter::Classification +VariantCallFilter::merge(const ClassificationList& sample_classifications, const MeasureVector& measures) const +{ + assert(!sample_classifications.empty()); + if (sample_classifications.size() == 1) { + return sample_classifications.front(); + } + Classification result {}; + if (all_equal(sample_classifications, [] (const auto& lhs, const auto& rhs) { return lhs.category == rhs.category; })) { + result.category = sample_classifications.front().category; + } else if (is_soft_filtered(sample_classifications, measures)) { + result.category = Classification::Category::soft_filtered; + } else { + result.category = Classification::Category::unfiltered; + } + if (result.category != Classification::Category::unfiltered) { + for (const auto& sample_classification : sample_classifications) { + utils::append(sample_classification.reasons, result.reasons); + } + std::sort(std::begin(result.reasons), std::end(result.reasons)); + result.reasons.erase(std::unique(std::begin(result.reasons), std::end(result.reasons)), std::end(result.reasons)); + result.reasons.shrink_to_fit(); + } + for (const auto& sample_classification : sample_classifications) { + if (sample_classification.quality) { + if (result.quality) { + result.quality = std::max(*result.quality, *sample_classification.quality); + } else { + result.quality = sample_classification.quality; + } + } + } + return result; +} + +VariantCallFilter::Classification VariantCallFilter::merge(const ClassificationList& sample_classifications) const +{ + return this->merge(sample_classifications, {}); +} + bool VariantCallFilter::can_measure_single_call() const noexcept { return facet_names_.empty(); @@ -163,8 +228,25 @@ VariantCallFilter::read_next_blocks(VcfIterator& first, const VcfIterator& last, VariantCallFilter::MeasureVector VariantCallFilter::measure(const VcfRecord& call) const { MeasureVector result(measures_.size()); - std::transform(std::cbegin(measures_), std::cend(measures_), std::begin(result), - [&call] (const MeasureWrapper& f) { return f(call); }); + if (duplicate_measures_.empty()) { + std::transform(std::cbegin(measures_), std::cend(measures_), std::begin(result), + [&call] (const MeasureWrapper& m) { return m(call); }); + } else { + std::unordered_map result_buffer {}; + result_buffer.reserve(duplicate_measures_.size()); + for (const auto& m : duplicate_measures_) { + result_buffer.emplace(m, m(call)); + } + std::transform(std::cbegin(measures_), std::cend(measures_), std::begin(result), + [&call, &result_buffer] (const MeasureWrapper& m) -> Measure::ResultType { + auto itr = result_buffer.find(m); + if (itr != std::cend(result_buffer)) { + return itr->second; + } else { + return m(call); + } + }); + } return result; } @@ -197,9 +279,21 @@ std::vector VariantCallFilter::measure(const st void VariantCallFilter::write(const VcfRecord& call, const Classification& classification, VcfWriter& dest) const { - if (classification.category != Classification::Category::hard_filtered) { + if (!is_hard_filtered(classification)) { + auto filtered_call = construct_template(call); + annotate(filtered_call, classification); + dest << filtered_call.build_once(); + } +} + +void VariantCallFilter::write(const VcfRecord& call, const Classification& classification, + const SampleList& samples, const ClassificationList& sample_classifications, + VcfWriter& dest) const +{ + if (!is_hard_filtered(classification)) { auto filtered_call = construct_template(call); annotate(filtered_call, classification); + annotate(filtered_call, samples, sample_classifications); dest << filtered_call.build_once(); } } @@ -212,15 +306,16 @@ void VariantCallFilter::annotate(VcfRecord::Builder& call, const MeasureVector& for (auto p : boost::combine(measures_, measures)) { const MeasureWrapper& measure {p.get<0>()}; const Measure::ResultType& measured_value {p.get<1>()}; - call.set_info(measure.name(), measure.serialise(measured_value)); + measure.annotate(call, measured_value); } } // private methods -void add_info(const MeasureWrapper& measure, VcfHeader::Builder& builder) +bool VariantCallFilter::is_soft_filtered(const ClassificationList& sample_classifications, const MeasureVector& measures) const { - builder.add_info(measure.name(), "1", "String", "CSR measure"); + return std::all_of(std::cbegin(sample_classifications), std::cend(sample_classifications), + [] (const auto& c) { return c.category != Classification::Category::unfiltered; }); } VcfHeader VariantCallFilter::make_header(const VcfReader& source) const @@ -234,7 +329,7 @@ VcfHeader VariantCallFilter::make_header(const VcfReader& source) const } if (output_config_.annotate_measures) { for (const auto& measure : measures_) { - add_info(measure, builder); + measure.annotate(builder); } } annotate(builder); @@ -253,6 +348,53 @@ VcfRecord::Builder VariantCallFilter::construct_template(const VcfRecord& call) return result; } +bool VariantCallFilter::is_hard_filtered(const Classification& classification) const noexcept +{ + return classification.category == Classification::Category::hard_filtered; +} + +void VariantCallFilter::annotate(VcfRecord::Builder& call, const SampleList& samples, const ClassificationList& sample_classifications) const +{ + assert(samples.size() == sample_classifications.size()); + bool all_hard_filtered {true}; + auto quality_name = this->genotype_quality_name(); + if (quality_name) { + call.add_format(std::move(*quality_name)); + } + for (auto p : boost::combine(samples, sample_classifications)) { + const SampleName& sample {p.get<0>()}; + const Classification& sample_classification {p.get<1>()}; + if (!is_hard_filtered(sample_classification)) { + annotate(call, sample, sample_classification); + all_hard_filtered = false; + } else { + call.clear_format(sample); + } + } + if (all_hard_filtered) { + call.clear_format(); + } else { + call.add_format(vcfspec::format::filter); + } +} + +void VariantCallFilter::annotate(VcfRecord::Builder& call, const SampleName& sample, Classification status) const +{ + if (status.category == Classification::Category::unfiltered) { + pass(sample, call); + } else { + fail(sample, call, std::move(status.reasons)); + } + const auto quality_name = this->genotype_quality_name(); + if (quality_name) { + if (status.quality) { + call.set_format(sample, *quality_name, *status.quality); + } else { + call.set_format_missing(sample, *quality_name); + } + } +} + void VariantCallFilter::annotate(VcfRecord::Builder& call, const Classification status) const { if (status.category == Classification::Category::unfiltered) { @@ -260,6 +402,15 @@ void VariantCallFilter::annotate(VcfRecord::Builder& call, const Classification } else { fail(call, std::move(status.reasons)); } + auto quality_name = this->call_quality_name(); + if (quality_name) { + call.add_info(*quality_name); + if (status.quality) { + call.set_info(*quality_name, *status.quality); + } else { + call.set_info_missing(*quality_name); + } + } } auto make_map(const std::vector& names, std::vector&& facets) @@ -296,23 +447,52 @@ VariantCallFilter::MeasureBlock VariantCallFilter::measure(const CallBlock& bloc } MeasureBlock result(block.size()); std::transform(std::cbegin(block), std::cend(block), std::begin(result), - [&] (const auto& call) { return measure(call, facets); }); + [&] (const VcfRecord& call) { return this->measure(call, facets); }); return result; } VariantCallFilter::MeasureVector VariantCallFilter::measure(const VcfRecord& call, const Measure::FacetMap& facets) const { MeasureVector result(measures_.size()); - std::transform(std::cbegin(measures_), std::cend(measures_), std::begin(result), - [&] (const MeasureWrapper& measure) { return measure(call, facets); }); + if (duplicate_measures_.empty()) { + std::transform(std::cbegin(measures_), std::cend(measures_), std::begin(result), + [&] (const MeasureWrapper& m) { return m(call, facets); }); + } else { + std::unordered_map result_buffer {}; + result_buffer.reserve(duplicate_measures_.size()); + for (const auto& m : duplicate_measures_) { + result_buffer.emplace(m, m(call, facets)); + } + std::transform(std::cbegin(measures_), std::cend(measures_), std::begin(result), + [&] (const MeasureWrapper& m) -> Measure::ResultType { + auto itr = result_buffer.find(m); + if (itr != std::cend(result_buffer)) { + return itr->second; + } else { + return m(call, facets); + } + }); + } return result; } +void VariantCallFilter::pass(const SampleName& sample, VcfRecord::Builder& call) const +{ + call.set_passed(sample); +} + void VariantCallFilter::pass(VcfRecord::Builder& call) const { call.set_passed(); } +void VariantCallFilter::fail(const SampleName& sample, VcfRecord::Builder& call, std::vector reasons) const +{ + for (auto& reason : reasons) { + call.add_filter(sample, std::move(reason)); + } +} + void VariantCallFilter::fail(VcfRecord::Builder& call, std::vector reasons) const { for (auto& reason : reasons) { diff --git a/src/core/csr/filters/variant_call_filter.hpp b/src/core/csr/filters/variant_call_filter.hpp index 78db7a5f0..9ca789ecb 100644 --- a/src/core/csr/filters/variant_call_filter.hpp +++ b/src/core/csr/filters/variant_call_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variant_call_filter_hpp @@ -78,9 +78,14 @@ class VariantCallFilter std::vector reasons = {}; boost::optional> quality = boost::none; }; + using ClassificationList = std::vector; + std::vector measures_; mutable boost::optional debug_log_; + virtual Classification merge(const ClassificationList& sample_classifications, const MeasureVector& measures) const; + virtual Classification merge(const ClassificationList& sample_classifications) const; + bool can_measure_single_call() const noexcept; bool can_measure_multiple_blocks() const noexcept; CallBlock read_next_block(VcfIterator& first, const VcfIterator& last, const SampleList& samples) const; @@ -89,6 +94,9 @@ class VariantCallFilter MeasureBlock measure(const CallBlock& block) const; std::vector measure(const std::vector& blocks) const; void write(const VcfRecord& call, const Classification& classification, VcfWriter& dest) const; + void write(const VcfRecord& call, const Classification& classification, + const SampleList& samples, const ClassificationList& sample_classifications, + VcfWriter& dest) const; void annotate(VcfRecord::Builder& call, const MeasureVector& measures) const; private: @@ -96,13 +104,16 @@ class VariantCallFilter FacetFactory facet_factory_; FacetNameSet facet_names_; - std::vector measures_; OutputOptions output_config_; + std::vector duplicate_measures_; mutable ThreadPool workers_; virtual void annotate(VcfHeader::Builder& header) const = 0; virtual void filter(const VcfReader& source, VcfWriter& dest, const SampleList& samples) const = 0; + virtual boost::optional call_quality_name() const { return boost::none; } + virtual boost::optional genotype_quality_name() const { return boost::none; } + virtual bool is_soft_filtered(const ClassificationList& sample_classifications, const MeasureVector& measures) const; VcfHeader make_header(const VcfReader& source) const; Measure::FacetMap compute_facets(const CallBlock& block) const; @@ -110,8 +121,13 @@ class VariantCallFilter MeasureBlock measure(const CallBlock& block, const Measure::FacetMap& facets) const; MeasureVector measure(const VcfRecord& call, const Measure::FacetMap& facets) const; VcfRecord::Builder construct_template(const VcfRecord& call) const; + bool is_hard_filtered(const Classification& classification) const noexcept; + void annotate(VcfRecord::Builder& call, const SampleList& samples, const ClassificationList& sample_classifications) const; + void annotate(VcfRecord::Builder& call, const SampleName& sample, Classification status) const; void annotate(VcfRecord::Builder& call, Classification status) const; + void pass(const SampleName& sample, VcfRecord::Builder& call) const; void pass(VcfRecord::Builder& call) const; + void fail(const SampleName& sample, VcfRecord::Builder& call, std::vector reasons) const; void fail(VcfRecord::Builder& call, std::vector reasons) const; bool is_multithreaded() const noexcept; unsigned max_concurrent_blocks() const noexcept; diff --git a/src/core/csr/filters/variant_call_filter_factory.cpp b/src/core/csr/filters/variant_call_filter_factory.cpp index 827f4ce7c..cd4fd5d9e 100644 --- a/src/core/csr/filters/variant_call_filter_factory.cpp +++ b/src/core/csr/filters/variant_call_filter_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "variant_call_filter_factory.hpp" @@ -14,14 +14,23 @@ std::unique_ptr VariantCallFilterFactory::clone() cons return do_clone(); } -std::unique_ptr VariantCallFilterFactory::make(const ReferenceGenome& reference, - BufferedReadPipe read_pipe, - VariantCallFilter::OutputOptions output_config, - boost::optional progress, - boost::optional max_threads) const +std::unique_ptr +VariantCallFilterFactory::make(const ReferenceGenome& reference, + BufferedReadPipe read_pipe, + VcfHeader input_header, + PloidyMap ploidies, + boost::optional pedigree, + VariantCallFilter::OutputOptions output_config, + boost::optional progress, + boost::optional max_threads) const { - FacetFactory facet_factory {reference, std::move(read_pipe)}; - return do_make(std::move(facet_factory), output_config, progress, {max_threads}); + if (pedigree) { + FacetFactory facet_factory {std::move(input_header), reference, std::move(read_pipe), std::move(ploidies), std::move(*pedigree)}; + return do_make(std::move(facet_factory), output_config, progress, {max_threads}); + } else { + FacetFactory facet_factory {std::move(input_header), reference, std::move(read_pipe), std::move(ploidies)}; + return do_make(std::move(facet_factory), output_config, progress, {max_threads}); + } } } // namespace csr diff --git a/src/core/csr/filters/variant_call_filter_factory.hpp b/src/core/csr/filters/variant_call_filter_factory.hpp index 1e9c62dfa..5dc8efbb2 100644 --- a/src/core/csr/filters/variant_call_filter_factory.hpp +++ b/src/core/csr/filters/variant_call_filter_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variant_call_filter_factory_hpp @@ -10,12 +10,15 @@ #include #include "logging/progress_meter.hpp" +#include "io/variant/vcf_header.hpp" #include "variant_call_filter.hpp" namespace octopus { class ReferenceGenome; class BufferedReadPipe; +class PloidyMap; +class Pedigree; namespace csr { @@ -28,18 +31,24 @@ class VariantCallFilterFactory std::unique_ptr clone() const; - std::unique_ptr make(const ReferenceGenome& reference, - BufferedReadPipe read_pipe, - VariantCallFilter::OutputOptions output_config, - boost::optional progress = boost::none, - boost::optional max_threads = 1) const; + std::unique_ptr + make(const ReferenceGenome& reference, + BufferedReadPipe read_pipe, + VcfHeader input_header, + PloidyMap ploidies, + boost::optional pedigree, + VariantCallFilter::OutputOptions output_config, + boost::optional progress = boost::none, + boost::optional max_threads = 1) const; private: virtual std::unique_ptr do_clone() const = 0; - virtual std::unique_ptr do_make(FacetFactory facet_factory, - VariantCallFilter::OutputOptions output_config, - boost::optional progress, - VariantCallFilter::ConcurrencyPolicy threading) const = 0; + virtual + std::unique_ptr + do_make(FacetFactory facet_factory, + VariantCallFilter::OutputOptions output_config, + boost::optional progress, + VariantCallFilter::ConcurrencyPolicy threading) const = 0; }; } // namespace csr diff --git a/src/core/csr/filters/variant_filter_utils.cpp b/src/core/csr/filters/variant_filter_utils.cpp new file mode 100644 index 000000000..c4b09cbc6 --- /dev/null +++ b/src/core/csr/filters/variant_filter_utils.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "variant_filter_utils.hpp" + +#include + +#include "../facets/facet_factory.hpp" +#include "../measures/is_somatic.hpp" +#include "../measures/is_denovo.hpp" +#include "threshold_filter.hpp" + +namespace octopus { namespace csr { + +namespace { + +template +void copy_flag(const VcfReader& source, VcfWriter& dest, Flag flag) +{ + FacetFactory facet_factory {source.fetch_header()}; + ThresholdVariantCallFilter::OutputOptions output_config {}; + VariantCallFilter::ConcurrencyPolicy thread_policy {}; + thread_policy.max_threads = 1; + ThresholdVariantCallFilter::ConditionVectorPair conditions {}; + conditions.hard.push_back({std::move(flag), make_wrapped_threshold>(false)}); + std::unique_ptr filter = std::make_unique(std::move(facet_factory), conditions, output_config, thread_policy); + filter->filter(source, dest); +} + +} // namesapce + +void copy_somatics(const VcfReader& source, VcfWriter& dest) +{ + copy_flag(source, dest, make_wrapped_measure(false)); +} + +void copy_somatics(VcfReader::Path source, VcfReader::Path dest) +{ + VcfReader src {std::move(source)}; + VcfWriter dst {std::move(dest)}; + copy_somatics(src, dst); +} + +void copy_denovos(const VcfReader& source, VcfWriter& dest) +{ + copy_flag(source, dest, make_wrapped_measure()); +} + +void copy_denovos(VcfReader::Path source, VcfReader::Path dest) +{ + VcfReader src {std::move(source)}; + VcfWriter dst {std::move(dest)}; + copy_denovos(src, dst); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/filters/variant_filter_utils.hpp b/src/core/csr/filters/variant_filter_utils.hpp new file mode 100644 index 000000000..0f6f66048 --- /dev/null +++ b/src/core/csr/filters/variant_filter_utils.hpp @@ -0,0 +1,21 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef variant_filter_utils_hpp +#define variant_filter_utils_hpp + +#include "io/variant/vcf_reader.hpp" +#include "io/variant/vcf_writer.hpp" + +namespace octopus { namespace csr { + +void copy_somatics(const VcfReader& source, VcfWriter& dest); +void copy_somatics(VcfReader::Path source, VcfReader::Path dest); + +void copy_denovos(const VcfReader& source, VcfWriter& dest); +void copy_denovos(VcfReader::Path source, VcfReader::Path dest); + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/allele_frequency.cpp b/src/core/csr/measures/allele_frequency.cpp index b1ad62870..ee483779e 100644 --- a/src/core/csr/measures/allele_frequency.cpp +++ b/src/core/csr/measures/allele_frequency.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "allele_frequency.hpp" @@ -13,63 +13,102 @@ #include "io/variant/vcf_record.hpp" #include "io/variant/vcf_spec.hpp" #include "utils/genotype_reader.hpp" +#include "../facets/samples.hpp" #include "../facets/read_assignments.hpp" namespace octopus { namespace csr { +const std::string AlleleFrequency::name_ = "AF"; + std::unique_ptr AlleleFrequency::do_clone() const { return std::make_unique(*this); } +namespace { + +bool is_canonical(const VcfRecord::NucleotideSequence& allele) noexcept +{ + const static VcfRecord::NucleotideSequence deleted_allele {vcfspec::deletedBase}; + return !(allele == vcfspec::missingValue || allele == deleted_allele); +} + +bool has_called_alt_allele(const VcfRecord& call, const VcfRecord::SampleName& sample) +{ + if (!call.has_genotypes()) return true; + const auto& genotype = get_genotype(call, sample); + return std::any_of(std::cbegin(genotype), std::cend(genotype), + [&] (const auto& allele) { return allele != call.ref() && is_canonical(allele); }); +} + +bool is_evaluable(const VcfRecord& call, const VcfRecord::SampleName& sample) +{ + return has_called_alt_allele(call, sample); +} + +} // namespace + Measure::ResultType AlleleFrequency::do_evaluate(const VcfRecord& call, const FacetMap& facets) const { - const auto& assignments = get_value(facets.at("ReadAssignments")); - boost::optional result {}; - for (const auto& p : assignments) { - std::vector alleles; bool has_ref; - std::tie(alleles, has_ref) = get_called_alleles(call, p.first, true); - std::size_t read_count {0}; - std::vector allele_counts(alleles.size()); - for (const auto& h : p.second) { - const auto& haplotype = h.first; - const auto& reads = h.second; - const auto haplotype_support_depth = count_overlapped(reads, call); - if (haplotype_support_depth > 0) { - std::transform(std::cbegin(alleles), std::cend(alleles), std::cbegin(allele_counts), std::begin(allele_counts), - [&] (const auto& allele, auto count) { - if (haplotype.includes(allele)) { - count += haplotype_support_depth; - } - return count; - }); - read_count += haplotype_support_depth; + const auto& samples = get_value(facets.at("Samples")); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; + std::vector> result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + boost::optional sample_result {}; + if (is_evaluable(call, sample)) { + const auto& sample_assignments = assignments.at(sample); + std::vector alleles; bool has_ref; + std::tie(alleles, has_ref) = get_called_alleles(call, sample, true); + assert(!alleles.empty()); + std::size_t read_count {0}; + std::vector allele_counts(alleles.size()); + for (const auto& p : sample_assignments) { + const auto& haplotype = p.first; + const auto& reads = p.second; + const auto haplotype_support_depth = count_overlapped(reads, call); + if (haplotype_support_depth > 0) { + std::transform(std::cbegin(alleles), std::cend(alleles), std::cbegin(allele_counts), std::begin(allele_counts), + [&] (const auto& allele, auto count) { + if (haplotype.includes(allele)) { + count += haplotype_support_depth; + } + return count; + }); + read_count += haplotype_support_depth; + } } - } - if (read_count > 0) { - auto first_called_count_itr = std::cbegin(allele_counts); - if (has_ref) ++first_called_count_itr; - assert(first_called_count_itr != std::cend(allele_counts)); - const auto min_count_itr = std::min_element(first_called_count_itr, std::cend(allele_counts)); - const auto maf = static_cast(*min_count_itr) / read_count; - if (result) { - result = std::min(*result, maf); - } else { - result = maf; + if (read_count > 0) { + auto first_called_count_itr = std::cbegin(allele_counts); + if (has_ref) ++first_called_count_itr; + assert(first_called_count_itr != std::cend(allele_counts)); + const auto min_count_itr = std::min_element(first_called_count_itr, std::cend(allele_counts)); + sample_result = static_cast(*min_count_itr) / read_count; } } + result.push_back(sample_result); } return result; } -std::string AlleleFrequency::do_name() const +Measure::ResultCardinality AlleleFrequency::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& AlleleFrequency::do_name() const +{ + return name_; +} + +std::string AlleleFrequency::do_describe() const { - return "AF"; + return "Minor allele frequency of ALT alleles"; } std::vector AlleleFrequency::do_requirements() const { - return {"ReadAssignments"}; + return {"Samples", "ReadAssignments"}; } } // namespace csr diff --git a/src/core/csr/measures/allele_frequency.hpp b/src/core/csr/measures/allele_frequency.hpp index 304893779..0b714a129 100644 --- a/src/core/csr/measures/allele_frequency.hpp +++ b/src/core/csr/measures/allele_frequency.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef allele_frequency_hpp @@ -17,9 +17,12 @@ namespace csr { class AlleleFrequency : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; }; diff --git a/src/core/csr/measures/alt_allele_count.cpp b/src/core/csr/measures/alt_allele_count.cpp new file mode 100644 index 000000000..a7c1c4d7a --- /dev/null +++ b/src/core/csr/measures/alt_allele_count.cpp @@ -0,0 +1,66 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "alt_allele_count.hpp" + +#include +#include + +#include + +#include "io/variant/vcf_record.hpp" +#include "io/variant/vcf_spec.hpp" +#include "../facets/samples.hpp" + +namespace octopus { namespace csr { + +const std::string AltAlleleCount::name_ = "AC"; + +std::unique_ptr AltAlleleCount::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +int count_non_ref_alleles(const VcfRecord& call, const VcfRecord::SampleName& sample) +{ + const auto genotype = get_genotype(call, sample); + return genotype.size() - std::count(std::cbegin(genotype), std::cend(genotype), call.ref()); +} + +} // namespace + +Measure::ResultType AltAlleleCount::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& samples = get_value(facets.at("Samples")); + std::vector result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + result.push_back(count_non_ref_alleles(call, sample)); + } + return result; +} + +Measure::ResultCardinality AltAlleleCount::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& AltAlleleCount::do_name() const +{ + return name_; +} + +std::string AltAlleleCount::do_describe() const +{ + return "Number of non-reference alleles called for each sample"; +} + +std::vector AltAlleleCount::do_requirements() const +{ + return {"Samples"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/alt_allele_count.hpp b/src/core/csr/measures/alt_allele_count.hpp new file mode 100644 index 000000000..ab9e06c92 --- /dev/null +++ b/src/core/csr/measures/alt_allele_count.hpp @@ -0,0 +1,32 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef alt_allele_count_hpp +#define alt_allele_count_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class AltAlleleCount : public Measure +{ + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/ambiguous_read_fraction.cpp b/src/core/csr/measures/ambiguous_read_fraction.cpp new file mode 100644 index 000000000..9ddb889b3 --- /dev/null +++ b/src/core/csr/measures/ambiguous_read_fraction.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "ambiguous_read_fraction.hpp" + +#include +#include + +#include + +#include "core/tools/read_assigner.hpp" +#include "core/types/allele.hpp" +#include "io/variant/vcf_record.hpp" +#include "io/variant/vcf_spec.hpp" +#include "utils/genotype_reader.hpp" +#include "../facets/samples.hpp" +#include "../facets/overlapping_reads.hpp" +#include "../facets/read_assignments.hpp" + +namespace octopus { namespace csr { + +const std::string AmbiguousReadFraction::name_ = "ARF"; + +std::unique_ptr AmbiguousReadFraction::do_clone() const +{ + return std::make_unique(*this); +} + +Measure::ResultType AmbiguousReadFraction::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& samples = get_value(facets.at("Samples")); + const auto& reads = get_value(facets.at("OverlappingReads")); + const auto& ambiguous_reads = get_value(facets.at("ReadAssignments")).ambiguous; + std::vector> result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + const auto num_overlapping_reads = count_overlapped(reads.at(sample), call); + if (num_overlapping_reads > 0) { + const auto num_ambiguous_reads = count_overlapped(ambiguous_reads.at(sample), call); + const auto ambiguous_fraction = static_cast(num_ambiguous_reads) / num_overlapping_reads; + result.emplace_back(ambiguous_fraction); + } else { + result.push_back(boost::none); + } + } + return result; +} + +Measure::ResultCardinality AmbiguousReadFraction::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& AmbiguousReadFraction::do_name() const +{ + return name_; +} + +std::string AmbiguousReadFraction::do_describe() const +{ + return "Fraction of reads overlapping the call that cannot be assigned to a unique haplotype"; +} + +std::vector AmbiguousReadFraction::do_requirements() const +{ + return {"Samples", "OverlappingReads", "ReadAssignments"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/ambiguous_read_fraction.hpp b/src/core/csr/measures/ambiguous_read_fraction.hpp new file mode 100644 index 000000000..376cf0176 --- /dev/null +++ b/src/core/csr/measures/ambiguous_read_fraction.hpp @@ -0,0 +1,32 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef ambiguous_read_fraction_hpp +#define ambiguous_read_fraction_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class AmbiguousReadFraction : public Measure +{ + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/clipped_read_fraction.cpp b/src/core/csr/measures/clipped_read_fraction.cpp index e9a60dd63..be2cf827c 100644 --- a/src/core/csr/measures/clipped_read_fraction.cpp +++ b/src/core/csr/measures/clipped_read_fraction.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "clipped_read_fraction.hpp" @@ -14,6 +14,8 @@ namespace octopus { namespace csr { +const std::string ClippedReadFraction::name_ = "CRF"; + std::unique_ptr ClippedReadFraction::do_clone() const { return std::make_unique(*this); @@ -32,7 +34,7 @@ bool is_significantly_clipped(const AlignedRead& read) noexcept return is_soft_clipped(read) && clip_fraction(read) > 0.25; } -double clipped_fraction(const ReadMap& reads, const GenomicRegion& region) +Measure::ResultType clipped_fraction(const ReadMap& reads, const GenomicRegion& region) { unsigned num_reads {0}, num_soft_clipped_reads {0}; for (const auto& p : reads) { @@ -41,7 +43,11 @@ double clipped_fraction(const ReadMap& reads, const GenomicRegion& region) ++num_reads; } } - return static_cast(num_soft_clipped_reads) / num_reads; + boost::optional result {}; + if (num_reads > 0) { + result = static_cast(num_soft_clipped_reads) / num_reads; + } + return result; } } // namespace @@ -52,9 +58,19 @@ Measure::ResultType ClippedReadFraction::do_evaluate(const VcfRecord& call, cons return clipped_fraction(reads, mapped_region(call)); } -std::string ClippedReadFraction::do_name() const +Measure::ResultCardinality ClippedReadFraction::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& ClippedReadFraction::do_name() const +{ + return name_; +} + +std::string ClippedReadFraction::do_describe() const { - return "CRF"; + return "Fraction of clipped reads covering the call"; } std::vector ClippedReadFraction::do_requirements() const diff --git a/src/core/csr/measures/clipped_read_fraction.hpp b/src/core/csr/measures/clipped_read_fraction.hpp index fecc2ebe6..1d2e68932 100644 --- a/src/core/csr/measures/clipped_read_fraction.hpp +++ b/src/core/csr/measures/clipped_read_fraction.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef clipped_read_fraction_hpp @@ -17,9 +17,12 @@ namespace csr { class ClippedReadFraction : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; }; diff --git a/src/core/csr/measures/denovo_contamination.cpp b/src/core/csr/measures/denovo_contamination.cpp new file mode 100644 index 000000000..1dada3704 --- /dev/null +++ b/src/core/csr/measures/denovo_contamination.cpp @@ -0,0 +1,179 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "denovo_contamination.hpp" + +#include +#include +#include + +#include +#include +#include + +#include "basics/trio.hpp" +#include "core/types/allele.hpp" +#include "core/types/haplotype.hpp" +#include "core/types/genotype.hpp" +#include "core/tools/read_assigner.hpp" +#include "io/variant/vcf_record.hpp" +#include "utils/genotype_reader.hpp" +#include "utils/concat.hpp" +#include "utils/append.hpp" +#include "is_denovo.hpp" +#include "../facets/samples.hpp" +#include "../facets/genotypes.hpp" +#include "../facets/pedigree.hpp" +#include "../facets/read_assignments.hpp" + +namespace octopus { namespace csr { + +const std::string DeNovoContamination::name_ = "DC"; + +std::unique_ptr DeNovoContamination::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +bool is_denovo(const VcfRecord& call, const Measure::FacetMap& facets) +{ + return boost::get(IsDenovo(false).evaluate(call, facets)); +} + +auto find_child_idx(const std::vector& samples, const octopus::Pedigree& pedigree) +{ + assert(samples.size() == 3); + if (is_parent_of(samples[0], samples[1], pedigree)) { + return 1; + } else if (is_parent_of(samples[1], samples[0], pedigree)) { + return 0; + } else { + return 2; + } +} + +template +void sort_unique(Container& values) +{ + std::sort(std::begin(values), std::end(values)); + values.erase(std::unique(std::begin(values), std::end(values)), std::end(values)); +} + +auto get_denovo_alleles(const VcfRecord& denovo, const Trio& trio) +{ + auto parent_alleles = concat(get_called_alleles(denovo, trio.mother(), true).first, + get_called_alleles(denovo, trio.father(), true).first); + auto child_alleles = get_called_alleles(denovo, trio.child(), true).first; + sort_unique(parent_alleles); sort_unique(child_alleles); + std::vector result {}; + result.reserve(child_alleles.size()); + std::set_difference(std::cbegin(child_alleles), std::cend(child_alleles), + std::cbegin(parent_alleles), std::cend(parent_alleles), + std::back_inserter(result)); + return result; +} + +auto get_denovo_haplotypes(const Facet::GenotypeMap& genotypes, const std::vector& denovos) +{ + std::vector result {}; + if (!denovos.empty()) { + const auto allele_region = denovos.front().mapped_region(); + for (const auto& p :genotypes) { + const auto& overlapped_genotypes = overlap_range(p.second, allele_region); + if (size(overlapped_genotypes) == 1) { + const auto& genotype = overlapped_genotypes.front(); + for (const auto& haplotype : genotype) { + if (std::any_of(std::cbegin(denovos), std::cend(denovos), + [&] (const auto& denovo) { return haplotype.includes(denovo); })) { + result.push_back(haplotype); + } + } + } + } + sort_unique(result); + } + return result; +} + +auto get_denovo_haplotypes(const VcfRecord& denovo, const Facet::GenotypeMap& genotypes, const Trio& trio) +{ + const auto denovo_alleles = get_denovo_alleles(denovo, trio); + return get_denovo_haplotypes(genotypes, denovo_alleles); +} + +} // namespace + +Measure::ResultType DeNovoContamination::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + boost::optional result {}; + if (is_denovo(call, facets)) { + result = 0; + const auto& samples = get_value(facets.at("Samples")); + const auto& pedigree = get_value(facets.at("Pedigree")); + assert(is_trio(samples, pedigree)); // TODO: Implement for general pedigree + const auto trio = *make_trio(samples[find_child_idx(samples, pedigree)], pedigree); + const auto& genotypes = get_value(facets.at("Genotypes")); + const auto denovo_haplotypes = get_denovo_haplotypes(call, genotypes, trio); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; + Genotype denovo_genotype {static_cast(denovo_haplotypes.size() + 1)}; + HaplotypeProbabilityMap haplotype_priors {}; + haplotype_priors.reserve(denovo_haplotypes.size() + 1); + for (const auto& haplotype : denovo_haplotypes) { + denovo_genotype.emplace(haplotype); + haplotype_priors[haplotype] = -1; + } + const std::array parents {trio.mother(), trio.father()}; + for (const auto& sample : parents) { + for (const auto& p : assignments.at(sample)) { + const auto overlapped_reads = copy_overlapped(p.second, call); + if (!overlapped_reads.empty()) { + const Haplotype& assigned_haplotype {p.first}; + if (!denovo_genotype.contains(assigned_haplotype)) { + auto dummy = denovo_genotype; + dummy.emplace(assigned_haplotype); + haplotype_priors[assigned_haplotype] = 0; + const auto support = compute_haplotype_support(dummy, overlapped_reads, haplotype_priors); + haplotype_priors.erase(assigned_haplotype); + for (const auto& denovo : denovo_haplotypes) { + if (support.count(denovo) == 1) { + *result += support.at(denovo).size(); + } + } + } else { + // This could happen if we don't call all 'de novo' alleles on the called de novo haplotype. + *result += overlapped_reads.size(); + } + } + } + } + } + return result; +} + +Measure::ResultCardinality DeNovoContamination::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& DeNovoContamination::do_name() const +{ + return name_; +} + +std::string DeNovoContamination::do_describe() const +{ + return "Number of reads supporting a de novo haplotype in the normal"; +} + +std::vector DeNovoContamination::do_requirements() const +{ + std::vector result {"Samples", "Genotypes", "ReadAssignments", "Pedigree"}; + utils::append(IsDenovo(false).requirements(), result); + sort_unique(result); + return result; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/denovo_contamination.hpp b/src/core/csr/measures/denovo_contamination.hpp new file mode 100644 index 000000000..649089df0 --- /dev/null +++ b/src/core/csr/measures/denovo_contamination.hpp @@ -0,0 +1,33 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef denovo_contamination_hpp +#define denovo_contamination_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class DeNovoContamination : public Measure +{ + const static std::string name_; + + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/depth.cpp b/src/core/csr/measures/depth.cpp index e2d7b2463..bed222d88 100644 --- a/src/core/csr/measures/depth.cpp +++ b/src/core/csr/measures/depth.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "depth.hpp" @@ -7,11 +7,19 @@ #include "io/variant/vcf_record.hpp" #include "io/variant/vcf_spec.hpp" +#include "../facets/samples.hpp" #include "../facets/overlapping_reads.hpp" namespace octopus { namespace csr { -Depth::Depth(bool recalculate) : recalculate_ {recalculate} {} +const std::string Depth::name_ = "DP"; + +Depth::Depth() : Depth {false, false} {} + +Depth::Depth(bool recalculate, bool aggregate_samples) +: recalculate_ {recalculate} +, aggregate_ {aggregate_samples} +{} std::unique_ptr Depth::do_clone() const { @@ -20,27 +28,63 @@ std::unique_ptr Depth::do_clone() const Measure::ResultType Depth::do_evaluate(const VcfRecord& call, const FacetMap& facets) const { - if (recalculate_) { - const auto& reads = get_value(facets.at("OverlappingReads")); - return static_cast(count_overlapped(reads, call)); + if (aggregate_) { + if (recalculate_) { + const auto& reads = get_value(facets.at("OverlappingReads")); + return static_cast(count_overlapped(reads, call)); + } else { + return static_cast(std::stoull(call.info_value(vcfspec::info::combinedReadDepth).front())); + } + } else { + const auto& samples = get_value(facets.at("Samples")); + std::vector result {}; + result.reserve(samples.size()); + if (recalculate_) { + const auto& reads = get_value(facets.at("OverlappingReads")); + for (const auto& sample : samples) { + result.push_back(count_overlapped(reads.at(sample), call)); + } + } else { + for (const auto& sample : samples) { + result.push_back(std::stoull(call.get_sample_value(sample, vcfspec::format::combinedReadDepth).front())); + } + } + return result; + } +} + +Measure::ResultCardinality Depth::do_cardinality() const noexcept +{ + if (aggregate_) { + return ResultCardinality::one; } else { - return static_cast(std::stoull(call.info_value(vcfspec::info::combinedReadDepth).front())); + return ResultCardinality::num_samples; } } -std::string Depth::do_name() const +const std::string& Depth::do_name() const { - return "DP"; + return name_; +} + +std::string Depth::do_describe() const +{ + return "Number of read overlapping the call"; } std::vector Depth::do_requirements() const { - if (recalculate_) { - return {"OverlappingReads"}; - } else { - return {}; - } + std::vector result {}; + if (!aggregate_) result.push_back("Samples"); + if (recalculate_) result.push_back("OverlappingReads"); + return result; +} + +bool Depth::is_equal(const Measure& other) const noexcept +{ + const auto& other_depth = static_cast(other); + return recalculate_ == other_depth.recalculate_ && aggregate_ == other_depth.aggregate_; } } // namespace csr -} // namespace octopus \ No newline at end of file +} // namespace octopus diff --git a/src/core/csr/measures/depth.hpp b/src/core/csr/measures/depth.hpp index a6ee1d96e..426f2cc91 100644 --- a/src/core/csr/measures/depth.hpp +++ b/src/core/csr/measures/depth.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef depth_hpp @@ -17,13 +17,18 @@ namespace csr { class Depth : public Measure { - bool recalculate_; + const static std::string name_; + bool recalculate_ = false, aggregate_ = false; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; public: - Depth(bool recalculate = false); + Depth(); + Depth(bool recalculate, bool aggregate_samples); }; } // namespace csr diff --git a/src/core/csr/measures/filtered_read_fraction.cpp b/src/core/csr/measures/filtered_read_fraction.cpp index 810d85ed8..6ab4b8eee 100644 --- a/src/core/csr/measures/filtered_read_fraction.cpp +++ b/src/core/csr/measures/filtered_read_fraction.cpp @@ -1,15 +1,20 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "filtered_read_fraction.hpp" +#include +#include + #include namespace octopus { namespace csr { -FilteredReadFraction::FilteredReadFraction() -: calling_depth_ {false} -, filtering_depth_ {true} +const std::string FilteredReadFraction::name_ = "FRF"; + +FilteredReadFraction::FilteredReadFraction(bool aggregate_samples) +: calling_depth_ {false, aggregate_samples} +, filtering_depth_ {true, aggregate_samples} {} std::unique_ptr FilteredReadFraction::do_clone() const @@ -19,18 +24,38 @@ std::unique_ptr FilteredReadFraction::do_clone() const Measure::ResultType FilteredReadFraction::do_evaluate(const VcfRecord& call, const FacetMap& facets) const { - const auto filtering_depth = boost::get(filtering_depth_.evaluate(call, facets)); - double result {0}; - if (filtering_depth > 0) { - auto calling_depth = boost::get(calling_depth_.evaluate(call, facets)); - result = 1.0 - (static_cast(calling_depth) / filtering_depth); + if (filtering_depth_.cardinality() == Measure::ResultCardinality::num_samples) { + const auto calling_depth = boost::get>(calling_depth_.evaluate(call, facets)); + const auto filtering_depth = boost::get>(filtering_depth_.evaluate(call, facets)); + assert(calling_depth.size() == filtering_depth.size()); + std::vector result(calling_depth.size()); + std::transform(std::cbegin(calling_depth), std::cend(calling_depth), std::cbegin(filtering_depth), std::begin(result), + [] (auto cd, auto fd) { return fd > 0 ? 1.0 - (static_cast(cd) / fd) : 0; }); + return result; + } else { + const auto filtering_depth = boost::get(filtering_depth_.evaluate(call, facets)); + double result {0}; + if (filtering_depth > 0) { + auto calling_depth = boost::get(calling_depth_.evaluate(call, facets)); + result = 1.0 - (static_cast(calling_depth) / filtering_depth); + } + return result; } - return result; } -std::string FilteredReadFraction::do_name() const +Measure::ResultCardinality FilteredReadFraction::do_cardinality() const noexcept +{ + return filtering_depth_.cardinality(); +} + +const std::string& FilteredReadFraction::do_name() const { - return "FRF"; + return name_; +} + +std::string FilteredReadFraction::do_describe() const +{ + return "Fraction of reads filtered for calling"; } std::vector FilteredReadFraction::do_requirements() const @@ -38,5 +63,10 @@ std::vector FilteredReadFraction::do_requirements() const return filtering_depth_.requirements(); } +bool FilteredReadFraction::is_equal(const Measure& other) const noexcept +{ + return calling_depth_ == static_cast(other).calling_depth_; +} + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/filtered_read_fraction.hpp b/src/core/csr/measures/filtered_read_fraction.hpp index 142690b9f..debb96eda 100644 --- a/src/core/csr/measures/filtered_read_fraction.hpp +++ b/src/core/csr/measures/filtered_read_fraction.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef filtered_read_fraction_hpp @@ -18,13 +18,17 @@ namespace csr { class FilteredReadFraction : public Measure { + const static std::string name_; Depth calling_depth_, filtering_depth_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; public: - FilteredReadFraction(); + FilteredReadFraction(bool aggregate_samples = false); }; } // namespace csr diff --git a/src/core/csr/measures/gc_content.cpp b/src/core/csr/measures/gc_content.cpp index be14804fa..1218967ec 100644 --- a/src/core/csr/measures/gc_content.cpp +++ b/src/core/csr/measures/gc_content.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "gc_content.hpp" @@ -11,6 +11,8 @@ namespace octopus { namespace csr { +const std::string GCContent::name_ = "GC"; + std::unique_ptr GCContent::do_clone() const { return std::make_unique(*this); @@ -22,9 +24,19 @@ Measure::ResultType GCContent::do_evaluate(const VcfRecord& call, const FacetMap return utils::gc_content(reference.sequence()); } -std::string GCContent::do_name() const +Measure::ResultCardinality GCContent::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& GCContent::do_name() const +{ + return name_; +} + +std::string GCContent::do_describe() const { - return "GC"; + return "GC bias of the reference in a window centred on the call"; } std::vector GCContent::do_requirements() const diff --git a/src/core/csr/measures/gc_content.hpp b/src/core/csr/measures/gc_content.hpp index 299209122..bfce071b6 100644 --- a/src/core/csr/measures/gc_content.hpp +++ b/src/core/csr/measures/gc_content.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef gc_content_hpp @@ -17,9 +17,12 @@ namespace csr { class GCContent : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; }; diff --git a/src/core/csr/measures/genotype_quality.cpp b/src/core/csr/measures/genotype_quality.cpp new file mode 100644 index 000000000..bf104b221 --- /dev/null +++ b/src/core/csr/measures/genotype_quality.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "genotype_quality.hpp" + +#include +#include + +#include +#include + +#include "io/variant/vcf_record.hpp" +#include "io/variant/vcf_spec.hpp" +#include "../facets/samples.hpp" + +namespace octopus { namespace csr { + +const std::string GenotypeQuality::name_ = "GQ"; + +std::unique_ptr GenotypeQuality::do_clone() const +{ + return std::make_unique(*this); +} + +Measure::ResultType GenotypeQuality::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& samples = get_value(facets.at("Samples")); + std::vector> result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + static const std::string gq_field {vcfspec::format::conditionalQuality}; + boost::optional sample_gq {}; + if (call.has_format(gq_field)) { + sample_gq = std::stod(call.get_sample_value(sample, gq_field).front()); + } + result.push_back(sample_gq); + } + return result; +} + +Measure::ResultCardinality GenotypeQuality::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& GenotypeQuality::do_name() const +{ + return name_; +} + +std::string GenotypeQuality::do_describe() const +{ + return "GQ of each sample"; +} + +std::vector GenotypeQuality::do_requirements() const +{ + return {"Samples"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/genotype_quality.hpp b/src/core/csr/measures/genotype_quality.hpp new file mode 100644 index 000000000..13ca563f2 --- /dev/null +++ b/src/core/csr/measures/genotype_quality.hpp @@ -0,0 +1,33 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef genotype_quality_hpp +#define genotype_quality_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class GenotypeQuality : public Measure +{ + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; + bool is_required_vcf_field() const noexcept override { return true; } +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/is_denovo.cpp b/src/core/csr/measures/is_denovo.cpp index aaa39d264..8f7f21592 100644 --- a/src/core/csr/measures/is_denovo.cpp +++ b/src/core/csr/measures/is_denovo.cpp @@ -1,26 +1,96 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "is_denovo.hpp" +#include "basics/pedigree.hpp" #include "io/variant/vcf_record.hpp" #include "config/octopus_vcf.hpp" +#include "../facets/samples.hpp" +#include "../facets/pedigree.hpp" namespace octopus { namespace csr { +const std::string IsDenovo::name_ = "DENOVO"; + +IsDenovo::IsDenovo(bool report_sample_status) : report_sample_status_ {report_sample_status} {} + std::unique_ptr IsDenovo::do_clone() const { return std::make_unique(*this); } -Measure::ResultType IsDenovo::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +namespace { + +bool is_denovo(const VcfRecord& call) { return call.has_info(vcf::spec::info::denovo); } -std::string IsDenovo::do_name() const +auto child_idx(const std::vector& samples, const octopus::Pedigree& pedigree) +{ + assert(samples.size() == 3); + if (is_parent_of(samples[0], samples[1], pedigree)) { + return 1; + } else if (is_parent_of(samples[1], samples[0], pedigree)) { + return 0; + } else { + return 2; + } +} + +} // namespace + +Measure::ResultType IsDenovo::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + if (report_sample_status_) { + const auto& samples = get_value(facets.at("Samples")); + std::vector result(samples.size(), false); + if (is_denovo(call)) { + const auto& pedigree = get_value(facets.at("Pedigree")); + result[child_idx(samples, pedigree)] = true; + } + return result; + } else { + return is_denovo(call); + } +} + +Measure::ResultCardinality IsDenovo::do_cardinality() const noexcept +{ + if (report_sample_status_) { + return ResultCardinality::num_samples; + } else { + return ResultCardinality::one; + } +} + +const std::string& IsDenovo::do_name() const +{ + return name_; +} + +std::string IsDenovo::do_describe() const +{ + if (report_sample_status_) { + return "DENOVO status of each sample"; + } else { + return "Is the call marked DENOVO"; + } +} + +std::vector IsDenovo::do_requirements() const +{ + if (report_sample_status_) { + return {"Samples", "Pedigree"}; + } else { + return {}; + } +} + +bool IsDenovo::is_equal(const Measure& other) const noexcept { - return "DENOVO"; + return report_sample_status_ == static_cast(other).report_sample_status_; } } // namespace csr diff --git a/src/core/csr/measures/is_denovo.hpp b/src/core/csr/measures/is_denovo.hpp index 01b2a8760..ec7d32f71 100644 --- a/src/core/csr/measures/is_denovo.hpp +++ b/src/core/csr/measures/is_denovo.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef is_denovo_hpp @@ -16,9 +16,17 @@ namespace csr { class IsDenovo : public Measure { + bool report_sample_status_; + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; +public: + IsDenovo(bool report_sample_status = true); }; } // namespace csr diff --git a/src/core/csr/measures/is_refcall.cpp b/src/core/csr/measures/is_refcall.cpp new file mode 100644 index 000000000..9d93fbf4c --- /dev/null +++ b/src/core/csr/measures/is_refcall.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "is_refcall.hpp" + +#include "io/variant/vcf_record.hpp" +#include "config/octopus_vcf.hpp" +#include "../facets/samples.hpp" + +namespace octopus { namespace csr { + +const std::string IsRefcall::name_ = "REFCALL"; + +IsRefcall::IsRefcall(bool report_sample_status) : report_sample_status_ {report_sample_status} {} + +std::unique_ptr IsRefcall::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +bool is_refcall(const VcfRecord& record) +{ + return record.alt().size() == 1 && record.alt()[0] == vcf::spec::allele::nonref; +} + +} // namespace + +Measure::ResultType IsRefcall::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + if (report_sample_status_) { + const auto& samples = get_value(facets.at("Samples")); + std::vector result(samples.size()); + std::transform(std::cbegin(samples), std::cend(samples), std::begin(result), + [&call] (const auto& sample) { return call.is_homozygous_ref(sample); }); + return result; + } else { + return is_refcall(call); + } +} + +Measure::ResultCardinality IsRefcall::do_cardinality() const noexcept +{ + if (report_sample_status_) { + return ResultCardinality::num_samples; + } else { + return ResultCardinality::one; + } +} + +const std::string& IsRefcall::do_name() const +{ + return name_; +} + +std::string IsRefcall::do_describe() const +{ + if (report_sample_status_) { + return "REFCALL status of each sample"; + } else { + return "Is the call marked REFCALL"; + } +} + +std::vector IsRefcall::do_requirements() const +{ + if (report_sample_status_) { + return {"Samples"}; + } else { + return {}; + } +} + +bool IsRefcall::is_equal(const Measure& other) const noexcept +{ + return report_sample_status_ == static_cast(other).report_sample_status_; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/is_refcall.hpp b/src/core/csr/measures/is_refcall.hpp new file mode 100644 index 000000000..cd3c1b755 --- /dev/null +++ b/src/core/csr/measures/is_refcall.hpp @@ -0,0 +1,36 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef is_refcall_hpp +#define is_refcall_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class IsRefcall : public Measure +{ + bool report_sample_status_; + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; +public: + IsRefcall(bool report_sample_status = true); +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/is_somatic.cpp b/src/core/csr/measures/is_somatic.cpp index 15e1b0655..475ca8a6e 100644 --- a/src/core/csr/measures/is_somatic.cpp +++ b/src/core/csr/measures/is_somatic.cpp @@ -1,25 +1,93 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "is_somatic.hpp" +#include +#include +#include + #include "io/variant/vcf_record.hpp" +#include "../facets/samples.hpp" +#include "../facets/ploidies.hpp" namespace octopus { namespace csr { +const std::string IsSomatic::name_ = "SOMATIC"; + +IsSomatic::IsSomatic(bool report_sample_status) : report_sample_status_ {report_sample_status} {} + std::unique_ptr IsSomatic::do_clone() const { return std::make_unique(*this); } +namespace { + +bool is_somatic_sample(const VcfRecord& call, const VcfRecord::SampleName& sample, const unsigned sample_ploidy) +{ + assert(is_somatic(call)); + const auto observed_sample_ploidy = call.ploidy(sample); + if (observed_sample_ploidy > sample_ploidy) { + return call.has_alt_allele(sample); + } else { + return false; + } +} + +} // namespace + Measure::ResultType IsSomatic::do_evaluate(const VcfRecord& call, const FacetMap& facets) const { - return is_somatic(call); + if (report_sample_status_) { + const auto& samples = get_value(facets.at("Samples")); + std::vector result(samples.size(), false); + if (is_somatic(call)) { + const auto& ploidies = get_value(facets.at("Ploidies")); + std::transform(std::cbegin(samples), std::cend(samples), std::begin(result), + [&] (const auto& sample) { return is_somatic_sample(call, sample, ploidies.at(sample)); }); + } + return result; + } else { + return is_somatic(call); + } +} + +Measure::ResultCardinality IsSomatic::do_cardinality() const noexcept +{ + if (report_sample_status_) { + return ResultCardinality::num_samples; + } else { + return ResultCardinality::one; + } +} + +const std::string& IsSomatic::do_name() const +{ + return name_; +} + +std::string IsSomatic::do_describe() const +{ + if (report_sample_status_) { + return "SOMATIC status of each sample"; + } else { + return "Is the call marked SOMATIC"; + } +} + +std::vector IsSomatic::do_requirements() const +{ + if (report_sample_status_) { + return {"Samples", "Ploidies"}; + } else { + return {}; + } } -std::string IsSomatic::do_name() const +bool IsSomatic::is_equal(const Measure& other) const noexcept { - return "SOMATIC"; + return report_sample_status_ == static_cast(other).report_sample_status_; } } // namespace csr diff --git a/src/core/csr/measures/is_somatic.hpp b/src/core/csr/measures/is_somatic.hpp index e60d339d7..66c67eed0 100644 --- a/src/core/csr/measures/is_somatic.hpp +++ b/src/core/csr/measures/is_somatic.hpp @@ -1,10 +1,11 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef is_somatic_hpp #define is_somatic_hpp #include +#include #include "measure.hpp" @@ -16,9 +17,17 @@ namespace csr { class IsSomatic : public Measure { + bool report_sample_status_; + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; +public: + IsSomatic(bool report_sample_status = true); }; } // namespace csr diff --git a/src/core/csr/measures/mapping_quality_divergence.cpp b/src/core/csr/measures/mapping_quality_divergence.cpp index 7697f6eff..1ea4ef8f9 100644 --- a/src/core/csr/measures/mapping_quality_divergence.cpp +++ b/src/core/csr/measures/mapping_quality_divergence.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "mapping_quality_divergence.hpp" @@ -17,6 +17,8 @@ namespace octopus { namespace csr { +const std::string MappingQualityDivergence::name_ = "MQD"; + std::unique_ptr MappingQualityDivergence::do_clone() const { return std::make_unique(*this); @@ -106,7 +108,7 @@ double calculate_max_pairwise_kl_divergence(const std::vector(facets.at("ReadAssignments")); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; boost::optional result {0}; for (const auto& p : assignments) { if (call.is_heterozygous(p.first)) { @@ -122,9 +124,19 @@ Measure::ResultType MappingQualityDivergence::do_evaluate(const VcfRecord& call, return result; } -std::string MappingQualityDivergence::do_name() const +Measure::ResultCardinality MappingQualityDivergence::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& MappingQualityDivergence::do_name() const +{ + return name_; +} + +std::string MappingQualityDivergence::do_describe() const { - return "MQD"; + return "Symmetric KL divergence of reads supporting the REF verses ALT alleles"; } std::vector MappingQualityDivergence::do_requirements() const diff --git a/src/core/csr/measures/mapping_quality_divergence.hpp b/src/core/csr/measures/mapping_quality_divergence.hpp index 843613c52..8f817be49 100644 --- a/src/core/csr/measures/mapping_quality_divergence.hpp +++ b/src/core/csr/measures/mapping_quality_divergence.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mapping_quality_divergence_hpp @@ -17,9 +17,12 @@ namespace csr { class MappingQualityDivergence : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; }; diff --git a/src/core/csr/measures/mapping_quality_zero_count.cpp b/src/core/csr/measures/mapping_quality_zero_count.cpp index 4dd2aa9ac..11590c0ea 100644 --- a/src/core/csr/measures/mapping_quality_zero_count.cpp +++ b/src/core/csr/measures/mapping_quality_zero_count.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "mapping_quality_zero_count.hpp" @@ -11,6 +11,8 @@ namespace octopus { namespace csr { +const std::string MappingQualityZeroCount::name_ = "MQ0"; + MappingQualityZeroCount::MappingQualityZeroCount(bool recalculate) : recalculate_ {recalculate} {} std::unique_ptr MappingQualityZeroCount::do_clone() const @@ -28,9 +30,19 @@ Measure::ResultType MappingQualityZeroCount::do_evaluate(const VcfRecord& call, } } -std::string MappingQualityZeroCount::do_name() const +Measure::ResultCardinality MappingQualityZeroCount::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& MappingQualityZeroCount::do_name() const +{ + return name_; +} + +std::string MappingQualityZeroCount::do_describe() const { - return "MQ0"; + return "Number of reads overlapping the call with mapping quality zero"; } std::vector MappingQualityZeroCount::do_requirements() const @@ -42,5 +54,10 @@ std::vector MappingQualityZeroCount::do_requirements() const } } +bool MappingQualityZeroCount::is_equal(const Measure& other) const noexcept +{ + return recalculate_ == static_cast(other).recalculate_; +} + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/mapping_quality_zero_count.hpp b/src/core/csr/measures/mapping_quality_zero_count.hpp index 01c84a24f..b5b49818b 100644 --- a/src/core/csr/measures/mapping_quality_zero_count.hpp +++ b/src/core/csr/measures/mapping_quality_zero_count.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mapping_quality_zero_count_hpp @@ -17,11 +17,15 @@ namespace csr { class MappingQualityZeroCount : public Measure { + const static std::string name_; bool recalculate_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; public: MappingQualityZeroCount(bool recalculate = true); }; diff --git a/src/core/csr/measures/max_genotype_quality.cpp b/src/core/csr/measures/max_genotype_quality.cpp deleted file mode 100644 index 2354bbf8e..000000000 --- a/src/core/csr/measures/max_genotype_quality.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2017 Daniel Cooke -// Use of this source code is governed by the MIT license that can be found in the LICENSE file. - -#include "max_genotype_quality.hpp" - -#include -#include - -#include -#include - -#include "io/variant/vcf_record.hpp" -#include "io/variant/vcf_spec.hpp" -#include "../facets/samples.hpp" - -namespace octopus { namespace csr { - -std::unique_ptr MaxGenotypeQuality::do_clone() const -{ - return std::make_unique(*this); -} - -Measure::ResultType MaxGenotypeQuality::do_evaluate(const VcfRecord& call, const FacetMap& facets) const -{ - const auto& samples = get_value(facets.at("Samples")); - boost::optional result {}; - for (const auto& sample : samples) { - static const std::string gq_field {vcfspec::format::conditionalQuality}; - if (call.has_format(gq_field)) { - const auto sample_gq = std::stod(call.get_sample_value(sample, gq_field).front()); - if (result) { - result = std::max(sample_gq, *result); - } else { - result = sample_gq; - } - } - } - return result; -} - -std::string MaxGenotypeQuality::do_name() const -{ - return "GQ"; -} - -std::vector MaxGenotypeQuality::do_requirements() const -{ - return {"Samples"}; -} - -} // namespace csr -} // namespace octopus diff --git a/src/core/csr/measures/mean_mapping_quality.cpp b/src/core/csr/measures/mean_mapping_quality.cpp index 47d84d35e..f4b755afe 100644 --- a/src/core/csr/measures/mean_mapping_quality.cpp +++ b/src/core/csr/measures/mean_mapping_quality.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "mean_mapping_quality.hpp" @@ -14,6 +14,8 @@ namespace octopus { namespace csr { +const std::string MeanMappingQuality::name_ = "MQ"; + MeanMappingQuality::MeanMappingQuality(bool recalculate) : recalculate_ {recalculate} {} std::unique_ptr MeanMappingQuality::do_clone() const @@ -32,9 +34,19 @@ Measure::ResultType MeanMappingQuality::do_evaluate(const VcfRecord& call, const } } -std::string MeanMappingQuality::do_name() const +Measure::ResultCardinality MeanMappingQuality::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& MeanMappingQuality::do_name() const +{ + return name_; +} + +std::string MeanMappingQuality::do_describe() const { - return "MQ"; + return "Mean mapping quality of reads overlapping the call"; } std::vector MeanMappingQuality::do_requirements() const @@ -46,5 +58,10 @@ std::vector MeanMappingQuality::do_requirements() const } } +bool MeanMappingQuality::is_equal(const Measure& other) const noexcept +{ + return recalculate_ == static_cast(other).recalculate_; +} + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/mean_mapping_quality.hpp b/src/core/csr/measures/mean_mapping_quality.hpp index 8a3372b5f..00b455c18 100644 --- a/src/core/csr/measures/mean_mapping_quality.hpp +++ b/src/core/csr/measures/mean_mapping_quality.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mean_mapping_quality_hpp @@ -17,11 +17,15 @@ namespace csr { class MeanMappingQuality : public Measure { + const static std::string name_; bool recalculate_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; public: MeanMappingQuality(bool recalculate = true); }; diff --git a/src/core/csr/measures/measure.cpp b/src/core/csr/measures/measure.cpp index fe4208bcb..875f5ad5f 100644 --- a/src/core/csr/measures/measure.cpp +++ b/src/core/csr/measures/measure.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "measure.hpp" @@ -6,9 +6,14 @@ #include #include #include +#include +#include +#include #include +#include "io/variant/vcf_spec.hpp" + namespace octopus { namespace csr { struct MeasureSerialiseVisitor : boost::static_visitor<> @@ -39,6 +44,23 @@ struct MeasureSerialiseVisitor : boost::static_visitor<> str = "."; } } + template + void operator()(const std::vector& values) + { + if (values.empty()) { + str = "."; + } else { + auto tmp_str = std::move(str); + std::for_each(std::cbegin(values), std::prev(std::cend(values)), [&] (const auto& value) { + (*this)(value); + tmp_str += str; + tmp_str += ','; + }); + (*this)(values.back()); + tmp_str += str; + str = std::move(tmp_str); + } + } }; std::string Measure::do_serialise(const ResultType& value) const @@ -48,5 +70,67 @@ std::string Measure::do_serialise(const ResultType& value) const return vis.str; } +void Measure::annotate(VcfHeader::Builder& header) const +{ + if (!is_required_vcf_field()) { + std::string number; + using namespace vcfspec::header::meta::number; + switch (this->cardinality()) { + case Measure::ResultCardinality::num_samples: number = unknown; break; + case Measure::ResultCardinality::num_alleles: number = per_allele; break; + case Measure::ResultCardinality::one: number = "1"; break; + } + header.add_info(this->name(), number, "String", this->describe()); + } +} + +void Measure::annotate(VcfRecord::Builder& record, const ResultType& value) const +{ + if (!is_required_vcf_field()) { + record.set_info(this->name(), this->serialise(value)); + } +} + +// non-member methods + +struct IsMissingMeasureVisitor : public boost::static_visitor +{ + template bool operator()(const boost::optional& value) const noexcept { return !value; } + template bool operator()(const T& value) const noexcept { return false; } +}; + +bool is_missing(const Measure::ResultType& value) noexcept +{ + return boost::apply_visitor(IsMissingMeasureVisitor {}, value); +} + +struct VectorIndexGetterVisitor : public boost::static_visitor +{ + VectorIndexGetterVisitor(std::size_t idx) : idx_ {idx} {} + template T operator()(const std::vector& value) const noexcept { return value[idx_]; } + template T operator()(const T& value) const noexcept { return value; } +private: + std::size_t idx_; +}; + +Measure::ResultType get_sample_value(const Measure::ResultType& value, const MeasureWrapper& measure, const std::size_t sample_idx) +{ + if (measure.cardinality() == Measure::ResultCardinality::num_samples) { + return boost::apply_visitor(VectorIndexGetterVisitor {sample_idx}, value); + } else { + return value; + } +} + +std::vector +get_sample_values(const std::vector& values, const std::vector& measures, std::size_t sample_idx) +{ + assert(values.size() == measures.size()); + std::vector result(values.size()); + std::transform(std::cbegin(values), std::cend(values), std::cbegin(measures), std::begin(result), + [&] (const auto& value, const auto& measure) { return get_sample_value(value, measure, sample_idx); }); + return result; +} + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/measure.hpp b/src/core/csr/measures/measure.hpp index a1c33e666..6eb1b7ea2 100644 --- a/src/core/csr/measures/measure.hpp +++ b/src/core/csr/measures/measure.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef measure_hpp @@ -14,22 +14,26 @@ #include #include +#include "concepts/equitable.hpp" +#include "io/variant/vcf_header.hpp" +#include "io/variant/vcf_record.hpp" #include "../facets/facet.hpp" -namespace octopus { +namespace octopus { namespace csr { -class VcfRecord; - -namespace csr { - -class Measure +class Measure : public Equitable { public: using FacetMap = std::unordered_map; - using ResultType = boost::variant, - std::size_t, boost::optional, - bool, + using ResultType = boost::variant, + boost::optional, std::vector>, + int, std::vector, + boost::optional, std::vector>, + std::size_t, std::vector, + boost::optional, std::vector>, + bool, std::vector, boost::any>; + enum class ResultCardinality { one, num_alleles, num_samples }; Measure() = default; @@ -43,19 +47,31 @@ class Measure std::unique_ptr clone() const { return do_clone(); } ResultType evaluate(const VcfRecord& call, const FacetMap& facets) const { return do_evaluate(call, facets); } - std::string name() const { return do_name(); } + ResultCardinality cardinality() const noexcept { return do_cardinality(); } + const std::string& name() const { return do_name(); } + std::string describe() const { return do_describe(); } std::vector requirements() const { return do_requirements(); } std::string serialise(const ResultType& value) const { return do_serialise(value); } + void annotate(VcfHeader::Builder& header) const; + void annotate(VcfRecord::Builder& record, const ResultType& value) const; + friend bool operator==(const Measure& lhs, const Measure& rhs) noexcept + { + return lhs.name() == rhs.name() && lhs.is_equal(rhs); + } private: virtual std::unique_ptr do_clone() const = 0; virtual ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const = 0; - virtual std::string do_name() const = 0; + virtual ResultCardinality do_cardinality() const noexcept = 0; + virtual const std::string& do_name() const = 0; + virtual std::string do_describe() const = 0; virtual std::vector do_requirements() const { return {}; } virtual std::string do_serialise(const ResultType& value) const; + virtual bool is_required_vcf_field() const noexcept { return false; } + virtual bool is_equal(const Measure& other) const noexcept { return true; } }; -class MeasureWrapper +class MeasureWrapper : public Equitable { public: MeasureWrapper() = delete; @@ -76,14 +92,23 @@ class MeasureWrapper const Measure* base() const noexcept { return measure_.get(); } auto operator()(const VcfRecord& call) const { return measure_->evaluate(call, {}); } auto operator()(const VcfRecord& call, const Measure::FacetMap& facets) const { return measure_->evaluate(call, facets); } - std::string name() const { return measure_->name(); } + Measure::ResultCardinality cardinality() const noexcept { return measure_->cardinality(); } + const std::string& name() const { return measure_->name(); } + std::string describe() const { return measure_->describe(); } std::vector requirements() const { return measure_->requirements(); } std::string serialise(const Measure::ResultType& value) const { return measure_->serialise(value); } + void annotate(VcfHeader::Builder& header) const { measure_->annotate(header); } + void annotate(VcfRecord::Builder& record, const Measure::ResultType& value) const { measure_->annotate(record, value); } private: std::unique_ptr measure_; }; +inline bool operator==(const MeasureWrapper& lhs, const MeasureWrapper& rhs) noexcept +{ + return *lhs.base() == *rhs.base(); +} + template MeasureWrapper make_wrapped_measure(Args&&... args) { @@ -91,27 +116,31 @@ MeasureWrapper make_wrapped_measure(Args&&... args) } template -std::string name() +const std::string& name() { return Measure().name(); } -namespace detail { +bool is_missing(const Measure::ResultType& value) noexcept; -struct IsMissingMeasureVisitor : public boost::static_visitor -{ - template bool operator()(const boost::optional& value) const noexcept { return !value; } - template bool operator()(const T& value) const noexcept { return false; } -}; +Measure::ResultType get_sample_value(const Measure::ResultType& value, const MeasureWrapper& measure, std::size_t sample_idx); +std::vector get_sample_values(const std::vector& values, + const std::vector& measures, + std::size_t sample_idx); -} // namespace detail +} // namespace csr +} // namespace octopus + +namespace std { -inline bool is_missing(const Measure::ResultType& value) noexcept +template <> struct hash { - return boost::apply_visitor(detail::IsMissingMeasureVisitor {}, value); -} + size_t operator()(const octopus::csr::MeasureWrapper& measure) const + { + return hash{}(measure.name()); + } +}; -} // namespace csr -} // namespace octopus +} // namespace std #endif diff --git a/src/core/csr/measures/measure_factory.cpp b/src/core/csr/measures/measure_factory.cpp index a66a4307f..bb1b41720 100644 --- a/src/core/csr/measures/measure_factory.cpp +++ b/src/core/csr/measures/measure_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "measure_factory.hpp" @@ -8,6 +8,7 @@ #include "measures_fwd.hpp" #include "exceptions/user_error.hpp" +#include "utils/map_utils.hpp" namespace octopus { namespace csr { @@ -23,15 +24,25 @@ void init(MeasureMakerMap& measure_makers) measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; - measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; measure_makers[name()] = [] () { return make_wrapped_measure(); }; - measure_makers[name()] = [] () { return make_wrapped_measure(); }; - measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; + measure_makers[name()] = [] () { return make_wrapped_measure(); }; } class UnknownMeasure : public UserError @@ -62,5 +73,14 @@ MeasureWrapper make_measure(const std::string& name) return measure_makers.at(name)(); } +std::vector get_all_measure_names() +{ + static MeasureMakerMap measure_makers {}; + if (measure_makers.empty()) { + init(measure_makers); + } + return extract_sorted_keys(measure_makers); +} + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/measure_factory.hpp b/src/core/csr/measures/measure_factory.hpp index 69f9006b3..66f6c6168 100644 --- a/src/core/csr/measures/measure_factory.hpp +++ b/src/core/csr/measures/measure_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef measure_factory_hpp @@ -12,6 +12,8 @@ namespace octopus { namespace csr { MeasureWrapper make_measure(const std::string& name); +std::vector get_all_measure_names(); + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/measures_fwd.hpp b/src/core/csr/measures/measures_fwd.hpp index 12323fb67..47bba5f46 100644 --- a/src/core/csr/measures/measures_fwd.hpp +++ b/src/core/csr/measures/measures_fwd.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef measures_fwd_hpp @@ -13,14 +13,24 @@ #include "core/csr/measures/model_posterior.hpp" #include "core/csr/measures/quality.hpp" #include "core/csr/measures/quality_by_depth.hpp" -#include "core/csr/measures/max_genotype_quality.hpp" +#include "core/csr/measures/genotype_quality.hpp" #include "core/csr/measures/strand_bias.hpp" #include "core/csr/measures/gc_content.hpp" #include "core/csr/measures/filtered_read_fraction.hpp" #include "core/csr/measures/clipped_read_fraction.hpp" #include "core/csr/measures/is_denovo.hpp" #include "core/csr/measures/is_somatic.hpp" -#include "core/csr/measures/unassigned_read_fraction.hpp" -#include "core/csr/measures/realignments.hpp" +#include "core/csr/measures/ambiguous_read_fraction.hpp" +#include "core/csr/measures/median_base_quality.hpp" +#include "core/csr/measures/mismatch_count.hpp" +#include "core/csr/measures/mismatch_fraction.hpp" +#include "core/csr/measures/is_refcall.hpp" +#include "core/csr/measures/somatic_contamination.hpp" +#include "core/csr/measures/denovo_contamination.hpp" +#include "core/csr/measures/read_position_bias.hpp" +#include "core/csr/measures/alt_allele_count.hpp" +#include "core/csr/measures/overlaps_tandem_repeat.hpp" +#include "core/csr/measures/str_length.hpp" +#include "core/csr/measures/str_period.hpp" #endif diff --git a/src/core/csr/measures/median_base_quality.cpp b/src/core/csr/measures/median_base_quality.cpp new file mode 100644 index 000000000..9825b5074 --- /dev/null +++ b/src/core/csr/measures/median_base_quality.cpp @@ -0,0 +1,124 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "median_base_quality.hpp" + +#include +#include + +#include +#include + +#include "core/tools/read_assigner.hpp" +#include "core/types/allele.hpp" +#include "io/variant/vcf_record.hpp" +#include "io/variant/vcf_spec.hpp" +#include "utils/genotype_reader.hpp" +#include "utils/maths.hpp" +#include "utils/append.hpp" +#include "../facets/samples.hpp" +#include "../facets/read_assignments.hpp" + +namespace octopus { namespace csr { + +const std::string MedianBaseQuality::name_ = "BQ"; + +std::unique_ptr MedianBaseQuality::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +bool is_canonical(const VcfRecord::NucleotideSequence& allele) noexcept +{ + const static VcfRecord::NucleotideSequence deleted_allele {vcfspec::deletedBase}; + return !(allele == vcfspec::missingValue || allele == deleted_allele); +} + +bool has_called_alt_allele(const VcfRecord& call, const VcfRecord::SampleName& sample) +{ + if (!call.has_genotypes()) return true; + const auto& genotype = get_genotype(call, sample); + return std::any_of(std::cbegin(genotype), std::cend(genotype), + [&] (const auto& allele) { return allele != call.ref() && is_canonical(allele); }); +} + +bool is_evaluable(const VcfRecord& call, const VcfRecord::SampleName& sample) +{ + return has_called_alt_allele(call, sample); +} + +auto median_base_quality(const ReadRefSupportSet& reads, const Allele& allele) +{ + boost::optional result {}; + if (!is_indel(allele)) { + std::vector base_qualities {}; + base_qualities.reserve(reads.size() * sequence_size(allele)); + for (const auto& read : reads) { + if (overlaps(read.get(), allele)) { + utils::append(copy_base_qualities(read, mapped_region(allele)), base_qualities); + } + } + if (!base_qualities.empty()) { + result = maths::median(base_qualities); + } + } + return result; +} + +} // namespace + +Measure::ResultType MedianBaseQuality::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& samples = get_value(facets.at("Samples")); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; + std::vector> result {}; + result.reserve(call.num_alt()); + for (const auto& sample : samples) { + boost::optional sample_result {}; + if (is_evaluable(call, sample)) { + std::vector alleles; bool has_ref; + std::tie(alleles, has_ref) = get_called_alleles(call, sample, true); + if (has_ref) alleles.erase(std::cbegin(alleles)); + if (!alleles.empty()) { + const auto sample_allele_support = compute_allele_support(alleles, assignments.at(sample)); + for (const auto& allele : alleles) { + const auto median_bq = median_base_quality(sample_allele_support.at(allele), allele); + if (median_bq) { + if (sample_result) { + sample_result = std::min(*sample_result, static_cast(*median_bq)); + } else { + sample_result = *median_bq; + } + } + } + } + } + result.push_back(sample_result); + } + return result; +} + +Measure::ResultCardinality MedianBaseQuality::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& MedianBaseQuality::do_name() const +{ + return name_; +} + +std::string MedianBaseQuality::do_describe() const +{ + return "Median base quality of reads supporting each allele"; +} + +std::vector MedianBaseQuality::do_requirements() const +{ + return {"Samples", "ReadAssignments"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/median_base_quality.hpp b/src/core/csr/measures/median_base_quality.hpp new file mode 100644 index 000000000..90d7f3b79 --- /dev/null +++ b/src/core/csr/measures/median_base_quality.hpp @@ -0,0 +1,32 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef median_base_quality_hpp +#define median_base_quality_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class MedianBaseQuality : public Measure +{ + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/mismatch_count.cpp b/src/core/csr/measures/mismatch_count.cpp new file mode 100644 index 000000000..ca2aef90d --- /dev/null +++ b/src/core/csr/measures/mismatch_count.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "mismatch_count.hpp" + +#include +#include +#include + +#include "io/variant/vcf_record.hpp" +#include "basics/aligned_read.hpp" +#include "core/types/allele.hpp" +#include "utils/genotype_reader.hpp" +#include "../facets/samples.hpp" +#include "../facets/read_assignments.hpp" + +namespace octopus { namespace csr { + +const std::string MismatchCount::name_ = "MC"; + +std::unique_ptr MismatchCount::do_clone() const +{ + return std::make_unique(*this); +} + +bool mismatches(const AlignedRead& read, const Allele& allele) +{ + if (!overlaps(read, allele)) return false; + const auto read_section = copy_sequence(read, mapped_region(allele)); + if (contains(read, allele)) { + return read_section != allele.sequence(); + } else { + if (read_section.size() > sequence_size(allele)) return true; + if (begins_before(read, allele)) { + return !std::equal(std::cbegin(read_section), std::cend(read_section), std::cbegin(allele.sequence())); + } else { + return !std::equal(std::crbegin(read_section), std::crend(read_section), std::crbegin(allele.sequence())); + } + } +} + +Measure::ResultType MismatchCount::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& samples = get_value(facets.at("Samples")); + const auto& assignments = get_value(facets.at("ReadAssignments")); + std::vector result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + std::vector alleles; bool has_ref; + std::tie(alleles, has_ref) = get_called_alleles(call, sample, true); + int sample_result {0}; + if (alleles.empty()) { + const auto sample_allele_support = compute_allele_support(alleles, assignments.support.at(sample)); + for (const auto& p : sample_allele_support) { + for (const auto& read : p.second) { + sample_result += mismatches(read, p.first); + } + } + } + result.push_back(sample_result); + } + return result; +} + +Measure::ResultCardinality MismatchCount::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& MismatchCount::do_name() const +{ + return name_; +} + +std::string MismatchCount::do_describe() const +{ + return "Number of mismatches at variant position in reads supporting variant haplotype"; +} + +std::vector MismatchCount::do_requirements() const +{ + return {"Samples", "ReadAssignments"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/realignments.hpp b/src/core/csr/measures/mismatch_count.hpp similarity index 58% rename from src/core/csr/measures/realignments.hpp rename to src/core/csr/measures/mismatch_count.hpp index cf636b588..05bf85b8f 100644 --- a/src/core/csr/measures/realignments.hpp +++ b/src/core/csr/measures/mismatch_count.hpp @@ -1,8 +1,8 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. -#ifndef realignments_hpp -#define realignments_hpp +#ifndef mismatch_count_hpp +#define mismatch_count_hpp #include #include @@ -15,16 +15,18 @@ class VcfRecord; namespace csr { -class Realignments : public Measure +class MismatchCount : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; - std::string do_serialise(const ResultType& value) const override; }; } // namespace csr } // namespace octopus -#endif \ No newline at end of file +#endif diff --git a/src/core/csr/measures/mismatch_fraction.cpp b/src/core/csr/measures/mismatch_fraction.cpp new file mode 100644 index 000000000..6b9ccfdac --- /dev/null +++ b/src/core/csr/measures/mismatch_fraction.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "mismatch_fraction.hpp" + +#include +#include +#include + +#include + +#include "io/variant/vcf_record.hpp" +#include "utils/concat.hpp" + +namespace octopus { namespace csr { + +const std::string MismatchFraction::name_ = "MF"; + +MismatchFraction::MismatchFraction() +: mismatch_count_ {} +, depth_ {true, false} +{} + +std::unique_ptr MismatchFraction::do_clone() const +{ + return std::make_unique(*this); +} + +Measure::ResultType MismatchFraction::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto depths = boost::get>(depth_.evaluate(call, facets)); + const auto mismatch_counts = boost::get>(mismatch_count_.evaluate(call, facets)); + assert(depths.size() == mismatch_counts.size()); + std::vector result(depths.size()); + std::transform(std::cbegin(mismatch_counts), std::cend(mismatch_counts), std::cbegin(depths), std::begin(result), + [] (auto mismatches, auto depth) { return depth > 0 ? static_cast(mismatches) / depth : 0.0; }); + return result; +} + +Measure::ResultCardinality MismatchFraction::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& MismatchFraction::do_name() const +{ + return name_; +} + +std::string MismatchFraction::do_describe() const +{ + return "Fraction of reads with mismatches at variant position"; +} + +std::vector MismatchFraction::do_requirements() const +{ + return concat(depth_.requirements(), mismatch_count_.requirements()); +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/mismatch_fraction.hpp b/src/core/csr/measures/mismatch_fraction.hpp new file mode 100644 index 000000000..6bc7a4848 --- /dev/null +++ b/src/core/csr/measures/mismatch_fraction.hpp @@ -0,0 +1,38 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef mismatch_fraction_hpp +#define mismatch_fraction_hpp + +#include +#include + +#include "measure.hpp" +#include "mismatch_count.hpp" +#include "depth.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class MismatchFraction : public Measure +{ + MismatchCount mismatch_count_; + Depth depth_; + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +public: + MismatchFraction(); +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/model_posterior.cpp b/src/core/csr/measures/model_posterior.cpp index f8a550ac6..9337cdbe5 100644 --- a/src/core/csr/measures/model_posterior.cpp +++ b/src/core/csr/measures/model_posterior.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "model_posterior.hpp" @@ -8,6 +8,8 @@ namespace octopus { namespace csr { +const std::string ModelPosterior::name_ = "MP"; + std::unique_ptr ModelPosterior::do_clone() const { return std::make_unique(*this); @@ -23,9 +25,19 @@ Measure::ResultType ModelPosterior::do_evaluate(const VcfRecord& call, const Fac return result; } -std::string ModelPosterior::do_name() const +Measure::ResultCardinality ModelPosterior::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& ModelPosterior::do_name() const +{ + return name_; +} + +std::string ModelPosterior::do_describe() const { - return "MP"; + return "Model posterior for this haplotype block"; } } // namespace csr diff --git a/src/core/csr/measures/model_posterior.hpp b/src/core/csr/measures/model_posterior.hpp index bc141fb3c..abba68469 100644 --- a/src/core/csr/measures/model_posterior.hpp +++ b/src/core/csr/measures/model_posterior.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef model_posterior_hpp @@ -16,9 +16,12 @@ namespace csr { class ModelPosterior : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; }; } // namespace csr diff --git a/src/core/csr/measures/overlaps_tandem_repeat.cpp b/src/core/csr/measures/overlaps_tandem_repeat.cpp new file mode 100644 index 000000000..614c9c9b5 --- /dev/null +++ b/src/core/csr/measures/overlaps_tandem_repeat.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "overlaps_tandem_repeat.hpp" + +#include + +#include "io/variant/vcf_record.hpp" +#include "utils/repeat_finder.hpp" +#include "../facets/reference_context.hpp" + +namespace octopus { namespace csr { + +const std::string OverlapsTandemRepeat::name_ = "STRC"; + +std::unique_ptr OverlapsTandemRepeat::do_clone() const +{ + return std::make_unique(*this); +} + +Measure::ResultType OverlapsTandemRepeat::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& reference = get_value(facets.at("ReferenceContext")); + const auto repeats = find_exact_tandem_repeats(reference.sequence(), reference.mapped_region(), 1, 6); + return has_overlapped(repeats, call); +} + +Measure::ResultCardinality OverlapsTandemRepeat::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& OverlapsTandemRepeat::do_name() const +{ + return name_; +} + +std::string OverlapsTandemRepeat::do_describe() const +{ + return "Is the variant in a tandem repeat"; +} + +std::vector OverlapsTandemRepeat::do_requirements() const +{ + return {"ReferenceContext"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/overlaps_tandem_repeat.hpp b/src/core/csr/measures/overlaps_tandem_repeat.hpp new file mode 100644 index 000000000..ab8f17dfc --- /dev/null +++ b/src/core/csr/measures/overlaps_tandem_repeat.hpp @@ -0,0 +1,32 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef overlaps_tandem_repeat_hpp +#define overlaps_tandem_repeat_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class OverlapsTandemRepeat : public Measure +{ + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/quality.cpp b/src/core/csr/measures/quality.cpp index 73662a1e1..c5f997e8b 100644 --- a/src/core/csr/measures/quality.cpp +++ b/src/core/csr/measures/quality.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "quality.hpp" @@ -7,6 +7,8 @@ namespace octopus { namespace csr { +const std::string Quality::name_ = "QUAL"; + std::unique_ptr Quality::do_clone() const { return std::make_unique(*this); @@ -21,9 +23,19 @@ Measure::ResultType Quality::do_evaluate(const VcfRecord& call, const FacetMap& return result; } -std::string Quality::do_name() const +Measure::ResultCardinality Quality::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& Quality::do_name() const +{ + return name_; +} + +std::string Quality::do_describe() const { - return "QUAL"; + return "Call QUAL"; } } // namespace csr diff --git a/src/core/csr/measures/quality.hpp b/src/core/csr/measures/quality.hpp index fd7fde0dd..4c076b730 100644 --- a/src/core/csr/measures/quality.hpp +++ b/src/core/csr/measures/quality.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef quality_hpp @@ -16,9 +16,13 @@ namespace csr { class Quality : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + bool is_required_vcf_field() const noexcept override { return true; } }; } // namespace csr diff --git a/src/core/csr/measures/quality_by_depth.cpp b/src/core/csr/measures/quality_by_depth.cpp index 4bc9e2e3a..e3e9468eb 100644 --- a/src/core/csr/measures/quality_by_depth.cpp +++ b/src/core/csr/measures/quality_by_depth.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "quality_by_depth.hpp" @@ -10,7 +10,9 @@ namespace octopus { namespace csr { -QualityByDepth::QualityByDepth(bool recalculate) : depth_ {recalculate} {} +const std::string QualityByDepth::name_ = "QD"; + +QualityByDepth::QualityByDepth(bool recalculate) : depth_ {recalculate, true} {} std::unique_ptr QualityByDepth::do_clone() const { @@ -27,9 +29,19 @@ Measure::ResultType QualityByDepth::do_evaluate(const VcfRecord& call, const Fac return result; } -std::string QualityByDepth::do_name() const +Measure::ResultCardinality QualityByDepth::do_cardinality() const noexcept +{ + return depth_.cardinality(); +} + +const std::string& QualityByDepth::do_name() const +{ + return name_; +} + +std::string QualityByDepth::do_describe() const { - return "QD"; + return "QUAL divided by DP"; } std::vector QualityByDepth::do_requirements() const @@ -37,5 +49,10 @@ std::vector QualityByDepth::do_requirements() const return depth_.requirements(); } +bool QualityByDepth::is_equal(const Measure& other) const noexcept +{ + return depth_ == static_cast(other); +} + } // namespace csr } // namespace octopus diff --git a/src/core/csr/measures/quality_by_depth.hpp b/src/core/csr/measures/quality_by_depth.hpp index dc710b021..f17c2de05 100644 --- a/src/core/csr/measures/quality_by_depth.hpp +++ b/src/core/csr/measures/quality_by_depth.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef quality_by_depth_hpp @@ -18,11 +18,15 @@ namespace csr { class QualityByDepth : public Measure { + const static std::string name_; Depth depth_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; public: QualityByDepth(bool recalculate = false); }; diff --git a/src/core/csr/measures/read_position_bias.cpp b/src/core/csr/measures/read_position_bias.cpp new file mode 100644 index 000000000..79e97d320 --- /dev/null +++ b/src/core/csr/measures/read_position_bias.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "read_position_bias.hpp" + +#include +#include +#include +#include +#include +#include + +#include + +#include "io/variant/vcf_record.hpp" +#include "basics/aligned_read.hpp" +#include "core/types/allele.hpp" +#include "utils/maths.hpp" +#include "utils/genotype_reader.hpp" +#include "../facets/samples.hpp" +#include "../facets/read_assignments.hpp" + +namespace octopus { namespace csr { + +const std::string ReadPositionBias::name_ = "RPB"; + +std::unique_ptr ReadPositionBias::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +struct PositionCounts +{ + unsigned head, tail; +}; + +bool overlaps_rhs(const Allele& allele, const AlignedRead& read) +{ + using D = GenomicRegion::Distance; + return overlaps(allele, expand_lhs(mapped_region(read), -static_cast(region_size(read) / 2))); +} + +void update_counts(const Allele& allele, const AlignedRead& read, PositionCounts& counts) +{ + if (overlaps(allele, read)) { + if (region_size(allele) >= region_size(read) / 2) { + ++counts.head; + ++counts.tail; + } else if (overlaps_rhs(allele, read)) { + if (is_forward_strand(read)) { + ++counts.tail; + } else { + ++counts.head; + } + } else { + if (is_forward_strand(read)) { + ++counts.head; + } else { + ++counts.tail; + } + } + } +} + +auto compute_allele_positions(const Allele& allele, const ReadRefSupportSet& support) +{ + PositionCounts forward_counts {}, reverse_counts {}; + for (const auto& read : support) { + if (is_forward_strand(read)) { + update_counts(allele, read, forward_counts); + } else { + update_counts(allele, read, reverse_counts); + } + } + return std::make_pair(forward_counts, reverse_counts); +} + +double calculate_position_bias(const PositionCounts forward_counts, const PositionCounts reverse_counts, + const double tolerance = 0.5) +{ + assert(tolerance > 0.0 && tolerance < 1.0); + const auto num_lhs = forward_counts.head + reverse_counts.tail; + const auto num_rhs = forward_counts.tail + reverse_counts.head; + const auto prob_lhs_biased = maths::beta_sf(static_cast(num_lhs + 1), static_cast(num_rhs + 1), 0.5 + tolerance / 2); + const auto prob_rhs_biased = maths::beta_cdf(static_cast(num_lhs + 1), static_cast(num_rhs + 1), 0.5 - tolerance / 2); + return prob_lhs_biased + prob_rhs_biased; +} + +double calculate_position_bias(const Allele& allele, const ReadRefSupportSet& support) +{ + PositionCounts forward_counts, reverse_counts; + std::tie(forward_counts, reverse_counts) = compute_allele_positions(allele, support); + return calculate_position_bias(forward_counts, reverse_counts); +} + +double calculate_position_bias(const AlleleSupportMap& support) +{ + double result {0}; + for (const auto& p : support) { + auto bias = calculate_position_bias(p.first, p.second); + result = std::max(result, bias); + } + return result; +} + +} // namespace + +Measure::ResultType ReadPositionBias::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + const auto& samples = get_value(facets.at("Samples")); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; + std::vector result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + std::vector alleles; bool has_ref; + std::tie(alleles, has_ref) = get_called_alleles(call, sample, true); + if (!alleles.empty()) { + const auto allele_support = compute_allele_support(alleles, assignments.at(sample)); + auto position_bias = calculate_position_bias(allele_support); + result.push_back(position_bias); + } else { + result.push_back(0); + } + } + return result; +} + +Measure::ResultCardinality ReadPositionBias::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& ReadPositionBias::do_name() const +{ + return name_; +} + +std::string ReadPositionBias::do_describe() const +{ + return "Bias of variant position in supporting reads"; +} + +std::vector ReadPositionBias::do_requirements() const +{ + return {"Samples", "ReadAssignments"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/read_position_bias.hpp b/src/core/csr/measures/read_position_bias.hpp new file mode 100644 index 000000000..3fb20e27e --- /dev/null +++ b/src/core/csr/measures/read_position_bias.hpp @@ -0,0 +1,32 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef read_position_bias_hpp +#define read_position_bias_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class ReadPositionBias : public Measure +{ + const static std::string name_; + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/realignments.cpp b/src/core/csr/measures/realignments.cpp deleted file mode 100644 index 16ceb8c14..000000000 --- a/src/core/csr/measures/realignments.cpp +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright (c) 2017 Daniel Cooke -// Use of this source code is governed by the MIT license that can be found in the LICENSE file. - -#include "realignments.hpp" - -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "basics/aligned_read.hpp" -#include "core/types/allele.hpp" -#include "core/tools/read_assigner.hpp" -#include "core/tools/read_realigner.hpp" -#include "io/variant/vcf_record.hpp" -#include "io/variant/vcf_spec.hpp" -#include "utils/append.hpp" -#include "utils/genotype_reader.hpp" -#include "../facets/overlapping_reads.hpp" -#include "../facets/read_assignments.hpp" -#include "../facets/reference_context.hpp" - -namespace octopus { namespace csr { - -std::unique_ptr Realignments::do_clone() const -{ - return std::make_unique(*this); -} - -namespace { - -using ReadRealignments = std::vector>; - -bool is_forward(const AlignedRead& read) noexcept -{ - return read.direction() == AlignedRead::Direction::forward; -} - -struct DirectionCounts -{ - unsigned forward, reverse; -}; - -DirectionCounts count_directions(const std::vector& reads) -{ - auto n_fwd = static_cast(std::count_if(std::cbegin(reads), std::cend(reads), - [] (const auto& read) { return is_forward(read); })); - return {n_fwd, static_cast(reads.size()) - n_fwd}; -} - -using MQHistogram = std::array; - -MQHistogram compute_mq_hist(const std::vector& reads) -{ - MQHistogram result {}; - for (const auto& read : reads) { - ++result[std::min(read.mapping_quality(), AlignedRead::MappingQuality {59}) / 10]; - } - return result; -} - -struct Pileup -{ - std::array match_counts, mismatch_counts; - unsigned insertion_count, deletion_count; - unsigned match_quality_sum, mismatch_quality_sum, insertion_quality_sum; -}; - -constexpr std::size_t window_size {61}; - -using PileupWindow = std::array; - -PileupWindow make_pileup(const std::vector& reads, const GenomicRegion& region) -{ - assert(size(region) == window_size); - PileupWindow result {}; - for (const auto& read : reads) { - if (overlaps(read, region)) { - const auto read_fragment = copy(read, region); - const auto expanded_cigar = decompose(read_fragment.cigar()); - const auto fragment_offset = static_cast(begin_distance(region, read_fragment)); - unsigned read_pos {0}, pileup_pos {fragment_offset}; - for (const auto flag : expanded_cigar) { - if (pileup_pos >= result.size() || read_pos >= sequence_size(read_fragment)) break; - switch (flag) { - assert(pileup_pos < result.size()); - using Flag = CigarOperation::Flag; - case Flag::sequenceMatch: - { - assert(read_pos < sequence_size(read_fragment)); - switch (read_fragment.sequence()[read_pos]) { - case 'A': ++result[pileup_pos].match_counts[0]; break; - case 'C': ++result[pileup_pos].match_counts[1]; break; - case 'G': ++result[pileup_pos].match_counts[2]; break; - case 'T': ++result[pileup_pos].match_counts[3]; break; - } - result[pileup_pos].match_quality_sum += read_fragment.base_qualities()[read_pos]; - ++pileup_pos; ++read_pos; - break; - } - case Flag::substitution: - { - assert(read_pos < sequence_size(read_fragment)); - switch (read_fragment.sequence()[read_pos]) { - case 'A': ++result[pileup_pos].mismatch_counts[0]; break; - case 'C': ++result[pileup_pos].mismatch_counts[1]; break; - case 'G': ++result[pileup_pos].mismatch_counts[2]; break; - case 'T': ++result[pileup_pos].mismatch_counts[3]; break; - } - result[pileup_pos].mismatch_quality_sum += read_fragment.base_qualities()[read_pos]; - ++pileup_pos; ++read_pos; - break; - } - case Flag::insertion: - { - ++result[pileup_pos].insertion_count; - result[pileup_pos].insertion_quality_sum += read_fragment.base_qualities()[read_pos]; - ++read_pos; - break; - } - case Flag::deletion: - { - ++result[pileup_pos].deletion_count; - ++pileup_pos; - break; - } - default: continue; - } - } - } - } - return result; -} - -struct HaplotypeSummary -{ - DirectionCounts strand_counts; - MQHistogram mq_hist; - PileupWindow pileups; -}; - -using RealignmentSummary = std::vector; - -auto compute_realignment_summary(const ReadRealignments& realignments, const GenomicRegion& region) -{ - RealignmentSummary result {}; - result.reserve(realignments.size()); - for (const auto& reads : realignments) { - result.push_back({count_directions(reads), compute_mq_hist(reads), make_pileup(reads, region)}); - } - return result; -} - -bool is_padded(const VcfRecord& call) noexcept -{ - const auto& ref = call.ref(); - if (ref.empty()) return false; - return std::any_of(std::cbegin(call.alt()), std::cend(call.alt()), - [&] (const auto& alt) { return !alt.empty() && alt.front() == ref.front(); }); -} - -auto compute_realignment_summary(const ReadRealignments& realignments, const VcfRecord& call) -{ - auto call_position = head_region(call, 1); - if (is_padded(call)) { - call_position = shift(call_position, 1); - } - return compute_realignment_summary(realignments, expand(head_position(call), window_size / 2)); -} - -} // namespace - -Measure::ResultType Realignments::do_evaluate(const VcfRecord& call, const FacetMap& facets) const -{ - const auto& assignments = get_value(facets.at("ReadAssignments")); - assert(assignments.size() == 1); - const auto& sample = assignments.cbegin()->first; - const auto& support = assignments.cbegin()->second; - std::vector alleles; bool has_ref; - std::tie(alleles, has_ref) = get_called_alleles(call, sample, true); - ReadRealignments result {}; - result.reserve(alleles.size() + 1); - std::vector assigned_reads {}; - for (const auto& allele : alleles) { - std::vector allele_realignments {}; - for (const auto& h : support) { - const auto& haplotype = h.first; - const auto& supporting_reads = h.second; - if (!supporting_reads.empty() && haplotype.contains(allele)) { - auto realignments = safe_realign(supporting_reads, haplotype); - for (std::size_t i {0}; i < realignments.size(); ++i) { - const auto& original_read = supporting_reads[i]; - const auto& realigned_read = realignments[i]; - if (overlaps(realigned_read, allele)) { - assigned_reads.push_back(original_read); - allele_realignments.push_back(realigned_read); - } - } - } - } - result.push_back(std::move(allele_realignments)); - } - const auto& reads = get_value(facets.at("OverlappingReads")); - const auto overlapping_reads = overlap_range(reads.at(sample), call); - std::vector unassigned_reads {}; - if (assigned_reads.size() < size(overlapping_reads)) { - std::sort(std::begin(assigned_reads), std::end(assigned_reads)); - unassigned_reads.reserve(size(overlapping_reads) - assigned_reads.size()); - std::set_difference(std::cbegin(overlapping_reads), std::cend(overlapping_reads), - std::cbegin(assigned_reads), std::cend(assigned_reads), - std::back_inserter(unassigned_reads)); - } - if (!unassigned_reads.empty()) { - const auto reference = get_value(facets.at("ReferenceContext")); - auto unassigned_realignments = safe_realign(unassigned_reads, reference); - result.push_back(std::move(unassigned_realignments)); - } else { - result.push_back(std::move(unassigned_reads)); - } - return boost::any {compute_realignment_summary(result, call)}; -} - -std::string Realignments::do_name() const -{ - return "RA"; -} - -std::vector Realignments::do_requirements() const -{ - return {"ReferenceContext", "OverlappingReads", "ReadAssignments"}; -} - -void serialise_reads(const std::vector& reads, std::ostringstream& ss) -{ - ss << "{"; - for (const auto& read : reads) { - ss << "[" << read.name() << "|" << mapped_begin(read) << "|" << read.cigar() << "]"; - } - ss << "}"; -} - -namespace { - -std::ostream& operator<<(std::ostream& os, const DirectionCounts& counts) -{ - os << counts.forward << "," << counts.reverse; - return os; -} - -std::ostream& operator<<(std::ostream& os, const MQHistogram& hist) -{ - std::copy(std::cbegin(hist), std::prev(std::cend(hist)), std::ostream_iterator {os, ","}); - os << hist.back(); - return os; -} - -std::ostream& operator<<(std::ostream& os, const Pileup& pileup) -{ - std::copy(std::cbegin(pileup.match_counts), std::cend(pileup.match_counts), std::ostream_iterator {os, ","}); - std::copy(std::cbegin(pileup.mismatch_counts), std::cend(pileup.mismatch_counts), std::ostream_iterator {os, ","}); - os << pileup.insertion_count << "," << pileup.deletion_count << "," - << pileup.match_quality_sum << "," << pileup.mismatch_quality_sum << "," << pileup.insertion_quality_sum; - return os; -} - -std::ostream& operator<<(std::ostream& os, const PileupWindow& window) -{ - std::copy(std::cbegin(window), std::prev(std::cend(window)), std::ostream_iterator {os, ":"}); - os << window.back(); - return os; -} - -std::ostream& operator<<(std::ostream& os, const HaplotypeSummary& summary) -{ - os << "{" << summary.strand_counts << "|" << summary.mq_hist << "|" << summary.pileups << "}"; - return os; -} - -} // namespace - -std::string Realignments::do_serialise(const ResultType& value) const -{ - const auto any = boost::get(value); - const auto& summary = *boost::any_cast(&any); - std::ostringstream ss {}; - std::copy(std::cbegin(summary), std::cend(summary), std::ostream_iterator {ss}); - return ss.str(); -} - -} // namespace csr -} // namespace octopus diff --git a/src/core/csr/measures/somatic_contamination.cpp b/src/core/csr/measures/somatic_contamination.cpp new file mode 100644 index 000000000..c25849e9d --- /dev/null +++ b/src/core/csr/measures/somatic_contamination.cpp @@ -0,0 +1,170 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "somatic_contamination.hpp" + +#include +#include + +#include +#include +#include + +#include "core/types/allele.hpp" +#include "core/types/haplotype.hpp" +#include "core/types/genotype.hpp" +#include "core/tools/read_assigner.hpp" +#include "io/variant/vcf_record.hpp" +#include "utils/genotype_reader.hpp" +#include "utils/append.hpp" +#include "is_somatic.hpp" +#include "../facets/samples.hpp" +#include "../facets/genotypes.hpp" +#include "../facets/read_assignments.hpp" + +namespace octopus { namespace csr { + +const std::string SomaticContamination::name_ = "SC"; + +std::unique_ptr SomaticContamination::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +template +void sort_unique(Container& values) +{ + std::sort(std::begin(values), std::end(values)); + values.erase(std::unique(std::begin(values), std::end(values)), std::end(values)); +} + +auto get_somatic_alleles(const VcfRecord& somatic, const std::vector& somatic_samples, + const std::vector& normal_samples) +{ + std::vector somatic_sample_alleles {}, normal_sample_alleles {}; + for (const auto& sample : somatic_samples) { + utils::append(get_called_alleles(somatic, sample, true).first, somatic_sample_alleles); + } + for (const auto& sample : normal_samples) { + utils::append(get_called_alleles(somatic, sample, true).first, normal_sample_alleles); + } + sort_unique(somatic_sample_alleles); sort_unique(normal_sample_alleles); + std::vector result {}; + result.reserve(somatic_sample_alleles.size()); + std::set_difference(std::cbegin(somatic_sample_alleles), std::cend(somatic_sample_alleles), + std::cbegin(normal_sample_alleles), std::cend(normal_sample_alleles), + std::back_inserter(result)); + return result; +} + +auto get_somatic_haplotypes(const Facet::GenotypeMap& genotypes, const std::vector& somatics) +{ + std::vector result {}; + if (!somatics.empty()) { + const auto allele_region = somatics.front().mapped_region(); + for (const auto& p :genotypes) { + const auto& overlapped_genotypes = overlap_range(p.second, allele_region); + if (size(overlapped_genotypes) == 1) { + const auto& genotype = overlapped_genotypes.front(); + for (const auto& haplotype : genotype) { + if (std::any_of(std::cbegin(somatics), std::cend(somatics), + [&] (const auto& somatic) { return haplotype.includes(somatic); })) { + result.push_back(haplotype); + } + } + } + } + sort_unique(result); + } + return result; +} + +auto get_somatic_haplotypes(const VcfRecord& somatic, const Facet::GenotypeMap& genotypes, + const std::vector& somatic_samples, const std::vector& normal_samples) +{ + const auto somatic_alleles = get_somatic_alleles(somatic, somatic_samples, normal_samples); + return get_somatic_haplotypes(genotypes, somatic_alleles); +} + +} // namespace + +Measure::ResultType SomaticContamination::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + boost::optional result {}; + if (is_somatic(call)) { + result = 0; + const auto& samples = get_value(facets.at("Samples")); + const auto somatic_status = boost::get>(IsSomatic(true).evaluate(call, facets)); + std::vector somatic_samples {}, normal_samples {}; + somatic_samples.reserve(samples.size()); normal_samples.reserve(samples.size()); + for (auto tup : boost::combine(samples, somatic_status)) { + if (tup.get<1>()) { + somatic_samples.push_back(tup.get<0>()); + } else { + normal_samples.push_back(tup.get<0>()); + } + } + const auto& genotypes = get_value(facets.at("Genotypes")); + const auto somatic_haplotypes = get_somatic_haplotypes(call, genotypes, somatic_samples, normal_samples); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; + Genotype somatic_genotype {static_cast(somatic_haplotypes.size() + 1)}; + HaplotypeProbabilityMap haplotype_priors {}; + haplotype_priors.reserve(somatic_haplotypes.size() + 1); + for (const auto& haplotype : somatic_haplotypes) { + somatic_genotype.emplace(haplotype); + haplotype_priors[haplotype] = -1; + } + for (const auto& sample : normal_samples) { + for (const auto& p : assignments.at(sample)) { + const auto overlapped_reads = copy_overlapped(p.second, call); + if (!overlapped_reads.empty()) { + const Haplotype& assigned_haplotype {p.first}; + if (!somatic_genotype.contains(assigned_haplotype)) { + auto dummy = somatic_genotype; + dummy.emplace(assigned_haplotype); + haplotype_priors[assigned_haplotype] = 0; + const auto support = compute_haplotype_support(dummy, overlapped_reads, haplotype_priors); + haplotype_priors.erase(assigned_haplotype); + for (const auto& somatic : somatic_haplotypes) { + if (support.count(somatic) == 1) { + *result += support.at(somatic).size(); + } + } + } else { + // This could happen if we don't call all 'somatic' alleles on the called somatic haplotype. + *result += overlapped_reads.size(); + } + } + } + } + } + return result; +} + +Measure::ResultCardinality SomaticContamination::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& SomaticContamination::do_name() const +{ + return name_; +} + +std::string SomaticContamination::do_describe() const +{ + return "Number of reads supporting a somatic haplotype in the normal"; +} + +std::vector SomaticContamination::do_requirements() const +{ + std::vector result {"Samples", "Genotypes", "ReadAssignments"}; + utils::append(IsSomatic(true).requirements(), result); + sort_unique(result); + return result; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/somatic_contamination.hpp b/src/core/csr/measures/somatic_contamination.hpp new file mode 100644 index 000000000..bc7fee8a6 --- /dev/null +++ b/src/core/csr/measures/somatic_contamination.hpp @@ -0,0 +1,33 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef somatic_contamination_hpp +#define somatic_contamination_hpp + +#include +#include + +#include "measure.hpp" + +namespace octopus { + +class VcfRecord; + +namespace csr { + +class SomaticContamination : public Measure +{ + const static std::string name_; + + std::unique_ptr do_clone() const override; + ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; + std::vector do_requirements() const override; +}; + +} // namespace csr +} // namespace octopus + +#endif diff --git a/src/core/csr/measures/str_length.cpp b/src/core/csr/measures/str_length.cpp new file mode 100644 index 000000000..71637f08c --- /dev/null +++ b/src/core/csr/measures/str_length.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "str_length.hpp" + +#include +#include + +#include + +#include "io/variant/vcf_record.hpp" +#include "utils/repeat_finder.hpp" +#include "../facets/reference_context.hpp" + +namespace octopus { namespace csr { + +const std::string STRLength::name_ = "STR_LENGTH"; + +std::unique_ptr STRLength::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +auto num_periods(const TandemRepeat& repeat) noexcept +{ + return region_size(repeat) / repeat.period; +} + +struct PeriodCountLess +{ + bool operator()(const TandemRepeat& lhs, const TandemRepeat& rhs) const noexcept + { + return num_periods(lhs) < num_periods(rhs); + } +}; + +} // namespace + +Measure::ResultType STRLength::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + int result {0}; + const auto& reference = get_value(facets.at("ReferenceContext")); + const auto repeats = find_exact_tandem_repeats(reference.sequence(), reference.mapped_region(), 1, 6); + const auto overlapping_repeats = overlap_range(repeats, call); + if (!empty(overlapping_repeats)) { + result = num_periods(*std::max_element(std::cbegin(overlapping_repeats), std::cend(overlapping_repeats), PeriodCountLess {})); + } + return result; +} + +Measure::ResultCardinality STRLength::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& STRLength::do_name() const +{ + return name_; +} + +std::string STRLength::do_describe() const +{ + return "Length of overlapping STR"; +} + +std::vector STRLength::do_requirements() const +{ + return {"ReferenceContext"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/unassigned_read_fraction.hpp b/src/core/csr/measures/str_length.hpp similarity index 60% rename from src/core/csr/measures/unassigned_read_fraction.hpp rename to src/core/csr/measures/str_length.hpp index eac70e4bf..d78f6db42 100644 --- a/src/core/csr/measures/unassigned_read_fraction.hpp +++ b/src/core/csr/measures/str_length.hpp @@ -1,8 +1,8 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. -#ifndef unassigned_read_fraction_hpp -#define unassigned_read_fraction_hpp +#ifndef str_length_hpp +#define str_length_hpp #include #include @@ -15,11 +15,14 @@ class VcfRecord; namespace csr { -class UnassignedReadFraction : public Measure +class STRLength : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; }; diff --git a/src/core/csr/measures/str_period.cpp b/src/core/csr/measures/str_period.cpp new file mode 100644 index 000000000..c9edd02f8 --- /dev/null +++ b/src/core/csr/measures/str_period.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "str_period.hpp" + +#include +#include + +#include + +#include "io/variant/vcf_record.hpp" +#include "utils/repeat_finder.hpp" +#include "../facets/reference_context.hpp" + +namespace octopus { namespace csr { + +const std::string STRPeriod::name_ = "STR_PERIOD"; + +std::unique_ptr STRPeriod::do_clone() const +{ + return std::make_unique(*this); +} + +namespace { + +struct PeriodLess +{ + bool operator()(const TandemRepeat& lhs, const TandemRepeat& rhs) const noexcept + { + return lhs.period < rhs.period; + } +}; + +} // namespace + +Measure::ResultType STRPeriod::do_evaluate(const VcfRecord& call, const FacetMap& facets) const +{ + int result {0}; + const auto& reference = get_value(facets.at("ReferenceContext")); + const auto repeats = find_exact_tandem_repeats(reference.sequence(), reference.mapped_region(), 1, 6); + const auto overlapping_repeats = overlap_range(repeats, call); + if (!empty(overlapping_repeats)) { + result = std::max_element(std::cbegin(overlapping_repeats), std::cend(overlapping_repeats), PeriodLess {})->period; + } + return result; +} + +Measure::ResultCardinality STRPeriod::do_cardinality() const noexcept +{ + return ResultCardinality::one; +} + +const std::string& STRPeriod::do_name() const +{ + return name_; +} + +std::string STRPeriod::do_describe() const +{ + return "Length of overlapping STR"; +} + +std::vector STRPeriod::do_requirements() const +{ + return {"ReferenceContext"}; +} + +} // namespace csr +} // namespace octopus diff --git a/src/core/csr/measures/max_genotype_quality.hpp b/src/core/csr/measures/str_period.hpp similarity index 60% rename from src/core/csr/measures/max_genotype_quality.hpp rename to src/core/csr/measures/str_period.hpp index e68ee7144..875f61599 100644 --- a/src/core/csr/measures/max_genotype_quality.hpp +++ b/src/core/csr/measures/str_period.hpp @@ -1,8 +1,8 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. -#ifndef max_genotype_quality_hpp -#define max_genotype_quality_hpp +#ifndef str_period_hpp +#define str_period_hpp #include #include @@ -15,11 +15,14 @@ class VcfRecord; namespace csr { -class MaxGenotypeQuality : public Measure +class STRPeriod : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; }; diff --git a/src/core/csr/measures/strand_bias.cpp b/src/core/csr/measures/strand_bias.cpp index 507c5eba5..2fcc81032 100644 --- a/src/core/csr/measures/strand_bias.cpp +++ b/src/core/csr/measures/strand_bias.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "strand_bias.hpp" @@ -15,13 +15,18 @@ #include #include "io/variant/vcf_record.hpp" +#include "io/variant/vcf_spec.hpp" #include "basics/aligned_read.hpp" #include "utils/maths.hpp" #include "utils/beta_distribution.hpp" +#include "utils/genotype_reader.hpp" +#include "../facets/samples.hpp" #include "../facets/read_assignments.hpp" namespace octopus { namespace csr { +const std::string StrandBias::name_ = "SB"; + StrandBias::StrandBias(const double critical_value) : min_medium_trigger_ {critical_value / 2} , min_big_trigger_ {critical_value / 8} @@ -35,9 +40,25 @@ std::unique_ptr StrandBias::do_clone() const return std::make_unique(*this); } -bool is_forward(const AlignedRead& read) noexcept +namespace { + +bool is_canonical(const VcfRecord::NucleotideSequence& allele) noexcept +{ + const static VcfRecord::NucleotideSequence deleted_allele {vcfspec::deletedBase}; + return !(allele == vcfspec::missingValue || allele == deleted_allele); +} + +bool has_called_alt_allele(const VcfRecord& call, const VcfRecord::SampleName& sample) { - return read.direction() == AlignedRead::Direction::forward; + if (!call.has_genotypes()) return true; + const auto& genotype = get_genotype(call, sample); + return std::any_of(std::cbegin(genotype), std::cend(genotype), + [&] (const auto& allele) { return allele != call.ref() && is_canonical(allele); }); +} + +bool is_evaluable(const VcfRecord& call, const VcfRecord::SampleName& sample) +{ + return has_called_alt_allele(call, sample) && call.is_heterozygous(sample); } struct DirectionCounts @@ -49,11 +70,13 @@ template DirectionCounts count_directions(const Container& reads, const GenomicRegion& call_region) { unsigned n_forward {0}, n_reverse {0}; - for (const auto& read : overlap_range(reads, call_region)) { - if (is_forward(read)) { - ++n_forward; - } else { - ++n_reverse; + for (const auto& read : reads) { + if (overlaps(read.get(), call_region)) { + if (is_forward_strand(read)) { + ++n_forward; + } else { + ++n_reverse; + } } } return {n_forward, n_reverse}; @@ -61,7 +84,7 @@ DirectionCounts count_directions(const Container& reads, const GenomicRegion& ca using DirectionCountVector = std::vector; -auto get_direction_counts(const HaplotypeSupportMap& support, const GenomicRegion& call_region, const unsigned prior = 1) +auto get_direction_counts(const AlleleSupportMap& support, const GenomicRegion& call_region, const unsigned prior = 1) { DirectionCountVector result {}; result.reserve(support.size()); @@ -117,17 +140,22 @@ double calculate_max_prob_different(const DirectionCountVector& direction_counts return result; } +} // namespace + Measure::ResultType StrandBias::do_evaluate(const VcfRecord& call, const FacetMap& facets) const { - const auto& assignments = get_value(facets.at("ReadAssignments")); - // TODO: What we should really do here is calculate which reads directly support each allele in the - // genotype by looking if each supporting read overlaps the allele given the realignment to the called haplotype. - // The current approach of just removing non-overlapping reads may not work optimally in complex indel regions. - boost::optional result {}; - for (const auto& p : assignments) { - if (call.is_heterozygous(p.first)) { - const auto& supporting_reads = p.second; - const auto direction_counts = get_direction_counts(supporting_reads, mapped_region(call)); + const auto& samples = get_value(facets.at("Samples")); + const auto& assignments = get_value(facets.at("ReadAssignments")).support; + std::vector> result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + boost::optional sample_result {}; + if (is_evaluable(call, sample)) { + std::vector alleles; bool has_ref; + std::tie(alleles, has_ref) = get_called_alleles(call, sample, true); + assert(!alleles.empty()); + const auto sample_allele_support = compute_allele_support(alleles, assignments.at(sample)); + const auto direction_counts = get_direction_counts(sample_allele_support, mapped_region(call)); double prob; if (use_resampling_) { prob = calculate_max_prob_different(direction_counts, small_sample_size_, min_difference_); @@ -145,24 +173,36 @@ Measure::ResultType StrandBias::do_evaluate(const VcfRecord& call, const FacetMa } else { prob = calculate_max_prob_different(direction_counts, big_sample_size_, min_difference_); } - if (result) { - result = std::max(*result, prob); - } else { - result = prob; - } + sample_result = prob; } + result.push_back(sample_result); } return result; } -std::string StrandBias::do_name() const +Measure::ResultCardinality StrandBias::do_cardinality() const noexcept +{ + return ResultCardinality::num_samples; +} + +const std::string& StrandBias::do_name() const { - return "SB"; + return name_; +} + +std::string StrandBias::do_describe() const +{ + return "Strand bias of reads based on haplotype support"; } std::vector StrandBias::do_requirements() const { - return {"ReadAssignments"}; + return {"Samples", "ReadAssignments"}; +} + +bool StrandBias::is_equal(const Measure& other) const noexcept +{ + return min_medium_trigger_ == static_cast(other).min_medium_trigger_; } } // namespace csr diff --git a/src/core/csr/measures/strand_bias.hpp b/src/core/csr/measures/strand_bias.hpp index 2574f33e6..76989ff6c 100644 --- a/src/core/csr/measures/strand_bias.hpp +++ b/src/core/csr/measures/strand_bias.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef strand_bias_hpp @@ -17,10 +17,14 @@ namespace csr { class StrandBias : public Measure { + const static std::string name_; std::unique_ptr do_clone() const override; ResultType do_evaluate(const VcfRecord& call, const FacetMap& facets) const override; - std::string do_name() const override; + ResultCardinality do_cardinality() const noexcept override; + const std::string& do_name() const override; + std::string do_describe() const override; std::vector do_requirements() const override; + bool is_equal(const Measure& other) const noexcept override; double min_difference_ = 0.25; std::size_t small_sample_size_ = 200, medium_sample_size_ = 1'000, diff --git a/src/core/csr/measures/unassigned_read_fraction.cpp b/src/core/csr/measures/unassigned_read_fraction.cpp deleted file mode 100644 index 959c0b820..000000000 --- a/src/core/csr/measures/unassigned_read_fraction.cpp +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2017 Daniel Cooke -// Use of this source code is governed by the MIT license that can be found in the LICENSE file. - -#include "unassigned_read_fraction.hpp" - -#include -#include - -#include - -#include "core/tools/read_assigner.hpp" -#include "core/types/allele.hpp" -#include "io/variant/vcf_record.hpp" -#include "io/variant/vcf_spec.hpp" -#include "utils/genotype_reader.hpp" -#include "../facets/samples.hpp" -#include "../facets/overlapping_reads.hpp" -#include "../facets/read_assignments.hpp" - -namespace octopus { namespace csr { - -std::unique_ptr UnassignedReadFraction::do_clone() const -{ - return std::make_unique(*this); -} - -Measure::ResultType UnassignedReadFraction::do_evaluate(const VcfRecord& call, const FacetMap& facets) const -{ - const auto& samples = get_value(facets.at("Samples")); - const auto& reads = get_value(facets.at("OverlappingReads")); - const auto& assignments = get_value(facets.at("ReadAssignments")); - boost::optional result {}; - for (const auto& sample : samples) { - const auto num_overlapping_reads = count_overlapped(reads.at(sample), call); - if (num_overlapping_reads > 0) { - std::vector alleles; bool has_ref; - std::tie(alleles, has_ref) = get_called_alleles(call, sample); - std::vector assigned_reads {}; - const auto& support = assignments.at(sample); - for (const auto& allele : alleles) { - std::vector allele_realignments {}; - for (const auto& h : support) { - const auto& haplotype = h.first; - const auto& supporting_reads = h.second; - if (!supporting_reads.empty() && haplotype.includes(allele)) { - for (const auto& read : supporting_reads) { - if (overlaps(read, allele)) { - assigned_reads.push_back(read); - } - } - } - } - } - const auto assigned_fraction = static_cast(assigned_reads.size()) / num_overlapping_reads; - if (result) { - result = std::max(*result, 1 - assigned_fraction); - } else { - result = 1 - assigned_fraction; - } - } - } - return result; -} - -std::string UnassignedReadFraction::do_name() const -{ - return "URF"; -} - -std::vector UnassignedReadFraction::do_requirements() const -{ - return {"Samples", "OverlappingReads", "ReadAssignments"}; -} - -} // namespace csr -} // namespace octopus diff --git a/src/core/models/error/error_model_factory.cpp b/src/core/models/error/error_model_factory.cpp index 03c2c8c18..783392547 100644 --- a/src/core/models/error/error_model_factory.cpp +++ b/src/core/models/error/error_model_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "error_model_factory.hpp" diff --git a/src/core/models/error/error_model_factory.hpp b/src/core/models/error/error_model_factory.hpp index def0cb182..92adf8090 100644 --- a/src/core/models/error/error_model_factory.hpp +++ b/src/core/models/error/error_model_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef error_model_factory_hpp diff --git a/src/core/models/error/hiseq_indel_error_model.cpp b/src/core/models/error/hiseq_indel_error_model.cpp index db6ecb122..126ed9d2b 100644 --- a/src/core/models/error/hiseq_indel_error_model.cpp +++ b/src/core/models/error/hiseq_indel_error_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "hiseq_indel_error_model.hpp" @@ -97,10 +97,9 @@ HiSeqIndelErrorModel::do_evaluate(const Haplotype& haplotype, PenaltyVector& gap } } switch (max_repeat.period) { - case 1: return 3; - case 2: return 5; - case 3: return 5; - default: return defaultGapExtension_; + case 2: + case 3: return 1; + default: return defaultGapExtension_; } } diff --git a/src/core/models/error/hiseq_indel_error_model.hpp b/src/core/models/error/hiseq_indel_error_model.hpp index 415ae8e90..bd391d484 100644 --- a/src/core/models/error/hiseq_indel_error_model.hpp +++ b/src/core/models/error/hiseq_indel_error_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef hiseq_indel_error_model_hpp @@ -31,12 +31,12 @@ class HiSeqIndelErrorModel : public IndelErrorModel }}; static constexpr std::array diNucleotideTandemRepeatErrors_ = {{ - 60,60,48,45,43,41,39,35,31,28,25,21,19,17,15,13,12,11,11,10, + 60,60,50,46,42,37,31,27,23,22,21,20,19,17,15,13,12,11,11,10, 9,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1 }}; static constexpr std::array triNucleotideTandemRepeatErrors_ = {{ - 60,60,50,48,46,45,42,39,35,31,28,25,22,20,16,14,13,12,12,11, + 60,60,50,48,46,44,40,38,35,31,28,25,22,20,16,14,13,12,12,11, 10,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1 }}; diff --git a/src/core/models/error/hiseq_snv_error_model.cpp b/src/core/models/error/hiseq_snv_error_model.cpp index a2d343e32..31fad34fe 100644 --- a/src/core/models/error/hiseq_snv_error_model.cpp +++ b/src/core/models/error/hiseq_snv_error_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "hiseq_snv_error_model.hpp" @@ -114,7 +114,7 @@ auto make_substitution_mask(const Haplotype& haplotype) std::vector result(sequence_size(haplotype)); auto mask_itr = std::begin(result); for (const auto& op : cigar) { - if (op.advances_sequence()) { + if (advances_sequence(op)) { mask_itr = std::fill_n(mask_itr, op.size(), op.flag() == CigarOperation::Flag::substitution); } } diff --git a/src/core/models/error/hiseq_snv_error_model.hpp b/src/core/models/error/hiseq_snv_error_model.hpp index 6705985fa..879f458c6 100644 --- a/src/core/models/error/hiseq_snv_error_model.hpp +++ b/src/core/models/error/hiseq_snv_error_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef hiseq_snv_error_model_hpp diff --git a/src/core/models/error/indel_error_model.cpp b/src/core/models/error/indel_error_model.cpp index 0f0553073..c00526c4c 100644 --- a/src/core/models/error/indel_error_model.cpp +++ b/src/core/models/error/indel_error_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "indel_error_model.hpp" diff --git a/src/core/models/error/indel_error_model.hpp b/src/core/models/error/indel_error_model.hpp index 17ade91eb..3f7dc9f06 100644 --- a/src/core/models/error/indel_error_model.hpp +++ b/src/core/models/error/indel_error_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef indel_error_model_hpp diff --git a/src/core/models/error/snv_error_model.cpp b/src/core/models/error/snv_error_model.cpp index 0d9daf4de..6c89d9e5c 100644 --- a/src/core/models/error/snv_error_model.cpp +++ b/src/core/models/error/snv_error_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "snv_error_model.hpp" diff --git a/src/core/models/error/snv_error_model.hpp b/src/core/models/error/snv_error_model.hpp index 195b66120..1578c9a12 100644 --- a/src/core/models/error/snv_error_model.hpp +++ b/src/core/models/error/snv_error_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef snv_error_model_hpp diff --git a/src/core/models/error/x10_indel_error_model.cpp b/src/core/models/error/x10_indel_error_model.cpp index 7cd2f98d0..c74195560 100644 --- a/src/core/models/error/x10_indel_error_model.cpp +++ b/src/core/models/error/x10_indel_error_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "x10_indel_error_model.hpp" @@ -97,10 +97,9 @@ X10IndelErrorModel::do_evaluate(const Haplotype& haplotype, PenaltyVector& gap_o } } switch (max_repeat.period) { - case 1: return 2; - case 2: return 4; - case 3: return 4; - default: return defaultGapExtension_; + case 2: + case 3: return 1; + default: return defaultGapExtension_; } } diff --git a/src/core/models/error/x10_indel_error_model.hpp b/src/core/models/error/x10_indel_error_model.hpp index 7e3970652..6f8d8acb0 100644 --- a/src/core/models/error/x10_indel_error_model.hpp +++ b/src/core/models/error/x10_indel_error_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef x10_indel_error_model_hpp @@ -26,26 +26,27 @@ class X10IndelErrorModel : public IndelErrorModel private: static constexpr std::array homopolymerErrors_ = {{ - 60,59,48,43,38,32,28,23,20,18,16,15,14,13,12,11,10,10,9, - 9,8,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1,1 + 60,60,49,44,40,35,28,24,21,19,18,16,15,15,14,13,12,11,11,10, + 9,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1 }}; static constexpr std::array diNucleotideTandemRepeatErrors_ = {{ - 60,58,47,42,37,31,27,22,19,18,16,15,14,13,12,11,10,10,10, - 9,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1,1 + 60,59,49,45,41,36,30,26,22,21,20,19,18,17,15,13,12,11,11,10, + 9,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1 }}; static constexpr std::array triNucleotideTandemRepeatErrors_ = {{ - 60,57,46,41,36,30,28,23,20,19,17,16,15,14,13,12,11,11,10, - 9,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1,1 + 60,59,49,47,45,43,39,37,34,30,27,24,21,18,16,14,13,12,12,11, + 10,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1 }}; + static constexpr std::array polyNucleotideTandemRepeatErrors_ = {{ - 60,60,51,45,45,45,45,45,23,20,19,17,16,15,14,13,12,11,11,10, + 60,60,50,44,44,44,44,44,22,19,18,16,16,15,14,13,12,11,11,10, 9,9,8,8,7,7,7,6,6,6,5,5,5,4,4,4,3,3,3,3,2,2,2,2,2,1,1,1,1,1 }}; - static constexpr PenaltyType defaultGapExtension_ = 2; + static constexpr PenaltyType defaultGapExtension_ = 3; virtual std::unique_ptr do_clone() const override; virtual PenaltyType do_evaluate(const Haplotype& haplotype, PenaltyVector& gap_open_penalties) const override; diff --git a/src/core/models/error/x10_snv_error_model.cpp b/src/core/models/error/x10_snv_error_model.cpp index 3eea656ec..11150fd14 100644 --- a/src/core/models/error/x10_snv_error_model.cpp +++ b/src/core/models/error/x10_snv_error_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "x10_snv_error_model.hpp" @@ -114,7 +114,7 @@ auto make_substitution_mask(const Haplotype& haplotype) std::vector result(sequence_size(haplotype)); auto mask_itr = std::begin(result); for (const auto& op : cigar) { - if (op.advances_sequence()) { + if (advances_sequence(op)) { mask_itr = std::fill_n(mask_itr, op.size(), op.flag() == CigarOperation::Flag::substitution); } } diff --git a/src/core/models/error/x10_snv_error_model.hpp b/src/core/models/error/x10_snv_error_model.hpp index 2d6c5809d..b24617ce9 100644 --- a/src/core/models/error/x10_snv_error_model.hpp +++ b/src/core/models/error/x10_snv_error_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef x10_snv_error_model_hpp diff --git a/src/core/models/genotype/cancer_genotype_prior_model.cpp b/src/core/models/genotype/cancer_genotype_prior_model.cpp index 07344636f..1276bcca3 100644 --- a/src/core/models/genotype/cancer_genotype_prior_model.cpp +++ b/src/core/models/genotype/cancer_genotype_prior_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "cancer_genotype_prior_model.hpp" @@ -34,16 +34,24 @@ const SomaticMutationModel& CancerGenotypePriorModel::mutation_model() const noe double CancerGenotypePriorModel::evaluate(const CancerGenotype& genotype) const { - const auto& germline = genotype.germline_genotype(); - const auto& somatic = genotype.somatic_element(); - const auto germline_log_prior = germline_model_.get().evaluate(germline); - return germline_log_prior + ln_probability_of_somatic_given_genotype(somatic, germline); + const auto& germline_genotype = genotype.germline(); + auto result = germline_model_.get().evaluate(germline_genotype); + // Model assumes independence between somatic haplotypes given germline genotype + for (const auto& somatic_haplotype : genotype.somatic()) { + result += ln_probability_of_somatic_given_genotype(somatic_haplotype, germline_genotype); + } + return result; } -double CancerGenotypePriorModel::evaluate(const std::vector& germline_indices, const unsigned somatic_index) const +double CancerGenotypePriorModel::evaluate(const CancerGenotypeIndex& genotype) const { - return germline_model_.get().evaluate(germline_indices) - + ln_probability_of_somatic_given_genotype(somatic_index, germline_indices); + const auto germline_indices = genotype.germline; + auto result = germline_model_.get().evaluate(germline_indices); + // Model assumes independence between somatic haplotypes given germline genotype + for (const auto& somatic_index : genotype.somatic) { + result += ln_probability_of_somatic_given_genotype(somatic_index, germline_indices); + } + return result; } double CancerGenotypePriorModel::ln_probability_of_somatic_given_haplotype(const Haplotype& somatic, const Haplotype& germline) const diff --git a/src/core/models/genotype/cancer_genotype_prior_model.hpp b/src/core/models/genotype/cancer_genotype_prior_model.hpp index 7ceeab7fb..752c6f58b 100644 --- a/src/core/models/genotype/cancer_genotype_prior_model.hpp +++ b/src/core/models/genotype/cancer_genotype_prior_model.hpp @@ -1,10 +1,9 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef cancer_genotype_prior_model_hpp #define cancer_genotype_prior_model_hpp -#include #include #include #include @@ -40,7 +39,7 @@ class CancerGenotypePriorModel const SomaticMutationModel& mutation_model() const noexcept; double evaluate(const CancerGenotype& genotype) const; - double evaluate(const std::vector& germline_indices, unsigned somatic_index) const; + double evaluate(const CancerGenotypeIndex& genotype) const; private: std::reference_wrapper germline_model_; @@ -60,9 +59,9 @@ inline auto get_ploidy(const Genotype& genotype) noexcept return genotype.ploidy(); } -inline auto get_ploidy(const std::vector& genotype_indices) noexcept +inline auto get_ploidy(const GenotypeIndex& genotype) noexcept { - return static_cast(genotype_indices.size()); + return static_cast(genotype.size()); } } // namespace detail @@ -105,14 +104,13 @@ double CancerGenotypePriorModel::ln_probability_of_somatic_given_genotype(const // non-member methods template -auto calculate_log_priors(const Container& genotypes, const CancerGenotypePriorModel& model) +auto calculate_log_priors(const Container& genotypes, const CancerGenotypePriorModel& model, + const bool normalise = false) { - static_assert(std::is_same>::value, - "genotypes must contain CancerGenotype's"); std::vector result(genotypes.size()); std::transform(std::cbegin(genotypes), std::cend(genotypes), std::begin(result), [&model] (const auto& genotype) { return model.evaluate(genotype); }); - maths::normalise_logs(result); + if (normalise) maths::normalise_logs(result); return result; } diff --git a/src/core/models/genotype/coalescent_genotype_prior_model.hpp b/src/core/models/genotype/coalescent_genotype_prior_model.hpp index 0b2175023..8a4050fd6 100644 --- a/src/core/models/genotype/coalescent_genotype_prior_model.hpp +++ b/src/core/models/genotype/coalescent_genotype_prior_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef coalescent_genotype_prior_model_hpp @@ -32,7 +32,7 @@ class CoalescentGenotypePriorModel : public GenotypePriorModel { return model_.evaluate(genotype); } - double do_evaluate(const std::vector& genotype) const override + double do_evaluate(const GenotypeIndex& genotype) const override { return model_.evaluate(genotype); } diff --git a/src/core/models/genotype/coalescent_population_prior_model.cpp b/src/core/models/genotype/coalescent_population_prior_model.cpp index d6dc861ad..7eeed3f76 100644 --- a/src/core/models/genotype/coalescent_population_prior_model.cpp +++ b/src/core/models/genotype/coalescent_population_prior_model.cpp @@ -1,10 +1,20 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "coalescent_population_prior_model.hpp" namespace octopus { +CoalescentPopulationPriorModel::CoalescentPopulationPriorModel(CoalescentModel segregation_model) +: segregation_model_ {std::move(segregation_model)} +, genotype_model_ {} +{} + +CoalescentPopulationPriorModel::CoalescentPopulationPriorModel(CoalescentModel segregation_model, HardyWeinbergModel genotype_model) +: segregation_model_ {std::move(segregation_model)} +, genotype_model_ {std::move(genotype_model)} +{} + template auto sum_sizes(const std::vector>& values) noexcept { @@ -19,10 +29,10 @@ auto sum_sizes(const std::vector>>& [] (auto curr, const auto& v) noexcept { return curr + v.get().size(); }); } -double CoalescentPopulationPriorModel::do_evaluate(const std::vector>& indices) const +double CoalescentPopulationPriorModel::evaluate_segregation_model(const std::vector& indices) const { if (indices.size() == 1) { - return model_.evaluate(indices.front()); + return segregation_model_.evaluate(indices.front()); } const auto num_indices = sum_sizes(indices); index_buffer_.resize(num_indices); @@ -30,13 +40,13 @@ double CoalescentPopulationPriorModel::do_evaluate(const std::vector& indices) const +double CoalescentPopulationPriorModel::evaluate_segregation_model(const std::vector& indices) const { if (indices.size() == 1) { - return model_.evaluate(indices.front().get()); + return segregation_model_.evaluate(indices.front().get()); } const auto num_indices = sum_sizes(indices); index_buffer_.resize(num_indices); @@ -44,7 +54,7 @@ double CoalescentPopulationPriorModel::do_evaluate(const std::vector #include "population_prior_model.hpp" +#include "hardy_weinberg_model.hpp" #include "../mutation/coalescent_model.hpp" -#include "timers.hpp" - namespace octopus { class CoalescentPopulationPriorModel : public PopulationPriorModel @@ -26,7 +25,8 @@ class CoalescentPopulationPriorModel : public PopulationPriorModel CoalescentPopulationPriorModel() = delete; - CoalescentPopulationPriorModel(CoalescentModel model) : model_ {std::move(model)} {} + CoalescentPopulationPriorModel(CoalescentModel segregation_model); + CoalescentPopulationPriorModel(CoalescentModel segregation_model, HardyWeinbergModel genotype_model); CoalescentPopulationPriorModel(const CoalescentPopulationPriorModel&) = default; CoalescentPopulationPriorModel& operator=(const CoalescentPopulationPriorModel&) = default; @@ -38,37 +38,59 @@ class CoalescentPopulationPriorModel : public PopulationPriorModel private: using HaplotypeReference = std::reference_wrapper; - CoalescentModel model_; + CoalescentModel segregation_model_; + HardyWeinbergModel genotype_model_; + mutable std::vector index_buffer_; - + double do_evaluate(const std::vector>& genotypes) const override { - return do_evaluate_helper(genotypes); + return evaluate_helper(genotypes); } double do_evaluate(const std::vector& genotypes) const override { - return do_evaluate_helper(genotypes); + return evaluate_helper(genotypes); + } + double do_evaluate(const std::vector& indices) const override + { + return evaluate_helper(indices); + } + double do_evaluate(const std::vector& indices) const override + { + return evaluate_helper(indices); } - double do_evaluate(const std::vector>& indices) const override; - double do_evaluate(const std::vector& indices) const override; void do_prime(const std::vector& haplotypes) override { - model_.prime(haplotypes); + segregation_model_.prime(haplotypes); } void do_unprime() noexcept override { - model_.unprime(); + segregation_model_.unprime(); } bool check_is_primed() const noexcept override { - return model_.is_primed(); + return segregation_model_.is_primed(); } - template - double do_evaluate_helper(const Container& genotypes) const; - + template + double evaluate_helper(const Range& genotypes) const; + template + double evaluate_segregation_model(const Range& genotypes) const; + double evaluate_segregation_model(const std::vector>& indices) const; + double evaluate_segregation_model(const std::vector& indices) const; }; +template +double CoalescentPopulationPriorModel::evaluate_helper(const Range& genotypes) const +{ + // p({g_1, ..., g_n}) = p(g_1 u ... u g_n) p({g_1, ..., g_n} | g_1 u ... u g_n) + // => ln p({g_1, ..., g_n}) = ln p(g_1 u ... u g_n) + ln p({g_1, ..., g_n} | g_1 u ... u g_n) + // i.e The prior probability of observing a particular combination of genotypes is the + // probability the haplotypes defined by the set of genotypes segregate, times the probability + // of the particular genotypes given the haplotypes segregate. + return evaluate_segregation_model(genotypes) + genotype_model_.evaluate(genotypes); +} + namespace detail { template @@ -104,10 +126,10 @@ inline auto ploidy(const Genotype& genotype) noexcept } // namespace detail -template -double CoalescentPopulationPriorModel::do_evaluate_helper(const Container& genotypes) const +template +double CoalescentPopulationPriorModel::evaluate_segregation_model(const Range& genotypes) const { - if (genotypes.size() == 1) return model_.evaluate(detail::get(genotypes.front())); + if (genotypes.size() == 1) return segregation_model_.evaluate(detail::get(genotypes.front())); if (genotypes.size() == 2) { const auto ploidy1 = detail::ploidy(genotypes[0]); const auto ploidy2 = detail::ploidy(genotypes[1]); @@ -115,12 +137,12 @@ double CoalescentPopulationPriorModel::do_evaluate_helper(const Container& genot if (ploidy1 == 1) { using detail::get; const std::array haplotypes {get(genotypes[0], 0), get(genotypes[0], 0)}; - return model_.evaluate(haplotypes); + return segregation_model_.evaluate(haplotypes); } else if (ploidy1 == 2) { using detail::get; const std::array haplotypes {get(genotypes[0], 0), get(genotypes[0], 1), get(genotypes[1], 0), get(genotypes[1], 1)}; - return model_.evaluate(haplotypes); + return segregation_model_.evaluate(haplotypes); } } } @@ -129,7 +151,7 @@ double CoalescentPopulationPriorModel::do_evaluate_helper(const Container& genot for (const auto& genotype : genotypes) { detail::append(genotype, haplotypes); } - return model_.evaluate(haplotypes); + return segregation_model_.evaluate(haplotypes); } } // namespace octopus diff --git a/src/core/models/genotype/genotype_prior_model.hpp b/src/core/models/genotype/genotype_prior_model.hpp index edea63659..f50eb6842 100644 --- a/src/core/models/genotype/genotype_prior_model.hpp +++ b/src/core/models/genotype/genotype_prior_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef genotype_prior_model_hpp @@ -26,11 +26,11 @@ class GenotypePriorModel bool is_primed() const noexcept { return check_is_primed(); } double evaluate(const Genotype& genotype) const { return do_evaluate(genotype); } - double evaluate(const std::vector& genotype_indices) const { return do_evaluate(genotype_indices); } + double evaluate(const GenotypeIndex& genotype) const { return do_evaluate(genotype); } private: virtual double do_evaluate(const Genotype& genotype) const = 0; - virtual double do_evaluate(const std::vector& genotype) const = 0; + virtual double do_evaluate(const GenotypeIndex& genotype) const = 0; virtual void do_prime(const std::vector& haplotypes) {}; virtual void do_unprime() noexcept {}; virtual bool check_is_primed() const noexcept = 0; diff --git a/src/core/models/genotype/germline_likelihood_model.cpp b/src/core/models/genotype/germline_likelihood_model.cpp index 96c2fbda6..77ec3acde 100644 --- a/src/core/models/genotype/germline_likelihood_model.cpp +++ b/src/core/models/genotype/germline_likelihood_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "germline_likelihood_model.hpp" diff --git a/src/core/models/genotype/germline_likelihood_model.hpp b/src/core/models/genotype/germline_likelihood_model.hpp index e0773b110..fba814de9 100644 --- a/src/core/models/genotype/germline_likelihood_model.hpp +++ b/src/core/models/genotype/germline_likelihood_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef germline_likelihood_model_hpp diff --git a/src/core/models/genotype/hardy_weinberg_model.cpp b/src/core/models/genotype/hardy_weinberg_model.cpp new file mode 100644 index 000000000..1f5da2822 --- /dev/null +++ b/src/core/models/genotype/hardy_weinberg_model.cpp @@ -0,0 +1,336 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "hardy_weinberg_model.hpp" + +#include +#include +#include +#include +#include + +#include "utils/maths.hpp" + +#include + +namespace octopus { + +HardyWeinbergModel::HardyWeinbergModel(Haplotype reference) +: reference_ {std::move(reference)} +, reference_idx_ {} +, haplotype_frequencies_ {} +, haplotype_idx_frequencies_ {} +, empirical_ {false} +{} + +HardyWeinbergModel::HardyWeinbergModel(unsigned reference_idx) +: reference_ {} +, reference_idx_ {reference_idx} +, haplotype_frequencies_ {} +, haplotype_idx_frequencies_ {} +, empirical_ {false} +{} + +HardyWeinbergModel::HardyWeinbergModel(HaplotypeFrequencyMap haplotype_frequencies) +: reference_ {} +, reference_idx_ {} +, haplotype_frequencies_ {std::move(haplotype_frequencies)} +, haplotype_idx_frequencies_ {} +, empirical_ {true} +{} + +HardyWeinbergModel::HardyWeinbergModel(HaplotypeFrequencyVector haplotype_frequencies) +: reference_ {} +, reference_idx_ {} +, haplotype_frequencies_ {} +, haplotype_idx_frequencies_ {std::move(haplotype_frequencies)} +, empirical_ {true} +{} + +void HardyWeinbergModel::set_frequencies(HaplotypeFrequencyMap haplotype_frequencies) +{ + haplotype_frequencies_ = std::move(haplotype_frequencies); + empirical_ = true; +} + +void HardyWeinbergModel::set_frequencies(HaplotypeFrequencyVector haplotype_frequencies) +{ + haplotype_idx_frequencies_ = std::move(haplotype_frequencies); + empirical_ = true; +} + +HardyWeinbergModel::HaplotypeFrequencyMap& HardyWeinbergModel::frequencies() noexcept +{ + return haplotype_frequencies_; +} + +HardyWeinbergModel::HaplotypeFrequencyVector& HardyWeinbergModel::index_frequencies() noexcept +{ + return haplotype_idx_frequencies_; +} + +namespace { + +auto ln_hardy_weinberg_haploid(const Genotype& genotype, + const HardyWeinbergModel::HaplotypeFrequencyMap& haplotype_frequencies) +{ + return std::log(haplotype_frequencies.at(genotype[0])); +} + +auto ln_hardy_weinberg_diploid(const Genotype& genotype, + const HardyWeinbergModel::HaplotypeFrequencyMap& haplotype_frequencies) +{ + if (genotype.is_homozygous()) { + return 2 * std::log(haplotype_frequencies.at(genotype[0])); + } + static const double ln2 {std::log(2.0)}; + return std::log(haplotype_frequencies.at(genotype[0])) + std::log(haplotype_frequencies.at(genotype[1])) + ln2; +} + +auto ln_hardy_weinberg_polyploid(const Genotype& genotype, + const HardyWeinbergModel::HaplotypeFrequencyMap& haplotype_frequencies) +{ + auto unique_haplotypes = genotype.copy_unique(); + std::vector occurences {}; + occurences.reserve(unique_haplotypes.size()); + double r {0}; + for (const auto& haplotype : unique_haplotypes) { + auto num_occurences = genotype.count(haplotype); + occurences.push_back(num_occurences); + r += num_occurences * std::log(haplotype_frequencies.at(haplotype)); + } + return maths::log_multinomial_coefficient(occurences) + r; +} + +auto ln_hardy_weinberg_haploid(const GenotypeIndex& genotype, + const HardyWeinbergModel::HaplotypeFrequencyVector& haplotype_frequencies) +{ + return std::log(haplotype_frequencies[genotype[0]]); +} + +auto ln_hardy_weinberg_diploid(const GenotypeIndex& genotype, + const HardyWeinbergModel::HaplotypeFrequencyVector& haplotype_frequencies) +{ + if (genotype[0] == genotype[1]) { + return 2 * std::log(haplotype_frequencies[genotype[0]]); + } + static const double ln2 {std::log(2.0)}; + return std::log(haplotype_frequencies[genotype[0]]) + std::log(haplotype_frequencies[genotype[1]]) + ln2; +} + +auto ln_hardy_weinberg_polyploid(const GenotypeIndex& genotype, + const HardyWeinbergModel::HaplotypeFrequencyVector& haplotype_frequencies) +{ + std::vector counts(haplotype_frequencies.size()); + for (auto idx : genotype) ++counts[idx]; + return maths::log_multinomial_pdf<>(counts, haplotype_frequencies); +} + +template +auto sum(const Range& values) noexcept +{ + using T = typename Range::value_type; + return std::accumulate(std::cbegin(values), std::cend(values), T {0}); +} + +std::vector to_frequencies(const std::vector& counts) +{ + std::vector result(counts.size()); + const auto norm = static_cast(sum(counts)); + std::transform(std::cbegin(counts), std::cend(counts), std::begin(result), + [norm] (auto count) noexcept { return static_cast(count) / norm; }); + return result; +} + +} // namespace + +double HardyWeinbergModel::evaluate(const Genotype& genotype) const +{ + if (empirical_) { + switch (genotype.ploidy()) { + case 1 : return ln_hardy_weinberg_haploid(genotype, haplotype_frequencies_); + case 2 : return ln_hardy_weinberg_diploid(genotype, haplotype_frequencies_); + default: return ln_hardy_weinberg_polyploid(genotype, haplotype_frequencies_); + } + } else { + static const double ln2 {std::log(2.0)}, ln3 {std::log(3.0)}; + if (is_haploid(genotype)) { + return reference_ && genotype.contains(*reference_) ? -ln2 : 0.0; + } + if (is_diploid(genotype)) { + if (reference_ && genotype.contains(*reference_)) { + return genotype.is_homozygous() ? -ln2 : -ln3; + } else { + return genotype.is_homozygous() ? -ln2 : 0.0; + } + } + auto counts = genotype.unique_counts(); + if (reference_ && !genotype.contains(*reference_)) { + counts.push_back(1); + } + auto probs = to_frequencies(counts); + return maths::log_multinomial_pdf(counts, probs); + } +} + +namespace { + +template +void unique_counts(const Range& range, std::vector& result) +{ + for (auto itr = std::cbegin(range), last = std::cend(range); itr != last;) { + auto next = std::find_if_not(std::next(itr), last, [itr] (const auto& x) { return x == *itr; }); + result.push_back(std::distance(itr, next)); + itr = next; + } +} + +} // namespace + +double HardyWeinbergModel::evaluate(const GenotypeIndex& genotype) const +{ + assert(!genotype.empty()); + if (empirical_) { + switch (genotype.size()) { + case 1 : return ln_hardy_weinberg_haploid(genotype, haplotype_idx_frequencies_); + case 2 : return ln_hardy_weinberg_diploid(genotype, haplotype_idx_frequencies_); + default: return ln_hardy_weinberg_polyploid(genotype, haplotype_idx_frequencies_); + } + } else { + static const double ln2 {std::log(2.0)}, ln3 {std::log(3.0)}; + if (genotype.size() == 1) { + return reference_idx_ && genotype[0] == *reference_idx_ ? -ln2 : 0.0; + } + if (genotype.size() == 2) { + if (reference_idx_ && !(genotype[0] == *reference_idx_ || genotype[1] == *reference_idx_)) { + return genotype[0] == genotype[1] ? -ln2 : -ln3; + } else { + return genotype[0] == genotype[1] ? -ln2 : 0.0; + } + } + std::vector counts {}; + counts.reserve(genotype.size()); + if (std::is_sorted(std::cbegin(genotype), std::cend(genotype))) { + unique_counts(genotype, counts); + } else { + auto sorted_genotype = genotype; + std::sort(std::begin(sorted_genotype), std::end(sorted_genotype)); + unique_counts(sorted_genotype, counts); + } + if (reference_idx_ && std::find(std::cbegin(genotype), std::cend(genotype), *reference_idx_) == std::cend(genotype)) { + counts.push_back(1); + } + auto probs = to_frequencies(counts); + return maths::log_multinomial_pdf(counts, probs); + } +} + +namespace { + +template const T& get(const T& value) noexcept { return value; } +template const T& get(std::reference_wrapper value) noexcept { return value.get(); } + +template +auto sum_plodies(const Range& genotypes) noexcept +{ + return std::accumulate(std::cbegin(genotypes), std::cend(genotypes), 0u, + [] (auto curr, const auto& g) { return curr + get(g).ploidy(); }); +} + +template +void fill_frequencies(const Range& genotypes, HardyWeinbergModel::HaplotypeFrequencyMap& result) +{ + const auto n = sum_plodies(genotypes); + const auto weight = 1.0 / n; + for (const auto& genotype : genotypes) { + for (const auto& haplotype : get(genotype)) { + result[haplotype] += weight; + } + } +} + +template +void fill_frequencies(const Range& genotypes, HardyWeinbergModel::HaplotypeFrequencyVector& result) +{ + unsigned max_haplotype_idx {0}, n {0}; + for (const auto& genotype : genotypes) { + for (auto haplotype_idx : get(genotype)) { + max_haplotype_idx = std::max(max_haplotype_idx, haplotype_idx); + ++n; + } + } + result.resize(max_haplotype_idx + 1); + const auto weight = 1.0 / n; + for (const auto& genotype : genotypes) { + for (auto haplotype_idx : get(genotype)) { + result[haplotype_idx] += weight; + } + } +} + +template +double joint_evaluate(const Range& genotypes, const HardyWeinbergModel& model) +{ + return std::accumulate(std::cbegin(genotypes), std::cend(genotypes), 0.0, + [&model] (auto curr, const auto& genotype) { return curr + model.evaluate(get(genotype)); }); +} + +} // namespace + +double HardyWeinbergModel::evaluate(const std::vector>& genotypes) const +{ + if (empirical_) { + return joint_evaluate(genotypes, *this); + } else { + fill_frequencies(genotypes, haplotype_frequencies_); + empirical_ = true; + auto result = joint_evaluate(genotypes, *this); + haplotype_frequencies_.clear(); + empirical_ = false; + return result; + } +} + +double HardyWeinbergModel::evaluate(const GenotypeReferenceVector& genotypes) const +{ + if (empirical_) { + return joint_evaluate(genotypes, *this); + } else { + fill_frequencies(genotypes, haplotype_frequencies_); + empirical_ = true; + auto result = joint_evaluate(genotypes, *this); + haplotype_frequencies_.clear(); + empirical_ = false; + return result; + } +} + +double HardyWeinbergModel::evaluate(const GenotypeIndexVector& genotypes) const +{ + if (empirical_) { + return joint_evaluate(genotypes, *this); + } else { + fill_frequencies(genotypes, haplotype_idx_frequencies_); + empirical_ = true; + auto result = joint_evaluate(genotypes, *this); + haplotype_idx_frequencies_.clear(); + empirical_ = false; + return result; + } +} + +double HardyWeinbergModel::evaluate(const GenotypeIndexReferenceVector& genotypes) const +{ + if (empirical_) { + return joint_evaluate(genotypes, *this); + } else { + fill_frequencies(genotypes, haplotype_idx_frequencies_); + empirical_ = true; + auto result = joint_evaluate(genotypes, *this); + haplotype_idx_frequencies_.clear(); + empirical_ = false; + return result; + } +} + +} // namespace diff --git a/src/core/models/genotype/hardy_weinberg_model.hpp b/src/core/models/genotype/hardy_weinberg_model.hpp new file mode 100644 index 000000000..a99ac8187 --- /dev/null +++ b/src/core/models/genotype/hardy_weinberg_model.hpp @@ -0,0 +1,69 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef hardy_weinberg_model_hpp +#define hardy_weinberg_model_hpp + +#include +#include +#include + +#include + +#include "core/types/haplotype.hpp" +#include "core/types/genotype.hpp" + +namespace octopus { + +class HardyWeinbergModel +{ +public: + using GenotypeReference = std::reference_wrapper>; + using GenotypeReferenceVector = std::vector; + using GenotypeIndexReference = std::reference_wrapper; + using GenotypeIndexVector = std::vector; + using GenotypeIndexReferenceVector = std::vector; + + using HaplotypeFrequencyMap = std::unordered_map; + using HaplotypeFrequencyVector = std::vector; + + HardyWeinbergModel() = default; + + HardyWeinbergModel(Haplotype reference); + HardyWeinbergModel(unsigned reference_idx); + HardyWeinbergModel(HaplotypeFrequencyMap haplotype_frequencies); + HardyWeinbergModel(HaplotypeFrequencyVector haplotype_frequencies); + + HardyWeinbergModel(const HardyWeinbergModel&) = default; + HardyWeinbergModel& operator=(const HardyWeinbergModel&) = default; + HardyWeinbergModel(HardyWeinbergModel&&) = default; + HardyWeinbergModel& operator=(HardyWeinbergModel&&) = default; + + ~HardyWeinbergModel() = default; + + void set_frequencies(HaplotypeFrequencyMap haplotype_frequencies); + void set_frequencies(HaplotypeFrequencyVector haplotype_frequencies); + + HaplotypeFrequencyMap& frequencies() noexcept; + HaplotypeFrequencyVector& index_frequencies() noexcept; + + double evaluate(const Genotype& genotype) const; + double evaluate(const GenotypeIndex& genotype) const; + + double evaluate(const std::vector>& genotypes) const; + double evaluate(const GenotypeReferenceVector& genotypes) const; + double evaluate(const GenotypeIndexVector& genotypes) const; + double evaluate(const GenotypeIndexReferenceVector& genotypes) const; + +private: + boost::optional reference_; + boost::optional reference_idx_; + mutable HaplotypeFrequencyMap haplotype_frequencies_; + mutable HaplotypeFrequencyVector haplotype_idx_frequencies_; + mutable bool empirical_; +}; + +} // namespace octopus + + +#endif diff --git a/src/core/models/genotype/independent_population_model.cpp b/src/core/models/genotype/independent_population_model.cpp index e3e46a0e9..343faf6f8 100644 --- a/src/core/models/genotype/independent_population_model.cpp +++ b/src/core/models/genotype/independent_population_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "independent_population_model.hpp" diff --git a/src/core/models/genotype/independent_population_model.hpp b/src/core/models/genotype/independent_population_model.hpp index 308bbe33e..163098593 100644 --- a/src/core/models/genotype/independent_population_model.hpp +++ b/src/core/models/genotype/independent_population_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef independent_population_model_hpp diff --git a/src/core/models/genotype/individual_model.cpp b/src/core/models/genotype/individual_model.cpp index 0e08d21ab..54cad0044 100644 --- a/src/core/models/genotype/individual_model.cpp +++ b/src/core/models/genotype/individual_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "individual_model.hpp" @@ -93,7 +93,7 @@ IndividualModel::evaluate(const std::vector>& genotypes, IndividualModel::InferredLatents IndividualModel::evaluate(const std::vector>& genotypes, - const std::vector>& genotype_indices, + const std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { assert(!genotypes.empty()); diff --git a/src/core/models/genotype/individual_model.hpp b/src/core/models/genotype/individual_model.hpp index cc3f85e11..eeb8cc52a 100644 --- a/src/core/models/genotype/individual_model.hpp +++ b/src/core/models/genotype/individual_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef individual_model_hpp @@ -50,7 +50,7 @@ class IndividualModel const HaplotypeLikelihoodCache& haplotype_likelihoods) const; InferredLatents evaluate(const std::vector>& genotypes, - const std::vector>& genotype_indices, + const std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; private: diff --git a/src/core/models/genotype/population_model.cpp b/src/core/models/genotype/population_model.cpp index 9314b8947..1c8cf8cb9 100644 --- a/src/core/models/genotype/population_model.cpp +++ b/src/core/models/genotype/population_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "population_model.hpp" @@ -7,11 +7,13 @@ #include #include #include +#include #include -#include #include "utils/maths.hpp" +#include "utils/select_top_k.hpp" #include "germline_likelihood_model.hpp" +#include "hardy_weinberg_model.hpp" namespace octopus { namespace model { @@ -60,10 +62,7 @@ auto make_inverse_genotype_table(const std::vector& haplotypes, const auto cardinality = element_cardinality_in_genotypes(static_cast(haplotypes.size()), genotypes.front().ploidy()); for (const auto& haplotype : haplotypes) { - auto itr = result_map.emplace(std::piecewise_construct, - std::forward_as_tuple(std::cref(haplotype)), - std::forward_as_tuple()); - itr.first->second.reserve(cardinality); + result_map[std::cref(haplotype)].reserve(cardinality); } for (std::size_t i {0}; i < genotypes.size(); ++i) { for (const auto& haplotype : genotypes[i]) { @@ -78,15 +77,28 @@ auto make_inverse_genotype_table(const std::vector& haplotypes, return result; } -using HaplotypeReference = std::reference_wrapper; -using HaplotypeFrequencyMap = std::unordered_map; +auto make_inverse_genotype_table(const std::vector& genotype_indices, const std::size_t num_haplotypes) +{ + InverseGenotypeTable result(num_haplotypes); + const auto num_genotypes = genotype_indices.size(); + for (auto& entry : result) entry.reserve(num_genotypes / 2); + for (std::size_t genotype_idx {0}; genotype_idx < num_genotypes; ++genotype_idx) { + for (auto idx : genotype_indices[genotype_idx]) { + if (result[idx].empty() || result[idx].back() != genotype_idx) { + result[idx].push_back(genotype_idx); + } + } + } + for (auto& entry : result) entry.shrink_to_fit(); + return result; +} double calculate_frequency_update_norm(const std::size_t num_samples, const unsigned ploidy) { return static_cast(num_samples) * ploidy; } -struct EmOptions +struct EMOptions { unsigned max_iterations; double epsilon; @@ -111,79 +123,26 @@ struct ModelConstants , frequency_update_norm {calculate_frequency_update_norm(genotype_log_likilhoods.size(), ploidy)} , genotypes_containing_haplotypes {make_inverse_genotype_table(haplotypes, genotypes)} {} + ModelConstants(const std::vector& haplotypes, + const std::vector>& genotypes, + const std::vector& genotype_indices, + const GenotypeLogLikelihoodMatrix& genotype_log_likilhoods) + : haplotypes {haplotypes} + , genotypes {genotypes} + , genotype_log_likilhoods {genotype_log_likilhoods} + , ploidy {genotypes.front().ploidy()} + , frequency_update_norm {calculate_frequency_update_norm(genotype_log_likilhoods.size(), ploidy)} + , genotypes_containing_haplotypes {make_inverse_genotype_table(genotype_indices, haplotypes.size())} + {} }; -HaplotypeFrequencyMap -init_haplotype_frequencies(const ModelConstants& constants) +HardyWeinbergModel make_hardy_weinberg_model(const ModelConstants& constants) { - HaplotypeFrequencyMap result {constants.haplotypes.size()}; + HardyWeinbergModel::HaplotypeFrequencyMap frequencies {constants.haplotypes.size()}; for (const auto& haplotype : constants.haplotypes) { - result.emplace(haplotype, 1.0 / constants.haplotypes.size()); - } - return result; -} - -template -double log_hardy_weinberg_haploid(const Genotype& genotype, - const Map& haplotype_frequencies) -{ - return std::log(haplotype_frequencies.at(genotype[0])); -} - -template -double log_hardy_weinberg_diploid(const Genotype& genotype, - const Map& haplotype_frequencies) -{ - if (genotype.is_homozygous()) { - return 2 * std::log(haplotype_frequencies.at(genotype[0])); - } - static const double ln2 {std::log(2.0)}; - return std::log(haplotype_frequencies.at(genotype[0])) + std::log(haplotype_frequencies.at(genotype[1])) + ln2; -} - -template -double log_hardy_weinberg_triploid(const Genotype& genotype, - const Map& haplotype_frequencies) -{ - // TODO: optimise this case - auto unique_haplotypes = genotype.copy_unique(); - std::vector occurences {}; - occurences.reserve(unique_haplotypes.size()); - double r {0}; - for (const auto& haplotype : unique_haplotypes) { - auto num_occurences = genotype.count(haplotype); - occurences.push_back(num_occurences); - r += num_occurences * std::log(haplotype_frequencies.at(haplotype)); - } - return maths::log_multinomial_coefficient(occurences) + r; -} - -template -double log_hardy_weinberg_polyploid(const Genotype& genotype, - const Map& haplotype_frequencies) -{ - auto unique_haplotypes = genotype.copy_unique(); - std::vector occurences {}; - occurences.reserve(unique_haplotypes.size()); - double r {0}; - for (const auto& haplotype : unique_haplotypes) { - auto num_occurences = genotype.count(haplotype); - occurences.push_back(num_occurences); - r += num_occurences * std::log(haplotype_frequencies.at(haplotype)); - } - return maths::log_multinomial_coefficient(occurences) + r; -} - -// TODO: improve this, possible bottleneck in EM update at the moment -template -double log_hardy_weinberg(const Genotype& genotype, const Map& haplotype_frequencies) -{ - switch (genotype.ploidy()) { - case 1 : return log_hardy_weinberg_haploid(genotype, haplotype_frequencies); - case 2 : return log_hardy_weinberg_diploid(genotype, haplotype_frequencies); - case 3 : return log_hardy_weinberg_triploid(genotype, haplotype_frequencies); - default: return log_hardy_weinberg_polyploid(genotype, haplotype_frequencies); + frequencies.emplace(haplotype, 1.0 / constants.haplotypes.size()); } + return HardyWeinbergModel {std::move(frequencies)}; } GenotypeLogLikelihoodMatrix @@ -210,23 +169,21 @@ compute_genotype_log_likelihoods(const std::vector& samples, GenotypeLogMarginalVector init_genotype_log_marginals(const std::vector>& genotypes, - const HaplotypeFrequencyMap& haplotype_frequencies) + const HardyWeinbergModel& hw_model) { GenotypeLogMarginalVector result {}; result.reserve(genotypes.size()); for (const auto& genotype : genotypes) { - result.push_back({genotype, log_hardy_weinberg(genotype, haplotype_frequencies)}); + result.push_back({genotype, hw_model.evaluate(genotype)}); } return result; } void update_genotype_log_marginals(GenotypeLogMarginalVector& current_log_marginals, - const HaplotypeFrequencyMap& haplotype_frequencies) + const HardyWeinbergModel& hw_model) { std::for_each(std::begin(current_log_marginals), std::end(current_log_marginals), - [&haplotype_frequencies] (auto& p) { - p.log_probability = log_hardy_weinberg(p.genotype, haplotype_frequencies); - }); + [&hw_model] (auto& p) { p.log_probability = hw_model.evaluate(p.genotype); }); } GenotypeMarginalPosteriorMatrix @@ -238,8 +195,7 @@ init_genotype_posteriors(const GenotypeLogMarginalVector& genotype_log_marginals for (const auto& sample_genotype_log_likilhoods : genotype_log_likilhoods) { GenotypeMarginalPosteriorVector posteriors(genotype_log_marginals.size()); std::transform(std::cbegin(genotype_log_marginals), std::cend(genotype_log_marginals), - std::cbegin(sample_genotype_log_likilhoods), - std::begin(posteriors), + std::cbegin(sample_genotype_log_likilhoods), std::begin(posteriors), [] (const auto& genotype_log_marginal, const auto genotype_log_likilhood) { return genotype_log_marginal.log_probability + genotype_log_likilhood; }); @@ -253,13 +209,12 @@ void update_genotype_posteriors(GenotypeMarginalPosteriorMatrix& current_genotyp const GenotypeLogMarginalVector& genotype_log_marginals, const GenotypeLogLikelihoodMatrix& genotype_log_likilhoods) { - auto it = std::cbegin(genotype_log_likilhoods); + auto likelihood_itr = std::cbegin(genotype_log_likilhoods); for (auto& sample_genotype_posteriors : current_genotype_posteriors) { std::transform(std::cbegin(genotype_log_marginals), std::cend(genotype_log_marginals), - std::cbegin(*it++), - std::begin(sample_genotype_posteriors), - [] (const auto& log_marginal, const auto& log_likilhood) { - return log_marginal.log_probability + log_likilhood; + std::cbegin(*likelihood_itr++), std::begin(sample_genotype_posteriors), + [] (const auto& log_marginal, const auto& log_likeilhood) { + return log_marginal.log_probability + log_likeilhood; }); maths::normalise_exp(sample_genotype_posteriors); } @@ -277,13 +232,14 @@ auto collapse_genotype_posteriors(const GenotypeMarginalPosteriorMatrix& genotyp } double update_haplotype_frequencies(const std::vector& haplotypes, - HaplotypeFrequencyMap& current_haplotype_frequencies, + HardyWeinbergModel& hw_model, const GenotypeMarginalPosteriorMatrix& genotype_posteriors, const InverseGenotypeTable& genotypes_containing_haplotypes, const double frequency_update_norm) { const auto collaped_posteriors = collapse_genotype_posteriors(genotype_posteriors); double max_frequency_change {0}; + auto& current_haplotype_frequencies = hw_model.frequencies(); for (std::size_t i {0}; i < haplotypes.size(); ++i) { auto& current_frequency = current_haplotype_frequencies.at(haplotypes[i]); double new_frequency {0}; @@ -301,56 +257,86 @@ double update_haplotype_frequencies(const std::vector& haplotypes, } double do_em_iteration(GenotypeMarginalPosteriorMatrix& genotype_posteriors, - HaplotypeFrequencyMap& haplotype_frequencies, + HardyWeinbergModel& hw_model, GenotypeLogMarginalVector& genotype_log_marginals, const ModelConstants& constants) { const auto max_change = update_haplotype_frequencies(constants.haplotypes, - haplotype_frequencies, + hw_model, genotype_posteriors, constants.genotypes_containing_haplotypes, constants.frequency_update_norm); - update_genotype_log_marginals(genotype_log_marginals, haplotype_frequencies); - update_genotype_posteriors(genotype_posteriors, genotype_log_marginals, - constants.genotype_log_likilhoods); + update_genotype_log_marginals(genotype_log_marginals, hw_model); + update_genotype_posteriors(genotype_posteriors, genotype_log_marginals, constants.genotype_log_likilhoods); return max_change; } void run_em(GenotypeMarginalPosteriorMatrix& genotype_posteriors, - HaplotypeFrequencyMap& haplotype_frequencies, + HardyWeinbergModel& hw_model, GenotypeLogMarginalVector& genotype_log_marginals, - const ModelConstants& constants, const EmOptions options, + const ModelConstants& constants, const EMOptions options, boost::optional trace_log = boost::none) { for (unsigned n {1}; n <= options.max_iterations; ++n) { - const auto max_change = do_em_iteration(genotype_posteriors, haplotype_frequencies, - genotype_log_marginals,constants); + const auto max_change = do_em_iteration(genotype_posteriors, hw_model, genotype_log_marginals,constants); if (max_change <= options.epsilon) break; } } -auto compute_approx_genotype_marginal_posteriors(const std::vector>& genotypes, +auto compute_approx_genotype_marginal_posteriors(const std::vector& haplotypes, + const std::vector>& genotypes, const GenotypeLogLikelihoodMatrix& genotype_likelihoods, - const EmOptions options) + const EMOptions options) { - const auto haplotypes = extract_unique_elements(genotypes); const ModelConstants constants {haplotypes, genotypes, genotype_likelihoods}; - auto haplotype_frequencies = init_haplotype_frequencies(constants); - auto genotype_log_marginals = init_genotype_log_marginals(genotypes, haplotype_frequencies); + auto hw_model = make_hardy_weinberg_model(constants); + auto genotype_log_marginals = init_genotype_log_marginals(genotypes, hw_model); + auto result = init_genotype_posteriors(genotype_log_marginals, genotype_likelihoods); + run_em(result, hw_model, genotype_log_marginals, constants, options); + return result; +} + +auto compute_approx_genotype_marginal_posteriors(const std::vector& haplotypes, + const std::vector>& genotypes, + const std::vector& genotype_indices, + const GenotypeLogLikelihoodMatrix& genotype_likelihoods, + const EMOptions options) +{ + const ModelConstants constants {haplotypes, genotypes, genotype_indices, genotype_likelihoods}; + auto hw_model = make_hardy_weinberg_model(constants); + auto genotype_log_marginals = init_genotype_log_marginals(genotypes, hw_model); auto result = init_genotype_posteriors(genotype_log_marginals, genotype_likelihoods); - run_em(result, haplotype_frequencies, genotype_log_marginals, constants, options); + run_em(result, hw_model, genotype_log_marginals, constants, options); return result; } +auto compute_approx_genotype_marginal_posteriors(const std::vector>& genotypes, + const GenotypeLogLikelihoodMatrix& genotype_likelihoods, + const EMOptions options) +{ + const auto haplotypes = extract_unique_elements(genotypes); + return compute_approx_genotype_marginal_posteriors(haplotypes, genotypes, genotype_likelihoods, options); +} + using GenotypeCombinationVector = std::vector; using GenotypeCombinationMatrix = std::vector; +auto log(std::size_t base, std::size_t x) +{ + return std::log2(x) / std::log2(base); +} + auto num_combinations(const std::size_t num_genotypes, const std::size_t num_samples) { - return std::pow(num_genotypes, num_samples); + static constexpr auto max_combinations = std::numeric_limits::max(); + if (num_samples <= log(num_genotypes, max_combinations)) { + return static_cast(std::pow(num_genotypes, num_samples)); + } else { + return max_combinations; + } } -auto get_all_genotype_combinations(const std::size_t num_genotypes, const std::size_t num_samples) +auto generate_all_genotype_combinations(const std::size_t num_genotypes, const std::size_t num_samples) { GenotypeCombinationMatrix result {}; result.reserve(num_combinations(num_genotypes, num_samples)); @@ -374,128 +360,101 @@ auto get_all_genotype_combinations(const std::size_t num_genotypes, const std::s return result; } -template -auto index(const std::vector& values) +bool is_homozygous_reference(const Genotype& genotype) { - std::vector> result(values.size()); - for (std::size_t i {0}; i < values.size(); ++i) { - result[i] = std::make_pair(values[i], i); - } - return result; + assert(genotype.ploidy() > 0); + return genotype.is_homozygous() && is_reference(genotype[0]); } -auto index_and_sort(const GenotypeMarginalPosteriorVector& genotype_posteriors, const std::size_t k) +boost::optional find_hom_ref_idx(const std::vector>& genotypes) { - auto result = index(genotype_posteriors); - const auto middle = std::next(std::begin(result), std::min(k, result.size())); - std::partial_sort(std::begin(result), middle, std::end(result), std::greater<> {}); - std::for_each(std::begin(result), middle, [] (auto& p) { p.first = std::log(p.first); }); - result.erase(middle, std::end(result)); - return result; -} - -using IndexedProbability = std::pair; -using IndexedProbabilityVector = std::vector; - -struct CombinationProbabilityRow -{ - std::vector combination; - double log_probability; -}; - -using CombinationProbabilityMatrix = std::vector; - -auto get_differences(const IndexedProbabilityVector& genotype_posteriors) -{ - std::vector result(genotype_posteriors.size()); - std::transform(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), std::begin(result), - [] (const auto& p) { return p.first; }); - std::adjacent_difference(std::cbegin(result), std::cend(result), std::begin(result)); - return result; + auto itr = std::find_if(std::cbegin(genotypes), std::cend(genotypes), + [] (const auto& g) { return is_homozygous_reference(g); }); + if (itr != std::cend(genotypes)) { + return std::distance(std::cbegin(genotypes), itr); + } else { + return boost::none; + } } -auto get_differences(const CombinationProbabilityMatrix& matrix) +template +auto zip_index(const std::vector& v) { - std::vector result(matrix.size()); - std::transform(std::cbegin(matrix), std::cend(matrix), std::begin(result), - [] (const auto& p) { return p.log_probability; }); - std::adjacent_difference(std::cbegin(result), std::cend(result), std::begin(result)); + std::vector> result(v.size()); + for (unsigned idx {0}; idx < v.size(); ++idx) { + result[idx] = std::make_pair(v[idx], idx); + } return result; } -void join(const IndexedProbabilityVector& genotype_posteriors, - CombinationProbabilityMatrix& result, - const std::size_t k) +std::vector +select_top_k_genotypes(const std::vector>& genotypes, + const GenotypeMarginalPosteriorMatrix& em_genotype_marginals, + const std::size_t k) { - const auto n = std::min(k, genotype_posteriors.size()); - if (result.empty()) { - std::transform(std::cbegin(genotype_posteriors), std::next(std::cbegin(genotype_posteriors), n), - std::back_inserter(result), [=] (const auto& p) -> CombinationProbabilityRow { - return {{p.second}, p.first}; - }); + if (genotypes.size() <= k) { + std::vector result(genotypes.size()); + std::iota(std::begin(result), std::end(result), 0); + return result; } else { - const auto m = result.size(); -// const auto differences1 = get_differences(result); -// const auto differences2 = get_differences(genotype_posteriors); - const auto K = std::min(k, n * m); - CombinationProbabilityMatrix tmp {}; - tmp.reserve(n * m); - for (std::size_t i {0}; i < m; ++i) { - for (std::size_t j {0}; j < n; ++j) { - tmp.push_back(result[i]); - tmp.back().combination.push_back(genotype_posteriors[j].second); - tmp.back().log_probability += genotype_posteriors[j].first; + std::vector>> indexed_marginals {}; + indexed_marginals.reserve(em_genotype_marginals.size()); + for (const auto& marginals : em_genotype_marginals) { + auto tmp = zip_index(marginals); + std::nth_element(std::begin(tmp), std::next(std::begin(tmp), k), std::end(tmp), std::greater<> {}); + indexed_marginals.push_back(std::move(tmp)); + } + std::vector result {}, top(genotypes.size(), 0u); + result.reserve(k); + for (std::size_t j {0}; j <= k; ++j) { + for (const auto& marginals : indexed_marginals) { + ++top[marginals.front().second]; + } + const auto max_itr = std::max_element(std::begin(top), std::end(top)); + const auto max_idx = static_cast(std::distance(std::begin(top), max_itr)); + if (std::find(std::cbegin(result), std::cend(result), max_idx) == std::cend(result)) { + result.push_back(max_idx); + } + *max_itr = 0; + for (auto& marginals : indexed_marginals) { + if (marginals.front().second == max_idx) { + marginals.erase(std::cbegin(marginals)); + } } } - const auto middle = std::next(std::begin(tmp), K); - std::partial_sort(std::begin(tmp), middle, std::end(tmp), - [] (const auto& lhs, const auto& rhs) { - return lhs.log_probability > rhs.log_probability; - }); - tmp.erase(middle, std::end(tmp)); -// tmp.reserve(K); -// for (std::size_t i {0}, j {0}, t {0}; t < K; ++t) { -// assert(i < m && j < n); -// std::cout << i << " " << j << std::endl; -// tmp.push_back(result[i]); -// tmp.back().combination.push_back(genotype_posteriors[j].second); -// tmp.back().log_probability += genotype_posteriors[j].first; -// const auto prev_score = tmp.back().log_probability; -// if (j < n - 1 && (i == m - 1 || differences1[i + 1] < differences2[j + 1])) { -// ++j; -// while (i > 0 && result[i - 1].log_probability + genotype_posteriors[j].first < prev_score) { -// --i; -// } -// } else { -// ++i; -// while (j > 0 && result[i].log_probability + genotype_posteriors[j - 1].first < prev_score) { -// --j; -// } -// } -// } - result = std::move(tmp); + return result; } } -auto get_genotype_combinations(const std::vector>& genotypes, - const GenotypeMarginalPosteriorMatrix& genotype_posteriors, - const std::size_t max_combinations) +auto propose_joint_genotypes(const std::vector>& genotypes, + const GenotypeMarginalPosteriorMatrix& em_genotype_marginals, + const std::size_t max_joint_genotypes) { - const auto num_samples = genotype_posteriors.size(); - assert(max_combinations >= num_samples); - const auto num_possible_combinations = num_combinations(genotypes.size(), num_samples); - if (num_possible_combinations <= max_combinations) { - return get_all_genotype_combinations(genotypes.size(), num_samples); + const auto num_samples = em_genotype_marginals.size(); + assert(max_joint_genotypes >= num_samples * genotypes.size()); + const auto num_joint_genotypes = num_combinations(genotypes.size(), num_samples); + if (num_joint_genotypes <= max_joint_genotypes) { + return generate_all_genotype_combinations(genotypes.size(), num_samples); } - CombinationProbabilityMatrix combinations {}; - combinations.reserve(max_combinations); - for (const auto& sample : genotype_posteriors) { - join(index_and_sort(sample, max_combinations), combinations, max_combinations); + auto result = select_top_k_tuples(em_genotype_marginals, max_joint_genotypes); + const auto top_k_genotype_indices = select_top_k_genotypes(genotypes, em_genotype_marginals, num_samples / 2); + for (const auto genotype_idx : top_k_genotype_indices) { + for (std::size_t sample_idx {0}; sample_idx < num_samples; ++sample_idx) { + if (result.front()[sample_idx] != genotype_idx) { + auto tmp = result.front(); + tmp[sample_idx] = genotype_idx; + if (std::find(std::cbegin(result), std::cend(result), tmp) == std::cend(result)) { + result.push_back(std::move(tmp)); + } + } + } } - GenotypeCombinationMatrix result {}; - result.reserve(max_combinations); - for (auto&& row : combinations) { - result.push_back(std::move(row.combination)); + const auto hom_ref_idx = find_hom_ref_idx(genotypes); + if (hom_ref_idx) { + std::vector ref_indices(num_samples, *hom_ref_idx); + if (std::find(std::cbegin(result), std::cend(result), ref_indices) == std::cend(result)) { + result.back() = std::move(ref_indices); + } } return result; } @@ -518,9 +477,8 @@ void fill(const GenotypeLogLikelihoodMatrix& genotype_likelihoods, using GenotypeReferenceVector = std::vector>>; -void fill(const std::vector>& genotypes, - const GenotypeCombinationVector& indices, - GenotypeReferenceVector& result) +template +void fill(const std::vector& genotypes, const GenotypeCombinationVector& indices, V& result) { result.clear(); std::transform(std::cbegin(indices), std::cend(indices), std::back_inserter(result), @@ -528,21 +486,69 @@ void fill(const std::vector>& genotypes, } auto calculate_posteriors(const std::vector>& genotypes, - const GenotypeCombinationMatrix& genotype_combinations, + const GenotypeCombinationMatrix& joint_genotypes, + const GenotypeLogLikelihoodMatrix& genotype_likelihoods, + const PopulationPriorModel& prior_model) +{ + std::vector result {}; + GenotypeLogLikelihoodVector likelihoods_buffer(genotype_likelihoods.size()); + GenotypeReferenceVector genotypes_refs {}; + for (const auto& indices : joint_genotypes) { + fill(genotype_likelihoods, indices, likelihoods_buffer); + fill(genotypes, indices, genotypes_refs); + result.push_back(prior_model.evaluate(genotypes_refs) + sum(likelihoods_buffer)); + } + const auto norm = maths::normalise_exp(result); + return std::make_pair(std::move(result), norm); +} + +using GenotypeIndexRefVector = std::vector; + +auto calculate_posteriors(const std::vector& genotype_indices, + const GenotypeCombinationMatrix& joint_genotypes, const GenotypeLogLikelihoodMatrix& genotype_likelihoods, const PopulationPriorModel& prior_model) { - assert(!genotypes.empty()); std::vector result {}; - GenotypeLogLikelihoodVector tmp_likelihoods(genotype_likelihoods.size()); - GenotypeReferenceVector tmp_genotypes {}; - for (const auto& indices : genotype_combinations) { - fill(genotype_likelihoods, indices, tmp_likelihoods); - fill(genotypes, indices, tmp_genotypes); - result.push_back(prior_model.evaluate(tmp_genotypes) + sum(tmp_likelihoods)); + GenotypeLogLikelihoodVector likelihoods_buffer(genotype_likelihoods.size()); + GenotypeIndexRefVector genotypes_index_refs {}; + for (const auto& indices : joint_genotypes) { + fill(genotype_likelihoods, indices, likelihoods_buffer); + fill(genotype_indices, indices, genotypes_index_refs); + result.push_back(prior_model.evaluate(genotypes_index_refs) + sum(likelihoods_buffer)); } const auto norm = maths::normalise_exp(result); - return std::make_pair(result, norm); + return std::make_pair(std::move(result), norm); +} + +void set_posterior_marginals(const GenotypeCombinationMatrix& joint_genotypes, + const std::vector& joint_posteriors, + const std::size_t num_genotypes, const std::size_t num_samples, + PopulationModel::InferredLatents& result) +{ + assert(joint_posteriors.size() == joint_genotypes.size()); + std::vector> marginals(num_samples, std::vector(num_genotypes, 0.0)); + for (std::size_t i {0}; i < joint_genotypes.size(); ++i) { + assert(joint_genotypes[i].size() == num_samples); + for (std::size_t s {0}; s < num_samples; ++s) { + marginals[s][joint_genotypes[i][s]] += joint_posteriors[i]; + } + } + result.posteriors.marginal_genotype_probabilities = std::move(marginals); +} + +template +void calculate_posterior_marginals(const std::vector& genotypes, + const GenotypeCombinationMatrix& joint_genotypes, + const GenotypeLogLikelihoodMatrix& genotype_likelihoods, + const PopulationPriorModel& prior_model, + PopulationModel::InferredLatents& result) +{ + std::vector joint_posteriors; double norm; + std::tie(joint_posteriors, norm) = calculate_posteriors(genotypes, joint_genotypes, genotype_likelihoods, prior_model); + const auto num_samples = genotype_likelihoods.size(); + set_posterior_marginals(joint_genotypes, joint_posteriors, genotypes.size(), num_samples, result); + result.log_evidence = norm; } } // namespace @@ -553,12 +559,42 @@ PopulationModel::evaluate(const SampleVector& samples, const GenotypeVector& gen { assert(!genotypes.empty()); const auto genotype_log_likelihoods = compute_genotype_log_likelihoods(samples, genotypes, haplotype_likelihoods); - const auto approx_genotype_posteriors = compute_approx_genotype_marginal_posteriors(genotypes, genotype_log_likelihoods, - {options_.max_em_iterations, 0.0001}); - const auto max_combinations = options_.max_combinations_per_sample * samples.size(); - auto genotype_combinations = get_genotype_combinations(genotypes, approx_genotype_posteriors, max_combinations); - auto p = calculate_posteriors(genotypes, genotype_combinations, genotype_log_likelihoods, prior_model_); - return {{std::move(genotype_combinations), std::move(p.first)}, p.second}; + const auto num_joint_genotypes = num_combinations(genotypes.size(), samples.size()); + InferredLatents result; + if (num_joint_genotypes <= options_.max_joint_genotypes) { + const auto joint_genotypes = generate_all_genotype_combinations(genotypes.size(), samples.size()); + calculate_posterior_marginals(genotypes, joint_genotypes, genotype_log_likelihoods, prior_model_, result); + } else { + const EMOptions em_options {options_.max_em_iterations, options_.em_epsilon}; + const auto em_genotype_marginals = compute_approx_genotype_marginal_posteriors(genotypes, genotype_log_likelihoods, em_options); + const auto joint_genotypes = propose_joint_genotypes(genotypes, em_genotype_marginals, options_.max_joint_genotypes); + calculate_posterior_marginals(genotypes, joint_genotypes, genotype_log_likelihoods, prior_model_, result); + } + return result; +} + +PopulationModel::InferredLatents +PopulationModel::evaluate(const SampleVector& samples, + const GenotypeVector& genotypes, + const std::vector& genotype_indices, + const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const +{ + assert(!genotypes.empty()); + const auto genotype_log_likelihoods = compute_genotype_log_likelihoods(samples, genotypes, haplotype_likelihoods); + const auto num_joint_genotypes = num_combinations(genotypes.size(), samples.size()); + InferredLatents result; + if (num_joint_genotypes <= options_.max_joint_genotypes) { + const auto joint_genotypes = generate_all_genotype_combinations(genotypes.size(), samples.size()); + calculate_posterior_marginals(genotypes, joint_genotypes, genotype_log_likelihoods, prior_model_, result); + } else { + const EMOptions em_options {options_.max_em_iterations, options_.em_epsilon}; + const auto em_genotype_marginals = compute_approx_genotype_marginal_posteriors(haplotypes, genotypes, genotype_indices, + genotype_log_likelihoods, em_options); + const auto joint_genotypes = propose_joint_genotypes(genotypes, em_genotype_marginals, options_.max_joint_genotypes); + calculate_posterior_marginals(genotype_indices, joint_genotypes, genotype_log_likelihoods, prior_model_, result); + } + return result; } PopulationModel::InferredLatents diff --git a/src/core/models/genotype/population_model.hpp b/src/core/models/genotype/population_model.hpp index 9e14c54e7..a995463bf 100644 --- a/src/core/models/genotype/population_model.hpp +++ b/src/core/models/genotype/population_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef population_model_hpp @@ -22,25 +22,23 @@ namespace octopus { namespace model { class PopulationModel { public: + struct Options + { + std::size_t max_joint_genotypes = 1'000'000; + unsigned max_em_iterations = 100; + double em_epsilon = 0.001; + }; struct Latents { - std::vector> genotype_combinations; - using GenotypeProbabilityVector = std::vector; - GenotypeProbabilityVector joint_genotype_probabilities; + using ProbabilityVector = std::vector; + std::vector marginal_genotype_probabilities; }; - struct InferredLatents { Latents posteriors; double log_evidence; }; - struct Options - { - std::size_t max_combinations_per_sample = 200; - unsigned max_em_iterations = 100; - }; - using SampleVector = std::vector; using GenotypeVector = std::vector>; using GenotypeVectorReference = std::reference_wrapper; @@ -66,7 +64,12 @@ class PopulationModel InferredLatents evaluate(const SampleVector& samples, const GenotypeVector& genotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; - + // All samples have same ploidy + InferredLatents evaluate(const SampleVector& samples, + const GenotypeVector& genotypes, + const std::vector& genotype_indices, + const std::vector& haplotypes, + const HaplotypeLikelihoodCache& haplotype_likelihoods) const; // Samples have different ploidy InferredLatents evaluate(const SampleVector& samples, const std::vector& genotypes, diff --git a/src/core/models/genotype/population_prior_model.hpp b/src/core/models/genotype/population_prior_model.hpp index 373991eaa..7a9f492cc 100644 --- a/src/core/models/genotype/population_prior_model.hpp +++ b/src/core/models/genotype/population_prior_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef population_prior_model_hpp @@ -33,15 +33,15 @@ class PopulationPriorModel double evaluate(const std::vector>& genotypes) const { return do_evaluate(genotypes); } double evaluate(const std::vector& genotypes) const { return do_evaluate(genotypes); } - double evaluate(const std::vector>& indices) const { return do_evaluate(indices); } - double evaluate(const std::vector& indices) const { return do_evaluate(indices); } + double evaluate(const std::vector& indices) const { return do_evaluate(indices); } + double evaluate(const std::vector& genotypes) const { return do_evaluate(genotypes); } private: std::vector haplotypes_; virtual double do_evaluate(const std::vector>& genotypes) const = 0; virtual double do_evaluate(const std::vector& genotypes) const = 0; - virtual double do_evaluate(const std::vector>& indices) const = 0; + virtual double do_evaluate(const std::vector& genotypes) const = 0; virtual double do_evaluate(const std::vector& indices) const = 0; virtual void do_prime(const std::vector& haplotypes) {}; virtual void do_unprime() noexcept {}; diff --git a/src/core/models/genotype/cnv_model.cpp b/src/core/models/genotype/subclone_model.cpp similarity index 80% rename from src/core/models/genotype/cnv_model.cpp rename to src/core/models/genotype/subclone_model.cpp index df32af77e..8d754e99d 100644 --- a/src/core/models/genotype/cnv_model.cpp +++ b/src/core/models/genotype/subclone_model.cpp @@ -1,7 +1,7 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. -#include "cnv_model.hpp" +#include "subclone_model.hpp" #include #include @@ -18,17 +18,17 @@ namespace octopus { namespace model { -CNVModel::CNVModel(std::vector samples, Priors priors) -: CNVModel {std::move(samples), std::move(priors), AlgorithmParameters {}} +SubcloneModel::SubcloneModel(std::vector samples, Priors priors) +: SubcloneModel {std::move(samples), std::move(priors), AlgorithmParameters {}} {} -CNVModel::CNVModel(std::vector samples, Priors priors, AlgorithmParameters parameters) +SubcloneModel::SubcloneModel(std::vector samples, Priors priors, AlgorithmParameters parameters) : samples_ {std::move(samples)} , priors_ {std::move(priors)} , parameters_ {parameters} {} -const CNVModel::Priors& CNVModel::priors() const noexcept +const SubcloneModel::Priors& SubcloneModel::priors() const noexcept { return priors_; } @@ -36,26 +36,26 @@ const CNVModel::Priors& CNVModel::priors() const noexcept namespace { template -CNVModel::InferredLatents +SubcloneModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const CNVModel::Priors& priors, + const SubcloneModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params); template -CNVModel::InferredLatents +SubcloneModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const std::vector>& genotype_indices, - const CNVModel::Priors& priors, + const std::vector& genotype_indices, + const SubcloneModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params); } // namespace -CNVModel::InferredLatents -CNVModel::evaluate(const std::vector>& genotypes, +SubcloneModel::InferredLatents +SubcloneModel::evaluate(const std::vector>& genotypes, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { assert(!genotypes.empty()); @@ -72,9 +72,9 @@ CNVModel::evaluate(const std::vector>& genotypes, } } -CNVModel::InferredLatents -CNVModel::evaluate(const std::vector>& genotypes, - const std::vector>& genotype_indices, +SubcloneModel::InferredLatents +SubcloneModel::evaluate(const std::vector>& genotypes, + const std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { assert(!genotypes.empty()); @@ -94,7 +94,7 @@ CNVModel::evaluate(const std::vector>& genotypes, namespace { template -VBAlpha flatten(const CNVModel::Priors::GenotypeMixturesDirichletAlphas& alpha) +VBAlpha flatten(const SubcloneModel::Priors::GenotypeMixturesDirichletAlphas& alpha) { VBAlpha result {}; std::copy_n(std::cbegin(alpha), K, std::begin(result)); @@ -102,7 +102,7 @@ VBAlpha flatten(const CNVModel::Priors::GenotypeMixturesDirichletAlphas& alph } template -VBAlphaVector flatten(const CNVModel::Priors::GenotypeMixturesDirichletAlphaMap& alphas, +VBAlphaVector flatten(const SubcloneModel::Priors::GenotypeMixturesDirichletAlphaMap& alphas, const std::vector& samples) { VBAlphaVector result(samples.size()); @@ -154,16 +154,16 @@ flatten(const std::vector>& genotypes, } template -CNVModel::Latents::GenotypeMixturesDirichletAlphas expand(VBAlpha& alpha) +SubcloneModel::Latents::GenotypeMixturesDirichletAlphas expand(VBAlpha& alpha) { - return CNVModel::Latents::GenotypeMixturesDirichletAlphas(std::begin(alpha), std::end(alpha)); + return SubcloneModel::Latents::GenotypeMixturesDirichletAlphas(std::begin(alpha), std::end(alpha)); } template -CNVModel::Latents::GenotypeMixturesDirichletAlphaMap +SubcloneModel::Latents::GenotypeMixturesDirichletAlphaMap expand(const std::vector& samples, VBAlphaVector&& alphas) { - CNVModel::Latents::GenotypeMixturesDirichletAlphaMap result {}; + SubcloneModel::Latents::GenotypeMixturesDirichletAlphaMap result {}; std::transform(std::cbegin(samples), std::cend(samples), std::begin(alphas), std::inserter(result, std::begin(result)), [] (const auto& sample, auto&& vb_alpha) { @@ -173,10 +173,10 @@ expand(const std::vector& samples, VBAlphaVector&& alphas) } template -CNVModel::InferredLatents +SubcloneModel::InferredLatents expand(const std::vector& samples, VBLatents&& inferred_latents, double evidence) { - CNVModel::Latents posterior_latents { + SubcloneModel::Latents posterior_latents { std::move(inferred_latents.genotype_posteriors), expand(samples, std::move(inferred_latents.alphas)) }; @@ -184,12 +184,12 @@ expand(const std::vector& samples, VBLatents&& inferred_latents, } template -auto calculate_log_priors(const Container& genotypes, const GenotypePriorModel& model) +auto calculate_log_priors(const Container& genotypes, const GenotypePriorModel& model, const bool normalise = false) { std::vector result(genotypes.size()); std::transform(std::cbegin(genotypes), std::cend(genotypes), std::begin(result), [&model] (const auto& genotype) { return model.evaluate(genotype); }); - maths::normalise_logs(result); + if (normalise) maths::normalise_logs(result); return result; } @@ -201,14 +201,13 @@ LogProbabilityVector log_uniform_dist(const std::size_t n) auto generate_seeds(const std::vector& samples, const std::vector>& genotypes, const LogProbabilityVector& genotype_log_priors, - const CNVModel::Priors& priors, + const SubcloneModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, - const boost::optional>&> genotype_indices = boost::none) + const boost::optional&> genotype_indices = boost::none) { std::vector result {}; - result.reserve(2 + 2 * samples.size()); + result.reserve(1 + samples.size()); result.push_back(genotype_log_priors); - result.push_back(log_uniform_dist(genotypes.size())); IndividualModel germline_model {priors.genotype_prior_model}; for (const auto& sample : samples) { haplotype_log_likelihoods.prime(sample); @@ -220,18 +219,15 @@ auto generate_seeds(const std::vector& samples, } result.push_back(latents.posteriors.genotype_probabilities); maths::log_each(result.back()); - result.push_back(latents.posteriors.genotype_probabilities); - for (auto& p : result.back()) p = 1.0 - p; - maths::log_each(result.back()); } return result; } template -CNVModel::InferredLatents +SubcloneModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const CNVModel::Priors::GenotypeMixturesDirichletAlphaMap& prior_alphas, + const SubcloneModel::Priors::GenotypeMixturesDirichletAlphaMap& prior_alphas, const std::vector& genotype_log_priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params, @@ -246,10 +242,10 @@ run_variational_bayes(const std::vector& samples, // Main entry point template -CNVModel::InferredLatents +SubcloneModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const CNVModel::Priors& priors, + const SubcloneModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params) { @@ -260,15 +256,14 @@ run_variational_bayes(const std::vector& samples, } template -CNVModel::InferredLatents +SubcloneModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const std::vector>& genotype_indices, - const CNVModel::Priors& priors, + const std::vector& genotype_indices, + const SubcloneModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params) { - const auto genotype_log_priors = calculate_log_priors(genotype_indices, priors.genotype_prior_model); auto seeds = generate_seeds(samples, genotypes, genotype_log_priors, priors, haplotype_log_likelihoods, genotype_indices); return run_variational_bayes(samples, genotypes, priors.alphas, genotype_log_priors, diff --git a/src/core/models/genotype/cnv_model.hpp b/src/core/models/genotype/subclone_model.hpp similarity index 74% rename from src/core/models/genotype/cnv_model.hpp rename to src/core/models/genotype/subclone_model.hpp index 2e2120bea..10e98e298 100644 --- a/src/core/models/genotype/cnv_model.hpp +++ b/src/core/models/genotype/subclone_model.hpp @@ -1,8 +1,8 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. -#ifndef cnv_model_hpp -#define cnv_model_hpp +#ifndef subclone_model_hpp +#define subclone_model_hpp #include #include @@ -16,7 +16,7 @@ namespace octopus { namespace model { -class CNVModel +class SubcloneModel { public: struct AlgorithmParameters @@ -50,17 +50,17 @@ class CNVModel double approx_log_evidence; }; - CNVModel() = delete; + SubcloneModel() = delete; - CNVModel(std::vector samples, Priors priors); - CNVModel(std::vector samples, Priors priors, AlgorithmParameters parameters); + SubcloneModel(std::vector samples, Priors priors); + SubcloneModel(std::vector samples, Priors priors, AlgorithmParameters parameters); - CNVModel(const CNVModel&) = default; - CNVModel& operator=(const CNVModel&) = default; - CNVModel(CNVModel&&) = default; - CNVModel& operator=(CNVModel&&) = default; + SubcloneModel(const SubcloneModel&) = default; + SubcloneModel& operator=(const SubcloneModel&) = default; + SubcloneModel(SubcloneModel&&) = default; + SubcloneModel& operator=(SubcloneModel&&) = default; - ~CNVModel() = default; + ~SubcloneModel() = default; const Priors& priors() const noexcept; @@ -68,7 +68,7 @@ class CNVModel const HaplotypeLikelihoodCache& haplotype_likelihoods) const; InferredLatents evaluate(const std::vector>& genotypes, - const std::vector>& genotype_indices, + const std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; private: diff --git a/src/core/models/genotype/trio_model.cpp b/src/core/models/genotype/trio_model.cpp index 049e0f76a..de7406589 100644 --- a/src/core/models/genotype/trio_model.cpp +++ b/src/core/models/genotype/trio_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "trio_model.hpp" @@ -19,6 +19,11 @@ namespace octopus { namespace model { +unsigned TrioModel::max_ploidy() noexcept +{ + return 3; +} + TrioModel::TrioModel(const Trio& trio, const PopulationPriorModel& prior_model, const DeNovoModel& mutation_model, @@ -46,7 +51,7 @@ void clear(Container& c) } using GenotypeReference = std::reference_wrapper>; -using GenotypeIndiceVector = std::vector; +using GenotypeIndiceVector = GenotypeIndex; using GenotypeIndiceVectorReference = std::reference_wrapper; bool operator==(const GenotypeReference lhs, const GenotypeReference rhs) @@ -422,17 +427,17 @@ auto join(const ReducedVectorMap& maternal, return result; } -bool is_haploid(const std::vector& genotype) noexcept +bool is_haploid(const GenotypeIndex& genotype) noexcept { return genotype.size() == 1; } -bool is_diploid(const std::vector& genotype) noexcept +bool is_diploid(const GenotypeIndex& genotype) noexcept { return genotype.size() == 2; } -bool is_triploid(const std::vector& genotype) noexcept +bool is_triploid(const GenotypeIndex& genotype) noexcept { return genotype.size() == 3; } @@ -735,7 +740,7 @@ TrioModel::evaluate(const GenotypeVector& genotypes, const HaplotypeLikelihoodCa } TrioModel::InferredLatents -TrioModel::evaluate(const GenotypeVector& genotypes, std::vector>& genotype_indices, +TrioModel::evaluate(const GenotypeVector& genotypes, std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { assert(prior_model_.is_primed() && mutation_model_.is_primed()); diff --git a/src/core/models/genotype/trio_model.hpp b/src/core/models/genotype/trio_model.hpp index b9c685616..f79ba9974 100644 --- a/src/core/models/genotype/trio_model.hpp +++ b/src/core/models/genotype/trio_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef trio_model_hpp @@ -64,6 +64,8 @@ class TrioModel ~TrioModel() = default; + static unsigned max_ploidy() noexcept; + const PopulationPriorModel& prior_model() const noexcept; InferredLatents evaluate(const GenotypeVector& maternal_genotypes, @@ -76,7 +78,7 @@ class TrioModel const HaplotypeLikelihoodCache& haplotype_likelihoods) const; InferredLatents evaluate(const GenotypeVector& genotypes, - std::vector>& genotype_indices, + std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; private: diff --git a/src/core/models/genotype/tumour_model.cpp b/src/core/models/genotype/tumour_model.cpp index 52ec6f821..dc4b59057 100644 --- a/src/core/models/genotype/tumour_model.cpp +++ b/src/core/models/genotype/tumour_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "tumour_model.hpp" @@ -11,6 +11,7 @@ #include #include +#include "exceptions/unimplemented_feature_error.hpp" #include "utils/maths.hpp" #include "logging/logging.hpp" #include "germline_likelihood_model.hpp" @@ -35,7 +36,6 @@ const TumourModel::Priors& TumourModel::priors() const noexcept namespace { -template TumourModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, @@ -43,11 +43,10 @@ run_variational_bayes(const std::vector& samples, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params); -template TumourModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const std::vector, unsigned>>& genotype_indices, + const std::vector& genotype_indices, const TumourModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params); @@ -60,28 +59,18 @@ TumourModel::evaluate(const std::vector>& genotypes, { assert(!genotypes.empty()); const VariationalBayesParameters vb_params {parameters_.epsilon, parameters_.max_iterations}; - auto ploidy = genotypes.front().ploidy(); - assert(ploidy < 3); - if (ploidy == 1) { - return run_variational_bayes<2>(samples_, genotypes, priors_, haplotype_likelihoods, vb_params); - } - return run_variational_bayes<3>(samples_, genotypes, priors_, haplotype_likelihoods, vb_params); + return run_variational_bayes(samples_, genotypes, priors_, haplotype_likelihoods, vb_params); } TumourModel::InferredLatents TumourModel::evaluate(const std::vector>& genotypes, - const std::vector, unsigned>>& genotype_indices, + const std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const { assert(!genotypes.empty()); assert(genotypes.size() == genotype_indices.size()); const VariationalBayesParameters vb_params {parameters_.epsilon, parameters_.max_iterations}; - auto ploidy = genotypes.front().ploidy(); - assert(ploidy < 3); - if (ploidy == 1) { - return run_variational_bayes<2>(samples_, genotypes, genotype_indices, priors_, haplotype_likelihoods, vb_params); - } - return run_variational_bayes<3>(samples_, genotypes, genotype_indices, priors_, haplotype_likelihoods, vb_params); + return run_variational_bayes(samples_, genotypes, genotype_indices, priors_, haplotype_likelihoods, vb_params); } namespace { @@ -104,20 +93,27 @@ VBAlphaVector flatten(const TumourModel::Priors::GenotypeMixturesDirichletAlp return result; } +template +auto copy_cref(const Genotype& genotype, const SampleName& sample, + const HaplotypeLikelihoodCache& haplotype_likelihoods, + typename VBGenotype::iterator result_itr) +{ + return std::transform(std::cbegin(genotype), std::cend(genotype), result_itr, + [&sample, &haplotype_likelihoods] (const Haplotype& haplotype) + -> std::reference_wrapper { + return std::cref(haplotype_likelihoods(sample, haplotype)); + }); +} + template VBGenotype flatten(const CancerGenotype& genotype, const SampleName& sample, const HaplotypeLikelihoodCache& haplotype_likelihoods) { VBGenotype result {}; - const Genotype& germline_genotype{genotype.germline_genotype()}; - assert(germline_genotype.ploidy() == (K - 1)); - std::transform(std::cbegin(germline_genotype), std::cend(germline_genotype), std::begin(result), - [&sample, &haplotype_likelihoods] (const Haplotype& haplotype) - -> std::reference_wrapper { - return std::cref(haplotype_likelihoods(sample, haplotype)); - }); - result.back() = haplotype_likelihoods(sample, genotype.somatic_element()); + assert(genotype.ploidy() == K); + auto itr = copy_cref(genotype.germline(), sample, haplotype_likelihoods, std::begin(result)); + copy_cref(genotype.somatic(), sample, haplotype_likelihoods, itr); return result; } @@ -148,6 +144,7 @@ flatten(const std::vector>& genotypes, }); return result; } + template TumourModel::Latents::GenotypeMixturesDirichletAlphas expand(VBAlpha& alpha) { @@ -169,41 +166,40 @@ expand(const std::vector& samples, VBAlphaVector&& alphas) template TumourModel::InferredLatents -expand(const std::vector& samples, VBLatents&& inferred_latents, double evidence) +expand(const std::vector& samples, VBLatents&& inferred_latents, + std::vector genotype_log_priors, double evidence) { - TumourModel::Latents posterior_latents { - std::move(inferred_latents.genotype_posteriors), - expand(samples, std::move(inferred_latents.alphas)) - }; - return {std::move(posterior_latents), evidence}; + TumourModel::Latents posterior_latents {std::move(inferred_latents.genotype_posteriors), + expand(samples, std::move(inferred_latents.alphas))}; + return {std::move(posterior_latents), std::move(genotype_log_priors), evidence}; } -auto calculate_log_priors(const std::vector, unsigned>>& genotype_indices, - const CancerGenotypePriorModel& model) +auto compute_germline_log_likelihoods(const SampleName& sample, + const std::vector>& genotypes, + const HaplotypeLikelihoodCache& haplotype_log_likelihoods) { - - std::vector result(genotype_indices.size()); - std::transform(std::cbegin(genotype_indices), std::cend(genotype_indices), std::begin(result), - [&] (const auto& p) { return model.evaluate(p.first, p.second); }); - maths::normalise_logs(result); + haplotype_log_likelihoods.prime(sample); + const GermlineLikelihoodModel likelihood_model {haplotype_log_likelihoods}; + std::vector result(genotypes.size()); + std::transform(std::cbegin(genotypes), std::cend(genotypes), std::begin(result), + [&] (const auto& genotype) { return likelihood_model.evaluate(genotype.germline()); }); return result; } -auto compute_germline_log_likelihoods(const SampleName& sample, - const std::vector>& genotypes, - const HaplotypeLikelihoodCache& haplotype_log_likelihoods) +auto compute_demoted_log_likelihoods(const SampleName& sample, + const std::vector>& genotypes, + const HaplotypeLikelihoodCache& haplotype_log_likelihoods) { + assert(!genotypes.empty()); haplotype_log_likelihoods.prime(sample); const GermlineLikelihoodModel likelihood_model {haplotype_log_likelihoods}; std::vector result(genotypes.size()); std::transform(std::cbegin(genotypes), std::cend(genotypes), std::begin(result), - [&] (const auto& genotype) { - return likelihood_model.evaluate(genotype.germline_genotype()); - }); + [&] (const auto& genotype) { return likelihood_model.evaluate(demote(genotype)); }); return result; } -auto compute_germline_log_posteriors(const LogProbabilityVector& log_priors, const LogProbabilityVector& log_likelihoods) +auto compute_log_posteriors(const LogProbabilityVector& log_priors, const LogProbabilityVector& log_likelihoods) { assert(log_priors.size() == log_likelihoods.size()); LogProbabilityVector result(log_priors.size()); @@ -213,52 +209,139 @@ auto compute_germline_log_posteriors(const LogProbabilityVector& log_priors, con return result; } -auto compute_log_posteriors_with_germline_model(const SampleName& sample, - const std::vector>& genotypes, - const HaplotypeLikelihoodCache& haplotype_log_likelihoods) +LogProbabilityVector log_uniform_dist(const std::size_t n) { - assert(!genotypes.empty()); - haplotype_log_likelihoods.prime(sample); - const GermlineLikelihoodModel likelihood_model {haplotype_log_likelihoods}; - std::vector result(genotypes.size()); - std::transform(std::cbegin(genotypes), std::cend(genotypes), std::begin(result), - [&] (const auto& genotype) { - return likelihood_model.evaluate(demote(genotype)); - }); - maths::normalise_logs(result); + return LogProbabilityVector(n, -std::log(static_cast(n))); +} + +auto make_point_seed(const std::size_t num_genotypes, const std::size_t n, const double p = 0.9999) +{ + LogProbabilityVector result(num_genotypes, num_genotypes > 1 ? std::log((1 - p) / (num_genotypes - 1)) : 0); + if (num_genotypes > 1) result[n] = std::log(p); return result; } -LogProbabilityVector log_uniform_dist(const std::size_t n) +auto make_range_seed(const std::size_t num_genotypes, const std::size_t begin, const std::size_t n, const double p = 0.9999) { - return LogProbabilityVector(n, -std::log(static_cast(n))); + LogProbabilityVector result(num_genotypes, std::log((1 - p) / (num_genotypes - n))); + std::fill_n(std::next(std::begin(result), begin), n, std::log(p / n)); + return result; } -auto generate_seeds(const std::vector& samples, - const std::vector>& genotypes, - const LogProbabilityVector& genotype_log_priors, - const HaplotypeLikelihoodCache& haplotype_log_likelihoods) +namespace debug { + +template +void print_top(S&& stream, const std::vector>& genotypes, + const LogProbabilityVector& probs, std::size_t n) +{ + assert(probs.size() == genotypes.size()); + n = std::min(n, genotypes.size()); + std::vector, double> > pairs {}; + pairs.reserve(genotypes.size()); + std::transform(std::cbegin(genotypes), std::cend(genotypes), std::cbegin(probs), std::back_inserter(pairs), + [] (const auto& g, auto p) { return std::make_pair(g, p); }); + const auto mth = std::next(std::begin(pairs), n); + std::partial_sort(std::begin(pairs), mth, std::end(pairs), + [] (const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); + std::for_each(std::begin(pairs), mth, [&] (const auto& p) { + octopus::debug::print_variant_alleles(stream, p.first); + stream << " " << p.second << '\n'; + }); +} + +void print_top(const std::vector>& genotypes, + const LogProbabilityVector& probs, std::size_t n = 10) +{ + print_top(std::cout, genotypes, probs, n); +} + +} // namespace debug + +bool is_somatic_expected(const SampleName& sample, const TumourModel::Priors& priors) +{ + const auto& alphas = priors.alphas.at(sample); + auto e = maths::dirichlet_expectation(alphas.size() - 1, alphas); + return e > 0.05; +} + +void add_to(const LogProbabilityVector& other, LogProbabilityVector& result) +{ + std::transform(std::cbegin(other), std::cend(other), std::cbegin(result), std::begin(result), + [] (auto a, auto b) { return a + b; }); +} + +auto generate_exhaustive_seeds(const std::size_t n) { std::vector result {}; - result.reserve(2 + 3 * samples.size()); + result.reserve(n); + for (unsigned i {0}; i < n; ++i) { + result.push_back(make_point_seed(n, i)); + } + return result; +} + +auto num_targetted_seeds(const std::vector& samples, + const std::vector>& genotypes) noexcept +{ + return 1 + 4 * samples.size() + 2 * (samples.size() > 1); +} + +auto generate_targetted_seeds(const std::vector& samples, + const std::vector>& genotypes, + const LogProbabilityVector& genotype_log_priors, + const HaplotypeLikelihoodCache& haplotype_log_likelihoods, + const TumourModel::Priors& priors) +{ + std::vector result {}; + result.reserve(num_targetted_seeds(samples, genotypes)); result.push_back(genotype_log_priors); - result.push_back(log_uniform_dist(genotypes.size())); + maths::normalise_logs(result.back()); + LogProbabilityVector combined_log_likelihoods(genotypes.size(), 0); for (const auto& sample : samples) { auto log_likelihoods = compute_germline_log_likelihoods(sample, genotypes, haplotype_log_likelihoods); - result.push_back(compute_germline_log_posteriors(genotype_log_priors, log_likelihoods)); - maths::normalise_exp(log_likelihoods); // convert to probabilities + auto demoted_log_likelihoods = compute_demoted_log_likelihoods(sample, genotypes, haplotype_log_likelihoods); + if (is_somatic_expected(sample, priors)) { + add_to(demoted_log_likelihoods, combined_log_likelihoods); + } else { + add_to(log_likelihoods, combined_log_likelihoods); + } + result.push_back(compute_log_posteriors(genotype_log_priors, log_likelihoods)); + maths::normalise_logs(log_likelihoods); result.push_back(std::move(log_likelihoods)); - result.push_back(compute_log_posteriors_with_germline_model(sample, genotypes, haplotype_log_likelihoods)); + result.push_back(compute_log_posteriors(genotype_log_priors, demoted_log_likelihoods)); + maths::normalise_logs(demoted_log_likelihoods); + result.push_back(std::move(demoted_log_likelihoods)); + } + if (samples.size() > 1) { + auto combined_log_posteriors = combined_log_likelihoods; + add_to(genotype_log_priors, combined_log_posteriors); + maths::normalise_logs(combined_log_posteriors); + result.push_back(std::move(combined_log_posteriors)); + maths::normalise_logs(combined_log_likelihoods); + result.push_back(std::move(combined_log_likelihoods)); } return result; } +auto generate_seeds(const std::vector& samples, + const std::vector>& genotypes, + const LogProbabilityVector& genotype_log_priors, + const HaplotypeLikelihoodCache& haplotype_log_likelihoods, + const TumourModel::Priors& priors) +{ + if (genotypes.size() <= num_targetted_seeds(samples, genotypes)) { + return generate_exhaustive_seeds(genotypes.size()); + } else { + return generate_targetted_seeds(samples, genotypes, genotype_log_priors, haplotype_log_likelihoods, priors); + } +} + template TumourModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, const TumourModel::Priors::GenotypeMixturesDirichletAlphaMap& prior_alphas, - const std::vector& genotype_log_priors, + std::vector genotype_log_priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params, std::vector>&& seeds) @@ -266,12 +349,40 @@ run_variational_bayes(const std::vector& samples, const auto vb_prior_alphas = flatten(prior_alphas, samples); const auto log_likelihoods = flatten(genotypes, samples, haplotype_log_likelihoods); auto p = run_variational_bayes(vb_prior_alphas, genotype_log_priors, log_likelihoods, params, std::move(seeds)); - return expand(samples, std::move(p.first), p.second); + return expand(samples, std::move(p.first), std::move(genotype_log_priors), p.second); +} + +TumourModel::InferredLatents +run_variational_bayes_helper(const std::vector& samples, + const std::vector>& genotypes, + const TumourModel::Priors::GenotypeMixturesDirichletAlphaMap& prior_alphas, + std::vector genotype_log_priors, + const HaplotypeLikelihoodCache& haplotype_log_likelihoods, + const VariationalBayesParameters& params, + std::vector>&& seeds) +{ + using std::move; + switch (genotypes.front().ploidy()) { + case 2: return run_variational_bayes<2>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + case 3: return run_variational_bayes<3>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + case 4: return run_variational_bayes<4>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + case 5: return run_variational_bayes<5>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + case 6: return run_variational_bayes<6>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + case 7: return run_variational_bayes<7>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + case 8: return run_variational_bayes<8>(samples, genotypes, prior_alphas, move(genotype_log_priors), + haplotype_log_likelihoods, params, move(seeds)); + default: throw UnimplementedFeatureError {"tumour model ploidies above 8", "TumourModel"}; + } } // Main entry point -template TumourModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, @@ -279,25 +390,24 @@ run_variational_bayes(const std::vector& samples, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params) { - const auto genotype_log_priors = calculate_log_priors(genotypes, priors.genotype_prior_model); - auto seeds = generate_seeds(samples, genotypes, genotype_log_priors, haplotype_log_likelihoods); - return run_variational_bayes(samples, genotypes, priors.alphas, genotype_log_priors, - haplotype_log_likelihoods, params, std::move(seeds)); + auto genotype_log_priors = calculate_log_priors(genotypes, priors.genotype_prior_model); + auto seeds = generate_seeds(samples, genotypes, genotype_log_priors, haplotype_log_likelihoods, priors); + return run_variational_bayes_helper(samples, genotypes, priors.alphas, std::move(genotype_log_priors), + haplotype_log_likelihoods, params, std::move(seeds)); } -template TumourModel::InferredLatents run_variational_bayes(const std::vector& samples, const std::vector>& genotypes, - const std::vector, unsigned>>& genotype_indices, + const std::vector& genotype_indices, const TumourModel::Priors& priors, const HaplotypeLikelihoodCache& haplotype_log_likelihoods, const VariationalBayesParameters& params) { - const auto genotype_log_priors = calculate_log_priors(genotype_indices, priors.genotype_prior_model); - auto seeds = generate_seeds(samples, genotypes, genotype_log_priors, haplotype_log_likelihoods); - return run_variational_bayes(samples, genotypes, priors.alphas, genotype_log_priors, - haplotype_log_likelihoods, params, std::move(seeds)); + auto genotype_log_priors = calculate_log_priors(genotype_indices, priors.genotype_prior_model); + auto seeds = generate_seeds(samples, genotypes, genotype_log_priors, haplotype_log_likelihoods, priors); + return run_variational_bayes_helper(samples, genotypes, priors.alphas, std::move(genotype_log_priors), + haplotype_log_likelihoods, params, std::move(seeds)); } } // namespace diff --git a/src/core/models/genotype/tumour_model.hpp b/src/core/models/genotype/tumour_model.hpp index dd40334a5..324d98929 100644 --- a/src/core/models/genotype/tumour_model.hpp +++ b/src/core/models/genotype/tumour_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef tumour_model_hpp @@ -46,6 +46,7 @@ class TumourModel struct InferredLatents { Latents posteriors; + Latents::ProbabilityVector genotype_log_priors; double approx_log_evidence; }; @@ -67,7 +68,7 @@ class TumourModel const HaplotypeLikelihoodCache& haplotype_likelihoods) const; InferredLatents evaluate(const std::vector>& genotypes, - const std::vector, unsigned>>& genotype_indices, + const std::vector& genotype_indices, const HaplotypeLikelihoodCache& haplotype_likelihoods) const; private: diff --git a/src/core/models/genotype/uniform_genotype_prior_model.hpp b/src/core/models/genotype/uniform_genotype_prior_model.hpp index cbe729784..6804b83ac 100644 --- a/src/core/models/genotype/uniform_genotype_prior_model.hpp +++ b/src/core/models/genotype/uniform_genotype_prior_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef uniform_genotype_prior_model_hpp @@ -22,7 +22,7 @@ class UniformGenotypePriorModel : public GenotypePriorModel private: virtual double do_evaluate(const Genotype& genotype) const override { return 1.0; } - virtual double do_evaluate(const std::vector& genotype) const override { return 1.0; } + virtual double do_evaluate(const GenotypeIndex& genotype) const override { return 1.0; } bool check_is_primed() const noexcept override { return true; } }; diff --git a/src/core/models/genotype/uniform_population_prior_model.hpp b/src/core/models/genotype/uniform_population_prior_model.hpp index 0db29c6fa..e044c2f0c 100644 --- a/src/core/models/genotype/uniform_population_prior_model.hpp +++ b/src/core/models/genotype/uniform_population_prior_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef uniform_population_prior_model_hpp @@ -38,7 +38,7 @@ class UniformPopulationPriorModel : public PopulationPriorModel { return 1.0; } - double do_evaluate(const std::vector>& indices) const override + double do_evaluate(const std::vector& genotypes) const override { return 1.0; } diff --git a/src/core/models/genotype/variational_bayes_mixture_model.hpp b/src/core/models/genotype/variational_bayes_mixture_model.hpp index 1ec7255b4..0b9aa7fc7 100644 --- a/src/core/models/genotype/variational_bayes_mixture_model.hpp +++ b/src/core/models/genotype/variational_bayes_mixture_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variational_bayes_mixture_model_hpp @@ -12,7 +12,9 @@ #include #include #include +#include +#include #include #include "utils/maths.hpp" @@ -77,10 +79,9 @@ using VBGenotypeVector = std::vector>; // Per element per genotype template using VBReadLikelihoodMatrix = std::vector>; // One element per sample +using VBTau = std::vector; // One element per read template -using VBTau = std::array; // One element per haplotype in genotype (i.e. K) -template -using VBResponsabilityVector = std::vector>; // One element per genotype +using VBResponsabilityVector = std::array; // One element per haplotype in genotype (i.e. K) template using VBResponsabilityMatrix = std::vector>; // One element per sample @@ -224,11 +225,19 @@ auto marginalise(const ProbabilityVector& distribution, const VBGenotypeVector +auto inner_product(const T1& lhs, const T2& rhs) noexcept +{ + assert(std::distance(std::cbegin(lhs), std::cend(lhs)) == std::distance(std::cbegin(rhs), std::cend(rhs))); + using T = typename T1::value_type; + return std::inner_product(std::cbegin(lhs), std::cend(lhs), std::cbegin(rhs), T {0}); +} + template auto marginalise(const ProbabilityVector& distribution, const VBInverseGenotypeVector& likelihoods, const unsigned k, const std::size_t n) noexcept { - return std::inner_product(std::cbegin(distribution), std::cend(distribution), std::cbegin(likelihoods[k][n]), 0.0); + return inner_product(distribution, likelihoods[k][n]); } template @@ -244,7 +253,8 @@ init_responsabilities(const VBAlpha& prior_alphas, al[k] = digamma_diff(prior_alphas[k], a0); } const auto N = count_reads(read_likelihoods); - VBResponsabilityVector result(N); + VBResponsabilityVector result {}; + for (auto& tau : result) tau.resize(N); std::array ln_rho; for (std::size_t n {0}; n < N; ++n) { for (unsigned k {0}; k < K; ++k) { @@ -252,7 +262,7 @@ init_responsabilities(const VBAlpha& prior_alphas, } const auto ln_rho_norm = log_sum_exp(ln_rho); for (unsigned k {0}; k < K; ++k) { - result[n][k] = std::exp(ln_rho[k] - ln_rho_norm); + result[k][n] = std::exp(ln_rho[k] - ln_rho_norm); } } return result; @@ -293,7 +303,7 @@ void update_responsabilities(VBResponsabilityVector& result, } const auto ln_rho_norm = log_sum_exp(ln_rho); for (unsigned k {0}; k < K; ++k) { - result[n][k] = std::exp(ln_rho[k] - ln_rho_norm); + result[k][n] = std::exp(ln_rho[k] - ln_rho_norm); } } } @@ -311,13 +321,10 @@ void update_responsabilities(VBResponsabilityMatrix& result, } } -template -auto sum(const VBResponsabilityVector& taus, const unsigned k) noexcept +template +inline auto sum(const std::vector& values) noexcept { - return std::accumulate(std::cbegin(taus), std::cend(taus), 0.0, - [k] (const auto curr, const auto& tau) noexcept { - return curr + tau[k]; - }); + return std::accumulate(std::cbegin(values), std::cend(values), T {}); } template @@ -325,7 +332,7 @@ void update_alpha(VBAlpha& alpha, const VBAlpha& prior_alpha, const VBResponsabilityVector& taus) noexcept { for (unsigned k {0}; k < K; ++k) { - alpha[k] = prior_alpha[k] + sum(taus, k); + alpha[k] = prior_alpha[k] + sum(taus[k]); } } @@ -340,20 +347,19 @@ void update_alphas(VBAlphaVector& alphas, const VBAlphaVector& prior_alpha } } +inline auto marginalise(const VBTau& responsabilities, const VBReadLikelihoodArray& likelihoods) noexcept +{ + assert(responsabilities.size() == likelihoods.size()); // num reads + return inner_product(responsabilities, likelihoods); +} + template auto marginalise(const VBResponsabilityVector& responsabilities, - const VBGenotypeVector& read_likelihoods, - const std::size_t g) noexcept + const VBGenotype& read_likelihoods) noexcept { double result {0}; - const auto N = read_likelihoods[0][0].size(); // num reads in sample s - assert(responsabilities.size() == N); - assert(responsabilities[0].size() == K && read_likelihoods[g].size() == K); for (unsigned k {0}; k < K; ++k) { - const auto& k_likelihoods = read_likelihoods[g][k]; - for (std::size_t n {0}; n < N; ++n) { - result += responsabilities[n][k] * k_likelihoods[n]; - } + result += marginalise(responsabilities[k], read_likelihoods[k]); } return result; } @@ -367,7 +373,7 @@ auto marginalise(const VBResponsabilityMatrix& responsabilities, const auto S = read_likelihoods.size(); // num samples assert(S == responsabilities.size()); for (std::size_t s {0}; s < S; ++s) { - result += marginalise(responsabilities[s], read_likelihoods[s], g); + result += marginalise(responsabilities[s], read_likelihoods[s][g]); } return result; } @@ -385,209 +391,49 @@ void update_genotype_log_posteriors(LogProbabilityVector& result, maths::normalise_logs(result); } -inline auto max_change(const VBAlpha<2>& lhs, const VBAlpha<2>& rhs) noexcept -{ - return std::max(std::abs(lhs.front() - rhs.front()), std::abs(lhs.back() - rhs.back())); -} - -inline auto max_change(const VBAlpha<3>& lhs, const VBAlpha<3>& rhs) noexcept -{ - return std::max({std::abs(lhs[0] - rhs[0]), std::abs(lhs[1] - rhs[1]), std::abs(lhs[2] - rhs[2])}); -} - -template -auto max_change(const VBAlpha& lhs, const VBAlpha& rhs) noexcept -{ - double result {0}; - for (std::size_t k {0}; k < K; ++k) { - const auto curr = std::abs(lhs[k] - rhs[k]); - if (curr > result) result = curr; - } - return result; -} - -template -auto max_change(const VBAlphaVector& prior_alphas, const VBAlphaVector& posterior_alphas) noexcept -{ - const auto S = prior_alphas.size(); - assert(S == posterior_alphas.size()); - double result {0}; - for (std::size_t s {0}; s < S; ++s) { - const auto curr = max_change(prior_alphas[s], posterior_alphas[s]); - if (curr > result) result = curr; - } - return result; -} - -template -std::pair check_convergence(const VBAlphaVector& prior_alphas, - const VBAlphaVector& posterior_alphas, - const double prev_max_change, - const double epsilon) noexcept -{ - const auto new_max_change = max_change(prior_alphas, posterior_alphas); - return std::make_pair(std::abs(new_max_change - prev_max_change) < epsilon, new_max_change); -} - -// lower-bound calculation - -inline auto expectation(const ProbabilityVector& genotype_posteriors, - const LogProbabilityVector& genotype_log_priors) noexcept -{ - return std::inner_product(std::cbegin(genotype_posteriors), std::cend(genotype_posteriors), - std::cbegin(genotype_log_priors), 0.0); -} - -template -auto dirichlet_expectation(const VBAlpha& priors, const VBAlpha& posteriors) -{ - using boost::math::digamma; - const auto da0 = digamma(sum(posteriors)); - return std::inner_product(std::cbegin(priors), std::cend(priors), - std::cbegin(posteriors), 0.0, std::plus<> {}, - [da0] (const auto& prior, const auto& post) { - return (prior - 1) * (digamma(post) - da0); - }) - maths::log_beta(priors); -} - -template -auto expectation(const VBAlphaVector& priors, const VBAlphaVector& posteriors) -{ - return std::inner_product(std::cbegin(priors), std::cend(priors), - std::cbegin(posteriors), 0.0, std::plus<> {}, - [] (const auto& prior, const auto& post) { - return dirichlet_expectation(prior, post); - }); -} - -// E[ln p(Z_s | pi_s)] -template -auto expectation(const VBResponsabilityVector& taus, const VBAlpha& alpha) +inline auto entropy(const VBTau& tau) noexcept { - using boost::math::digamma; - const auto das = digamma(sum(alpha)); - double result {0}; - for (unsigned k {0}; k < K; ++k) { - result += (digamma(alpha[k]) - das) * sum(taus, k); - } - return result; + return -std::accumulate(std::cbegin(tau), std::cend(tau), 0.0, + [] (const auto curr, const auto t) noexcept { return curr + (t * std::log(t)); }); } -// sum s E[ln p(Z_s | pi_s)] +// E [ln q(Z_s)] template -auto expectation(const VBResponsabilityMatrix& taus, const VBAlphaVector& alphas) +auto sum_entropies(const VBResponsabilityVector& taus) noexcept { - return std::inner_product(std::cbegin(taus), std::cend(taus), std::cbegin(alphas), 0.0, std::plus<> {}, - [] (const auto& tau, const auto& alpha) { - return expectation(tau, alpha); - }); + return std::accumulate(std::cbegin(taus), std::cend(taus), 0.0, + [] (const auto curr, const auto& tau) noexcept { return curr + entropy(tau); }); } template -auto expectation(const VBResponsabilityMatrix& taus, - const VBReadLikelihoodMatrix& log_likelihoods, - const std::size_t g) -{ +auto calculate_evidence_lower_bound(const VBAlphaVector& prior_alphas, + const VBAlphaVector& posterior_alphas, + const LogProbabilityVector& genotype_log_priors, + const ProbabilityVector& genotype_posteriors, + const LogProbabilityVector& genotype_log_posteriors, + const VBResponsabilityMatrix& taus, + const VBReadLikelihoodMatrix& log_likelihoods, + const boost::optional max_posterior_skip = boost::none) +{ + const auto G = genotype_log_priors.size(); + const auto S = log_likelihoods.size(); double result {0}; - for (std::size_t s {0}; s < taus.size(); ++s) { - for (std::size_t n {0}; n < taus[s].size(); ++n) { - for (unsigned k {0}; k < K; ++k) { - result += taus[s][n][k] * log_likelihoods[s][g][k][n]; + for (std::size_t g {0}; g < G; ++g) { + if (!max_posterior_skip || genotype_posteriors[g] >= *max_posterior_skip) { + auto w = genotype_log_priors[g] - genotype_log_posteriors[g]; + for (std::size_t s {0}; s < S; ++s) { + w += marginalise(taus[s], log_likelihoods[s][g]); } + result += genotype_posteriors[g] * w; } } - return result; -} - -// E[ln p(R | Z, g)] -template -auto expectation(const ProbabilityVector& genotype_posteriors, - const VBResponsabilityMatrix& taus, - const VBReadLikelihoodMatrix& log_likelihoods) -{ - double result {0}; - for (std::size_t g {0}; g < genotype_posteriors.size(); ++g) { - result += genotype_posteriors[g] * expectation(taus, log_likelihoods, g); + for (std::size_t s {0}; s < S; ++s) { + result += (maths::log_beta(posterior_alphas[s]) - maths::log_beta(prior_alphas[s])); + result += sum_entropies(taus[s]); } return result; } -template -auto dirichlet_expectation(const VBAlpha& posterior) -{ - using boost::math::digamma; - const auto da0 = digamma(sum(posterior)); - return std::accumulate(std::cbegin(posterior), std::cend(posterior), 0.0, - [da0] (const auto curr, const auto a) { - return curr + ((a - 1) * (digamma(a) - da0)); - }) - maths::log_beta(posterior); -} - -template -auto expectation(const VBAlphaVector& posteriors) -{ - return std::accumulate(std::cbegin(posteriors), std::cend(posteriors), 0.0, - [] (const auto curr, const auto& posterior) { - return curr + dirichlet_expectation(posterior); - }); -} - -template -auto q_expectation(const VBTau& tau) noexcept -{ - return std::accumulate(std::cbegin(tau), std::cend(tau), 0.0, - [] (const auto curr, const auto t) noexcept { - return curr + (t * std::log(t)); - }); -} - -template <> -inline auto q_expectation<2>(const VBTau<2>& tau) noexcept -{ - return tau[0] * std::log(tau[0]) + tau[1] * std::log(tau[1]); -} - -// E [ln q(Z_s)] -template -auto q_expectation(const VBResponsabilityVector& taus) noexcept -{ - return std::accumulate(std::cbegin(taus), std::cend(taus), 0.0, - [] (const auto curr, const auto& tau) noexcept { - return curr + q_expectation(tau); - }); -} - -// sum s E [ln q(Z_s)] -template -auto q_expectation(const VBResponsabilityMatrix& taus) noexcept -{ - return std::accumulate(std::cbegin(taus), std::cend(taus), 0.0, - [] (const auto curr, const auto& t) noexcept { - return curr + q_expectation(t); - }); -} - -template -auto calculate_lower_bound(const VBAlphaVector& prior_alphas, - const LogProbabilityVector& genotype_log_priors, - const VBReadLikelihoodMatrix& log_likelihoods, - const VBLatents& latents) -{ - const auto& genotype_posteriors = latents.genotype_posteriors; - const auto& genotype_log_posteriors = latents.genotype_log_posteriors; - const auto& posterior_alphas = latents.alphas; - const auto& taus = latents.responsabilities; - double result {0}; - result += expectation(genotype_posteriors, genotype_log_priors); - result += expectation(prior_alphas, posterior_alphas); - result += expectation(taus, posterior_alphas); - result += expectation(genotype_posteriors, taus, log_likelihoods); - result -= expectation(genotype_posteriors, genotype_log_posteriors); - result -= expectation(posterior_alphas); - result -= q_expectation(taus); - return result; -} - // Main algorithm - single seed // Starting iteration with given genotype_log_posteriors @@ -611,15 +457,17 @@ run_variational_bayes(const VBAlphaVector& prior_alphas, auto posterior_alphas = prior_alphas; auto responsabilities = init_responsabilities(posterior_alphas, genotype_posteriors, log_likelihoods2); assert(responsabilities.size() == log_likelihoods1.size()); // num samples - bool is_converged {false}; - double max_change {0}; + auto prev_evidence = std::numeric_limits::lowest(); for (unsigned i {0}; i < params.max_iterations; ++i) { update_genotype_log_posteriors(genotype_log_posteriors, genotype_log_priors, responsabilities, log_likelihoods1); exp(genotype_log_posteriors, genotype_posteriors); update_alphas(posterior_alphas, prior_alphas, responsabilities); + auto curr_evidence = calculate_evidence_lower_bound(prior_alphas, posterior_alphas, genotype_log_priors, + genotype_posteriors, genotype_log_posteriors, responsabilities, + log_likelihoods1, 1e-10); + if (curr_evidence <= prev_evidence || (curr_evidence - prev_evidence) < params.epsilon) break; + prev_evidence = curr_evidence; update_responsabilities(responsabilities, posterior_alphas, genotype_posteriors, log_likelihoods2); - std::tie(is_converged, max_change) = check_convergence(prior_alphas, posterior_alphas, max_change, params.epsilon); - if (is_converged) break; } return VBLatents { std::move(genotype_posteriors), std::move(genotype_log_posteriors), @@ -677,6 +525,20 @@ run_variational_bayes(const VBAlphaVector& prior_alphas, return result; } +// lower-bound calculation + +template +auto calculate_evidence_lower_bound(const VBAlphaVector& prior_alphas, + const LogProbabilityVector& genotype_log_priors, + const VBReadLikelihoodMatrix& log_likelihoods, + const VBLatents& latents) +{ + return calculate_evidence_lower_bound(prior_alphas, latents.alphas, genotype_log_priors, + latents.genotype_posteriors, latents.genotype_log_posteriors, + latents.responsabilities, log_likelihoods); + +} + template std::pair, double> get_max_evidence_latents(const VBAlphaVector& prior_alphas, @@ -687,7 +549,7 @@ get_max_evidence_latents(const VBAlphaVector& prior_alphas, std::vector seed_evidences(latents.size()); std::transform(std::cbegin(latents), std::cend(latents), std::begin(seed_evidences), [&] (const auto& seed_latents) { - return calculate_lower_bound(prior_alphas, genotype_log_priors, log_likelihoods, seed_latents); + return calculate_evidence_lower_bound(prior_alphas, genotype_log_priors, log_likelihoods, seed_latents); }); const auto max_itr = std::max_element(std::cbegin(seed_evidences), std::cend(seed_evidences)); const auto max_idx = std::distance(std::cbegin(seed_evidences), max_itr); @@ -704,6 +566,7 @@ run_variational_bayes(const VBAlphaVector& prior_alphas, const VariationalBayesParameters& params, std::vector seeds) { + assert(!seeds.empty()); auto latents = detail::run_variational_bayes(prior_alphas, genotype_log_priors, log_likelihoods, params, std::move(seeds)); return detail::get_max_evidence_latents(prior_alphas, genotype_log_priors, log_likelihoods, std::move(latents)); } diff --git a/src/core/models/haplotype_likelihood_cache.cpp b/src/core/models/haplotype_likelihood_cache.cpp index 9d1afd217..b77f25276 100644 --- a/src/core/models/haplotype_likelihood_cache.cpp +++ b/src/core/models/haplotype_likelihood_cache.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "haplotype_likelihood_cache.hpp" diff --git a/src/core/models/haplotype_likelihood_cache.hpp b/src/core/models/haplotype_likelihood_cache.hpp index e61a628d4..1bd404517 100644 --- a/src/core/models/haplotype_likelihood_cache.hpp +++ b/src/core/models/haplotype_likelihood_cache.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef haplotype_likelihood_cache_hpp @@ -177,17 +177,18 @@ void print_read_haplotype_likelihoods(S&& stream, [] (const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); - std::for_each(std::begin(likelihoods), mth, - [&] (const auto& p) { - if (is_single_sample) { - stream << "\t"; - } else { - stream << "\t\t"; - } - stream << p.first.get().mapped_region() - << " " << p.first.get().cigar() << ": "; - stream << p.second << '\n'; - }); + std::for_each(std::begin(likelihoods), mth, [&] (const auto& p) { + if (is_single_sample) { + stream << "\t"; + } else { + stream << "\t\t"; + } + const auto& read = p.first.get(); + stream << read.name() << " " + << mapped_region(read) << " " + << p.first.get().cigar() << ": " + << p.second << '\n'; + }); } } } diff --git a/src/core/models/haplotype_likelihood_model.cpp b/src/core/models/haplotype_likelihood_model.cpp index bd311fc54..9b04d61b8 100644 --- a/src/core/models/haplotype_likelihood_model.cpp +++ b/src/core/models/haplotype_likelihood_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "haplotype_likelihood_model.hpp" @@ -307,22 +307,22 @@ compute_optimal_alignment(const AlignedRead& read, const Haplotype& haplotype, } if (is_in_range(position, read, haplotype)) { has_in_range_mapping_position = true; - auto p = hmm::align(read.sequence(), haplotype.sequence(), read.base_qualities(), position, model); - if (p.second > result.likelihood) { - result.mapping_position = position; - result.likelihood = p.second; - result.cigar = std::move(p.first); + auto alignment = hmm::align(read.sequence(), haplotype.sequence(), read.base_qualities(), position, model); + if (alignment.likelihood > result.likelihood) { + result.mapping_position = alignment.target_offset; + result.likelihood = alignment.likelihood; + result.cigar = std::move(alignment.cigar); } } }); if (!is_original_position_mapped && is_in_range(original_mapping_position, read, haplotype)) { has_in_range_mapping_position = true; - auto p = hmm::align(read.sequence(), haplotype.sequence(), read.base_qualities(), - original_mapping_position, model); - if (p.second > result.likelihood) { - result.mapping_position = original_mapping_position; - result.likelihood = p.second; - result.cigar = std::move(p.first); + auto alignment = hmm::align(read.sequence(), haplotype.sequence(), read.base_qualities(), + original_mapping_position, model); + if (alignment.likelihood >= result.likelihood) { + result.mapping_position = alignment.target_offset; + result.likelihood = alignment.likelihood; + result.cigar = std::move(alignment.cigar); } } if (!has_in_range_mapping_position) { @@ -342,9 +342,11 @@ compute_optimal_alignment(const AlignedRead& read, const Haplotype& haplotype, throw HaplotypeLikelihoodModel::ShortHaplotypeError {haplotype, required_extension}; } } - result.mapping_position = final_mapping_position; - std::tie(result.cigar, result.likelihood) = hmm::align(read.sequence(), haplotype.sequence(), read.base_qualities(), - final_mapping_position, model); + auto alignment = hmm::align(read.sequence(), haplotype.sequence(), read.base_qualities(), + final_mapping_position, model); + result.likelihood = alignment.likelihood; + result.cigar = std::move(alignment.cigar); + result.mapping_position = alignment.target_offset; } assert(result.likelihood > std::numeric_limits::lowest() && result.likelihood <= 0); return result; diff --git a/src/core/models/haplotype_likelihood_model.hpp b/src/core/models/haplotype_likelihood_model.hpp index 82282b78c..8cbdcd99c 100644 --- a/src/core/models/haplotype_likelihood_model.hpp +++ b/src/core/models/haplotype_likelihood_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef haplotype_likelihood_model_hpp diff --git a/src/core/models/mutation/coalescent_model.cpp b/src/core/models/mutation/coalescent_model.cpp index 11be22591..aad35d38a 100644 --- a/src/core/models/mutation/coalescent_model.cpp +++ b/src/core/models/mutation/coalescent_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "coalescent_model.hpp" @@ -11,58 +11,14 @@ #include -#include "tandem/tandem.hpp" #include "utils/maths.hpp" namespace octopus { -auto find_repeats(const Haplotype& haplotype, const unsigned max_period) -{ - if (max_period < 4) { - return tandem::extract_exact_tandem_repeats(haplotype.sequence(), 1, max_period); - } else { - thread_local std::vector buffer {}; - buffer.resize(sequence_size(haplotype) + 1); - std::copy(std::cbegin(haplotype.sequence()), std::cend(haplotype.sequence()), std::begin(buffer)); - buffer.back() = '$'; - return tandem::extract_exact_tandem_repeats(buffer, 1, max_period); - } -} - -auto percent_of_bases_in_repeat(const Haplotype& haplotype) -{ - const auto repeats = find_repeats(haplotype, 6); - if (repeats.empty()) return 0.0; - std::vector repeat_counts(sequence_size(haplotype), 0); - for (const auto& repeat : repeats) { - const auto itr1 = std::next(std::begin(repeat_counts), repeat.pos); - const auto itr2 = std::next(itr1, repeat.length); - std::transform(itr1, itr2, itr1, [] (const auto c) { return c + 1; }); - } - const auto c = std::count_if(std::cbegin(repeat_counts), std::cend(repeat_counts), - [] (const auto c) { return c > 0; }); - return static_cast(c) / repeat_counts.size(); -} - -auto calculate_base_indel_heterozygosities(const Haplotype& haplotype, const double base_indel_heterozygosity) -{ - std::vector result(sequence_size(haplotype), base_indel_heterozygosity); - const auto repeats = find_repeats(haplotype, 3); - for (const auto& repeat : repeats) { - const auto itr1 = std::next(std::begin(result), repeat.pos); - const auto itr2 = std::next(itr1, repeat.length); - const auto n = repeat.length / repeat.period; - // TODO: implement a proper model for this - const auto t = std::min(base_indel_heterozygosity * std::pow(n, 2.6), 1.0); - std::transform(itr1, itr2, itr1, [t] (const auto h) { return std::max(h, t); }); - } - return result; -} - CoalescentModel::CoalescentModel(Haplotype reference, Parameters params, std::size_t num_haplotyes_hint, CachingStrategy caching) : reference_ {std::move(reference)} -, reference_base_indel_heterozygosities_ {} +, indel_heterozygosity_model_ {make_indel_model(reference_, {params.indel_heterozygosity})} , params_ {params} , haplotypes_ {} , caching_ {caching} @@ -73,7 +29,6 @@ CoalescentModel::CoalescentModel(Haplotype reference, Parameters params, if (params_.snp_heterozygosity <= 0 || params_.indel_heterozygosity <= 0) { throw std::domain_error {"CoalescentModel: snp and indel heterozygosity must be > 0"}; } - reference_base_indel_heterozygosities_ = calculate_base_indel_heterozygosities(reference_, params_.indel_heterozygosity); site_buffer1_.reserve(128); site_buffer2_.reserve(128); if (caching == CachingStrategy::address) { @@ -122,6 +77,11 @@ bool CoalescentModel::is_primed() const noexcept return !index_cache_.empty(); } +double CoalescentModel::evaluate(const Haplotype& haplotype) const +{ + return evaluate(count_segregating_sites(haplotype)); +} + double CoalescentModel::evaluate(const std::vector& haplotype_indices) const { return evaluate(count_segregating_sites(haplotype_indices)); @@ -158,12 +118,10 @@ template auto complex_log_sum_exp(ForwardIt first, ForwardIt last) { using ComplexType = typename std::iterator_traits::value_type; - const auto l = [](const auto& lhs, const auto& rhs) { return lhs.real() < rhs.real(); }; + const auto l = [] (const auto& lhs, const auto& rhs) { return lhs.real() < rhs.real(); }; const auto max = *std::max_element(first, last, l); return max + std::log(std::accumulate(first, last, ComplexType {}, - [max](const auto curr, const auto x) { - return curr + std::exp(x - max); - })); + [max] (const auto curr, const auto x) { return curr + std::exp(x - max); })); } template @@ -242,21 +200,8 @@ double CoalescentModel::evaluate(const unsigned k_snp, const unsigned n) const double CoalescentModel::evaluate(const unsigned k_snp, const unsigned k_indel, const unsigned n) const { - auto indel_heterozygosity = params_.indel_heterozygosity; - int max_offset {-1}; - for (const auto& site : site_buffer1_) { - if (is_indel(site)) { - const auto offset = begin_distance(reference_, site.get()); - auto itr = std::next(std::cbegin(reference_base_indel_heterozygosities_), offset); - using S = Variant::MappingDomain::Size; - itr = std::max_element(itr, std::next(itr, std::max(S {1}, region_size(site.get())))); - if (*itr > indel_heterozygosity) { - indel_heterozygosity = *itr; - max_offset = offset; - } - } - } - const auto t = std::make_tuple(k_snp, k_indel, n, max_offset); + const auto indel_heterozygosity = calculate_buffered_indel_heterozygosity(); + const auto t = std::make_tuple(k_snp, k_indel, n, maths::round_sf(indel_heterozygosity, 6)); auto itr = k_indel_pos_result_cache_.find(t); if (itr != std::cend(k_indel_pos_result_cache_)) { return itr->second; @@ -266,6 +211,19 @@ double CoalescentModel::evaluate(const unsigned k_snp, const unsigned k_indel, c return result; } +void CoalescentModel::fill_site_buffer(const Haplotype& haplotype) const +{ + assert(site_buffer2_.empty()); + site_buffer1_.clear(); + if (caching_ == CachingStrategy::address) { + fill_site_buffer_from_address_cache(haplotype); + } else { + fill_site_buffer_from_value_cache(haplotype); + } + site_buffer1_ = std::move(site_buffer2_); + site_buffer2_.clear(); +} + void CoalescentModel::fill_site_buffer(const std::vector& haplotype_indices) const { site_buffer1_.clear(); @@ -297,13 +255,26 @@ void CoalescentModel::fill_site_buffer(const std::vector& haplotype_in } } +void CoalescentModel::fill_site_buffer_uncached(const Haplotype& haplotype) const +{ + // Although we won't retrieve from the cache, we need to make sure all the variants + // stay in existence as we populate the buffers by reference. + auto itr = difference_value_cache_.find(reference_); + if (itr == std::cend(difference_value_cache_)) { + itr = difference_value_cache_.emplace(reference_, haplotype.difference(reference_)).first; + } else { + itr->second = haplotype.difference(reference_); + } + std::set_union(std::begin(site_buffer1_), std::end(site_buffer1_), + std::cbegin(itr->second), std::cend(itr->second), + std::back_inserter(site_buffer2_)); +} + void CoalescentModel::fill_site_buffer_from_value_cache(const Haplotype& haplotype) const { auto itr = difference_value_cache_.find(haplotype); if (itr == std::cend(difference_value_cache_)) { - itr = difference_value_cache_.emplace(std::piecewise_construct, - std::forward_as_tuple(haplotype), - std::forward_as_tuple(haplotype.difference(reference_))).first; + itr = difference_value_cache_.emplace(haplotype, haplotype.difference(reference_)).first; } std::set_union(std::begin(site_buffer1_), std::end(site_buffer1_), std::cbegin(itr->second), std::cend(itr->second), @@ -314,13 +285,78 @@ void CoalescentModel::fill_site_buffer_from_address_cache(const Haplotype& haplo { auto itr = difference_address_cache_.find(std::addressof(haplotype)); if (itr == std::cend(difference_address_cache_)) { - itr = difference_address_cache_.emplace(std::piecewise_construct, - std::forward_as_tuple(std::addressof(haplotype)), - std::forward_as_tuple(haplotype.difference(reference_))).first; + itr = difference_address_cache_.emplace(std::addressof(haplotype), haplotype.difference(reference_)).first; } std::set_union(std::begin(site_buffer1_), std::end(site_buffer1_), std::cbegin(itr->second), std::cend(itr->second), std::back_inserter(site_buffer2_)); } +CoalescentModel::SiteCountTuple CoalescentModel::count_segregating_sites(const Haplotype& haplotype) const +{ + fill_site_buffer(haplotype); + return count_segregating_sites_in_buffer(1); +} + +CoalescentModel::SiteCountTuple CoalescentModel::count_segregating_sites_in_buffer(const unsigned num_haplotypes) const +{ + const auto num_indels = std::count_if(std::cbegin(site_buffer1_), std::cend(site_buffer1_), + [] (const auto& v) noexcept { return is_indel(v); }); + return std::make_tuple(site_buffer1_.size() - num_indels, num_indels, num_haplotypes + 1); +} + +double CoalescentModel::calculate_buffered_indel_heterozygosity() const +{ + boost::optional result {}; + for (const auto& site : site_buffer1_) { + if (is_indel(site)) { + auto site_heterozygosity = calculate_heterozygosity(site); + if (result) { + result = std::max(*result, site_heterozygosity); + } else { + result = site_heterozygosity; + } + } + } + return result ? *result : params_.indel_heterozygosity; +} + +double CoalescentModel::calculate_heterozygosity(const Variant& indel) const +{ + assert(is_indel(indel)); + const auto offset = static_cast(begin_distance(reference_, indel)); + const auto indel_length = indel_size(indel); + assert(offset < indel_heterozygosity_model_.gap_open.size()); + constexpr decltype(indel_length) max_indel_length {50}; + return indel_heterozygosity_model_.gap_open[offset] + * std::pow(indel_heterozygosity_model_.gap_extend[offset], std::min(indel_length, max_indel_length) - 1); +} + +CoalescentProbabilityGreater::CoalescentProbabilityGreater(CoalescentModel model) +: model_ {std::move(model)} +, buffer_ {} +, cache_ {} +{ + buffer_.reserve(1); + cache_.reserve(100); +} + +bool CoalescentProbabilityGreater::operator()(const Haplotype& lhs, const Haplotype& rhs) const +{ + if (have_same_alleles(lhs, rhs)) return true; + auto cache_itr = cache_.find(lhs); + if (cache_itr == std::cend(cache_)) { + buffer_.assign({lhs}); + cache_itr = cache_.emplace(lhs, model_.evaluate(buffer_)).first; + } + const auto lhs_probability = cache_itr->second; + cache_itr = cache_.find(rhs); + if (cache_itr == std::cend(cache_)) { + buffer_.assign({rhs}); + cache_itr = cache_.emplace(rhs, model_.evaluate(buffer_)).first; + } + const auto rhs_probability = cache_itr->second; + return lhs_probability > rhs_probability; +} + } // namespace octopus diff --git a/src/core/models/mutation/coalescent_model.hpp b/src/core/models/mutation/coalescent_model.hpp index c82f813b4..2ef12ffe7 100644 --- a/src/core/models/mutation/coalescent_model.hpp +++ b/src/core/models/mutation/coalescent_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef coalescent_model_hpp @@ -20,6 +20,7 @@ #include "core/types/haplotype.hpp" #include "core/types/variant.hpp" +#include "indel_mutation_model.hpp" namespace octopus { @@ -54,15 +55,15 @@ class CoalescentModel void unprime() noexcept; bool is_primed() const noexcept; - template - double evaluate(const Container& haplotypes) const; - + // ln p(haplotype(s)) + double evaluate(const Haplotype& haplotype) const; + template double evaluate(const Container& haplotypes) const; double evaluate(const std::vector& haplotype_indices) const; private: using VariantReference = std::reference_wrapper; using SiteCountTuple = std::tuple; - using SiteCountIndelTuple = std::tuple; + using SiteCountIndelTuple = std::tuple; struct SiteCountTupleHash { @@ -73,7 +74,7 @@ class CoalescentModel }; Haplotype reference_; - std::vector reference_base_indel_heterozygosities_; + IndelMutationModel::ContextIndelModel indel_heterozygosity_model_; Parameters params_; std::vector haplotypes_; CachingStrategy caching_; @@ -90,21 +91,24 @@ class CoalescentModel double evaluate(unsigned k_snp, unsigned n) const; double evaluate(unsigned k_snp, unsigned k_indel, unsigned n) const; - template - void fill_site_buffer(const Container& haplotypes) const; + void fill_site_buffer(const Haplotype& haplotype) const; + template void fill_site_buffer(const Container& haplotypes) const; void fill_site_buffer(const std::vector& haplotype_indices) const; + void fill_site_buffer_uncached(const Haplotype& haplotype) const; void fill_site_buffer_from_value_cache(const Haplotype& haplotype) const; void fill_site_buffer_from_address_cache(const Haplotype& haplotype) const; - template - SiteCountTuple count_segregating_sites(const Container& haplotypes) const; + SiteCountTuple count_segregating_sites(const Haplotype& haplotype) const; + template SiteCountTuple count_segregating_sites(const Container& haplotypes) const; + SiteCountTuple count_segregating_sites_in_buffer(unsigned num_haplotypes) const; + double calculate_buffered_indel_heterozygosity() const; + double calculate_heterozygosity(const Variant& indel) const; }; template double CoalescentModel::evaluate(const Container& haplotypes) const { - const auto t = count_segregating_sites(haplotypes); - return evaluate(t); + return evaluate(count_segregating_sites(haplotypes)); } // private methods @@ -115,10 +119,10 @@ void CoalescentModel::fill_site_buffer(const Container& haplotypes) const assert(site_buffer2_.empty()); site_buffer1_.clear(); for (const Haplotype& haplotype : haplotypes) { - if (caching_ == CachingStrategy::address) { - fill_site_buffer_from_address_cache(haplotype); - } else { - fill_site_buffer_from_value_cache(haplotype); + switch (caching_) { + case CachingStrategy::address: fill_site_buffer_from_address_cache(haplotype); break; + case CachingStrategy::value: fill_site_buffer_from_value_cache(haplotype); break; + default: fill_site_buffer_uncached(haplotype); } site_buffer1_ = std::move(site_buffer2_); site_buffer2_.clear(); @@ -141,12 +145,21 @@ CoalescentModel::SiteCountTuple CoalescentModel::count_segregating_sites(const Container& haplotypes) const { fill_site_buffer(haplotypes); - const auto num_indels = std::count_if(std::cbegin(site_buffer1_), std::cend(site_buffer1_), - [] (const auto& v) noexcept { return is_indel(v); }); - return std::make_tuple(site_buffer1_.size() - num_indels, num_indels, - static_cast(detail::size(haplotypes) + 1)); + return count_segregating_sites_in_buffer(detail::size(haplotypes)); } +struct CoalescentProbabilityGreater +{ + CoalescentProbabilityGreater(CoalescentModel model); + + bool operator()(const Haplotype& lhs, const Haplotype& rhs) const; + +private: + CoalescentModel model_; + mutable std::vector buffer_; + mutable std::unordered_map, HaveSameAlleles> cache_; +}; + } // namespace octopus #endif diff --git a/src/core/models/mutation/denovo_model.cpp b/src/core/models/mutation/denovo_model.cpp index e20d1c15d..bea077dd8 100644 --- a/src/core/models/mutation/denovo_model.cpp +++ b/src/core/models/mutation/denovo_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "denovo_model.hpp" @@ -26,55 +26,22 @@ auto make_flat_hmm_model(const double snv_mutation_rate, const double indel_muta { auto mutation_penalty = static_cast(probability_to_phred(snv_mutation_rate).score()); auto gap_open_penalty = mutation_penalty; - auto gap_extension_penalty = static_cast(probability_to_phred(std::min(100 * indel_mutation_rate, max_rate)).score()); + auto gap_extension_penalty = static_cast(probability_to_phred(std::min(10'000 * indel_mutation_rate, max_rate)).score()); return hmm::FlatGapMutationModel {mutation_penalty, gap_open_penalty, gap_extension_penalty}; } -auto make_exponential_repeat_count_model(const double base_rate, const double repeat_count_log_multiplier, - const std::size_t max_repeat_count, - const boost::optional max_rate = boost::none) -{ - using Penalty = hmm::VariableGapOpenMutationModel::Penalty; - std::vector result(max_repeat_count); - const auto log_base_mutation_rate = std::log(base_rate); - boost::optional min_penalty {}; - if (max_rate) min_penalty = static_cast(std::max(std::log(*max_rate) / -maths::constants::ln10Div10<>, 1.0)); - for (unsigned i {0}; i < max_repeat_count; ++i) { - auto adjusted_log_rate = repeat_count_log_multiplier * i + log_base_mutation_rate; - auto adjusted_phred_rate = adjusted_log_rate / -maths::constants::ln10Div10<>; - if (min_penalty && adjusted_phred_rate < *min_penalty) { - std::fill(std::next(std::begin(result), i), std::end(result), *min_penalty); - break; - } else { - static constexpr double max_representable_penalty {127}; - result[i] = std::max(std::min(adjusted_phred_rate, max_representable_penalty), 1.0); - } - } - return result; -} - -auto make_gap_open_model(double indel_mutation_rate, std::size_t max_repeat_number, double max_rate = 0.2) -{ - return make_exponential_repeat_count_model(indel_mutation_rate, 1.5, max_repeat_number, max_rate); -} - -auto make_gap_extend_model(double indel_mutation_rate, std::size_t max_repeat_number) -{ - return make_exponential_repeat_count_model(100 * indel_mutation_rate, 1.6, max_repeat_number); -} - } // namespace DeNovoModel::DeNovoModel(Parameters parameters, std::size_t num_haplotypes_hint, CachingStrategy caching) : flat_mutation_model_ {make_flat_hmm_model(parameters.snv_mutation_rate, parameters.indel_mutation_rate)} -, repeat_length_gap_open_model_ {make_gap_open_model(parameters.indel_mutation_rate, 10)} -, repeat_length_gap_extend_model_ {make_gap_extend_model(parameters.indel_mutation_rate, 10)} +, indel_model_ {{parameters.indel_mutation_rate}} , min_ln_probability_ {} , num_haplotypes_hint_ {num_haplotypes_hint} , haplotypes_ {} , caching_ {caching} , gap_open_penalties_ {} -, gap_open_index_cache_ {} +, gap_extend_penalties_ {} +, gap_model_index_cache_ {} , value_cache_ {} , address_cache_ {} , guarded_index_cache_ {} @@ -83,7 +50,6 @@ DeNovoModel::DeNovoModel(Parameters parameters, std::size_t num_haplotypes_hint, , use_unguarded_ {false} { gap_open_penalties_.reserve(1000); - min_ln_probability_ = std::log(parameters.snv_mutation_rate) + std::log(parameters.indel_mutation_rate); if (caching_ == CachingStrategy::address) { address_cache_.reserve(num_haplotypes_hint_ * num_haplotypes_hint_); } else if (caching == CachingStrategy::value) { @@ -97,7 +63,7 @@ void DeNovoModel::prime(std::vector haplotypes) if (is_primed()) throw std::runtime_error {"DeNovoModel: already primed"}; constexpr std::size_t max_unguardered {50}; haplotypes_ = std::move(haplotypes); - gap_open_index_cache_.resize(haplotypes_.size()); + gap_model_index_cache_.resize(haplotypes_.size()); if (haplotypes_.size() <= max_unguardered) { unguarded_index_cache_.assign(haplotypes_.size(), std::vector(haplotypes_.size(), 0)); for (unsigned target {0}; target < haplotypes_.size(); ++target) { @@ -117,8 +83,12 @@ void DeNovoModel::unprime() noexcept { haplotypes_.clear(); haplotypes_.shrink_to_fit(); - gap_open_index_cache_.clear(); - gap_open_index_cache_.shrink_to_fit(); + gap_open_penalties_.clear(); + gap_open_penalties_.shrink_to_fit(); + gap_extend_penalties_.clear(); + gap_extend_penalties_.shrink_to_fit(); + gap_model_index_cache_.clear(); + gap_model_index_cache_.shrink_to_fit(); guarded_index_cache_.clear(); guarded_index_cache_.shrink_to_fit(); unguarded_index_cache_.clear(); @@ -141,7 +111,7 @@ double DeNovoModel::evaluate(const Haplotype& target, const Haplotype& given) co } } -double DeNovoModel::evaluate(const unsigned target, const unsigned given) const noexcept +double DeNovoModel::evaluate(const unsigned target, const unsigned given) const { if (use_unguarded_) { return unguarded_index_cache_[target][given]; @@ -191,7 +161,7 @@ auto sequence_length_distance(const Haplotype& lhs, const Haplotype& rhs) noexce bool can_align_with_hmm(const Haplotype& target, const Haplotype& given) noexcept { - return sequence_length_distance(target, given) <= 10 * hmm::min_flank_pad(); + return sequence_length_distance(target, given) < hmm::min_flank_pad(); } template @@ -201,145 +171,97 @@ void rotate_right(Container& c, const std::size_t n) std::rotate(std::rbegin(c), std::next(std::rbegin(c), n), std::rend(c)); } -double hmm_align(const Haplotype::NucleotideSequence& target, const Haplotype::NucleotideSequence& given, - const hmm::FlatGapMutationModel& model, const boost::optional min_ln_probability) noexcept +double calculate_score(const Variant& variant, const Haplotype& context, + const hmm::VariableGapExtendMutationModel& mutation_model) { - const auto p = hmm::evaluate(target, given, model); - return min_ln_probability ? std::max(p, *min_ln_probability) : p; + if (is_indel(variant)) { + assert(contains(context, variant)); + const auto offset = static_cast(begin_distance(context, variant)); + const auto indel_length = indel_size(variant); + assert(offset < mutation_model.gap_open.size()); + constexpr decltype(indel_length) max_indel_length {50}; + return mutation_model.gap_open[offset] + (std::min(indel_length, max_indel_length) - 1) * mutation_model.gap_extend[offset]; + } else { + return mutation_model.mutation; + } } -double hmm_align(const Haplotype::NucleotideSequence& target, const Haplotype::NucleotideSequence& given, - const hmm::VariableGapOpenMutationModel& model, const boost::optional min_ln_probability) noexcept +double approx_align(const Haplotype& target, const Haplotype& given, + const hmm::VariableGapExtendMutationModel& mutation_model) { - const auto p = hmm::evaluate(target, given, model); - return min_ln_probability ? std::max(p, *min_ln_probability) : p; + double score {0}; + const auto variants = target.difference(given); + for (const auto& variant : variants) { + score += calculate_score(variant, given, mutation_model); + } + return score * -maths::constants::ln10Div10<>; } -double approx_align(const Haplotype& target, const Haplotype& given, const hmm::FlatGapMutationModel& model, - const boost::optional min_ln_probability) +void set_penalties(const IndelMutationModel::ContextIndelModel::ProbabilityVector& probabilities, + hmm::VariableGapExtendMutationModel::PenaltyVector& penalties) { - using maths::constants::ln10Div10; - const auto indel_size = sequence_length_distance(target, given); - double score = {model.gap_open + model.gap_extend * (static_cast(indel_size) - 1)}; - if (min_ln_probability) { - const auto max_score = -*min_ln_probability / ln10Div10<>; - if (score >= max_score) { - return *min_ln_probability; - } - } - const auto variants = target.difference(given); - const auto mismatch_size = std::accumulate(std::cbegin(variants), std::cend(variants), 0, - [](auto curr, const auto& variant) noexcept { - return curr + (!is_indel(variant) ? region_size(variant) : 0); - }); - score += mismatch_size * model.mutation; - return min_ln_probability ? std::max(-ln10Div10<> * score, *min_ln_probability) : -ln10Div10<> * score; + assert(probabilities.size() == penalties.size()); + std::transform(std::cbegin(probabilities), std::cend(probabilities), std::begin(penalties), + [] (auto rate) { + static constexpr double max_penalty {127}; + return std::max(std::min(std::log(rate) / -maths::constants::ln10Div10<>, max_penalty), 1.0); + }); } } // namespace -boost::optional DeNovoModel::set_gap_open_penalties(const Haplotype& given) const +void DeNovoModel::set_gap_penalties(const Haplotype& given) const { - const auto repeats = get_short_tandem_repeats(given.sequence()); - if (!repeats.empty()) { - gap_open_penalties_.assign(sequence_size(given), flat_mutation_model_.gap_open); - const auto max_num_repeats = static_cast(repeat_length_gap_open_model_.size()); - unsigned max_repeat_number {0}; - for (const auto& repeat : repeats) { - const auto num_repeats = repeat.length / repeat.period; - assert(num_repeats > 0); - const auto penalty = repeat_length_gap_open_model_[std::min(num_repeats - 1, max_num_repeats - 1)]; - assert(repeat.pos + repeat.length <= gap_open_penalties_.size()); - std::fill_n(std::next(std::begin(gap_open_penalties_), repeat.pos), repeat.length, penalty); - max_repeat_number = std::max(num_repeats, max_repeat_number); - } - return max_repeat_number; - } else { - gap_open_penalties_.clear(); - return boost::none; - } + const auto contextual_indel_model = indel_model_.evaluate(given); + const auto num_bases = sequence_size(given); + assert(contextual_indel_model.gap_open.size() == num_bases); + gap_open_penalties_.resize(num_bases); + set_penalties(contextual_indel_model.gap_open, gap_open_penalties_); + assert(contextual_indel_model.gap_extend.size() == num_bases); + gap_extend_penalties_.resize(num_bases); + set_penalties(contextual_indel_model.gap_extend, gap_extend_penalties_); } -boost::optional DeNovoModel::set_gap_open_penalties(const unsigned given) const +void DeNovoModel::set_gap_penalties(const unsigned given) const { - assert(given < gap_open_index_cache_.size()); - auto& cached_result = gap_open_index_cache_[given]; + assert(given < gap_model_index_cache_.size()); + auto& cached_result = gap_model_index_cache_[given]; if (cached_result) { - if (cached_result->second) { - gap_open_penalties_ = cached_result->first; - } - return cached_result->second; + gap_open_penalties_ = cached_result->first; + gap_extend_penalties_ = cached_result->second; } else { - auto result = set_gap_open_penalties(haplotypes_[given]); - cached_result = std::make_pair(gap_open_penalties_, result); - return result; + set_gap_penalties(haplotypes_[given]); + cached_result = std::make_pair(gap_open_penalties_, gap_extend_penalties_); } } -hmm::VariableGapOpenMutationModel DeNovoModel::make_variable_hmm_model(const unsigned max_repeat_number) const +hmm::VariableGapExtendMutationModel DeNovoModel::make_hmm_model_from_cache() const { - assert(max_repeat_number > 0); - auto extension_idx = std::min(repeat_length_gap_extend_model_.size() - 1, static_cast(max_repeat_number) - 1); - auto extension_penalty = repeat_length_gap_extend_model_[extension_idx]; - return {flat_mutation_model_.mutation, gap_open_penalties_, extension_penalty}; + return {flat_mutation_model_.mutation, gap_open_penalties_, gap_extend_penalties_}; } -double DeNovoModel::evaluate_uncached(const Haplotype& target, const Haplotype& given) const +double DeNovoModel::evaluate_uncached(const Haplotype& target, const Haplotype& given, const bool gap_penalties_cached) const { - if (sequence_size(target) == sequence_size(given)) { + if (!gap_penalties_cached) set_gap_penalties(given); + const auto mutation_model = make_hmm_model_from_cache(); + double result; + if (can_align_with_hmm(target, given)) { pad_given(target, given, padded_given_); - return hmm_align(target.sequence(), padded_given_, flat_mutation_model_, min_ln_probability_); + gap_open_penalties_.resize(padded_given_.size(), mutation_model.mutation); + rotate_right(gap_open_penalties_, hmm::min_flank_pad()); + gap_extend_penalties_.resize(padded_given_.size(), mutation_model.mutation); + rotate_right(gap_extend_penalties_, hmm::min_flank_pad()); + result = hmm::evaluate(target.sequence(), padded_given_, mutation_model); } else { - const auto max_repeat_number = set_gap_open_penalties(given); - if (max_repeat_number) { - if (can_align_with_hmm(target, given)) { - pad_given(target, given, padded_given_); - gap_open_penalties_.resize(padded_given_.size(), flat_mutation_model_.gap_open); - rotate_right(gap_open_penalties_, hmm::min_flank_pad()); - const auto model = make_variable_hmm_model(*max_repeat_number); - return hmm_align(target.sequence(), padded_given_, model, min_ln_probability_); - } else { - return approx_align(target, given, flat_mutation_model_, min_ln_probability_); - } - } else { - if (can_align_with_hmm(target, given)) { - pad_given(target, given, padded_given_); - return hmm_align(target.sequence(), padded_given_, flat_mutation_model_, min_ln_probability_); - } else { - return approx_align(target, given, flat_mutation_model_, min_ln_probability_); - } - } + result = approx_align(target, given, mutation_model); } + return min_ln_probability_ ? std::max(result, *min_ln_probability_) : result; } double DeNovoModel::evaluate_uncached(const unsigned target_idx, const unsigned given_idx) const { - const auto& target = haplotypes_[target_idx]; - const auto& given = haplotypes_[given_idx]; - if (sequence_size(target) == sequence_size(given)) { - pad_given(target, given, padded_given_); - return hmm_align(target.sequence(), padded_given_, flat_mutation_model_, min_ln_probability_); - } else { - const auto max_repeat_length = set_gap_open_penalties(given_idx); - if (max_repeat_length) { - if (can_align_with_hmm(target, given)) { - pad_given(target, given, padded_given_); - gap_open_penalties_.resize(padded_given_.size(), flat_mutation_model_.gap_open); - rotate_right(gap_open_penalties_, hmm::min_flank_pad()); - const auto model = make_variable_hmm_model(*max_repeat_length); - return hmm_align(target.sequence(), padded_given_, model, min_ln_probability_); - } else { - return approx_align(target, given, flat_mutation_model_, min_ln_probability_); - } - } else { - if (can_align_with_hmm(target, given)) { - pad_given(target, given, padded_given_); - return hmm_align(target.sequence(), padded_given_, flat_mutation_model_, min_ln_probability_); - } else { - return approx_align(target, given, flat_mutation_model_, min_ln_probability_); - } - } - } + set_gap_penalties(given_idx); + return evaluate_uncached(haplotypes_[target_idx], haplotypes_[given_idx], true); } double DeNovoModel::evaluate_basic_cache(const Haplotype& target, const Haplotype& given) const diff --git a/src/core/models/mutation/denovo_model.hpp b/src/core/models/mutation/denovo_model.hpp index b896b713a..9c13360c6 100644 --- a/src/core/models/mutation/denovo_model.hpp +++ b/src/core/models/mutation/denovo_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef denovo_model_hpp @@ -14,6 +14,7 @@ #include "core/types/haplotype.hpp" #include "../pairhmm/pair_hmm.hpp" +#include "indel_mutation_model.hpp" namespace octopus { @@ -46,8 +47,7 @@ class DeNovoModel // ln p(target | given) double evaluate(const Haplotype& target, const Haplotype& given) const; - - double evaluate(unsigned target, unsigned given) const noexcept; + double evaluate(unsigned target, unsigned given) const; private: struct AddressPairHash @@ -60,17 +60,18 @@ class DeNovoModel } }; - using PenaltyVector = hmm::VariableGapOpenMutationModel::PenaltyVector; - using GapOpenResult = std::pair>; + using PenaltyVector = hmm::VariableGapOpenMutationModel::PenaltyVector; + using GapPenaltyModel = std::pair; hmm::FlatGapMutationModel flat_mutation_model_; - std::vector repeat_length_gap_open_model_, repeat_length_gap_extend_model_; + IndelMutationModel indel_model_; boost::optional min_ln_probability_; std::size_t num_haplotypes_hint_; std::vector haplotypes_; CachingStrategy caching_; - mutable PenaltyVector gap_open_penalties_; - mutable std::vector> gap_open_index_cache_; + + mutable PenaltyVector gap_open_penalties_, gap_extend_penalties_; + mutable std::vector> gap_model_index_cache_; mutable std::unordered_map> value_cache_; mutable std::unordered_map, double, AddressPairHash> address_cache_; mutable std::vector>> guarded_index_cache_; @@ -78,10 +79,10 @@ class DeNovoModel mutable std::string padded_given_; mutable bool use_unguarded_; - boost::optional set_gap_open_penalties(const Haplotype& given) const; - boost::optional set_gap_open_penalties(unsigned given) const; - hmm::VariableGapOpenMutationModel make_variable_hmm_model(unsigned max_repeat_length) const; - double evaluate_uncached(const Haplotype& target, const Haplotype& given) const; + void set_gap_penalties(const Haplotype& given) const; + void set_gap_penalties(unsigned given) const; + hmm::VariableGapExtendMutationModel make_hmm_model_from_cache() const; + double evaluate_uncached(const Haplotype& target, const Haplotype& given, bool gap_penalties_cached = false) const; double evaluate_uncached(unsigned target, unsigned given) const; double evaluate_basic_cache(const Haplotype& target, const Haplotype& given) const; double evaluate_address_cache(const Haplotype& target, const Haplotype& given) const; diff --git a/src/core/models/mutation/indel_mutation_model.cpp b/src/core/models/mutation/indel_mutation_model.cpp new file mode 100644 index 000000000..254b6c5fc --- /dev/null +++ b/src/core/models/mutation/indel_mutation_model.cpp @@ -0,0 +1,96 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "indel_mutation_model.hpp" + +#include +#include + +#include "utils/maths.hpp" +#include "utils/repeat_finder.hpp" + +namespace octopus { + +namespace { + +double calculate_gap_open_rate(const double base_rate, const unsigned period, const unsigned num_periods) +{ + return base_rate * std::pow(10.0, (3.0 / (6 + std::min(2 * period, 12u))) * period * num_periods); +} + +double calculate_gap_extend_rate(const double base_rate, const unsigned period, const unsigned num_periods, + const double gap_open_rate) +{ + return std::max(1'000 * gap_open_rate, 0.7); +} + +} // namespace + +IndelMutationModel::IndelMutationModel(Parameters params) +: params_ {std::move(params)} +, indel_repeat_model_ {params_.max_period + 1, std::vector(params_.max_periodicity + 1)} +{ + for (unsigned period {0}; period <= params_.max_period; ++period) { + for (unsigned n {0}; n <= params_.max_periodicity; ++n) { + const auto open_rate = calculate_gap_open_rate(params.indel_mutation_rate, period, n); + indel_repeat_model_[period][n].open = std::min(open_rate, params_.max_open_probability); + const auto extend_rate = calculate_gap_extend_rate(params.indel_mutation_rate, period, n, open_rate); + indel_repeat_model_[period][n].extend = std::min(extend_rate, params_.max_extend_probability); + } + } +} + +namespace { + +auto find_short_tandem_repeats(const Haplotype& haplotype) +{ + constexpr unsigned max_repeat_period {5}; + return find_exact_tandem_repeats(haplotype.sequence(), haplotype.mapped_region(), 1, max_repeat_period); +} + +template +auto fill_if_greater(FordwardIt first, FordwardIt last, const Tp& value) +{ + return std::transform(first, last, first, [&] (const auto& x) { return std::max(x, value); }); +} + +template +auto fill_n_if_greater(FordwardIt first, std::size_t n, const Tp& value) +{ + return fill_if_greater(first, std::next(first, n), value); +} + +} // namespace + +IndelMutationModel::ContextIndelModel IndelMutationModel::evaluate(const Haplotype& haplotype) const +{ + const auto repeats = find_short_tandem_repeats(haplotype); + ContextIndelModel result {}; + const auto haplotype_len = sequence_size(haplotype); + const auto& base_probabilities = indel_repeat_model_[0][0]; + result.gap_open.resize(haplotype_len, base_probabilities.open); + result.gap_extend.resize(haplotype_len, base_probabilities.extend); + for (const auto& repeat : repeats) { + assert(repeat.period > 0 && repeat.period <= params_.max_period); + const auto repeat_offset = static_cast(begin_distance(haplotype, repeat)); + const auto repeat_len = region_size(repeat); + const auto num_repeats = static_cast(repeat_len / repeat.period); + assert(num_repeats > 0); + const auto& repeat_state = indel_repeat_model_[repeat.period][std::min(num_repeats, params_.max_periodicity)]; + assert(repeat_offset + repeat_len <= result.gap_open.size()); + fill_n_if_greater(std::next(std::begin(result.gap_open), repeat_offset), repeat_len, repeat_state.open); + assert(repeat_offset + repeat_len <= result.gap_extend.size()); + fill_n_if_greater(std::next(std::begin(result.gap_extend), repeat_offset), repeat_len, repeat_state.extend); + } + return result; +} + +// non-member methods + +IndelMutationModel::ContextIndelModel make_indel_model(const Haplotype& context, IndelMutationModel::Parameters params) +{ + IndelMutationModel model {params}; + return model.evaluate(context); +} + +} // namespace octopus diff --git a/src/core/models/mutation/indel_mutation_model.hpp b/src/core/models/mutation/indel_mutation_model.hpp new file mode 100644 index 000000000..cb80b49be --- /dev/null +++ b/src/core/models/mutation/indel_mutation_model.hpp @@ -0,0 +1,57 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef indel_mutation_model_hpp +#define indel_mutation_model_hpp + +#include +#include + +#include "core/types/haplotype.hpp" +#include "core/types/variant.hpp" + +namespace octopus { + +class IndelMutationModel +{ +public: + struct Parameters + { + double indel_mutation_rate; + unsigned max_period = 10, max_periodicity = 20; + double max_open_probability = 1.0, max_extend_probability = 1.0; + }; + + struct ContextIndelModel + { + using Probability = double; + using ProbabilityVector = std::vector; + ProbabilityVector gap_open, gap_extend; + }; + + IndelMutationModel() = delete; + + IndelMutationModel(Parameters params); + + IndelMutationModel(const IndelMutationModel&) = default; + IndelMutationModel& operator=(const IndelMutationModel&) = default; + IndelMutationModel(IndelMutationModel&&) = default; + IndelMutationModel& operator=(IndelMutationModel&&) = default; + + ~IndelMutationModel() = default; + + ContextIndelModel evaluate(const Haplotype& haplotype) const; + +private: + struct ModelCell { double open, extend; }; + using RepeatModel = std::vector>; + + Parameters params_; + RepeatModel indel_repeat_model_; +}; + +IndelMutationModel::ContextIndelModel make_indel_model(const Haplotype& context, IndelMutationModel::Parameters params); + +} // namespace octopus + +#endif diff --git a/src/core/models/mutation/somatic_mutation_model.cpp b/src/core/models/mutation/somatic_mutation_model.cpp index 887398f65..d3cdfb4f4 100644 --- a/src/core/models/mutation/somatic_mutation_model.cpp +++ b/src/core/models/mutation/somatic_mutation_model.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "somatic_mutation_model.hpp" @@ -37,7 +37,7 @@ double SomaticMutationModel::evaluate(const Haplotype& somatic, const Haplotype& return model_.evaluate(somatic, germline); } -double SomaticMutationModel::evaluate(unsigned somatic, unsigned germline) const noexcept +double SomaticMutationModel::evaluate(unsigned somatic, unsigned germline) const { return model_.evaluate(somatic, germline); } diff --git a/src/core/models/mutation/somatic_mutation_model.hpp b/src/core/models/mutation/somatic_mutation_model.hpp index 170c10fc5..7718d6ffe 100644 --- a/src/core/models/mutation/somatic_mutation_model.hpp +++ b/src/core/models/mutation/somatic_mutation_model.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef somatic_mutation_model_hpp @@ -38,7 +38,7 @@ class SomaticMutationModel // ln p(somatic | germline) double evaluate(const Haplotype& somatic, const Haplotype& germline) const; - double evaluate(unsigned somatic, unsigned germline) const noexcept; + double evaluate(unsigned somatic, unsigned germline) const; private: DeNovoModel model_; diff --git a/src/core/models/pairhmm/pair_hmm.cpp b/src/core/models/pairhmm/pair_hmm.cpp index 1e4d0c300..a29e82665 100644 --- a/src/core/models/pairhmm/pair_hmm.cpp +++ b/src/core/models/pairhmm/pair_hmm.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "pair_hmm.hpp" @@ -193,10 +193,10 @@ auto simd_align(const std::string& truth, const std::string& target, } } -auto simd_align_with_cigar(const std::string& truth, const std::string& target, - const std::vector& target_qualities, - const std::size_t target_offset, - const MutationModel& model) noexcept +Alignment simd_align_with_cigar(const std::string& truth, const std::string& target, + const std::vector& target_qualities, + const std::size_t target_offset, + const MutationModel& model) noexcept { constexpr auto pad = simd::min_flank_pad(); const auto truth_size = static_cast(truth.size()); @@ -204,7 +204,7 @@ auto simd_align_with_cigar(const std::string& truth, const std::string& target, const auto truth_alignment_size = static_cast(target_size + 2 * pad - 1); const auto alignment_offset = std::max(0, static_cast(target_offset) - pad); if (alignment_offset + truth_alignment_size > truth_size) { - return std::make_pair(CigarString {}, std::numeric_limits::lowest()); + return {0, CigarString {}, std::numeric_limits::lowest()}; } const auto qualities = reinterpret_cast(target_qualities.data()); thread_local std::vector align1 {}, align2 {}; @@ -261,7 +261,8 @@ auto simd_align_with_cigar(const std::string& truth, const std::string& target, // Overflow has occurred when calculating score; } } - return std::make_pair(make_cigar(align1, align2), -ln10Div10<> * static_cast(score)); + auto mapping_position = target_offset - pad + first_pos; + return {mapping_position, make_cigar(align1, align2), -ln10Div10<> * static_cast(score)}; } unsigned min_flank_pad() noexcept @@ -304,6 +305,8 @@ double evaluate(const std::string& target, const std::string& truth, const auto m2 = std::mismatch(next(m1.first), cend(target), next(m1.second)); if (m2.first == cend(target)) { // then there is only a single base difference between the sequences, can optimise + // target: ACGTACGT + // truth: ACGTTCGT const auto truth_mismatch_idx = distance(offsetted_truth_begin_itr, m1.second) + target_offset; if (truth_mismatch_idx < model.lhs_flank_size || truth_mismatch_idx >= (truth.size() - model.rhs_flank_size)) { return 0; @@ -314,30 +317,85 @@ double evaluate(const std::string& target, const std::string& truth, mispatch_penalty = std::min(target_qualities[target_index], static_cast(model.snv_priors[truth_mismatch_idx])); } - if (mispatch_penalty <= model.gap_open[truth_mismatch_idx] - || !std::equal(next(m1.first), cend(target), m1.second)) { + if (mispatch_penalty <= model.gap_open[truth_mismatch_idx]) { return lnProbability[mispatch_penalty]; + } else { + if (std::equal(next(m1.first), cend(target), m1.second)) { + // target: AAAAGGGG + // truth: AAA GGGGG + return lnProbability[model.gap_open[truth_mismatch_idx]]; + } else if (std::equal(m1.first, cend(target), next(m1.second))) { + // target: AAA GGGGG + // truth: AAAAGGGGG + return lnProbability[model.gap_open[truth_mismatch_idx]]; + } else if (mispatch_penalty <= (model.gap_open[truth_mismatch_idx] + model.gap_extend)) { + return lnProbability[mispatch_penalty]; + } } - return lnProbability[model.gap_open[truth_mismatch_idx]]; } // TODO: we should be able to optimise the alignment based of the first mismatch postition return simd_align(truth, target, target_qualities, target_offset, model); } -std::pair -align(const std::string& target, const std::string& truth, - const std::vector& target_qualities, - std::size_t target_offset, - const MutationModel& model) +Alignment align(const std::string& target, const std::string& truth, + const std::vector& target_qualities, + std::size_t target_offset, + const MutationModel& model) { validate(truth, target, target_qualities, target_offset, model); if (std::equal(std::cbegin(target), std::cend(target), std::next(std::cbegin(truth), target_offset))) { - return {{CigarOperation {static_cast(target.size()), CigarOperation::Flag::sequenceMatch}}, 0}; + return {target_offset, + {CigarOperation {static_cast(target.size()), CigarOperation::Flag::sequenceMatch}}, + 0}; } else { return simd_align_with_cigar(truth, target, target_qualities, target_offset, model); } } +double evaluate(const std::string& target, const std::string& truth, const VariableGapExtendMutationModel& model) noexcept +{ + assert(truth.size() == model.gap_open.size()); + using std::cbegin; using std::cend; using std::next; using std::distance; + static constexpr auto lnProbability = make_phred_to_ln_prob_lookup(); + const auto truth_begin = next(cbegin(truth), min_flank_pad()); + const auto m1 = std::mismatch(cbegin(target), cend(target), truth_begin); + if (m1.first == cend(target)) { + return 0; // sequences are equal, can't do better than this + } + const auto m2 = std::mismatch(next(m1.first), cend(target), next(m1.second)); + if (m2.first == cend(target)) { + // target: ACGTACGT + // truth: ACGTTCGT + const auto truth_mismatch_idx = static_cast(distance(cbegin(truth), m1.second)); + if (model.mutation <= model.gap_open[truth_mismatch_idx]) { + return lnProbability[model.mutation]; + } else { + if (std::equal(next(m1.first), cend(target), m1.second)) { + // target: AAAAGGGG + // truth: AAA GGGGG + return lnProbability[model.gap_open[truth_mismatch_idx]]; + } else if (std::equal(m1.first, cend(target), next(m1.second))) { + // target: AAA GGGGG + // truth: AAAAGGGGG + return lnProbability[model.gap_open[truth_mismatch_idx]]; + } else if (model.mutation <= (model.gap_open[truth_mismatch_idx] + model.gap_extend[truth_mismatch_idx])) { + return lnProbability[model.mutation]; + } + } + } + const auto truth_alignment_size = static_cast(target.size() + 2 * min_flank_pad() - 1); + thread_local std::vector dummy_qualities; + dummy_qualities.assign(target.size(), model.mutation); + auto score = simd::align(truth.c_str(), target.c_str(), + dummy_qualities.data(), + truth_alignment_size, + static_cast(target.size()), + model.gap_open.data(), + model.gap_extend.data(), + model.nuc_prior); + return -ln10Div10<> * static_cast(score); +} + double evaluate(const std::string& target, const std::string& truth, const VariableGapOpenMutationModel& model) noexcept { assert(truth.size() == model.gap_open.size()); @@ -350,12 +408,24 @@ double evaluate(const std::string& target, const std::string& truth, const Varia } const auto m2 = std::mismatch(next(m1.first), cend(target), next(m1.second)); if (m2.first == cend(target)) { - // then there is only a single base difference between the sequences, can optimise + // target: ACGTACGT + // truth: ACGTTCGT const auto truth_mismatch_idx = static_cast(distance(cbegin(truth), m1.second)); - if (model.mutation <= model.gap_open[truth_mismatch_idx] || !std::equal(next(m1.first), cend(target), m1.second)) { - return lnProbability[model.gap_open[truth_mismatch_idx]]; + if (model.mutation <= model.gap_open[truth_mismatch_idx]) { + return lnProbability[model.mutation]; + } else { + if (std::equal(next(m1.first), cend(target), m1.second)) { + // target: AAAAGGGG + // truth: AAA GGGGG + return lnProbability[model.gap_open[truth_mismatch_idx]]; + } else if (std::equal(m1.first, cend(target), next(m1.second))) { + // target: AAA GGGGG + // truth: AAAAGGGGG + return lnProbability[model.gap_open[truth_mismatch_idx]]; + } else if (model.mutation <= (model.gap_open[truth_mismatch_idx] + model.gap_extend)) { + return lnProbability[model.mutation]; + } } - return lnProbability[model.mutation]; } const auto truth_alignment_size = static_cast(target.size() + 2 * min_flank_pad() - 1); thread_local std::vector dummy_qualities; @@ -365,7 +435,8 @@ double evaluate(const std::string& target, const std::string& truth, const Varia truth_alignment_size, static_cast(target.size()), model.gap_open.data(), - model.gap_extend, 2); + model.gap_extend, + model.nuc_prior); return -ln10Div10<> * static_cast(score); } @@ -380,11 +451,23 @@ double evaluate(const std::string& target, const std::string& truth, const FlatG } const auto m2 = std::mismatch(next(m1.first), cend(target), next(m1.second)); if (m2.first == cend(target)) { - // then there is only a single base difference between the sequences, can optimise - if (model.mutation <= model.gap_open || !std::equal(next(m1.first), cend(target), m1.second)) { - return lnProbability[model.gap_open]; + // target: ACGTACGT + // truth: ACGTTCGT + if (model.mutation <= model.gap_open) { + return lnProbability[model.mutation]; + } else { + if (std::equal(next(m1.first), cend(target), m1.second)) { + // target: AAAAGGGG + // truth: AAA GGGGG + return lnProbability[model.gap_open]; + } else if (std::equal(m1.first, cend(target), next(m1.second))) { + // target: AAA GGGGG + // truth: AAAAGGGGG + return lnProbability[model.gap_open]; + } else if (model.mutation <= (model.gap_open + model.gap_extend)) { + return lnProbability[model.mutation]; + } } - return lnProbability[model.mutation]; } const auto truth_alignment_size = static_cast(target.size() + 2 * min_flank_pad() - 1); thread_local std::vector dummy_qualities; @@ -394,7 +477,8 @@ double evaluate(const std::string& target, const std::string& truth, const FlatG truth_alignment_size, static_cast(target.size()), model.gap_open, - model.gap_extend, 2); + model.gap_extend, + model.nuc_prior); return -ln10Div10<> * static_cast(score); } diff --git a/src/core/models/pairhmm/pair_hmm.hpp b/src/core/models/pairhmm/pair_hmm.hpp index 5e06b011b..d54083e62 100644 --- a/src/core/models/pairhmm/pair_hmm.hpp +++ b/src/core/models/pairhmm/pair_hmm.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef pair_hmm_hpp @@ -29,6 +29,16 @@ struct MutationModel std::size_t lhs_flank_size = 0, rhs_flank_size = 0; }; +struct VariableGapExtendMutationModel +{ + using Penalty = std::int8_t; + using PenaltyVector = std::vector; + Penalty mutation; + const PenaltyVector& gap_open; + const std::vector& gap_extend; + short nuc_prior = 2; +}; + struct VariableGapOpenMutationModel { using Penalty = std::int8_t; @@ -36,12 +46,21 @@ struct VariableGapOpenMutationModel Penalty mutation; const PenaltyVector& gap_open; short gap_extend; + short nuc_prior = 2; }; struct FlatGapMutationModel { std::int8_t mutation; short gap_open, gap_extend; + short nuc_prior = 2; +}; + +struct Alignment +{ + std::size_t target_offset; + CigarString cigar; + double likelihood; }; // p(target | truth, target_qualities, target_offset, model) @@ -53,11 +72,17 @@ double evaluate(const std::string& target, const std::string& truth, std::size_t target_offset, const MutationModel& model); -std::pair -align(const std::string& target, const std::string& truth, - const std::vector& target_qualities, - std::size_t target_offset, - const MutationModel& model); +Alignment align(const std::string& target, const std::string& truth, + const std::vector& target_qualities, + std::size_t target_offset, + const MutationModel& model); + +// p(target | truth, model) +// +// Warning: The target must be contained by the truth by exactly +// min_flank_pad() on either side. +double evaluate(const std::string& target, const std::string& truth, + const VariableGapExtendMutationModel& model) noexcept; // p(target | truth, model) // diff --git a/src/core/models/pairhmm/simd_pair_hmm.cpp b/src/core/models/pairhmm/simd_pair_hmm.cpp index d5f1ae26f..939879e50 100755 --- a/src/core/models/pairhmm/simd_pair_hmm.cpp +++ b/src/core/models/pairhmm/simd_pair_hmm.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke and Gerton Lunter +// Copyright (c) 2015-2018 Daniel Cooke and Gerton Lunter // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #if __GNUC__ >= 6 @@ -307,6 +307,113 @@ int align(const char* truth, const char* target, const std::int8_t* qualities, return (minscore + 0x8000) >> 2; } +int align(const char* truth, const char* target, const std::int8_t* qualities, + const int truth_len, const int target_len, + const std::int8_t* gap_open, const std::int8_t* gap_extend, + short nuc_prior) noexcept +{ + assert(truth_len > bandSize && (truth_len == target_len + 2 * bandSize - 1)); + + nuc_prior <<= 2; + + using SimdInt = __m128i; + + SimdInt _m1 {_mm_set1_epi16(inf)}; + auto _i1 = _m1; + auto _d1 = _m1; + auto _m2 = _m1; + auto _i2 = _m1; + auto _d2 = _m1; + + SimdInt _nuc_prior {_mm_set1_epi16(nuc_prior)}; + + SimdInt _initmask {_mm_set_epi16(0,0,0,0,0,0,0,-1)}; + SimdInt _initmask2 {_mm_set_epi16(0,0,0,0,0,0,0,-0x8000)}; + + // truth is initialized with the n-long prefix, in forward direction + // target is initialized as empty; reverse direction + SimdInt _truthwin {_mm_set_epi16(truth[7], truth[6], truth[5], truth[4], + truth[3], truth[2], truth[1], truth[0])}; + SimdInt _targetwin {_m1}; + SimdInt _qualitieswin {_mm_set1_epi16(64 << 2)}; + + // if N, make nScore; if != N, make inf + SimdInt _truthnqual {_mm_add_epi16(_mm_and_si128(_mm_cmpeq_epi16(_truthwin, _mm_set1_epi16('N')), + _mm_set1_epi16(nScore - inf)), + _mm_set1_epi16(inf))}; + + SimdInt _gap_open {_mm_set_epi16(gap_open[7] << 2,gap_open[6] << 2,gap_open[5] << 2,gap_open[4] << 2, + gap_open[3] << 2,gap_open[2] << 2,gap_open[1] << 2,gap_open[0] << 2)}; + SimdInt _gap_extend {_mm_set_epi16(gap_extend[7] << 2,gap_extend[6] << 2,gap_extend[5] << 2,gap_extend[4] << 2, + gap_extend[3] << 2,gap_extend[2] << 2,gap_extend[1] << 2,gap_extend[0] << 2)}; + + short minscore {inf}; + + for (int s {0}; s <= 2 * (target_len + bandSize); s += 2) { + // truth is current; target needs updating + _targetwin = _mm_slli_si128(_targetwin, 2); + _qualitieswin = _mm_slli_si128(_qualitieswin, 2); + + if (s / 2 < target_len) { + _targetwin = _mm_insert_epi16(_targetwin, target[s / 2], 0); + _qualitieswin = _mm_insert_epi16(_qualitieswin, qualities[s / 2] << 2, 0); + } else { + _targetwin = _mm_insert_epi16(_targetwin, '0', 0); + _qualitieswin = _mm_insert_epi16(_qualitieswin, 64 << 2, 0); + } + + // S even + _m1 = _mm_or_si128(_initmask2, _mm_andnot_si128(_initmask, _m1)); + _m2 = _mm_or_si128(_initmask2, _mm_andnot_si128(_initmask, _m2)); + _m1 = _mm_min_epi16(_m1, _mm_min_epi16(_i1, _d1)); + + if (s / 2 >= target_len) { + minscore = std::min(static_cast(extract_epi16(_m1, s / 2 - target_len)), minscore); + } + + _m1 = _mm_add_epi16(_m1, _mm_min_epi16(_mm_andnot_si128(_mm_cmpeq_epi16(_targetwin, _truthwin), + _qualitieswin), _truthnqual)); + _d1 = _mm_min_epi16(_mm_add_epi16(_d2, _gap_extend), + _mm_add_epi16(_mm_min_epi16(_m2, _i2), + _mm_srli_si128(_gap_open, 2))); // allow I->D + _d1 = _mm_insert_epi16(_mm_slli_si128(_d1, 2), inf, 0); + _i1 = _mm_add_epi16(_mm_min_epi16(_mm_add_epi16(_i2, _gap_extend), + _mm_add_epi16(_m2, _gap_open)), + _nuc_prior); + + // S odd; truth needs updating; target is current + const auto pos = bandSize + s / 2; + const char base {(pos < truth_len) ? truth[pos] : 'N'}; + _truthwin = _mm_insert_epi16(_mm_srli_si128(_truthwin, 2), base, bandSize - 1); + _truthnqual = _mm_insert_epi16(_mm_srli_si128(_truthnqual, 2), base == 'N' ? nScore : inf, bandSize - 1); + const auto gap_idx = pos < truth_len ? pos : truth_len - 1; + _gap_open = _mm_insert_epi16(_mm_srli_si128(_gap_open, 2), gap_open[gap_idx] << 2, bandSize - 1); + _gap_extend = _mm_insert_epi16(_mm_srli_si128(_gap_extend, 2), gap_extend[gap_idx] << 2, bandSize - 1); + + _initmask = _mm_slli_si128(_initmask, 2); + _initmask2 = _mm_slli_si128(_initmask2, 2); + + _m2 = _mm_min_epi16(_m2, _mm_min_epi16(_i2, _d2)); + + if (s / 2 >= target_len) { + minscore = std::min(static_cast(extract_epi16(_m2, s / 2 - target_len)), minscore); + } + + _m2 = _mm_add_epi16(_m2, _mm_min_epi16(_mm_andnot_si128(_mm_cmpeq_epi16(_targetwin, _truthwin), + _qualitieswin), _truthnqual)); + _d2 = _mm_min_epi16(_mm_add_epi16(_d1, _gap_extend), + _mm_add_epi16(_mm_min_epi16(_m1, _i1), _gap_open)); // allow I->D + _i2 = _mm_insert_epi16(_mm_add_epi16(_mm_min_epi16(_mm_add_epi16(_mm_srli_si128(_i1, 2), + _gap_extend), + _mm_add_epi16(_mm_srli_si128(_m1, 2), + _gap_open)), + _nuc_prior), inf, bandSize - 1); + + } + + return (minscore + 0x8000) >> 2; +} + int align(const char* truth, const char* target, const std::int8_t* qualities, const int truth_len, const int target_len, const char* snv_mask, const std::int8_t* snv_prior, diff --git a/src/core/models/pairhmm/simd_pair_hmm.hpp b/src/core/models/pairhmm/simd_pair_hmm.hpp index 3abcdb658..258a526d2 100755 --- a/src/core/models/pairhmm/simd_pair_hmm.hpp +++ b/src/core/models/pairhmm/simd_pair_hmm.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke and Gerton Lunter +// Copyright (c) 2015-2018 Daniel Cooke and Gerton Lunter // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef simd_pair_hmm_hpp @@ -18,6 +18,11 @@ int align(const char* truth, const char* target, const std::int8_t* qualities, int truth_len, int target_len, const std::int8_t* gap_open, short gap_extend, short nuc_prior) noexcept; +int align(const char* truth, const char* target, const std::int8_t* qualities, + int truth_len, int target_len, + const std::int8_t* gap_open, const std::int8_t* gap_extend, + short nuc_prior) noexcept; + int align(const char* truth, const char* target, const std::int8_t* qualities, int truth_len, int target_len, const char* snv_mask, const std::int8_t* snv_prior, diff --git a/src/core/models/reference/individual_reference_likelihood_model.cpp b/src/core/models/reference/individual_reference_likelihood_model.cpp new file mode 100644 index 000000000..b16974bef --- /dev/null +++ b/src/core/models/reference/individual_reference_likelihood_model.cpp @@ -0,0 +1,9 @@ +// Copyright (c) 2016 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "individual_reference_likelihood_model.hpp" + +namespace octopus { + + +} // namespace octopus diff --git a/src/core/models/reference/individual_reference_likelihood_model.hpp b/src/core/models/reference/individual_reference_likelihood_model.hpp new file mode 100644 index 000000000..f2dd9d0ee --- /dev/null +++ b/src/core/models/reference/individual_reference_likelihood_model.hpp @@ -0,0 +1,12 @@ +// Copyright (c) 2016 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef individual_reference_likelihood_model_hpp +#define individual_reference_likelihood_model_hpp + +namespace octopus { + + +} // namespace octopus + +#endif diff --git a/src/core/octopus.cpp b/src/core/octopus.cpp index b4cc7e9a9..98aacb550 100644 --- a/src/core/octopus.cpp +++ b/src/core/octopus.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "octopus.hpp" @@ -28,6 +28,7 @@ #include "config/common.hpp" #include "basics/genomic_region.hpp" +#include "basics/ploidy_map.hpp" #include "concepts/mappable.hpp" #include "containers/mappable_flat_multi_set.hpp" #include "containers/mappable_map.hpp" @@ -52,6 +53,7 @@ #include "csr/filters/variant_call_filter.hpp" #include "csr/filters/variant_call_filter_factory.hpp" #include "readpipe/buffered_read_pipe.hpp" +#include "core/tools/bam_realigner.hpp" #include "timers.hpp" // BENCHMARK @@ -1256,7 +1258,7 @@ bool use_unfiltered_call_region_hints_for_filtering(const GenomeCallingComponent return true; } -void run_filtering(GenomeCallingComponents& components) +void run_csr(GenomeCallingComponents& components) { if (apply_csr(components)) { log_filtering_info(components); @@ -1283,12 +1285,14 @@ void run_filtering(GenomeCallingComponents& components) if (components.sites_only()) { output_config.emit_sites_only = true; } - const auto filter = filter_factory.make(components.reference(), std::move(buffered_rp), + const VcfReader in {std::move(*input_path)}; + const auto filter = filter_factory.make(components.reference(), std::move(buffered_rp), in.fetch_header(), + components.ploidies(), components.pedigree(), output_config, progress, components.num_threads()); assert(filter); - const VcfReader in {std::move(*input_path)}; VcfWriter& out {*components.filtered_output()}; filter->filter(in, out); + out.close(); } } @@ -1347,6 +1351,130 @@ class CallingBug : public ProgramError CallingBug(const std::exception& e) : what_ {e.what()} {} }; +bool is_bam_realignment_requested(const GenomeCallingComponents& components) +{ + return static_cast(components.bamout()); +} + +bool is_stdout_final_output(const GenomeCallingComponents& components) +{ + return (components.filtered_output() && !components.filtered_output()->path()) || !components.output().path(); +} + +bool check_bam_realign(const GenomeCallingComponents& components) +{ + logging::WarningLogger warn_log {}; + if (components.samples().size() > 1) { + warn_log << "BAM realignment currently only supported for single sample"; + return false; + } + if (components.read_manager().num_files() > 1) { + warn_log << "BAM realignment currently only supported for single input BAM"; + return false; + } + if (is_stdout_final_output(components)) { + warn_log << "BAM realignment is not supported for stdout calling"; + return false; + } + return true; +} + +auto get_bam_realignment_vcf(const GenomeCallingComponents& components) +{ + if (components.filtered_output()) { + return *components.filtered_output()->path(); + } else { + return *components.output().path(); + } +} + +bool is_sam_type(const boost::filesystem::path& path) +{ + const auto type = path.extension().string(); + return type == ".bam" || type == ".sam" || type == ".cram"; +} + +bool is_called_ploidy_known(const GenomeCallingComponents& components) +{ + const auto contigs = components.contigs(); + return std::all_of(std::cbegin(contigs), std::cend(contigs), [&] (const auto& contig) { + const auto caller = components.caller_factory().make(components.contigs().front()); + return caller->min_callable_ploidy() == caller->max_callable_ploidy(); + }); +} + +auto get_max_called_ploidy(VcfReader& vcf) +{ + const auto samples = vcf.fetch_header().samples(); + unsigned result {0}; + for (auto p = vcf.iterate(); p.first != p.second; ++p.first) { + for (const auto& sample : samples) { + result = std::max(p.first->ploidy(sample), result); + } + } + return result; +} + +auto get_max_called_ploidy(const boost::filesystem::path& output_vcf) +{ + VcfReader vcf {output_vcf}; + return get_max_called_ploidy(vcf); +} + +auto get_final_output_path(const GenomeCallingComponents& components) +{ + if (apply_csr(components)) { + return components.filtered_output()->path(); + } else { + return components.output().path(); + } +} + +auto get_max_ploidy(const GenomeCallingComponents& components) +{ + if (is_called_ploidy_known(components)) { + return get_max_ploidy(components.samples(), components.contigs(), components.ploidies()); + } else { + assert(get_final_output_path(components)); + return get_max_called_ploidy(*get_final_output_path(components)); + } +} + +auto get_haplotype_bam_paths(const boost::filesystem::path& prefix, const unsigned max_ploidy) +{ + std::vector result {}; + result.reserve(max_ploidy + 1); // + 1 for unassigned reads + for (unsigned i {1}; i <= max_ploidy + 1; ++i) { + result.emplace_back(prefix.string() + std::to_string(i) + ".bam"); + } + return result; +} + +auto get_bamout_paths(const GenomeCallingComponents& components) +{ + namespace fs = boost::filesystem; + std::vector result {}; + auto request = components.bamout(); + if (!request) return result; + if (is_sam_type(*request)) { + result.assign({*request}); + } else { + result = get_haplotype_bam_paths(*request, get_max_ploidy(components)); + } + return result; +} + +void run_bam_realign(GenomeCallingComponents& components) +{ + if (is_bam_realignment_requested(components)) { + if (check_bam_realign(components)) { + components.read_manager().close(); + realign(components.read_manager().paths().front(), get_bam_realignment_vcf(components), + get_bamout_paths(components), components.reference()); + } + } +} + void run_octopus(GenomeCallingComponents& components, std::string command) { static auto debug_log = get_debug_log(); @@ -1381,7 +1509,7 @@ void run_octopus(GenomeCallingComponents& components, std::string command) } components.output().close(); try { - run_filtering(components); + run_csr(components); } catch (...) { try { if (debug_log) *debug_log << "Encountered an error whilst filtering, attempting to cleanup"; @@ -1400,6 +1528,7 @@ void run_octopus(GenomeCallingComponents& components, std::string command) stream(info_log) << "Finished calling " << utils::format_with_commas(search_size) << "bp, total runtime " << TimeInterval {start, end}; + run_bam_realign(components); cleanup(components); } diff --git a/src/core/octopus.hpp b/src/core/octopus.hpp index 9025cc078..c922c7b2b 100644 --- a/src/core/octopus.hpp +++ b/src/core/octopus.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef octopus_hpp diff --git a/src/core/tools/bam_realigner.cpp b/src/core/tools/bam_realigner.cpp new file mode 100644 index 000000000..f56414741 --- /dev/null +++ b/src/core/tools/bam_realigner.cpp @@ -0,0 +1,365 @@ +// Copyright (c) 2017 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "bam_realigner.hpp" + +#include +#include +#include +#include +#include +#include + +#include "basics/genomic_region.hpp" +#include "basics/cigar_string.hpp" +#include "utils/genotype_reader.hpp" +#include "io/variant/vcf_record.hpp" +#include "utils/append.hpp" +#include "read_assigner.hpp" +#include "read_realigner.hpp" + +namespace octopus { + +namespace { + +unsigned get_pool_size(const BAMRealigner::Config& config) +{ + const auto num_cores = std::thread::hardware_concurrency(); + if (config.max_threads) { + if (*config.max_threads > 1) { + return num_cores > 0 ? std::min(*config.max_threads, num_cores) : *config.max_threads; + } else { + return 0; + } + } else { + return num_cores > 0 ? num_cores : 8; + } +} + +} // namespace + +BAMRealigner::BAMRealigner(Config config) +: config_ {std::move(config)} +, workers_ {get_pool_size(config_)} +{} + +namespace { + +bool is_alignable(const AlignedRead& read) noexcept +{ + return is_valid(read.cigar()); +} + +auto remove_unalignable_reads(std::vector& reads) +{ + const auto bad_read_itr = std::stable_partition(std::begin(reads), std::end(reads), is_alignable); + std::deque result {std::make_move_iterator(bad_read_itr), + std::make_move_iterator(std::end(reads))}; + reads.erase(bad_read_itr, std::end(reads)); + return result; +} + +bool is_homozygous_nonreference(const Genotype& genotype) +{ + return genotype.is_homozygous() && !is_reference(genotype[0]); +} + +auto vectorise(std::deque&& reads) +{ + std::vector result {}; + utils::append(std::move(reads), result); + return result; +} + +auto assign_and_realign(const std::vector& reads, const Genotype& genotype, + BAMRealigner::Report& report) +{ + std::vector result {}; + if (!reads.empty()) { + result.reserve(reads.size()); + if (is_homozygous_nonreference(genotype)) { + utils::append(safe_realign_to_reference(reads, genotype[0]), result); + } else { + std::deque unassigned_reads {}; + auto support = compute_haplotype_support(genotype, reads, unassigned_reads); + for (auto& p : support) { + if (!p.second.empty()) { + report.n_reads_assigned += p.second.size(); + safe_realign_to_reference(p.second, p.first); + utils::append(std::move(p.second), result); + } + } + if (!unassigned_reads.empty()) { + report.n_reads_assigned += unassigned_reads.size(); + utils::append(safe_realign_to_reference(vectorise(std::move(unassigned_reads)), genotype[0]), result); + } + } + std::sort(std::begin(result), std::end(result)); + } + return result; +} + +template +auto move_merge(Container& src, std::vector& dst) +{ + auto itr = utils::append(std::move(src), dst); + std::inplace_merge(std::begin(dst), itr, std::end(dst)); +} + +} // namespace + +BAMRealigner::Report BAMRealigner::realign(ReadReader& src, VcfReader& variants, ReadWriter& dst, + const ReferenceGenome& reference, SampleList samples) const +{ + Report report {}; + boost::optional batch_region {}; + for (auto p = variants.iterate(); p.first != p.second; ) { + auto batch = read_next_batch(p.first, p.second, src, reference, samples, batch_region); + for (auto& sample : batch) { + std::vector genotype_reads {}, realigned_reads {}; + auto sample_reads_itr = std::begin(sample.reads); + for (const auto& genotype : sample.genotypes) { + const auto padded_genotype_region = expand(mapped_region(genotype), 1); + const auto overlapped_reads = bases(overlap_range(sample_reads_itr, std::end(sample.reads), padded_genotype_region)); + genotype_reads.assign(std::make_move_iterator(overlapped_reads.begin()), + std::make_move_iterator(overlapped_reads.end())); + sample_reads_itr = sample.reads.erase(overlapped_reads.begin(), overlapped_reads.end()); + auto bad_reads = remove_unalignable_reads(genotype_reads); + auto realignments = assign_and_realign(genotype_reads, genotype, report); + report.n_reads_unassigned += bad_reads.size(); + move_merge(bad_reads, realignments); + utils::append(std::move(realignments), realigned_reads); + } + move_merge(realigned_reads, sample.reads); + dst << sample.reads; + } + batch_region = encompassing_region(batch.front().genotypes); + } + return report; +} + +BAMRealigner::Report BAMRealigner::realign(ReadReader& src, VcfReader& variants, ReadWriter& dst, + const ReferenceGenome& reference) const +{ + return realign(src, variants, dst, reference, src.extract_samples()); +} + +namespace { + +auto split_and_realign(const std::vector& reads, const Genotype& genotype, + BAMRealigner::Report& report) +{ + std::vector> result(genotype.zygosity() + 1); + if (!reads.empty()) { + if (is_homozygous_nonreference(genotype)) { + report.n_reads_assigned += reads.size(); + result.back() = safe_realign_to_reference(reads, genotype[0]); + } else { + std::deque unassigned_reads {}; + auto support = compute_haplotype_support(genotype, reads, unassigned_reads); + std::size_t result_idx {0}; + for (const auto& haplotype : genotype) { + auto support_itr = support.find(haplotype); + if (support_itr != std::cend(support)) { + auto& haplotype_support = support_itr->second; + if (!haplotype_support.empty()) { + report.n_reads_assigned += haplotype_support.size(); + safe_realign_to_reference(haplotype_support, haplotype); + result[result_idx] = std::move(haplotype_support); + ++result_idx; + } + support.erase(support_itr); + } + } + if (!unassigned_reads.empty()) { + report.n_reads_assigned += unassigned_reads.size(); + utils::append(std::move(unassigned_reads), result.back()); + } + } + for (auto& set : result) std::sort(std::begin(set), std::end(set)); + } + return result; +} + +} // namespace + +BAMRealigner::Report BAMRealigner::realign(ReadReader& src, VcfReader& variants, std::vector& dsts, + const ReferenceGenome& reference, SampleList samples) const +{ + if (dsts.size() == 1) return realign(src, variants, dsts.front(), reference, samples); + Report report {}; + boost::optional batch_region {}; + for (auto p = variants.iterate(); p.first != p.second; ) { + auto batch = read_next_batch(p.first, p.second, src, reference, samples, batch_region); + for (auto& sample : batch) { + std::vector genotype_reads {}, unassigned_realigned_reads {}; + auto sample_reads_itr = std::begin(sample.reads); + for (const auto& genotype : sample.genotypes) { + const auto overlapped_reads = bases(overlap_range(sample_reads_itr, std::end(sample.reads), genotype)); + genotype_reads.assign(std::make_move_iterator(overlapped_reads.begin()), + std::make_move_iterator(overlapped_reads.end())); + sample_reads_itr = sample.reads.erase(overlapped_reads.begin(), overlapped_reads.end()); + auto bad_reads = remove_unalignable_reads(genotype_reads); + auto realignments = split_and_realign(genotype_reads, genotype, report); + report.n_reads_unassigned += bad_reads.size(); + move_merge(bad_reads, realignments.back()); + assert(realignments.size() <= dsts.size()); + for (unsigned i {0}; i < realignments.size() - 1; ++i) { + dsts[i] << realignments[i]; + } + utils::append(std::move(realignments.back()), unassigned_realigned_reads); // end is always unassigned, but ploidy can change + } + move_merge(unassigned_realigned_reads, sample.reads); + dsts.back() << sample.reads; + } + batch_region = encompassing_region(batch.front().genotypes); + } + return report; +} + +BAMRealigner::Report BAMRealigner::realign(ReadReader& src, VcfReader& variants, std::vector& dsts, + const ReferenceGenome& reference) const +{ + return realign(src, variants, dsts, reference, src.extract_samples()); +} + +// private methods + +namespace { + +GenomicRegion get_phase_set(const VcfRecord& record, const SampleName& sample) +{ + auto result = get_phase_region(record, sample); + return result ? *result : mapped_region(record); +} + +std::vector get_phase_sets(const VcfRecord& record, const std::vector& samples) +{ + std::vector result{}; + result.reserve(samples.size()); + std::transform(std::cbegin(samples), std::cend(samples), std::back_inserter(result), + [&record](const auto& sample) { return get_phase_set(record, sample); }); + return result; +} + +GenomicRegion get_phase_region(const VcfRecord& record, const std::vector& samples) +{ + return encompassing_region(get_phase_sets(record, samples)); +} + +template +std::vector copy_each_first(const std::vector>& items) +{ + std::vector result {}; + result.reserve(items.size()); + std::transform(std::cbegin(items), std::cend(items), std::back_inserter(result), + [] (const auto& p) { return p.first; }); + return result; +} + +} // namespace + +BAMRealigner::CallBlock +BAMRealigner::read_next_block(VcfIterator& first, const VcfIterator& last, const SampleList& samples) const +{ + std::vector> block {}; + for (; first != last; ++first) { + const VcfRecord& call {*first}; + auto call_phase_region = get_phase_region(call, samples); + if (!block.empty() && !overlaps(block.back().second, call_phase_region)) { + return copy_each_first(block); + } + block.emplace_back(call, std::move(call_phase_region)); + } + return copy_each_first(block); +} + +namespace { + +void erase_overlapped(std::vector& reads, const GenomicRegion& region) +{ + auto itr = std::remove_if(std::begin(reads), std::end(reads), [&] (const auto& read) { return overlaps(read, region); }); + reads.erase(itr, std::end(reads)); +} + +} // namespace + +BAMRealigner::BatchList +BAMRealigner::read_next_batch(VcfIterator& first, const VcfIterator& last, ReadReader& src, + const ReferenceGenome& reference, const SampleList& samples, + const boost::optional& prev_batch_region) const +{ + const auto records = read_next_block(first, last, samples); + BatchList result {}; + if (!records.empty()) { + auto genotypes = extract_genotypes(records, samples, reference); + result.reserve(samples.size()); + auto reads_region = encompassing_region(records); + if (config_.copy_hom_ref_reads) { + if (prev_batch_region && is_same_contig(reads_region, *prev_batch_region)) { + reads_region = expand_lhs(reads_region, intervening_region_size(*prev_batch_region, reads_region)); + } else { + reads_region = expand_lhs(reads_region, reads_region.begin()); + } + } + auto reads = src.fetch_reads(samples, expand(reads_region, 1)); + for (const auto& sample : samples) { + auto& sample_genotypes = genotypes[sample]; + if (prev_batch_region) { + erase_overlapped(reads[sample], *prev_batch_region); + } + result.push_back({std::move(sample_genotypes), std::move(reads[sample])}); + } + } else if (prev_batch_region && config_.copy_hom_ref_reads) { + const auto contig_region = reference.contig_region(prev_batch_region->contig_name()); + const auto reads_region = right_overhang_region(contig_region, *prev_batch_region); + if (!is_empty(reads_region)) { + auto reads = src.fetch_reads(samples, reads_region); + for (const auto& sample : samples) { + erase_overlapped(reads[sample], *prev_batch_region); + result.push_back({{}, std::move(reads[sample])}); + } + } + } + return result; +} + +// non-member methods + +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, io::ReadWriter::Path dst, + const ReferenceGenome& reference) +{ + return realign(std::move(src), std::move(variants), std::move(dst), reference, BAMRealigner::Config {}); +} + +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, io::ReadWriter::Path dst, + const ReferenceGenome& reference, BAMRealigner::Config config) +{ + io::ReadWriter dst_bam {std::move(dst), src}; + io::ReadReader src_bam {std::move(src)}; + VcfReader vcf {std::move(variants)}; + BAMRealigner realigner {std::move(config)}; + return realigner.realign(src_bam, vcf, dst_bam, reference); +} + +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, + std::vector dsts, const ReferenceGenome& reference) +{ + return realign(std::move(src), std::move(variants), std::move(dsts), reference, BAMRealigner::Config {}); +} + +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, std::vector dsts, + const ReferenceGenome& reference, BAMRealigner::Config config) +{ + std::vector dst_bams {}; + dst_bams.reserve(dsts.size()); + for (auto& dst : dsts) { + dst_bams.emplace_back(std::move(dst), src); + } + io::ReadReader src_bam {std::move(src)}; + VcfReader vcf {std::move(variants)}; + BAMRealigner realigner {std::move(config)}; + return realigner.realign(src_bam, vcf, dst_bams, reference); +} + +} // namespace octopus diff --git a/src/core/tools/bam_realigner.hpp b/src/core/tools/bam_realigner.hpp new file mode 100644 index 000000000..a1a2abe35 --- /dev/null +++ b/src/core/tools/bam_realigner.hpp @@ -0,0 +1,95 @@ +// Copyright (c) 2017 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef bam_realigner_hpp +#define bam_realigner_hpp + +#include +#include + +#include + +#include "basics/aligned_read.hpp" +#include "core/types/haplotype.hpp" +#include "core/types/genotype.hpp" +#include "containers/mappable_flat_set.hpp" +#include "io/reference/reference_genome.hpp" +#include "io/read/read_reader.hpp" +#include "io/read/read_writer.hpp" +#include "io/variant/vcf_reader.hpp" +#include "utils/thread_pool.hpp" + +namespace octopus { + +class BAMRealigner +{ +public: + using ReadReader = io::ReadReader; + using ReadWriter = io::ReadWriter; + using SampleName = ReadReader::SampleName; + using SampleList = std::vector; + + struct Config + { + bool copy_hom_ref_reads = false; + bool simplify_cigars = false; + boost::optional max_threads = 1; + }; + + struct Report + { + std::size_t n_reads_assigned; + std::size_t n_reads_unassigned; + }; + + BAMRealigner() = default; + BAMRealigner(Config config); + + BAMRealigner(const BAMRealigner&) = default; + BAMRealigner& operator=(const BAMRealigner&) = default; + BAMRealigner(BAMRealigner&&) = default; + BAMRealigner& operator=(BAMRealigner&&) = default; + + ~BAMRealigner() = default; + + Report realign(ReadReader& src, VcfReader& variants, ReadWriter& dst, + const ReferenceGenome& reference, SampleList samples) const; + Report realign(ReadReader& src, VcfReader& variants, ReadWriter& dst, + const ReferenceGenome& reference) const; + + Report realign(ReadReader& src, VcfReader& variants, std::vector& dsts, + const ReferenceGenome& reference, SampleList samples) const; + Report realign(ReadReader& src, VcfReader& variants, std::vector& dsts, + const ReferenceGenome& reference) const; + +private: + using VcfIterator = VcfReader::RecordIterator; + using CallBlock = std::vector; + struct Batch + { + MappableFlatSet> genotypes; + std::vector reads; + }; + using BatchList = std::vector; + + Config config_; + mutable ThreadPool workers_; + + CallBlock read_next_block(VcfIterator& first, const VcfIterator& last, const SampleList& samples) const; + BatchList read_next_batch(VcfIterator& first, const VcfIterator& last, ReadReader& src, + const ReferenceGenome& reference, const SampleList& samples, + const boost::optional& prev_batch_region) const; +}; + +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, io::ReadWriter::Path dst, + const ReferenceGenome& reference); +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, io::ReadWriter::Path dst, + const ReferenceGenome& reference, BAMRealigner::Config config); +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, std::vector dsts, + const ReferenceGenome& reference); +BAMRealigner::Report realign(io::ReadReader::Path src, VcfReader::Path variants, std::vector dsts, + const ReferenceGenome& reference, BAMRealigner::Config config); + +} // namespace octopus + +#endif diff --git a/src/core/tools/coretools.hpp b/src/core/tools/coretools.hpp index b1ab03df6..18905ac8b 100644 --- a/src/core/tools/coretools.hpp +++ b/src/core/tools/coretools.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef coretools_hpp diff --git a/src/core/tools/hapgen/dense_variation_detector.cpp b/src/core/tools/hapgen/dense_variation_detector.cpp new file mode 100644 index 000000000..472b3414d --- /dev/null +++ b/src/core/tools/hapgen/dense_variation_detector.cpp @@ -0,0 +1,328 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "dense_variation_detector.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "concepts/mappable.hpp" +#include "concepts/comparable.hpp" +#include "basics/contig_region.hpp" +#include "basics/aligned_read.hpp" +#include "containers/mappable_flat_set.hpp" +#include "containers/mappable_map.hpp" +#include "utils/mappable_algorithms.hpp" +#include "utils/read_stats.hpp" +#include "utils/maths.hpp" + +#include + +namespace octopus { namespace coretools { + +DenseVariationDetector::DenseVariationDetector(double heterozygosity, double heterozygosity_stdev, + boost::optional reads_profile) +: expected_heterozygosity_ {heterozygosity} +, heterozygosity_stdev_ {heterozygosity_stdev} +, reads_profile_ {std::move(reads_profile)} +{} + +namespace { + +auto get_max_expected_log_allele_count_per_base(double heterozygosity, double heterozygosity_stdev) noexcept +{ + return heterozygosity + 6 * heterozygosity_stdev; +} + +struct AlleleBlock : public Mappable, public Comparable +{ + ContigRegion region; + double log_count; + const ContigRegion& mapped_region() const noexcept { return region; } + AlleleBlock() = default; + AlleleBlock(ContigRegion region, double count) noexcept : region {region}, log_count {count} {} + AlleleBlock(const GenomicRegion& region, double count) noexcept : region {region.contig_region()}, log_count {count} {} +}; + +bool operator==(const AlleleBlock& lhs, const AlleleBlock& rhs) noexcept { return lhs.region < rhs.region; } +bool operator<(const AlleleBlock& lhs, const AlleleBlock& rhs) noexcept { return lhs.region < rhs.region; } + +template +auto sum_block_counts(const Range& block_range) noexcept +{ + return std::accumulate(std::cbegin(block_range), std::cend(block_range), 0.0, + [] (const auto curr, const AlleleBlock& block) noexcept { return curr + block.log_count; }); +} + +bool all_empty(const ReadMap& reads) noexcept +{ + return std::all_of(std::cbegin(reads), std::cend(reads), [](const auto& p) noexcept { return p.second.empty(); }); +} + +auto mean_mapped_region_size(const ReadMap& reads) noexcept +{ + double total_read_size {0}; + std::size_t num_reads {0}; + for (const auto& p : reads) { + total_read_size += sum_region_sizes(p.second); + num_reads += p.second.size(); + } + return total_read_size / num_reads; +} + +auto calculate_positional_coverage(const MappableFlatSet& variants, const GenomicRegion& region) +{ + std::vector variant_regions(variants.size()); + std::transform(std::cbegin(variants), std::cend(variants), std::begin(variant_regions), + [] (const auto& v) { return is_simple_insertion(v) ? expand(contig_region(v), 1) : contig_region(v); }); + return calculate_positional_coverage(variant_regions, region.contig_region()); +} + +struct BaseState +{ + std::uint8_t variant_depth; + std::uint16_t read_depth; + AlignedRead::MappingQuality median_mq; +}; + +auto compute_median(std::vector& mapping_qualities) +{ + if (mapping_qualities.empty()) { + static constexpr auto max_mapping_quality = std::numeric_limits::max(); + return max_mapping_quality; + } else { + return maths::median(mapping_qualities); + } +} + +auto compute_positional_median_mapping_qualities(const MappableFlatMultiSet& reads, const GenomicRegion& region) +{ + const auto num_positions = size(region); + std::vector result(num_positions); + std::deque> pileup_buffer {}; + auto pileup_region = head_region(region.contig_region()); + auto result_itr = std::begin(result); + auto pileup_buffer_begin_itr = std::begin(pileup_buffer); + for (const AlignedRead& read : overlap_range(reads, region)) { + const auto& read_region = contig_region(read); + if (ends_before(pileup_region, read_region)) { + const auto extension_size = right_overhang_size(read_region, pileup_region); + pileup_buffer.resize(pileup_buffer.size() + extension_size); + pileup_buffer_begin_itr = std::begin(pileup_buffer); + pileup_region = expand_rhs(pileup_region, extension_size); + } + auto pileup_read_overlap_begin_itr = pileup_buffer_begin_itr; + if (begins_before(pileup_region, read_region)) { + const auto offset = left_overhang_size(pileup_region, read_region); + pileup_read_overlap_begin_itr = std::next(pileup_buffer_begin_itr, offset); + } + const auto num_overlapped_positions = overlap_size(pileup_region, read_region); + const auto pileup_read_overlap_end_itr = std::next(pileup_read_overlap_begin_itr, num_overlapped_positions); + for (auto itr = pileup_read_overlap_begin_itr; itr != pileup_read_overlap_end_itr; ++itr) { + itr->push_back(read.mapping_quality()); + } + if (pileup_read_overlap_begin_itr != pileup_buffer_begin_itr) { + result_itr = std::transform(pileup_buffer_begin_itr, pileup_read_overlap_begin_itr, result_itr, compute_median); + pileup_buffer_begin_itr = pileup_buffer.erase(pileup_buffer_begin_itr, pileup_read_overlap_begin_itr); + pileup_region = expand_lhs(pileup_region, begin_distance(read_region, pileup_region)); + } + } + return result; +} + +auto compute_positional_mean_median_mapping_qualities(const ReadMap& reads, const GenomicRegion& region) +{ + if (reads.size() == 1) { + return compute_positional_median_mapping_qualities(std::cbegin(reads)->second, region); + } + const auto num_positions = size(region); + std::vector mean_mqs(num_positions); + for (const auto& p : reads) { + const auto median_mqs = compute_positional_median_mapping_qualities(p.second, region); + for (std::size_t i {0}; i < num_positions; ++i) { + mean_mqs[i] += (static_cast(median_mqs[i]) - mean_mqs[i]) / num_positions; + } + } + return std::vector {std::cbegin(mean_mqs), std::cend(mean_mqs)}; +} + +auto compute_base_states(const MappableFlatSet& variants, const ReadMap& reads) +{ + const auto region = encompassing_region(variants); + const auto num_bases = size(region); + std::vector result(num_bases); + { + auto variant_depths = calculate_positional_coverage(variants, region); + for (std::size_t i {0}; i < num_bases; ++i) { + static constexpr unsigned max_variant_depth {std::numeric_limits::max()}; + result[i].variant_depth = std::min(variant_depths[i], max_variant_depth); + } + } + { + auto read_depths = calculate_positional_coverage(reads, region); + for (std::size_t i {0}; i < num_bases; ++i) { + static constexpr unsigned max_read_depth {std::numeric_limits::max()}; + result[i].read_depth = std::min(read_depths[i], max_read_depth); + } + } + { + auto median_mqs = compute_positional_mean_median_mapping_qualities(reads, region); + for (std::size_t i {0}; i < num_bases; ++i) { + result[i].median_mq = median_mqs[i]; + } + } + return result; +} + +auto find_dense_regions(const MappableFlatSet& variants, const ReadMap& reads, + const double dense_zone_log_count_threshold, + const double max_shared_dense_zones) +{ + const auto initial_blocks = extract_covered_regions(variants); + MappableFlatSet blocks {}; + for (const auto& region : initial_blocks) { + blocks.emplace(region, std::log2(2 * count_overlapped(variants, region))); + } + for (const auto& p : reads) { + for (const auto& read : p.second) { + const auto interacting_blocks = bases(contained_range(blocks, contig_region(read))); + if (size(interacting_blocks) > 1) { + auto joined_block_region = closed_region(interacting_blocks.front(), interacting_blocks.back()); + auto joined_block_count = sum_block_counts(interacting_blocks); + auto hint = blocks.erase(std::cbegin(interacting_blocks), std::cend(interacting_blocks)); + blocks.insert(hint, AlleleBlock {joined_block_region, joined_block_count}); + } + } + } + std::deque seeds {}; + const auto& contig = contig_name(variants.front()); + for (const auto& block : blocks) { + if (block.log_count > dense_zone_log_count_threshold) { + seeds.push_back(GenomicRegion {contig, block.region}); + } + } + auto joined_seeds = join_if(seeds, [&] (const auto& lhs, const auto& rhs) { return has_shared(reads, lhs, rhs); }); + std::vector dense_blocks(joined_seeds.size()); + std::transform(std::cbegin(joined_seeds), std::cend(joined_seeds), std::begin(dense_blocks), + [&](const auto& region) { + auto contained = contained_range(blocks, region.contig_region()); + assert(!empty(contained)); + auto joined_block_region = closed_region(contained.front(), contained.back()); + auto joined_block_count = sum_block_counts(contained); + return AlleleBlock {joined_block_region, joined_block_count}; + }); + MappableFlatSet result {}; + const auto join_threshold = max_shared_dense_zones * dense_zone_log_count_threshold; + for (const auto& block : dense_blocks) { + if (block.log_count > join_threshold) { + result.insert(GenomicRegion {contig, block.region}); + } + } + return result; +} + +struct RegionState +{ + GenomicRegion region; + unsigned variant_count; + double variant_density; + unsigned mean_read_depth; + double rmq_mapping_quality; +}; + +auto compute_state(const GenomicRegion& region, const MappableFlatSet& variants, const ReadMap& reads) +{ + RegionState result {}; + result.region = region; + result.rmq_mapping_quality = rmq_mapping_quality(reads, region); + result.mean_read_depth = mean_coverage(reads, region) / reads.size(); + result.variant_count = count_overlapped(variants, region); + result.variant_density = static_cast(result.variant_count) / size(region); + return result; +} + +auto compute_states(const MappableFlatSet& regions, const MappableFlatSet& variants, const ReadMap& reads) +{ + std::vector result {}; + result.reserve(regions.size()); + for (const auto& region : regions) { + result.push_back(compute_state(region, variants, reads)); + } + return result; +} + +bool should_join(const RegionState& lhs_state, const RegionState& connecting_state, const RegionState& rhs_state) +{ + if (connecting_state.variant_density > std::min(lhs_state.variant_density, rhs_state.variant_density) / 2) { + return true; + } + if (size(connecting_state.region) > std::max(size(lhs_state.region), size(rhs_state.region))) { + return false; + } + if (connecting_state.rmq_mapping_quality < std::min(lhs_state.rmq_mapping_quality, rhs_state.rmq_mapping_quality)) { + return true; + } + if (connecting_state.mean_read_depth > std::min(lhs_state.mean_read_depth, rhs_state.mean_read_depth)) { + return true; + } + return false; +} + +auto join_dense_regions(const MappableFlatSet& dense_regions, + const MappableFlatSet& variants, const ReadMap& reads) +{ + if (dense_regions.size() > 1) { + std::vector final_regions {}; + final_regions.reserve(dense_regions.size()); + const auto dense_states = compute_states(dense_regions, variants, reads); + final_regions.push_back(dense_regions.front()); + for (std::size_t i {1}; i < dense_regions.size(); ++i) { + const auto connecting_region = *intervening_region(dense_regions[i - 1], dense_regions[i]); + const auto connecting_state = compute_state(connecting_region, variants, reads); + if (should_join(dense_states[i - 1], connecting_state, dense_states[i])) { + final_regions.push_back(connecting_region); + } + final_regions.push_back(dense_regions[i]); + } + return extract_covered_regions(final_regions); + } else { + return std::vector {dense_regions.front()}; + } +} + +} // namespace + +std::vector +DenseVariationDetector::detect(const MappableFlatSet& variants, const ReadMap& reads) const +{ + const auto mean_read_size = mean_mapped_region_size(reads); + auto expected_log_count = get_max_expected_log_allele_count_per_base(expected_heterozygosity_, heterozygosity_stdev_); + const auto dense_zone_log_count_threshold = expected_log_count * mean_read_size; + auto dense_regions = find_dense_regions(variants, reads, dense_zone_log_count_threshold, 1); + if (dense_regions.empty()) return {}; + auto joined_dense_regions = join_dense_regions(dense_regions, variants, reads); + std::vector result {}; + result.reserve(joined_dense_regions.size()); + double max_expected_coverage {}; + if (reads_profile_) { + max_expected_coverage = 2 * reads_profile_->mean_depth + 2 * reads_profile_->depth_stdev; + } else { + max_expected_coverage = 2 * mean_coverage(reads); + } + for (const auto& region : joined_dense_regions) { + const auto state = compute_state(region, variants, reads); + if (state.variant_count > 100 && size(state.region) > 3 * mean_read_size && state.mean_read_depth > max_expected_coverage) { + result.push_back({region, DenseRegion::RecommendedAction::skip}); + } + } + return result; +} + +} // namespace coretools +} // namespace octopus diff --git a/src/core/tools/hapgen/dense_variation_detector.hpp b/src/core/tools/hapgen/dense_variation_detector.hpp new file mode 100644 index 000000000..1b55dc56a --- /dev/null +++ b/src/core/tools/hapgen/dense_variation_detector.hpp @@ -0,0 +1,50 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef dense_variation_detector_hpp +#define dense_variation_detector_hpp + +#include + +#include + +#include "config/common.hpp" +#include "basics/genomic_region.hpp" +#include "core/types/variant.hpp" +#include "utils/input_reads_profiler.hpp" + +namespace octopus { namespace coretools { + +class DenseVariationDetector +{ +public: + struct DenseRegion + { + enum class RecommendedAction { skip, restrict_lagging }; + GenomicRegion region; + RecommendedAction action; + }; + + DenseVariationDetector() = default; + + DenseVariationDetector(double heterozygosity, double heterozygosity_stdev, + boost::optional = boost::none); + + DenseVariationDetector(const DenseVariationDetector&) = default; + DenseVariationDetector& operator=(const DenseVariationDetector&) = default; + DenseVariationDetector(DenseVariationDetector&&) = default; + DenseVariationDetector& operator=(DenseVariationDetector&&) = default; + + ~DenseVariationDetector() = default; + + std::vector detect(const MappableFlatSet& variants, const ReadMap& reads) const; + +private: + double expected_heterozygosity_, heterozygosity_stdev_; + boost::optional reads_profile_; +}; + +} // namespace coretools +} // namespace octopus + +#endif diff --git a/src/core/tools/hapgen/genome_walker.cpp b/src/core/tools/hapgen/genome_walker.cpp index 5a2dca67d..4edf68271 100644 --- a/src/core/tools/hapgen/genome_walker.cpp +++ b/src/core/tools/hapgen/genome_walker.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "genome_walker.hpp" @@ -49,18 +49,33 @@ bool is_indel_boundary(BidirIt first, BidirIt allele, BidirIt last) { if (allele != first && allele != last && std::next(allele) != last && is_indel(*allele)) { auto itr = std::find_if(std::make_reverse_iterator(std::prev(allele)), std::make_reverse_iterator(first), - [&] (const auto& a) { return !overlaps(a, *allele) || is_indel(a); }); + [allele] (const auto& a) { return !overlaps(a, *allele) || is_indel(a); }); return itr != std::make_reverse_iterator(first) && overlaps(*itr, *allele) && is_indel(*itr); } else { return false; } } +template +bool is_interacting_indel(BidirIt first, BidirIt allele, BidirIt last, + const GenomicRegion::Size max_gap = 3) +{ + if (allele != first && allele != last && std::next(allele) != last && is_indel(*allele)) { + const auto interaction_region = expand_lhs(mapped_region(*allele), std::min(reference_distance(*allele), max_gap)); + auto itr = std::find_if(std::make_reverse_iterator(std::prev(allele)), std::make_reverse_iterator(first), + [&interaction_region] (const auto& a) { return overlaps(a, interaction_region); }); + return itr != std::make_reverse_iterator(first); + } else { + return false; + } +} + template bool is_good_indicator_begin(BidirIt first_possible, BidirIt allele_itr, BidirIt last_possible) { return !(is_sandwich_allele(first_possible, allele_itr, last_possible) - || is_indel_boundary(first_possible, allele_itr, last_possible)); + || is_indel_boundary(first_possible, allele_itr, last_possible) + || is_interacting_indel(first_possible, allele_itr, last_possible)); } template diff --git a/src/core/tools/hapgen/genome_walker.hpp b/src/core/tools/hapgen/genome_walker.hpp index d19e9384e..23759825e 100644 --- a/src/core/tools/hapgen/genome_walker.hpp +++ b/src/core/tools/hapgen/genome_walker.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef genome_walker_hpp diff --git a/src/core/tools/hapgen/haplotype_generator.cpp b/src/core/tools/hapgen/haplotype_generator.cpp index 3ee53858f..e2852be8b 100644 --- a/src/core/tools/hapgen/haplotype_generator.cpp +++ b/src/core/tools/hapgen/haplotype_generator.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "haplotype_generator.hpp" @@ -117,97 +117,20 @@ auto make_lagged_walker(const HaplotypeGenerator::Policies& policies) }; } -struct AlleleBlock : public Mappable, public Comparable -{ - ContigRegion region; - double log_count; - const ContigRegion& mapped_region() const noexcept { return region; } - AlleleBlock() = default; - AlleleBlock(ContigRegion region, double count) noexcept : region {region}, log_count {count} {} - AlleleBlock(const GenomicRegion& region, double count) noexcept : region {region.contig_region()}, log_count {count} {} -}; - -bool operator==(const AlleleBlock& lhs, const AlleleBlock& rhs) noexcept { return lhs.region < rhs.region; } -bool operator<(const AlleleBlock& lhs, const AlleleBlock& rhs) noexcept { return lhs.region < rhs.region; } - -template -auto sum_block_counts(const Range& block_range) noexcept -{ - return std::accumulate(std::cbegin(block_range), std::cend(block_range), 0.0, - [] (const auto curr, const AlleleBlock& block) noexcept { - return curr + block.log_count; - }); -} - bool all_empty(const ReadMap& reads) noexcept { return std::all_of(std::cbegin(reads), std::cend(reads), [] (const auto& p) noexcept { return p.second.empty(); }); } -auto mean_mapped_region_size(const ReadMap& reads) noexcept -{ - double total_read_size {0}; - std::size_t num_reads {0}; - for (const auto& p : reads) { - total_read_size += sum_region_sizes(p.second); - num_reads += p.second.size(); - } - return total_read_size / num_reads; -} - -auto find_dense_regions(const MappableFlatSet& alleles, const ReadMap& reads, - const double dense_zone_log_count_threshold, - const double max_shared_dense_zones) -{ - const auto initial_blocks = extract_covered_regions(alleles); - MappableFlatSet blocks {}; - for (const auto& region : initial_blocks) { - blocks.emplace(region, std::log2(count_overlapped(alleles, region))); - } - for (const auto& p : reads) { - for (const auto& read : p.second) { - const auto interacting_blocks = bases(contained_range(blocks, contig_region(read))); - if (size(interacting_blocks) > 1) { - auto joined_block_region = closed_region(interacting_blocks.front(), interacting_blocks.back()); - auto joined_block_count = sum_block_counts(interacting_blocks); - auto hint = blocks.erase(std::cbegin(interacting_blocks), std::cend(interacting_blocks)); - blocks.insert(hint, AlleleBlock {joined_block_region, joined_block_count}); - } - } - } - std::deque seeds {}; - const auto& contig = contig_name(alleles.front()); - for (const auto& block : blocks) { - if (block.log_count > dense_zone_log_count_threshold) { - seeds.push_back(GenomicRegion {contig, block.region}); - } - } - auto joined_seeds = join_if(seeds, [&] (const auto& lhs, const auto& rhs) { return has_shared(reads, lhs, rhs); }); - std::vector dense_blocks(joined_seeds.size()); - std::transform(std::cbegin(joined_seeds), std::cend(joined_seeds), std::begin(dense_blocks), - [&] (const auto& region) { - auto contained = contained_range(blocks, region.contig_region()); - assert(!empty(contained)); - auto joined_block_region = closed_region(contained.front(), contained.back()); - auto joined_block_count = sum_block_counts(contained); - return AlleleBlock {joined_block_region, joined_block_count}; - }); - MappableFlatSet result {}; - const auto join_threshold = max_shared_dense_zones * dense_zone_log_count_threshold; - for (const auto& block : dense_blocks) { - if (block.log_count > join_threshold) { - result.insert(GenomicRegion {contig, block.region}); - } - } - return result; -} - } // namespace // public members -HaplotypeGenerator::HaplotypeGenerator(const ReferenceGenome& reference, const MappableFlatSet& candidates, - const ReadMap& reads, Policies policies) +HaplotypeGenerator::HaplotypeGenerator(const ReferenceGenome& reference, + const MappableFlatSet& candidates, + const ReadMap& reads, + Policies policies, + DenseVariationDetector dense_variation_detector) : policies_ {std::move(policies)} , tree_ {get_contig(candidates), reference} , default_walker_ { @@ -230,6 +153,28 @@ HaplotypeGenerator::HaplotypeGenerator(const ReferenceGenome& reference, const M , debug_log_ {logging::get_debug_log()} , trace_log_ {logging::get_trace_log()} { + assert(!candidates.empty()); + if (!all_empty(reads_)) { + for (const auto& dense : dense_variation_detector.detect(candidates, reads)) { + if (dense.action == DenseVariationDetector::DenseRegion::RecommendedAction::skip) { + if (debug_log_) { + stream(*debug_log_) << "Erasing " << count_contained(alleles_, dense.region) + << " alleles in dense region " << dense.region; + } + alleles_.erase_contained(dense.region); + } else if (is_lagging_enabled()) { + lagging_exclusion_zones_.insert(dense.region); + } + } + if (!lagging_exclusion_zones_.empty() && debug_log_) { + auto log = stream(*debug_log_); + log << "Found lagging exclusion zones: "; + for (const auto& zone : lagging_exclusion_zones_) log << zone << " "; + } + if (alleles_.empty()) { + alleles_.insert(candidates.back().ref_allele()); + } + } assert(!alleles_.empty()); rightmost_allele_ = alleles_.rightmost(); active_region_ = head_region(alleles_.leftmost()); @@ -239,16 +184,6 @@ HaplotypeGenerator::HaplotypeGenerator(const ReferenceGenome& reference, const M if (policies.lagging != Policies::Lagging::none) { lagged_walker_ = make_lagged_walker(policies); } - if (is_lagging_enabled() && policies.max_expected_log_allele_count_per_base && !all_empty(reads_)) { - const auto mean_read_size = mean_mapped_region_size(reads_); - const auto dense_zone_log_count_threshold = *policies.max_expected_log_allele_count_per_base * mean_read_size; - lagging_exclusion_zones_ = find_dense_regions(alleles_, reads_, dense_zone_log_count_threshold, 3.0); - if (!lagging_exclusion_zones_.empty() && debug_log_) { - auto log = stream(*debug_log_); - log << "Found lagging exclusion zones: "; - for (const auto& zone : lagging_exclusion_zones_) log << zone << " "; - } - } } namespace { @@ -313,7 +248,8 @@ void HaplotypeGenerator::clear_progress() noexcept void HaplotypeGenerator::jump(GenomicRegion region) { clear_progress(); - progress(std::move(region)); + next_active_region_ = std::move(region); + remove_passed_alleles(); } bool HaplotypeGenerator::removal_has_impact() const @@ -886,8 +822,11 @@ void HaplotypeGenerator::update_lagged_next_active_region() const const auto ideal_num_new_novel_blocks = get_num_ideal_new_novel_blocks(novel_blocks, indicator_region, alleles_); auto num_novel_blocks_added = extend_novel(test_tree, novel_blocks, novel_alleles, ideal_num_new_novel_blocks, policies_.haplotype_limits); - if (num_novel_blocks_added == 0 && !protected_indicator_blocks.empty()) { + if (num_novel_blocks_added == 0 && protected_indicator_blocks.size() > 1) { + auto last_block = protected_indicator_blocks.back(); + protected_indicator_blocks.pop_back(); prune_indicators(test_tree, protected_indicator_blocks, target_tree_size); + protected_indicator_blocks.assign({last_block}); num_novel_blocks_added = extend_novel(test_tree, novel_blocks, novel_alleles, 1, policies_.haplotype_limits); } if (num_novel_blocks_added > 0) { @@ -903,48 +842,44 @@ void HaplotypeGenerator::update_lagged_next_active_region() const } } -void HaplotypeGenerator::progress(GenomicRegion to) -{ - if (to == active_region_) return; - next_active_region_ = std::move(to); - if (!in_holdout_mode()) { - if (begins_before(active_region_, *next_active_region_)) { - auto passed_region = left_overhang_region(active_region_, *next_active_region_); - const auto passed_alleles = overlap_range(alleles_, passed_region); - if (passed_alleles.empty()) return; - if (can_remove_entire_passed_region(active_region_, *next_active_region_, passed_alleles)) { - alleles_.erase_overlapped(passed_region); - tree_.clear(passed_region); - } else if (requires_staged_removal(passed_alleles)) { - // We need to be careful here as insertions adjacent to passed_region are - // considered overlapped and would be wrongly erased if we erased the whole - // region. But, we also want to clear all single base alleles left adjacent with - // next_active_region_, as they have truly been passed. - - // This will erase everything to the left of the adjacent insertion, other than - // the single base alleles adjacent with next_active_region_. - const auto first_removal_region = expand_rhs(passed_region, -1); - alleles_.erase_overlapped(first_removal_region); - // This will erase the remaining single base alleles in passed_region, but not the - // insertions in next_active_region_. - const auto second_removal_region = tail_region(first_removal_region); - alleles_.erase_overlapped(second_removal_region); - - if (is_after(*next_active_region_, active_region_)) { - assert(tree_.is_empty() || contains(active_region_, tree_.encompassing_region())); - tree_.clear(); - } else { - tree_.clear(first_removal_region); - tree_.clear(second_removal_region); - } +void HaplotypeGenerator::remove_passed_alleles() +{ + if (begins_before(active_region_, *next_active_region_)) { + auto passed_region = left_overhang_region(active_region_, *next_active_region_); + const auto passed_alleles = overlap_range(alleles_, passed_region); + if (passed_alleles.empty()) return; + if (can_remove_entire_passed_region(active_region_, *next_active_region_, passed_alleles)) { + alleles_.erase_overlapped(passed_region); + tree_.clear(passed_region); + } else if (requires_staged_removal(passed_alleles)) { + // We need to be careful here as insertions adjacent to passed_region are + // considered overlapped and would be wrongly erased if we erased the whole + // region. But, we also want to clear all single base alleles left adjacent with + // next_active_region_, as they have truly been passed. + + // This will erase everything to the left of the adjacent insertion, other than + // the single base alleles adjacent with next_active_region_. + const auto first_removal_region = expand_rhs(passed_region, -1); + alleles_.erase_overlapped(first_removal_region); + // This will erase the remaining single base alleles in passed_region, but not the + // insertions in next_active_region_. + const auto second_removal_region = tail_region(first_removal_region); + alleles_.erase_overlapped(second_removal_region); + + if (is_after(*next_active_region_, active_region_)) { + assert(tree_.is_empty() || contains(active_region_, tree_.encompassing_region())); + tree_.clear(); } else { - const auto removal_region = expand_rhs(passed_region, -1); - alleles_.erase_overlapped(removal_region); - tree_.clear(removal_region); - } - if (overlaps(passed_region, rightmost_allele_) && !alleles_.empty()) { - rightmost_allele_ = alleles_.rightmost(); + tree_.clear(first_removal_region); + tree_.clear(second_removal_region); } + } else { + const auto removal_region = expand_rhs(passed_region, -1); + alleles_.erase_overlapped(removal_region); + tree_.clear(removal_region); + } + if (overlaps(passed_region, rightmost_allele_) && !alleles_.empty()) { + rightmost_allele_ = alleles_.rightmost(); } } } @@ -967,7 +902,7 @@ void HaplotypeGenerator::populate_tree_with_novel_alleles() active_region_ = *next_active_region_; return; } - progress(*next_active_region_); + if (!in_holdout_mode()) remove_passed_alleles(); auto novel_active_region = *next_active_region_; if (!tree_.is_empty()) { novel_active_region = right_overhang_region(*next_active_region_, active_region_); @@ -1010,7 +945,6 @@ void HaplotypeGenerator::populate_tree_with_novel_alleles() next_holdout_region = novel_active_region; } } - assert(!begins_before(active_region_before_holdout, active_region_)); if (last_added_novel_itr != std::cend(novel_active_alleles)) { last_added_novel_itr = extend_tree_until(last_added_novel_itr, std::cend(novel_active_alleles), tree_, policies_.haplotype_limits.overflow); @@ -1411,11 +1345,17 @@ HaplotypeGenerator::Builder& HaplotypeGenerator::Builder::set_max_expected_log_a return *this; } +HaplotypeGenerator::Builder& HaplotypeGenerator::Builder::set_dense_variation_detector(DenseVariationDetector detector) noexcept +{ + dense_variation_detector_ = std::move(detector); + return *this; +} + HaplotypeGenerator HaplotypeGenerator::Builder::build(const ReferenceGenome& reference, const MappableFlatSet& candidates, const ReadMap& reads) const { - return HaplotypeGenerator {reference, candidates, reads, policies_}; + return HaplotypeGenerator {reference, candidates, reads, policies_, dense_variation_detector_}; } } // namespace coretools diff --git a/src/core/tools/hapgen/haplotype_generator.hpp b/src/core/tools/hapgen/haplotype_generator.hpp index 6c0f244ac..6555d7722 100644 --- a/src/core/tools/hapgen/haplotype_generator.hpp +++ b/src/core/tools/hapgen/haplotype_generator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef haplotype_generator_hpp @@ -21,6 +21,7 @@ #include "core/types/allele.hpp" #include "genome_walker.hpp" #include "haplotype_tree.hpp" +#include "dense_variation_detector.hpp" namespace octopus { @@ -57,7 +58,8 @@ class HaplotypeGenerator HaplotypeGenerator(const ReferenceGenome& reference, const MappableFlatSet& candidates, const ReadMap& reads, - Policies policies); + Policies policies, + DenseVariationDetector dense_variation_detector); HaplotypeGenerator(const HaplotypeGenerator&) = default; HaplotypeGenerator& operator=(const HaplotypeGenerator&) = default; @@ -144,7 +146,7 @@ class HaplotypeGenerator GenomicRegion find_max_lagged_region() const; void update_next_active_region() const; void update_lagged_next_active_region() const; - void progress(GenomicRegion to); + void remove_passed_alleles(); void populate_tree(); void populate_tree_with_novel_alleles(); void populate_tree_with_holdouts(); @@ -222,18 +224,14 @@ class HaplotypeGenerator::Builder Builder& set_lagging_policy(Policies::Lagging policy) noexcept; Builder& set_extension_policy(Policies::Extension policy) noexcept; - Builder& set_target_limit(unsigned n) noexcept; Builder& set_holdout_limit(unsigned n) noexcept; Builder& set_overflow_limit(unsigned n) noexcept; - Builder& set_max_holdout_depth(unsigned n) noexcept; - Builder& set_min_flank_pad(Haplotype::MappingDomain::Size n) noexcept; - Builder& set_max_indicator_join_distance(Haplotype::NucleotideSequence::size_type n) noexcept; - Builder& set_max_expected_log_allele_count_per_base(double v) noexcept; + Builder& set_dense_variation_detector(DenseVariationDetector detector) noexcept; HaplotypeGenerator build(const ReferenceGenome& reference, const MappableFlatSet& candidates, @@ -241,6 +239,7 @@ class HaplotypeGenerator::Builder private: Policies policies_; + DenseVariationDetector dense_variation_detector_; }; } // namespace coretools diff --git a/src/core/tools/hapgen/haplotype_tree.cpp b/src/core/tools/hapgen/haplotype_tree.cpp index a1f697fd7..1a3120425 100644 --- a/src/core/tools/hapgen/haplotype_tree.cpp +++ b/src/core/tools/hapgen/haplotype_tree.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "haplotype_tree.hpp" @@ -213,10 +213,23 @@ auto make_splicer(const ContigAllele& allele, std::stack& candidate_splice_si return Splicer {allele, candidate_splice_sites, splice_sites, root}; } +template +bool is_possible_splice_site(const ContigAllele& allele, const V& v, const G& tree) +{ + // Can allele go before v in the tree? + return begins_before(allele, tree[v]) + || (boost::out_degree(v, tree) == 0 && overlaps(allele, tree[v])) + || (begins_equal(allele, tree[v]) && (!is_empty_region(tree[v]) || (is_insertion(tree[v]) && is_deletion(allele)))); +} + +bool is_deletion_and_insertion(const ContigAllele& new_allele, const ContigAllele& leaf) +{ + return (is_insertion(leaf) && is_deletion(new_allele)) || (is_deletion(leaf) && is_insertion(new_allele)); +} + bool can_add_to_branch(const ContigAllele& new_allele, const ContigAllele& leaf) { - return !are_adjacent(leaf, new_allele) - || !((is_insertion(leaf) && is_deletion(new_allele)) || (is_deletion(leaf) && is_insertion(new_allele))); + return !are_adjacent(leaf, new_allele) || !is_deletion_and_insertion(new_allele, leaf); } void HaplotypeTree::splice(const ContigAllele& allele) @@ -233,19 +246,19 @@ void HaplotypeTree::splice(const ContigAllele& allele) make_splicer(allele, candidate_splice_sites, splice_sites, root_), boost::make_assoc_property_map(colours), [&] (const Vertex v, const Tree& tree) -> bool { - if (v != root_ && (begins_before(allele, tree[v]) - || (begins_equal(allele, tree[v]) && !is_empty_region(tree[v])))) { - const auto p = boost::inv_adjacent_vertices(v, tree); - if (p.first != p.second) { - const auto u = *p.first; - if (candidate_splice_sites.empty() || candidate_splice_sites.top() != u) { - candidate_splice_sites.push(u); + if (v != root_) { + if (is_possible_splice_site(allele, v, tree_)) { + const auto p = boost::inv_adjacent_vertices(v, tree); + if (p.first != p.second) { + const auto u = *p.first; + if (candidate_splice_sites.empty() || candidate_splice_sites.top() != u) { + candidate_splice_sites.push(u); + } } + return true; } - return true; - } else { - return false; } + return false; }); assert(candidate_splice_sites.empty()); for (const auto v : splice_sites) { diff --git a/src/core/tools/hapgen/haplotype_tree.hpp b/src/core/tools/hapgen/haplotype_tree.hpp index 56fa023ae..cb7453465 100644 --- a/src/core/tools/hapgen/haplotype_tree.hpp +++ b/src/core/tools/hapgen/haplotype_tree.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef haplotype_tree_hpp diff --git a/src/core/tools/haplotype_filter.cpp b/src/core/tools/haplotype_filter.cpp index 3022c2447..644a31399 100644 --- a/src/core/tools/haplotype_filter.cpp +++ b/src/core/tools/haplotype_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "haplotype_filter.hpp" diff --git a/src/core/tools/haplotype_filter.hpp b/src/core/tools/haplotype_filter.hpp index f2ddbe397..b7907de5c 100644 --- a/src/core/tools/haplotype_filter.hpp +++ b/src/core/tools/haplotype_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef haplotype_filter_hpp diff --git a/src/core/tools/phaser/phaser.cpp b/src/core/tools/phaser/phaser.cpp index 540b0685a..3106d60db 100644 --- a/src/core/tools/phaser/phaser.cpp +++ b/src/core/tools/phaser/phaser.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "phaser.hpp" diff --git a/src/core/tools/phaser/phaser.hpp b/src/core/tools/phaser/phaser.hpp index 7e5ff82d7..38e1f3d30 100644 --- a/src/core/tools/phaser/phaser.hpp +++ b/src/core/tools/phaser/phaser.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef phaser_hpp diff --git a/src/core/tools/read_assigner.cpp b/src/core/tools/read_assigner.cpp index 48c5b1547..033eabb7c 100644 --- a/src/core/tools/read_assigner.cpp +++ b/src/core/tools/read_assigner.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_assigner.hpp" @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -19,16 +20,35 @@ namespace octopus { +namespace { + using HaplotypeLikelihoods = std::vector>; -void find_max_likelihood_haplotypes(const std::vector& haplotypes, const unsigned read, - const HaplotypeLikelihoods& likelihoods, - std::vector& result) +auto vectorise(const std::vector& haplotypes, const HaplotypeProbabilityMap& priors) +{ + std::vector result(haplotypes.size()); + std::transform(std::cbegin(haplotypes), std::cend(haplotypes), std::begin(result), + [&] (const auto& haplotype) { return priors.at(haplotype); }); + return result; +} + +auto get_priors(const std::vector& haplotypes, const HaplotypeProbabilityMap& log_priors) +{ + if (log_priors.empty()) { + return std::vector(haplotypes.size()); + } else { + return vectorise(haplotypes, log_priors); + } +} + +void find_map_haplotypes(const std::vector& haplotypes, const unsigned read, + const HaplotypeLikelihoods& likelihoods, const std::vector& log_priors, + std::vector& result) { assert(result.empty()); auto max_likelihood = std::numeric_limits::lowest(); for (unsigned k {0}; k < haplotypes.size(); ++k) { - const auto curr = likelihoods[k][read]; + const auto curr = likelihoods[k][read] + log_priors[k]; if (maths::almost_equal(curr, max_likelihood)) { result.push_back(k); } else if (curr > max_likelihood) { @@ -36,22 +56,70 @@ void find_max_likelihood_haplotypes(const std::vector& haplotypes, co max_likelihood = curr; } } + if (result.empty()) { + result.resize(haplotypes.size()); + std::iota(std::begin(result), std::end(result), 0); + } +} + +template +ForwardIt random_select(ForwardIt first, ForwardIt last, RandomGenerator& g) +{ + if (first == last) return first; + const auto max = static_cast(std::distance(first, last)); + if (max == 1) return first; + std::uniform_int_distribution dist {0, max - 1}; + std::advance(first, dist(g)); + return first; +} + +template +ForwardIt random_select(ForwardIt first, ForwardIt last) +{ + static thread_local std::mt19937 generator {42}; + return random_select(first, last, generator); +} + +template +decltype(auto) random_select(const Range& values) +{ + assert(!values.empty()); + return *random_select(std::cbegin(values), std::cend(values)); } auto calculate_support(const std::vector& haplotypes, const std::vector& reads, + const std::vector& log_priors, const HaplotypeLikelihoods& likelihoods, - boost::optional&> unassigned) + boost::optional&> ambiguous, + const AssignmentConfig& config) { HaplotypeSupportMap result {}; std::vector top {}; top.reserve(haplotypes.size()); for (unsigned i {0}; i < reads.size(); ++i) { - find_max_likelihood_haplotypes(haplotypes, i, likelihoods, top); + find_map_haplotypes(haplotypes, i, likelihoods, log_priors, top); if (top.size() == 1) { result[haplotypes[top.front()]].push_back(reads[i]); - } else if (unassigned) { - unassigned->push_back(reads[i]); + } else { + using UA = AssignmentConfig::AmbiguousAction; + switch (config.ambiguous_action) { + case UA::first: + result[haplotypes[top.front()]].push_back(reads[i]); + break; + case UA::all: { + for (auto idx : top) result[haplotypes[idx]].push_back(reads[i]); + break; + } + case UA::random: { + result[haplotypes[random_select(top)]].push_back(reads[i]); + break; + } + case UA::drop: + default: + break; + } + if (ambiguous) ambiguous->push_back(reads[i]); } top.clear(); } @@ -88,7 +156,7 @@ auto calculate_likelihoods(const std::vector& haplotypes, const auto& haplotype_region = mapped_region(haplotypes.front()); const auto reads_region = encompassing_region(reads); const auto min_flank_pad = HaplotypeLikelihoodModel::pad_requirement(); - unsigned min_lhs_expansion {2 * min_flank_pad}, min_rhs_expansion {2 * min_flank_pad}; + unsigned min_lhs_expansion{2 * min_flank_pad}, min_rhs_expansion{2 * min_flank_pad}; if (begins_before(reads_region, haplotype_region)) { min_lhs_expansion += begin_distance(reads_region, haplotype_region); } @@ -99,7 +167,7 @@ auto calculate_likelihoods(const std::vector& haplotypes, const auto read_hashes = compute_read_hashes(reads); static constexpr unsigned char mapperKmerSize {6}; auto haplotype_hashes = init_kmer_hash_table(); - HaplotypeLikelihoods result {}; + HaplotypeLikelihoods result{}; result.reserve(haplotypes.size()); for (const auto& haplotype : haplotypes) { const auto expanded_haplotype = expand(haplotype, min_expansion); @@ -119,64 +187,119 @@ auto calculate_likelihoods(const std::vector& haplotypes, return result; } -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - HaplotypeLikelihoodModel model, - boost::optional&> unassigned) +} // namespace + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + const HaplotypeProbabilityMap& log_priors, + HaplotypeLikelihoodModel model, + boost::optional&> ambiguous, + AssignmentConfig config) { - if (!genotype.is_homozygous() && !reads.empty()) { - const auto unique_haplotypes = genotype.copy_unique(); - assert(unique_haplotypes.size() > 1); - const auto likelihoods = calculate_likelihoods(unique_haplotypes, reads, model); - return calculate_support(unique_haplotypes, reads, likelihoods, unassigned); - } else { - return {}; + if (!reads.empty()) { + if (!genotype.is_homozygous()) { + const auto unique_haplotypes = genotype.copy_unique(); + assert(unique_haplotypes.size() > 1); + const auto priors = get_priors(unique_haplotypes, log_priors); + const auto likelihoods = calculate_likelihoods(unique_haplotypes, reads, model); + return calculate_support(unique_haplotypes, reads, priors, likelihoods, ambiguous, config); + } else if (config.ambiguous_action != AssignmentConfig::AmbiguousAction::drop) { + HaplotypeSupportMap result {}; + result.emplace(genotype[0], reads); + return result; + } } + return {}; } -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - std::deque& unassigned) +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + const HaplotypeProbabilityMap& log_priors, + HaplotypeLikelihoodModel model, + AssignmentConfig config) { - return compute_haplotype_support(genotype, reads, HaplotypeLikelihoodModel {nullptr, make_indel_error_model(), false}, unassigned); + return compute_haplotype_support(genotype, reads, log_priors, model, ambiguous, config); } -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads) +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + HaplotypeLikelihoodModel model, + std::deque& ambiguous, + AssignmentConfig config) { - return compute_haplotype_support(genotype, reads, HaplotypeLikelihoodModel {nullptr, make_indel_error_model(), false}); + return compute_haplotype_support(genotype, reads, ambiguous, {}, model, config); } -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - HaplotypeLikelihoodModel model) +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + const HaplotypeProbabilityMap& log_priors, + AssignmentConfig config) { - return compute_haplotype_support(genotype, reads, std::move(model), boost::none); + HaplotypeLikelihoodModel model {nullptr, make_indel_error_model(), false}; + return compute_haplotype_support(genotype, reads, log_priors, model, ambiguous, config); } -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - std::deque& unassigned, - HaplotypeLikelihoodModel model) +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + AssignmentConfig config) { - return compute_haplotype_support(genotype, reads, std::move(model), unassigned); + HaplotypeLikelihoodModel model {nullptr, make_indel_error_model(), false}; + return compute_haplotype_support(genotype, reads, model, ambiguous, config); } -AlleleSupportMap compute_allele_support(const std::vector& alleles, - const HaplotypeSupportMap& haplotype_support) +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + const HaplotypeProbabilityMap& log_priors, + AssignmentConfig config) { - AlleleSupportMap result {}; - result.reserve(alleles.size()); - for (const auto& allele : alleles) { - ReadRefSupportSet allele_support {}; - for (const auto& p : haplotype_support) { - if (p.first.contains(allele)) { - allele_support.insert(std::cend(allele_support), std::cbegin(p.second), std::cend(p.second)); - } - } - result.emplace(allele, std::move(allele_support)); - } - return result; + HaplotypeLikelihoodModel model {nullptr, make_indel_error_model(), false}; + return compute_haplotype_support(genotype, reads, log_priors, model, boost::none, config); +} + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + AssignmentConfig config) +{ + HaplotypeLikelihoodModel model {nullptr, make_indel_error_model(), false}; + return compute_haplotype_support(genotype, reads, model, config); +} + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + HaplotypeLikelihoodModel model, + AssignmentConfig config) +{ + return compute_haplotype_support(genotype, reads, {}, std::move(model), boost::none, config); +} + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + HaplotypeLikelihoodModel model, + AssignmentConfig config) +{ + return compute_haplotype_support(genotype, reads, std::move(model), ambiguous, config); +} + +AlleleSupportMap +compute_allele_support(const std::vector& alleles, const HaplotypeSupportMap& haplotype_support) +{ + return compute_allele_support(alleles, haplotype_support, + [] (const Haplotype& haplotype, const Allele& allele) { + return haplotype.includes(allele); + }); } } // namespace octopus diff --git a/src/core/tools/read_assigner.hpp b/src/core/tools/read_assigner.hpp index 87dad2dc9..ec266d2a6 100644 --- a/src/core/tools/read_assigner.hpp +++ b/src/core/tools/read_assigner.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_assigner_hpp @@ -20,29 +20,85 @@ namespace octopus { class HaplotypeLikelihoodModel; +using HaplotypeProbabilityMap = std::unordered_map; using ReadSupportSet = std::vector; using HaplotypeSupportMap = std::unordered_map; using ReadRefSupportSet = std::vector>; using AlleleSupportMap = std::unordered_map; -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads); +struct AssignmentConfig +{ + enum class AmbiguousAction { drop, first, random, all } ambiguous_action = AmbiguousAction::drop; +}; -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - std::deque& unassigned); +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + const HaplotypeProbabilityMap& log_priors, + AssignmentConfig config = AssignmentConfig {}); -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - HaplotypeLikelihoodModel model); +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + AssignmentConfig config = AssignmentConfig {}); -HaplotypeSupportMap compute_haplotype_support(const Genotype& genotype, - const std::vector& reads, - std::deque& unassigned, - HaplotypeLikelihoodModel model); +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + const HaplotypeProbabilityMap& log_priors, + AssignmentConfig config = AssignmentConfig {}); -AlleleSupportMap compute_allele_support(const std::vector& alleles, - const HaplotypeSupportMap& haplotype_support); +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + AssignmentConfig config = AssignmentConfig {}); + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + HaplotypeLikelihoodModel model, + AssignmentConfig config = AssignmentConfig {}); + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + const HaplotypeProbabilityMap& log_priors, + HaplotypeLikelihoodModel model, + AssignmentConfig config = AssignmentConfig {}); + +HaplotypeSupportMap +compute_haplotype_support(const Genotype& genotype, + const std::vector& reads, + std::deque& ambiguous, + HaplotypeLikelihoodModel model, + AssignmentConfig config = AssignmentConfig {}); + +template +AlleleSupportMap +compute_allele_support(const std::vector& alleles, + const HaplotypeSupportMap& haplotype_support, + BinaryPredicate inclusion_pred) +{ + AlleleSupportMap result {}; + result.reserve(alleles.size()); + for (const auto& allele : alleles) { + ReadRefSupportSet allele_support {}; + for (const auto& p : haplotype_support) { + if (inclusion_pred(p.first, allele)) { + allele_support.insert(std::cend(allele_support), std::cbegin(p.second), std::cend(p.second)); + } + } + result.emplace(allele, std::move(allele_support)); + } + return result; +} + +AlleleSupportMap +compute_allele_support(const std::vector& alleles, + const HaplotypeSupportMap& haplotype_support); } // namespace octopus diff --git a/src/core/tools/read_realigner.cpp b/src/core/tools/read_realigner.cpp index 0dab48ab7..87d077732 100644 --- a/src/core/tools/read_realigner.cpp +++ b/src/core/tools/read_realigner.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_realigner.hpp" @@ -6,10 +6,12 @@ #include #include #include +#include #include "basics/genomic_region.hpp" #include "utils/maths.hpp" #include "utils/kmer_mapper.hpp" +#include "utils/append.hpp" #include "core/models/error/error_model_factory.hpp" namespace octopus { @@ -69,29 +71,27 @@ auto compute_read_hashes(const std::vector& reads) return result; } -AlignedRead realign(const AlignedRead& read, GenomicRegion region, CigarString cigar) -{ - return AlignedRead {read.name(), std::move(region), read.sequence(), read.base_qualities(), - std::move(cigar), read.mapping_quality(), read.flags()}; -} - -AlignedRead realign(const AlignedRead& read, const Haplotype& haplotype, - GenomicRegion::Position mapping_position, CigarString alignment) +void realign(AlignedRead& read, const Haplotype& haplotype, + GenomicRegion::Position mapping_position, CigarString alignment) { const auto remapped_read_begin = mapped_begin(haplotype) + mapping_position; const auto remapped_read_end = remapped_read_begin + reference_size(alignment); - return realign(read, GenomicRegion {contig_name(read), remapped_read_begin, remapped_read_end}, std::move(alignment)); + read.realign(GenomicRegion {contig_name(read), remapped_read_begin, remapped_read_end}, std::move(alignment)); } -AlignedRead realign(const AlignedRead& read, const Haplotype& haplotype, HaplotypeLikelihoodModel::Alignment alignment) +void realign(AlignedRead& read, const Haplotype& haplotype, HaplotypeLikelihoodModel::Alignment alignment) { - return realign(read, haplotype, alignment.mapping_position, std::move(alignment.cigar)); + realign(read, haplotype, alignment.mapping_position, std::move(alignment.cigar)); } } // namespace -std::vector realign(const std::vector& reads, const Haplotype& haplotype, - HaplotypeLikelihoodModel model) +void realign(std::vector& reads, const Haplotype& haplotype) +{ + realign(reads, haplotype, HaplotypeLikelihoodModel {nullptr, make_indel_error_model(), false}); +} + +void realign(std::vector& reads, const Haplotype& haplotype, HaplotypeLikelihoodModel model) { const auto read_hashes = compute_read_hashes(reads); static constexpr unsigned char mapperKmerSize {6}; @@ -99,14 +99,18 @@ std::vector realign(const std::vector& reads, const Ha populate_kmer_hash_table(haplotype.sequence(), haplotype_hashes); auto haplotype_mapping_counts = init_mapping_counts(haplotype_hashes); model.reset(haplotype); - std::vector result {}; - result.reserve(reads.size()); - std::transform(std::cbegin(reads), std::cend(reads), std::cbegin(read_hashes), std::back_inserter(result), - [&] (const auto& read, const auto& read_hash) { - auto mapping_positions = map_query_to_target(read_hash, haplotype_hashes, haplotype_mapping_counts); - reset_mapping_counts(haplotype_mapping_counts); - return realign(read, haplotype, model.align(read, mapping_positions)); - }); + for (std::size_t i {0}; i < reads.size(); ++i) { + auto mapping_positions = map_query_to_target(read_hashes[i], haplotype_hashes, haplotype_mapping_counts); + reset_mapping_counts(haplotype_mapping_counts); + realign(reads[i], haplotype, model.align(reads[i], mapping_positions)); + } +} + +std::vector realign(const std::vector& reads, const Haplotype& haplotype, + HaplotypeLikelihoodModel model) +{ + auto result = reads; + realign(result, haplotype, model); return result; } @@ -115,15 +119,375 @@ std::vector realign(const std::vector& reads, const Ha return realign(reads, haplotype, HaplotypeLikelihoodModel {nullptr, make_indel_error_model(), false}); } -std::vector safe_realign(const std::vector& reads, const Haplotype& haplotype) +void safe_realign(std::vector& reads, const Haplotype& haplotype) { auto expanded_haplotype = expand_for_realignment(haplotype, reads); try { - return realign(reads, expanded_haplotype); + realign(reads, expanded_haplotype); } catch (const HaplotypeLikelihoodModel::ShortHaplotypeError& e) { expanded_haplotype = expand(expanded_haplotype, e.required_extension()); - return realign(reads, expanded_haplotype); + realign(reads, expanded_haplotype); + } +} + +std::vector safe_realign(const std::vector& reads, const Haplotype& haplotype) +{ + auto result = reads; + safe_realign(result, haplotype); + return result; +} + +namespace { + +CigarString minimise(const CigarString& cigar) +{ + CigarString result {}; + result.reserve(cigar.size()); + for (auto op_itr = std::cbegin(cigar); op_itr != std::cend(cigar);) { + const auto next_op_itr = std::find_if_not(std::next(op_itr), std::cend(cigar), + [=] (const auto& op) { return op.flag() == op_itr->flag();}); + const auto op_size = std::accumulate(op_itr, next_op_itr, CigarOperation::Size {0}, + [] (auto curr, const auto& op) { return curr + op.size(); }); + if (op_size > 0) { + result.emplace_back(op_size, op_itr->flag()); + } + op_itr = next_op_itr; + } + return result; +} + +auto get_match_type(const CigarOperation::Flag haplotype, const CigarOperation::Flag read) noexcept +{ + assert(is_match(haplotype) && is_match(read)); + using Flag = CigarOperation::Flag; + if ((haplotype == Flag::substitution && read == Flag::sequenceMatch) + || (haplotype == Flag::sequenceMatch && read == Flag::substitution)) { + return Flag::substitution; + } else if (haplotype == Flag::sequenceMatch && read == Flag::sequenceMatch) { + return Flag::sequenceMatch; + } else { + return Flag::alignmentMatch; + } +} + +} // namespace + +CigarString rebase(const CigarString& read_to_haplotype, const CigarString& haplotype_to_reference) +{ + assert(is_valid(read_to_haplotype) && is_valid(haplotype_to_reference)); + assert(reference_size(read_to_haplotype) <= sequence_size(haplotype_to_reference)); + const auto haplotypes_ops = decompose(haplotype_to_reference); + CigarString result {}; + result.reserve(haplotypes_ops.size()); + auto hap_flag_itr = std::cbegin(haplotypes_ops); + for (const auto& read_op : read_to_haplotype) { + if (is_match(read_op)) { + for (unsigned n {0}; n < read_op.size();) { + assert(hap_flag_itr != std::cend(haplotypes_ops)); + if (is_match(*hap_flag_itr)) { + result.emplace_back(1, get_match_type(*hap_flag_itr, read_op.flag())); + ++n; + } else { + result.emplace_back(1, *hap_flag_itr); + if (advances_sequence(*hap_flag_itr)) { + ++n; + } + } + ++hap_flag_itr; + } + } else if (is_insertion(read_op)) { + result.push_back(read_op); + } else { // deletion + auto op_size = read_op.size(); + for (unsigned n {0}; n < read_op.size();) { + assert(hap_flag_itr != std::cend(haplotypes_ops)); + if (is_deletion(*hap_flag_itr)) { + result.emplace_back(1, *hap_flag_itr); + ++hap_flag_itr; + } else { + if (is_insertion(*hap_flag_itr)) { + --op_size; + } + ++hap_flag_itr; + ++n; + } + } + if (op_size > 0) { + result.emplace_back(op_size, read_op.flag()); + } + } + } + return minimise(result); +} + +namespace { + +bool is_sequence_match(const CigarOperation& op) noexcept +{ + return op.flag() == CigarOperation::Flag::sequenceMatch; +} + +CigarString pad_reference(const GenomicRegion& read_region, const CigarString& read_to_haplotype, + const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference) +{ + assert(overlaps(read_region, haplotype_region)); + assert(!read_to_haplotype.empty() || !haplotype_to_reference.empty()); + CigarString result {}; + using Flag = CigarOperation::Flag; + if (read_region == haplotype_region) { + result = haplotype_to_reference; + } else if (contains(haplotype_region, read_region)) { + const auto offset = left_overhang_size(haplotype_region, read_region); + const auto copy_length = std::max(sequence_size(read_to_haplotype), sequence_size(haplotype_to_reference)); + result = copy_sequence(haplotype_to_reference, offset, copy_length); + } else { + result.reserve(haplotype_to_reference.size() + 2); + if (contains(read_region, haplotype_region)) { + const auto lhs_pad_size = left_overhang_size(read_region, haplotype_region); + const auto rhs_pad_size = right_overhang_size(read_region, haplotype_region); + if (is_sequence_match(haplotype_to_reference.front())) { + if (haplotype_to_reference.size() == 1) { + result.emplace_back(lhs_pad_size + haplotype_to_reference.front().size() + rhs_pad_size, Flag::sequenceMatch); + } else { + result.emplace_back(lhs_pad_size + haplotype_to_reference.front().size(), Flag::sequenceMatch); + result.insert(result.cend(), haplotype_to_reference.cbegin() + 1, haplotype_to_reference.cend() - 1); + if (is_sequence_match(haplotype_to_reference.back())) { + result.emplace_back(haplotype_to_reference.back().size() + rhs_pad_size, Flag::sequenceMatch); + } else { + result.push_back(haplotype_to_reference.back()); + if (rhs_pad_size > 0) result.emplace_back(rhs_pad_size, Flag::sequenceMatch); + } + } + } else { + if (lhs_pad_size > 0) result.emplace_back(lhs_pad_size, Flag::sequenceMatch); + if (is_sequence_match(haplotype_to_reference.back())) { + result.insert(result.cend(), haplotype_to_reference.cbegin(), haplotype_to_reference.cend() - 1); + result.emplace_back(haplotype_to_reference.back().size() + rhs_pad_size, Flag::sequenceMatch); + } else { + utils::append(haplotype_to_reference, result); + if (rhs_pad_size > 0) result.emplace_back(rhs_pad_size, Flag::sequenceMatch); + } + } + } else if (begins_before(read_region, haplotype_region)) { + assert(ends_before(read_region, haplotype_region)); + const auto lhs_pad_size = left_overhang_size(read_region, haplotype_region); + if (is_sequence_match(haplotype_to_reference.front())) { + result.emplace_back(lhs_pad_size + haplotype_to_reference.front().size(), Flag::sequenceMatch); + result.insert(result.cend(), haplotype_to_reference.cbegin() + 1, haplotype_to_reference.cend()); + } else { + assert(lhs_pad_size > 0); + result.emplace_back(lhs_pad_size, Flag::sequenceMatch); + utils::append(haplotype_to_reference, result); + } + } else { + assert(begins_before(haplotype_region, read_region) && ends_before(haplotype_region, read_region)); + const auto offset = left_overhang_size(haplotype_region, read_region); + const auto rhs_pad_size = right_overhang_size(read_region, haplotype_region); + result = copy_sequence(haplotype_to_reference, offset); + if (!result.empty() && is_sequence_match(result.back())) { + increment_size(result.back(), rhs_pad_size); + } else if (rhs_pad_size > 0) { + result.emplace_back(rhs_pad_size, Flag::sequenceMatch); + } + } + } + if (sequence_size(result) < reference_size(read_to_haplotype)) { + const auto rhs_pad_size = reference_size(read_to_haplotype) - sequence_size(result); + if (!result.empty() && is_sequence_match(result.back())) { + increment_size(result.back(), rhs_pad_size); + } else { + assert(rhs_pad_size > 0); + result.emplace_back(rhs_pad_size, Flag::sequenceMatch); + } + } + return result; +} + +CigarString copy_tail(const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference, + const GenomicRegion& rebased_read_region) +{ + const auto offset = left_overhang_size(haplotype_region, rebased_read_region); + auto result = copy_reference(haplotype_to_reference, offset, size(haplotype_region)); + if (sequence_size(result) < size(rebased_read_region)) { + const auto rhs_pad_size = size(rebased_read_region) - sequence_size(result); + if (is_sequence_match(result.back())) { + increment_size(result.back(), rhs_pad_size); + } else { + result.emplace_back(rhs_pad_size, CigarOperation::Flag::sequenceMatch); + } + } + return result; +} + +auto get_tail_op_size(const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference, + const GenomicRegion& read_region) +{ + auto tail = copy_sequence(haplotype_to_reference, left_overhang_size(haplotype_region, read_region), + haplotype_to_reference.back().size()); + return tail.empty() ? 0 : tail.back().size(); +} + +bool has_indels(const CigarString& cigar) noexcept +{ + return std::any_of(std::cbegin(cigar), std::cend(cigar), [] (const auto& op) { return is_indel(op); }); +} + +auto calculate_rebase_shift(const AlignedRead& read, const GenomicRegion& haplotype_region, + const CigarString& haplotype_to_reference) +{ + GenomicRegion::Distance result {0}; + if (begins_before(haplotype_region, read) && has_indels(haplotype_to_reference)) { + auto lhs_flank_length = static_cast(left_overhang_size(haplotype_region, read)); + for (const auto& op : haplotype_to_reference) { + if (lhs_flank_length == 0) { + if (is_deletion(op)) { + result += op.size(); + } + break; + } else if (lhs_flank_length < 0) { + break; + } + if (is_insertion(op)) { + result -= std::min(static_cast(op.size()), lhs_flank_length); + lhs_flank_length -= op.size(); + } else if (is_deletion(op)) { + result += op.size(); + } else { + lhs_flank_length -= op.size(); + } + } + } + return result; +} + +void rebase_overlapped(AlignedRead& read, const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference, + const GenomicRegion::Distance rebase_shift) +{ + auto padded_haplotype_cigar = pad_reference(read.mapped_region(), read.cigar(), haplotype_region, haplotype_to_reference); + assert(!padded_haplotype_cigar.empty()); + if (is_deletion(padded_haplotype_cigar.front())) { + padded_haplotype_cigar.erase(std::cbegin(padded_haplotype_cigar)); + } + auto rebased_cigar = rebase(read.cigar(), padded_haplotype_cigar); + auto rebased_read_region = expand_rhs(shift(head_region(read), rebase_shift), reference_size(rebased_cigar)); + read.realign(std::move(rebased_read_region), std::move(rebased_cigar)); +} + +void rebase_not_overlapped(AlignedRead& read, const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference, + GenomicRegion rebased_read_region) +{ + if (overlaps(rebased_read_region, haplotype_region)) { + auto padded_haplotype_cigar = copy_tail(haplotype_region, haplotype_to_reference, rebased_read_region); + assert(!padded_haplotype_cigar.empty()); + if (is_insertion(padded_haplotype_cigar.front())) { + const auto called_insertion_size = padded_haplotype_cigar.front().size(); + const auto rebase_shift = static_cast(begin_distance(rebased_read_region, read)); + const auto supported_insertion_size = called_insertion_size - std::min(rebase_shift, called_insertion_size); + assert(supported_insertion_size <= called_insertion_size); + if (supported_insertion_size > 0) { + padded_haplotype_cigar.front().set_size(supported_insertion_size); + if (is_match(padded_haplotype_cigar.back())) { + increment_size(padded_haplotype_cigar.back(), supported_insertion_size); + } else { + padded_haplotype_cigar.emplace_back(supported_insertion_size, CigarOperation::Flag::sequenceMatch); + } + } else { + padded_haplotype_cigar.erase(std::cbegin(padded_haplotype_cigar)); + } + } + if (reference_size(read.cigar()) > sequence_size(padded_haplotype_cigar)) { + const auto pad = reference_size(read.cigar()) - sequence_size(padded_haplotype_cigar); + if (is_match(padded_haplotype_cigar.back())) { + increment_size(padded_haplotype_cigar.back(), pad); + } else { + padded_haplotype_cigar.emplace_back(pad, CigarOperation::Flag::sequenceMatch); + } + } + auto rebased_cigar = rebase(read.cigar(), padded_haplotype_cigar); + rebased_read_region = expand_rhs(head_region(rebased_read_region), reference_size(rebased_cigar)); + read.realign(std::move(rebased_read_region), std::move(rebased_cigar)); + } else if (are_adjacent(haplotype_region, rebased_read_region) && is_insertion(haplotype_to_reference.back())) { + const auto insertion_size = get_tail_op_size(haplotype_region, haplotype_to_reference, mapped_region(read)); + if (insertion_size > 0) { + using Flag = CigarOperation::Flag; + const CigarString padded_haplotype_cigar {CigarOperation {insertion_size, Flag::insertion}, + CigarOperation {size(rebased_read_region), Flag::sequenceMatch}}; + auto rebased_cigar = rebase(read.cigar(), padded_haplotype_cigar); + rebased_read_region = expand_rhs(head_region(rebased_read_region), reference_size(rebased_cigar)); + read.realign(std::move(rebased_read_region), std::move(rebased_cigar)); + } else { + read.realign(std::move(rebased_read_region), read.cigar()); + } + } else { + read.realign(std::move(rebased_read_region), read.cigar()); + } +} + +void rebase_adjacent(AlignedRead& read, const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference) +{ + assert(!haplotype_to_reference.empty()); + if (is_before(haplotype_region, read) && is_insertion(haplotype_to_reference.back())) { + using Flag = CigarOperation::Flag; + const CigarString padded_haplotype_cigar {haplotype_to_reference.back(), + CigarOperation {region_size(read), Flag::sequenceMatch}}; + auto rebased_read_cigar = rebase(read.cigar(), padded_haplotype_cigar); + auto rebased_read_region = expand_rhs(head_region(read), reference_size(rebased_read_cigar)); + read.realign(std::move(rebased_read_region), std::move(rebased_read_cigar)); } } +void rebase(AlignedRead& read, const GenomicRegion& haplotype_region, const CigarString& haplotype_to_reference) +{ + const auto rebase_shift = calculate_rebase_shift(read, haplotype_region, haplotype_to_reference); + if (overlaps(read, haplotype_region)) { + rebase_overlapped(read, haplotype_region, haplotype_to_reference, rebase_shift); + } else if (rebase_shift != 0) { + auto rebased_read_region = shift(mapped_region(read), rebase_shift); + if (rebase_shift < 0) { + rebase_not_overlapped(read, haplotype_region, haplotype_to_reference, std::move(rebased_read_region)); + } else { + read.realign(std::move(rebased_read_region), read.cigar()); + } + } else if (are_adjacent(haplotype_region, read)) { + rebase_adjacent(read, haplotype_region, haplotype_to_reference); + } +} + +} // namespace + +void rebase(std::vector& reads, const Haplotype& haplotype) +{ + const auto haplotype_cigar = haplotype.cigar(); + for (auto& read : reads) { + rebase(read, haplotype.mapped_region(), haplotype_cigar); + } + std::sort(std::begin(reads), std::end(reads)); +} + +void realign_to_reference(std::vector& reads, const Haplotype& haplotype) +{ + realign(reads, haplotype); + rebase(reads, haplotype); +} + +std::vector realign_to_reference(const std::vector& reads, const Haplotype& haplotype) +{ + auto result = realign(reads, haplotype); + rebase(result, haplotype); + return result; +} + +void safe_realign_to_reference(std::vector& reads, const Haplotype& haplotype) +{ + safe_realign(reads, haplotype); + rebase(reads, haplotype); +} + +std::vector safe_realign_to_reference(const std::vector& reads, const Haplotype& haplotype) +{ + auto result = safe_realign(reads, haplotype); + rebase(result, haplotype); + return result; +} + } // namespace octopus diff --git a/src/core/tools/read_realigner.hpp b/src/core/tools/read_realigner.hpp index 03a0a6072..5d5c8b434 100644 --- a/src/core/tools/read_realigner.hpp +++ b/src/core/tools/read_realigner.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_realigner_hpp @@ -15,13 +15,24 @@ namespace octopus { Haplotype expand_for_realignment(const Haplotype& haplotype, const std::vector& reads); +void realign(std::vector& reads, const Haplotype& haplotype); +void realign(std::vector& reads, const Haplotype& haplotype, HaplotypeLikelihoodModel model); std::vector realign(const std::vector& reads, const Haplotype& haplotype, HaplotypeLikelihoodModel model); - std::vector realign(const std::vector& reads, const Haplotype& haplotype); +void safe_realign(std::vector& reads, const Haplotype& haplotype); std::vector safe_realign(const std::vector& reads, const Haplotype& haplotype); +CigarString rebase(const CigarString& read_to_haplotype, const CigarString& haplotype_to_reference); +void rebase(std::vector& reads, const Haplotype& haplotype); + +void realign_to_reference(std::vector& reads, const Haplotype& haplotype); +std::vector realign_to_reference(const std::vector& reads, const Haplotype& haplotype); + +void safe_realign_to_reference(std::vector& reads, const Haplotype& haplotype); +std::vector safe_realign_to_reference(const std::vector& reads, const Haplotype& haplotype); + } // namespace #endif diff --git a/src/core/tools/vargen/active_region_generator.cpp b/src/core/tools/vargen/active_region_generator.cpp new file mode 100644 index 000000000..d541c9e3c --- /dev/null +++ b/src/core/tools/vargen/active_region_generator.cpp @@ -0,0 +1,109 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "active_region_generator.hpp" + +#include + +#include "utils/repeat_finder.hpp" +#include "utils/mappable_algorithms.hpp" +#include "utils/append.hpp" + +namespace octopus { namespace coretools { + +ActiveRegionGenerator::ActiveRegionGenerator(const ReferenceGenome& reference, Options options) +: reference_ {reference} +, options_ {options} +, assembler_name_ {"LocalReassembler"} +, cigar_scanner_name_ {"CigarScanner"} +, using_assembler_ {false} +, assembler_active_region_generator_ {} +, max_read_length_ {} +{} + +void ActiveRegionGenerator::add_generator(const std::string& name) +{ + if (is_assembler(name)) { + using_assembler_ = true; + if (!options_.assemble_all) { + assembler_active_region_generator_ = AssemblerActiveRegionGenerator {reference_}; + } + } +} + +void ActiveRegionGenerator::add_read(const SampleName& sample, const AlignedRead& read) +{ + max_read_length_ = std::max(max_read_length_, sequence_size(read)); + if (assembler_active_region_generator_) assembler_active_region_generator_->add(sample, read); +} + +auto merge(std::vector lhs, std::vector rhs) +{ + auto itr = utils::append(std::move(rhs), lhs); + std::inplace_merge(std::begin(lhs), itr, std::end(lhs)); + return extract_covered_regions(lhs); +} + +template +auto append(const Range& range, std::vector& result) +{ + return result.insert(std::cend(result), std::cbegin(range), std::cend(range)); +}; + +auto find_minisatellites(const std::vector& repeats, const GenomicRegion& region, + const std::size_t max_read_length) +{ + InexactRepeatDefinition minisatellite_def {}; + minisatellite_def.min_exact_repeat_seed_length = 2 * max_read_length / 3; + minisatellite_def.min_exact_repeat_seed_periods = 3; + minisatellite_def.max_seed_join_distance = max_read_length / 3; + minisatellite_def.min_joined_repeat_length = 2 * max_read_length; + return join(find_repeat_regions(repeats, region, minisatellite_def), max_read_length / 2); +} + +auto find_compound_microsatellites(const std::vector& repeats, const GenomicRegion& region, + const std::size_t max_read_length) +{ + InexactRepeatDefinition compound_microsatellite_def {}; + compound_microsatellite_def.max_exact_repeat_seed_period = 6; + compound_microsatellite_def.min_exact_repeat_seed_length = 4; + compound_microsatellite_def.min_exact_repeat_seed_periods = 4; + compound_microsatellite_def.min_exact_seeds = 2; + compound_microsatellite_def.max_seed_join_distance = 1; + compound_microsatellite_def.min_joined_repeat_length = max_read_length / 4; + return join(find_repeat_regions(repeats, region, compound_microsatellite_def), max_read_length / 2); +} + +std::vector ActiveRegionGenerator::generate(const GenomicRegion& region, const std::string& generator) const +{ + if (is_assembler(generator) && assembler_active_region_generator_) { + return assembler_active_region_generator_->generate(region); + } else { + return {region}; + } +} + +void ActiveRegionGenerator::clear() noexcept +{ + if (assembler_active_region_generator_) assembler_active_region_generator_->clear(); +} + +// private methods + +bool ActiveRegionGenerator::is_cigar_scanner(const std::string& generator) const noexcept +{ + return generator == cigar_scanner_name_; +} + +bool ActiveRegionGenerator::is_assembler(const std::string& generator) const noexcept +{ + return generator == assembler_name_; +} + +bool ActiveRegionGenerator::using_assembler() const noexcept +{ + return using_assembler_; +} + +} // namespace coretools +} // namespace octopus diff --git a/src/core/tools/vargen/active_region_generator.hpp b/src/core/tools/vargen/active_region_generator.hpp new file mode 100644 index 000000000..576e26ec6 --- /dev/null +++ b/src/core/tools/vargen/active_region_generator.hpp @@ -0,0 +1,89 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef active_region_generator_hpp +#define active_region_generator_hpp + +#include +#include +#include +#include + +#include + +#include "config/common.hpp" +#include "basics/genomic_region.hpp" +#include "basics/aligned_read.hpp" +#include "io/reference/reference_genome.hpp" +#include "utils/assembler_active_region_generator.hpp" + +namespace octopus { namespace coretools { + +class ActiveRegionGenerator +{ +public: + struct Options + { + bool assemble_all = false; + }; + + ActiveRegionGenerator() = delete; + + ActiveRegionGenerator(const ReferenceGenome& reference, Options options); + + ActiveRegionGenerator(const ActiveRegionGenerator&) = default; + ActiveRegionGenerator& operator=(const ActiveRegionGenerator&) = default; + ActiveRegionGenerator(ActiveRegionGenerator&&) = default; + ActiveRegionGenerator& operator=(ActiveRegionGenerator&&) = default; + + ~ActiveRegionGenerator() = default; + + void add_generator(const std::string& name); + + void add_read(const SampleName& sample, const AlignedRead& read); + template + void add_reads(const SampleName& sample, ForwardIterator first, ForwardIterator last); + + std::vector generate(const GenomicRegion& region, const std::string& generator) const; + + void clear() noexcept; + +private: + struct RepeatRegions + { + GenomicRegion request_region; + std::vector minisatellites, compound_microsatellites; + std::vector assembler_microsatellites; + }; + struct AssemblerActiveRegions + { + GenomicRegion request_region; + std::vector active_regions; + }; + + std::reference_wrapper reference_; + Options options_; + std::string assembler_name_, cigar_scanner_name_; + bool using_assembler_; + + boost::optional assembler_active_region_generator_; + std::size_t max_read_length_; + mutable boost::optional repeats_; + mutable boost::optional assembler_active_regions_; + + bool is_cigar_scanner(const std::string& generator) const noexcept; + bool is_assembler(const std::string& generator) const noexcept; + bool using_assembler() const noexcept; +}; + +template +void ActiveRegionGenerator::add_reads(const SampleName& sample, ForwardIterator first, ForwardIterator last) +{ + if (assembler_active_region_generator_) assembler_active_region_generator_->add(sample, first, last); + std::for_each(first, last, [this] (const auto& read) { max_read_length_ = std::max(max_read_length_, sequence_size(read)); }); +} + +} // namespace coretools +} // namespace octopus + +#endif diff --git a/src/core/tools/vargen/cigar_scanner.cpp b/src/core/tools/vargen/cigar_scanner.cpp index 1cf0b74ad..a30451ec6 100644 --- a/src/core/tools/vargen/cigar_scanner.cpp +++ b/src/core/tools/vargen/cigar_scanner.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "cigar_scanner.hpp" @@ -17,6 +17,7 @@ #include "concepts/mappable_range.hpp" #include "utils/mappable_algorithms.hpp" #include "utils/append.hpp" +#include "utils/sequence_utils.hpp" #include "logging/logging.hpp" #include "utils/maths.hpp" @@ -82,25 +83,19 @@ void CigarScanner::add_read(const SampleName& sample, const AlignedRead& read, { using std::cbegin; using std::next; using std::move; using Flag = CigarOperation::Flag; - const auto& read_contig = contig_name(read); const auto& read_sequence = read.sequence(); - auto sequence_iter = cbegin(read_sequence); - auto base_quality_iter = cbegin(read.base_qualities()); auto ref_index = mapped_begin(read); std::size_t read_index {0}; GenomicRegion region; double misalignment_penalty {0}; + buffer_.clear(); for (const auto& cigar_operation : read.cigar()) { const auto op_size = cigar_operation.size(); switch (cigar_operation.flag()) { case Flag::alignmentMatch: misalignment_penalty += add_snvs_in_match_range(GenomicRegion {read_contig, ref_index, ref_index + op_size}, - next(sequence_iter, read_index), - next(sequence_iter, read_index + op_size), - sample, - next(base_quality_iter, read_index), - read.direction()); + read, read_index, sample); read_index += op_size; ref_index += op_size; break; @@ -114,9 +109,7 @@ void CigarScanner::add_read(const SampleName& sample, const AlignedRead& read, add_candidate(region, reference_.get().fetch_sequence(region), copy(read_sequence, read_index, op_size), - sample, - next(base_quality_iter, read_index), - read.direction()); + read, read_index, sample); read_index += op_size; ref_index += op_size; misalignment_penalty += op_size * options_.misalignment_parameters.snv_penalty; @@ -127,9 +120,7 @@ void CigarScanner::add_read(const SampleName& sample, const AlignedRead& read, add_candidate(GenomicRegion {read_contig, ref_index, ref_index}, "", copy(read_sequence, read_index, op_size), - sample, - next(base_quality_iter, read_index), - read.direction()); + read, read_index, sample); read_index += op_size; misalignment_penalty += options_.misalignment_parameters.indel_penalty; break; @@ -140,9 +131,7 @@ void CigarScanner::add_read(const SampleName& sample, const AlignedRead& read, add_candidate(move(region), reference_.get().fetch_sequence(region), "", - sample, - next(base_quality_iter, read_index), - read.direction()); + read, read_index, sample); ref_index += op_size; misalignment_penalty += options_.misalignment_parameters.indel_penalty; break; @@ -187,13 +176,13 @@ void CigarScanner::add_read(const SampleName& sample, const AlignedRead& read, } } -void CigarScanner::do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) +void CigarScanner::do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) { auto& sample_coverage_tracker = sample_read_coverage_tracker_[sample]; std::for_each(first, last, [&] (const AlignedRead& read) { add_read(sample, read, sample_coverage_tracker); }); } -void CigarScanner::do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) +void CigarScanner::do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) { auto& sample_coverage_tracker = sample_read_coverage_tracker_[sample]; std::for_each(first, last, [&] (const AlignedRead& read) { add_read(sample, read, sample_coverage_tracker); }); @@ -204,12 +193,12 @@ unsigned get_min_depth(const Variant& v, const CoverageTracker& t if (is_insertion(v)) { const auto& region = mapped_region(v); if (region.begin() > 0) { - return tracker.min_coverage(expand(region, 1, 1)); + return tracker.min(expand(region, 1, 1)); } else { - return tracker.min_coverage(expand_rhs(region, 1)); + return tracker.min(expand_rhs(region, 1)); } } else { - return tracker.min_coverage(mapped_region(v)); + return tracker.min(mapped_region(v)); } } @@ -254,69 +243,14 @@ void choose_push_back(Variant candidate, std::vector& final_candidates, } } -std::vector CigarScanner::do_generate_variants(const GenomicRegion& region) +std::vector CigarScanner::do_generate(const RegionSet& regions) const { - using std::begin; using std::end; using std::cbegin; using std::cend; using std::next; - - std::sort(begin(candidates_), end(candidates_)); - auto viable_candidates = overlap_range(candidates_, region, max_seen_candidate_size_); + std::sort(std::begin(candidates_), std::end(candidates_)); + std::sort(std::begin(likely_misaligned_candidates_), std::end(likely_misaligned_candidates_)); std::vector result {}; - if (empty(viable_candidates)) return result; - result.reserve(size(viable_candidates, BidirectionallySortedTag {})); // maximum possible - const auto repeat_regions = get_repeat_regions(region); - auto repeat_buckets = init_variant_buckets(repeat_regions); - const auto last_viable_candidate_itr = cend(viable_candidates); - - while (!viable_candidates.empty()) { - const Candidate& candidate {viable_candidates.front()}; - const auto next_candidate_itr = std::find_if_not(next(cbegin(viable_candidates)), last_viable_candidate_itr, - [this, &candidate] (const Candidate& c) { - return options_.match(c.variant, candidate.variant); - }); - const auto num_matches = std::distance(cbegin(viable_candidates), next_candidate_itr); - const auto observation = make_observation(cbegin(viable_candidates), next_candidate_itr); - if (options_.include(observation)) { - if (num_matches > 1) { - auto unique_itr = cbegin(viable_candidates); - while (unique_itr != next_candidate_itr) { - choose_push_back(unique_itr->variant, result, repeat_buckets); - unique_itr = std::find_if_not(next(unique_itr), next_candidate_itr, - [unique_itr] (const Candidate& c) { - return c.variant == unique_itr->variant; - }); - } - } else { - choose_push_back(candidate.variant, result, repeat_buckets); - } - } - viable_candidates.advance_begin(num_matches); - } - const auto novel_unique_misaligned_variants = get_novel_likely_misaligned_candidates(result); - if (debug_log_ && !novel_unique_misaligned_variants.empty()) { - stream(*debug_log_) << "DynamicCigarScanner: ignoring " - << count_overlapped(novel_unique_misaligned_variants, region) - << " unique candidates in " << region; + for (const auto& region : regions) { + generate(region, result); } - for (const auto& candidate : novel_unique_misaligned_variants) { - auto bucket = find_contained(repeat_buckets, candidate); - if (bucket) bucket->variants.push_back(candidate); - } - for (auto& bucket : repeat_buckets) { - std::sort(begin(bucket.variants), end(bucket.variants)); - } - std::vector good_repeat_region_variants {}; - for (auto& bucket : repeat_buckets) { - if (options_.include_repeat_region(bucket.region, bucket.variants)) { - utils::append(std::move(bucket.variants), good_repeat_region_variants); - } else { - if (debug_log_) { - stream(*debug_log_) << "DynamicCigarScanner: ignoring " << bucket.variants.size() - << " candidates in repetitive region " << bucket.region; - } - } - } - auto itr = utils::append(std::move(good_repeat_region_variants), result); - std::inplace_merge(begin(result), itr, end(result)); return result; } @@ -340,50 +274,72 @@ std::string CigarScanner::name() const // private methods -double CigarScanner::add_snvs_in_match_range(const GenomicRegion& region, - const SequenceIterator first_base, const SequenceIterator last_base, - const SampleName& origin, - AlignedRead::BaseQualityVector::const_iterator first_base_quality, - AlignedRead::Direction support_direction) +double CigarScanner::add_snvs_in_match_range(const GenomicRegion& region, const AlignedRead& read, + std::size_t read_index, const SampleName& origin) { - using boost::make_zip_iterator; using std::for_each; using std::cbegin; using std::cend; - using Tuple = boost::tuple; const NucleotideSequence ref_segment {reference_.get().fetch_sequence(region)}; - const auto& contig = region.contig_name(); - auto ref_index = mapped_begin(region); double misalignment_penalty {0}; - for_each(make_zip_iterator(boost::make_tuple(cbegin(ref_segment), first_base)), - make_zip_iterator(boost::make_tuple(cend(ref_segment), last_base)), - [this, &contig, &ref_index, &origin, &first_base_quality, - &misalignment_penalty, support_direction] (const Tuple& t) { - const char ref_base {t.get<0>()}, read_base {t.get<1>()}; - if (ref_base != read_base && ref_base != 'N' && read_base != 'N') { - add_candidate(GenomicRegion {contig, ref_index, ref_index + 1}, - ref_base, read_base, origin, first_base_quality, support_direction); - if (*first_base_quality >= options_.misalignment_parameters.snv_threshold) { - misalignment_penalty += options_.misalignment_parameters.snv_penalty; - } - } - ++ref_index; - ++first_base_quality; - }); + for (std::size_t ref_index {0}; ref_index < ref_segment.size(); ++ref_index, ++read_index) { + const char ref_base {ref_segment[ref_index]}, read_base {read.sequence()[read_index]}; + if (ref_base != read_base && ref_base != 'N' && read_base != 'N') { + const auto begin_pos = region.begin() + static_cast(ref_index); + add_candidate(GenomicRegion {region.contig_name(), begin_pos, begin_pos + 1}, + ref_base, read_base, read, read_index, origin); + if (read.base_qualities()[read_index] >= options_.misalignment_parameters.snv_threshold) { + misalignment_penalty += options_.misalignment_parameters.snv_penalty; + } + } + } return misalignment_penalty; } -unsigned CigarScanner::sum_base_qualities(const Candidate& candidate) const noexcept +void CigarScanner::generate(const GenomicRegion& region, std::vector& result) const { - return std::accumulate(candidate.first_base_quality_iter, - std::next(candidate.first_base_quality_iter, alt_sequence_size(candidate.variant)), - 0u); + using std::begin; using std::end; using std::cbegin; using std::cend; using std::next; + assert(std::is_sorted(std::cbegin(candidates_), std::cend(candidates_))); + auto viable_candidates = overlap_range(candidates_, region, max_seen_candidate_size_); + if (empty(viable_candidates)) return; + result.reserve(result.size() + size(viable_candidates, BidirectionallySortedTag {})); // maximum possible + const auto last_viable_candidate_itr = cend(viable_candidates); + while (!viable_candidates.empty()) { + const Candidate& candidate {viable_candidates.front()}; + const auto next_candidate_itr = std::find_if_not(next(cbegin(viable_candidates)), last_viable_candidate_itr, + [this, &candidate] (const Candidate& c) { + return options_.match(c.variant, candidate.variant); + }); + const auto num_matches = std::distance(cbegin(viable_candidates), next_candidate_itr); + const auto observation = make_observation(cbegin(viable_candidates), next_candidate_itr); + if (options_.include(observation)) { + if (num_matches > 1) { + auto unique_itr = cbegin(viable_candidates); + while (unique_itr != next_candidate_itr) { + result.push_back(unique_itr->variant); + unique_itr = std::find_if_not(next(unique_itr), next_candidate_itr, + [unique_itr] (const Candidate& c) { + return c.variant == unique_itr->variant; + }); + } + } else { + result.push_back(candidate.variant); + } + } + viable_candidates.advance_begin(num_matches); + } + if (debug_log_ && !likely_misaligned_candidates_.empty()) { + const auto novel_unique_misaligned_variants = get_novel_likely_misaligned_candidates(result); + if (!novel_unique_misaligned_variants.empty()) { + stream(*debug_log_) << "DynamicCigarScanner: ignoring " + << count_overlapped(novel_unique_misaligned_variants, region) + << " unique candidates in " << region; + } + } } -std::vector CigarScanner::get_repeat_regions(const GenomicRegion& region) const +unsigned CigarScanner::sum_base_qualities(const Candidate& candidate) const noexcept { - if (options_.repeat_region_generator) { - return (*options_.repeat_region_generator)(reference_, region); - } else { - return {}; - } + const auto first_base_qual_itr = std::next(std::cbegin(candidate.source.get().base_qualities()), candidate.offset); + const auto last_base_qual_itr = std::next(first_base_qual_itr, alt_sequence_size(candidate.variant)); + return std::accumulate(first_base_qual_itr, last_base_qual_itr, 0u); } bool CigarScanner::is_likely_misaligned(const AlignedRead& read, const double penalty) const @@ -393,6 +349,7 @@ bool CigarScanner::is_likely_misaligned(const AlignedRead& read, const double pe auto min_ln_prob_misaligned = options_.misalignment_parameters.min_ln_prob_correctly_aligned; return ln_prob_misaligned < min_ln_prob_misaligned; } + CigarScanner::ObservedVariant CigarScanner::make_observation(const CandidateIterator first_match, const CandidateIterator last_match) const { @@ -401,34 +358,39 @@ CigarScanner::make_observation(const CandidateIterator first_match, const Candid ObservedVariant result {}; result.variant = candidate.variant; result.total_depth = get_min_depth(candidate.variant, read_coverage_tracker_); - result.num_samples = sample_read_coverage_tracker_.size(); std::vector observations {first_match, last_match}; - std::sort(begin(observations), end(observations), [] (const Candidate& lhs, const Candidate& rhs) { return lhs.origin < rhs.origin; }); + std::sort(begin(observations), end(observations), + [] (const Candidate& lhs, const Candidate& rhs) { return lhs.origin.get() < rhs.origin.get(); }); for (auto observation_itr = begin(observations); observation_itr != end(observations);) { const auto& origin = observation_itr->origin; auto next_itr = std::find_if_not(next(observation_itr), end(observations), - [&] (const Candidate& c) { return c.origin == origin; }); - std::vector observed_qualities(std::distance(observation_itr, next_itr)); - std::transform(observation_itr, next_itr, begin(observed_qualities), + [&] (const Candidate& c) { return c.origin.get() == origin.get(); }); + const auto num_observations = static_cast(std::distance(observation_itr, next_itr)); + std::vector observed_base_qualities(num_observations); + std::transform(observation_itr, next_itr, begin(observed_base_qualities), [this] (const Candidate& c) noexcept { return sum_base_qualities(c); }); + std::vector observed_mapping_qualities(num_observations); + std::transform(observation_itr, next_itr, begin(observed_mapping_qualities), + [] (const Candidate& c) noexcept { return c.source.get().mapping_quality(); }); const auto num_fwd_support = std::accumulate(observation_itr, next_itr, 0u, [] (unsigned curr, const Candidate& c) noexcept { - if (c.support_direction == AlignedRead::Direction::forward) { + if (c.source.get().direction() == AlignedRead::Direction::forward) { ++curr; } return curr; }); const auto depth = get_min_depth(candidate.variant, sample_read_coverage_tracker_.at(origin)); - result.sample_observations.push_back({depth, std::move(observed_qualities), num_fwd_support}); + result.sample_observations.push_back({origin, depth, std::move(observed_base_qualities), + std::move(observed_mapping_qualities), num_fwd_support}); observation_itr = next_itr; } return result; } std::vector -CigarScanner::get_novel_likely_misaligned_candidates(const std::vector& current_candidates) +CigarScanner::get_novel_likely_misaligned_candidates(const std::vector& current_candidates) const { - std::sort(std::begin(likely_misaligned_candidates_), std::end(likely_misaligned_candidates_)); + std::is_sorted(std::cbegin(likely_misaligned_candidates_), std::cend(likely_misaligned_candidates_)); std::vector unique_misaligned_candidates {}; unique_misaligned_candidates.reserve(likely_misaligned_candidates_.size()); std::unique_copy(std::cbegin(likely_misaligned_candidates_), std::cend(likely_misaligned_candidates_), @@ -469,55 +431,108 @@ void partial_sort(std::vector& observed_qualities, const unsigned n) std::end(observed_qualities), std::greater<> {}); } -bool is_strongly_strand_biased(const unsigned num_observations, const unsigned num_fwd_observations) noexcept +bool is_strongly_strand_biased(const unsigned num_fwd_observations, const unsigned num_rev_observations, + const unsigned min_observations = 20) noexcept +{ + const auto num_observations = num_fwd_observations + num_rev_observations; + return num_observations > min_observations && (num_observations == num_fwd_observations || num_fwd_observations == 0); +} + +bool is_likely_runthrough_artifact(const unsigned num_fwd_observations, const unsigned num_rev_observations, + std::vector& observed_qualities) { - return num_observations > 20 && (num_observations == num_fwd_observations || num_fwd_observations == 0); + if (!is_strongly_strand_biased(num_fwd_observations, num_rev_observations, 10)) return false; + assert(!observed_qualities.empty()); + const auto median_bq = maths::median(observed_qualities); + return median_bq < 15; } -bool is_good(const Variant& variant, const unsigned depth, const unsigned num_fwd_observations, - std::vector observed_qualities) +bool is_tandem_repeat(const Allele& allele, const unsigned max_period = 4) +{ + for (unsigned period {0}; period <= max_period; ++period) { + if (utils::is_tandem_repeat(allele.sequence(), period)) return true; + } + return false; +} + +bool is_good_germline(const Variant& variant, const unsigned depth, const unsigned num_fwd_observations, + std::vector observed_qualities) { const auto num_observations = observed_qualities.size(); if (depth < 4) { - return num_observations > 1 || sum(observed_qualities) >= 20 || is_deletion(variant); + return num_observations > 1 || sum(observed_qualities) >= 30 || is_deletion(variant); } - if (is_strongly_strand_biased(num_observations, num_fwd_observations)) { + const auto num_rev_observations = num_observations - num_fwd_observations; + if (is_strongly_strand_biased(num_fwd_observations, num_rev_observations)) { return false; } if (is_snv(variant)) { - const auto base_quality_sum = sum(observed_qualities); - if (depth <= 60) { - if (num_observations < 2) return false; - if (base_quality_sum > 100) return true; - erase_below(observed_qualities, 5); - if (observed_qualities.size() < 2) return false; - if (static_cast(observed_qualities.size()) / depth > 0.2) return true; + if (is_likely_runthrough_artifact(num_fwd_observations, num_rev_observations, observed_qualities)) return false; + erase_below(observed_qualities, 20); + if (depth <= 10) return observed_qualities.size() > 1; + return observed_qualities.size() > 2 && static_cast(observed_qualities.size()) / depth > 0.1; + } else if (is_insertion(variant)) { + if (num_observations == 1 && alt_sequence_size(variant) > 10) return false; + if (depth < 10) { + return num_observations > 1 || (alt_sequence_size(variant) > 3 && is_tandem_repeat(variant.alt_allele())); + } else if (depth <= 30) { + return num_observations > 1; + } else if (depth <= 60) { + if (num_observations == 1) return false; + if (static_cast(num_observations) / depth > 0.3) return true; + erase_below(observed_qualities, 25); + if (observed_qualities.size() <= 1) return false; + if (observed_qualities.size() > 2) return true; partial_sort(observed_qualities, 2); - return observed_qualities[0] >= 20 && observed_qualities[1] >= 20; - } else if (depth < 300) { - if (num_observations < 3) return false; - if (base_quality_sum > 150) return true; - erase_below(observed_qualities, 10); - if (observed_qualities.size() < 3) return false; - if (static_cast(observed_qualities.size()) / depth > 0.2) return true; - partial_sort(observed_qualities, 3); - return observed_qualities[0] >= 30 && observed_qualities[1] >= 25 && observed_qualities[2] >= 20; + return static_cast(observed_qualities[0]) / alt_sequence_size(variant) > 20; + } else { + if (num_observations == 1) return false; + if (static_cast(num_observations) / depth > 0.35) return true; + erase_below(observed_qualities, 20); + if (observed_qualities.size() <= 1) return false; + if (observed_qualities.size() > 3) return true; + return static_cast(observed_qualities[0]) / alt_sequence_size(variant) > 20; + } + } else { + // deletion or mnv + if (region_size(variant) < 10) { + return num_observations > 1 && static_cast(num_observations) / depth > 0.05; + } else { + return static_cast(num_observations) / (depth - std::sqrt(depth)) > 0.1; + } + } +} + +bool is_good_somatic(const Variant& variant, const unsigned depth, const unsigned num_fwd_observations, + std::vector observed_qualities) +{ + const auto num_observations = observed_qualities.size(); + if (depth < 4) { + return num_observations > 1 || sum(observed_qualities) >= 20 || is_deletion(variant); + } + const auto num_rev_observations = num_observations - num_fwd_observations; + if (is_strongly_strand_biased(num_fwd_observations, num_rev_observations)) { + return false; + } + if (is_snv(variant)) { + if (is_likely_runthrough_artifact(num_fwd_observations, num_rev_observations, observed_qualities)) return false; + erase_below(observed_qualities, 10); + if (depth <= 30) { + return observed_qualities.size() >= 2; } else { - if (num_observations < 8) return false; - erase_below(observed_qualities, 10); - return observed_qualities.size() > 7; + return static_cast(observed_qualities.size()) / depth > 0.03; } } else if (is_insertion(variant)) { if (num_observations == 1 && alt_sequence_size(variant) > 8) return false; - if (depth <= 15) { - return num_observations > 1; + if (depth <= 10) { + return num_observations > 1 || (alt_sequence_size(variant) > 3 && is_tandem_repeat(variant.alt_allele())); } else if (depth <= 30) { - if (static_cast(num_observations) / depth > 0.45) return true; + if (static_cast(num_observations) / depth > 0.35) return true; erase_below(observed_qualities, 20); return num_observations > 1; } else if (depth <= 60) { if (num_observations == 1) return false; - if (static_cast(num_observations) / depth > 0.4) return true; + if (static_cast(num_observations) / depth > 0.3) return true; erase_below(observed_qualities, 25); if (observed_qualities.size() <= 1) return false; if (observed_qualities.size() > 2) return true; @@ -525,7 +540,7 @@ bool is_good(const Variant& variant, const unsigned depth, const unsigned num_fw return static_cast(observed_qualities[0]) / alt_sequence_size(variant) > 20; } else { if (num_observations == 1) return false; - if (static_cast(num_observations) / depth > 0.35) return true; + if (static_cast(num_observations) / depth > 0.25) return true; erase_below(observed_qualities, 20); if (observed_qualities.size() <= 1) return false; if (observed_qualities.size() > 3) return true; @@ -534,20 +549,71 @@ bool is_good(const Variant& variant, const unsigned depth, const unsigned num_fw } else { // deletion or mnv if (region_size(variant) < 10) { - return num_observations > 1 && static_cast(num_observations) / depth > 0.05; + return num_observations > 1 && static_cast(num_observations) / depth > 0.02; } else { - return num_observations > 1 && num_observations >= 5; + return static_cast(num_observations) / (depth - std::sqrt(depth)) > 0.04; } } } +bool is_good_germline(const Variant& v, const CigarScanner::ObservedVariant::SampleObservation& observation) +{ + return is_good_germline(v, observation.depth, observation.num_fwd_observations, observation.observed_base_qualities); +} + +bool any_good_germline_samples(const CigarScanner::ObservedVariant& candidate) +{ + return std::any_of(std::cbegin(candidate.sample_observations), std::cend(candidate.sample_observations), + [&] (const auto& observation) { return is_good_germline(candidate.variant, observation); }); +} + +auto count_forward_observations(const CigarScanner::ObservedVariant& candidate) +{ + return std::accumulate(std::cbegin(candidate.sample_observations), std::cend(candidate.sample_observations), 0u, + [&] (auto curr, const auto& observation) { return curr + observation.num_fwd_observations; }); +} + +auto concat_observed_base_qualities(const CigarScanner::ObservedVariant& candidate) +{ + std::size_t num_base_qualities {0}; + for (const auto& observation : candidate.sample_observations) { + num_base_qualities += observation.observed_base_qualities.size(); + } + std::vector result {}; + result.reserve(num_base_qualities); + for (const auto& observation : candidate.sample_observations) { + utils::append(observation.observed_base_qualities, result); + } + return result; +} + +bool is_good_germline_pooled(const CigarScanner::ObservedVariant& candidate) +{ + return is_good_germline(candidate.variant, candidate.total_depth, count_forward_observations(candidate), + concat_observed_base_qualities(candidate)); +} + +bool is_good_somatic(const Variant& v, const CigarScanner::ObservedVariant::SampleObservation& observation) +{ + return is_good_somatic(v, observation.depth, observation.num_fwd_observations, observation.observed_base_qualities); +} + } // namespace bool DefaultInclusionPredicate::operator()(const CigarScanner::ObservedVariant& candidate) +{ + return any_good_germline_samples(candidate) || (candidate.sample_observations.size() > 1 && is_good_germline_pooled(candidate)); +} + +bool DefaultSomaticInclusionPredicate::operator()(const CigarScanner::ObservedVariant& candidate) { return std::any_of(std::cbegin(candidate.sample_observations), std::cend(candidate.sample_observations), - [&] (const auto& o) { - return is_good(candidate.variant, o.depth, o.num_fwd_observations, o.observed_qualities); + [&] (const auto& observation) { + if (normal_ && observation.sample.get() == *normal_) { + return is_good_germline(candidate.variant, observation); + } else { + return is_good_somatic(candidate.variant, observation); + } }); } @@ -556,7 +622,7 @@ namespace { auto count_observations(const CigarScanner::ObservedVariant& candidate) { return std::accumulate(std::cbegin(candidate.sample_observations), std::cend(candidate.sample_observations), std::size_t {0}, - [] (auto curr, const auto& sample) { return curr + sample.observed_qualities.size(); }); + [] (auto curr, const auto& sample) { return curr + sample.observed_base_qualities.size(); }); } } // namespace diff --git a/src/core/tools/vargen/cigar_scanner.hpp b/src/core/tools/vargen/cigar_scanner.hpp index 84d559e97..d4b1e14f5 100644 --- a/src/core/tools/vargen/cigar_scanner.hpp +++ b/src/core/tools/vargen/cigar_scanner.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef cigar_scanner_hpp @@ -11,6 +11,8 @@ #include #include +#include + #include "concepts/mappable.hpp" #include "concepts/comparable.hpp" #include "basics/aligned_read.hpp" @@ -31,11 +33,13 @@ class CigarScanner : public VariantGenerator struct ObservedVariant { Variant variant; - unsigned num_samples, total_depth; + unsigned total_depth; struct SampleObservation { + std::reference_wrapper sample; unsigned depth; - std::vector observed_qualities; + std::vector observed_base_qualities; + std::vector observed_mapping_qualities; unsigned num_fwd_observations; }; std::vector sample_observations; @@ -54,16 +58,12 @@ class CigarScanner : public VariantGenerator using InclusionPredicate = std::function; using MatchPredicate = std::function; - using RepeatRegionGenerator = std::function(const ReferenceGenome&, GenomicRegion)>; - using RepeatRegionInclusionPredicate = std::function&)>; InclusionPredicate include; MatchPredicate match = std::equal_to<> {}; bool use_clipped_coverage_tracking = false; Variant::MappingDomain::Size max_variant_size = 2000; MisalignmentParameters misalignment_parameters = MisalignmentParameters {}; - boost::optional repeat_region_generator = boost::none; - RepeatRegionInclusionPredicate include_repeat_region = [] (const auto& region, const auto& variants) { return variants.size() < 100; }; }; CigarScanner() = delete; @@ -77,31 +77,31 @@ class CigarScanner : public VariantGenerator ~CigarScanner() override = default; private: - using VariantGenerator::VectorIterator; - using VariantGenerator::FlatSetIterator; + using VariantGenerator::ReadVectorIterator; + using VariantGenerator::ReadFlatSetIterator; std::unique_ptr do_clone() const override; bool do_requires_reads() const noexcept override; void do_add_read(const SampleName& sample, const AlignedRead& read) override; void add_read(const SampleName& sample, const AlignedRead& read, CoverageTracker& sample_coverage_tracker); - void do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) override; - void do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) override; - std::vector do_generate_variants(const GenomicRegion& region) override; + void do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) override; + void do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) override; + std::vector do_generate(const RegionSet& regions) const override; void do_clear() noexcept override; std::string name() const override; struct Candidate : public Comparable, public Mappable { Variant variant; - SampleName origin; - AlignedRead::BaseQualityVector::const_iterator first_base_quality_iter; - AlignedRead::Direction support_direction; + std::reference_wrapper source; + std::reference_wrapper origin; + std::size_t offset; - template - Candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, T4&& origin, - AlignedRead::BaseQualityVector::const_iterator first_base_quality, - AlignedRead::Direction support_direction); + template + Candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, + const AlignedRead& source, std::size_t offset, + const SampleName& origin); const GenomicRegion& mapped_region() const noexcept { return variant.mapped_region(); } @@ -115,52 +115,46 @@ class CigarScanner : public VariantGenerator std::reference_wrapper reference_; Options options_; std::vector buffer_; - std::deque candidates_, likely_misaligned_candidates_; + mutable std::deque candidates_, likely_misaligned_candidates_; Variant::MappingDomain::Size max_seen_candidate_size_; CoverageTracker read_coverage_tracker_, misaligned_tracker_; std::unordered_map> sample_read_coverage_tracker_; using CandidateIterator = OverlapIterator; - template - void add_candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, T4&& origin, - AlignedRead::BaseQualityVector::const_iterator first_base_quality, - AlignedRead::Direction support_direction); - double add_snvs_in_match_range(const GenomicRegion& region, SequenceIterator first_base, SequenceIterator last_base, - const SampleName& origin, - AlignedRead::BaseQualityVector::const_iterator first_quality, - AlignedRead::Direction support_direction); + template + void add_candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, + const AlignedRead& read, std::size_t offset, const SampleName& sample); + double add_snvs_in_match_range(const GenomicRegion& region, const AlignedRead& read, + std::size_t read_index, const SampleName& origin); + void generate(const GenomicRegion& region, std::vector& result) const; unsigned sum_base_qualities(const Candidate& candidate) const noexcept; - std::vector get_repeat_regions(const GenomicRegion& region) const; bool is_likely_misaligned(const AlignedRead& read, double penalty) const; ObservedVariant make_observation(CandidateIterator first_match, CandidateIterator last_match) const; - std::vector get_novel_likely_misaligned_candidates(const std::vector& current_candidates); + std::vector get_novel_likely_misaligned_candidates(const std::vector& current_candidates) const; }; -template +template CigarScanner::Candidate::Candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, - T4&& origin, - AlignedRead::BaseQualityVector::const_iterator first_base_quality, - AlignedRead::Direction support_direction) + const AlignedRead& source, std::size_t offset, + const SampleName& origin) : variant {std::forward(region), std::forward(sequence_removed), std::forward(sequence_added)} -, origin {std::forward(origin)} -, first_base_quality_iter {first_base_quality} -, support_direction {support_direction} +, source {source} +, origin {origin} +, offset {offset} {} -template -void CigarScanner::add_candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, T4&& origin, - AlignedRead::BaseQualityVector::const_iterator first_base_quality, - AlignedRead::Direction support_direction) +template +void CigarScanner::add_candidate(T1&& region, T2&& sequence_removed, T3&& sequence_added, + const AlignedRead& read, const std::size_t offset, + const SampleName& sample) { const auto candidate_size = size(region); if (candidate_size <= options_.max_variant_size) { buffer_.emplace_back(std::forward(region), std::forward(sequence_removed), std::forward(sequence_added), - std::forward(origin), - first_base_quality, - support_direction); + read, offset, sample); max_seen_candidate_size_ = std::max(max_seen_candidate_size_, candidate_size); } } @@ -170,6 +164,15 @@ struct DefaultInclusionPredicate bool operator()(const CigarScanner::ObservedVariant& candidate); }; +struct DefaultSomaticInclusionPredicate +{ + DefaultSomaticInclusionPredicate() = default; + DefaultSomaticInclusionPredicate(SampleName normal) : normal_ {std::move(normal)} {} + bool operator()(const CigarScanner::ObservedVariant& candidate); +private: + boost::optional normal_; +}; + struct SimpleThresholdInclusionPredicate { SimpleThresholdInclusionPredicate(std::size_t min_observations) : min_observations_ {min_observations} {} diff --git a/src/core/tools/vargen/downloader.cpp b/src/core/tools/vargen/downloader.cpp index 37276ae7a..af2ffa471 100644 --- a/src/core/tools/vargen/downloader.cpp +++ b/src/core/tools/vargen/downloader.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "downloader.hpp" @@ -19,7 +19,7 @@ std::unique_ptr Downloader::do_clone() const return std::make_unique(*this); } -std::vector Downloader::do_generate_variants(const GenomicRegion& region) +std::vector Downloader::do_generate(const RegionSet& regions) const { //namespace http = boost::network::http; diff --git a/src/core/tools/vargen/downloader.hpp b/src/core/tools/vargen/downloader.hpp index 1e99071c4..82d4a43e4 100644 --- a/src/core/tools/vargen/downloader.hpp +++ b/src/core/tools/vargen/downloader.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef downloader_hpp @@ -39,14 +39,12 @@ class Downloader : public VariantGenerator ~Downloader() override = default; private: - std::unique_ptr do_clone() const override; - - std::vector do_generate_variants(const GenomicRegion& region) override; - - std::string name() const override; - std::reference_wrapper reference_; Options options_; + + std::unique_ptr do_clone() const override; + std::vector do_generate(const RegionSet& regions) const override; + std::string name() const override; }; } // namespace coretools diff --git a/src/core/tools/vargen/local_reassembler.cpp b/src/core/tools/vargen/local_reassembler.cpp index aa218020e..64cd68e8b 100644 --- a/src/core/tools/vargen/local_reassembler.cpp +++ b/src/core/tools/vargen/local_reassembler.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "local_reassembler.hpp" @@ -57,14 +57,11 @@ LocalReassembler::LocalReassembler(const ReferenceGenome& reference, Options opt , read_buffer_ {} , max_bin_size_ {options.bin_size} , max_bin_overlap_ {options.bin_overlap} -, bins_ {} -, masked_sequence_buffer_ {} , mask_threshold_ {options.mask_threshold} , min_kmer_observations_ {options.min_kmer_observations} , max_bubbles_ {options.max_bubbles} , min_bubble_score_ {options.min_bubble_score} , max_variant_size_ {options.max_variant_size} -, active_region_generator_ {reference} { if (max_bin_size_ == 0) { throw std::runtime_error {"bin size must be greater than zero"}; @@ -157,7 +154,7 @@ bool has_low_quality_match(const AlignedRead& read, const AlignedRead::BaseQuali [=] (auto q) { return q < good_quality; }); std::advance(quality_itr, op.size()); return result; - } else if (op.advances_sequence()) { + } else if (advances_sequence(op)) { std::advance(quality_itr, op.size()); } return false; @@ -175,16 +172,6 @@ auto find_first_sequence_op(const std::vector& cigar) noex [] (auto op) { return op == CigarOperation::Flag::hardClipped; }); } -bool is_match(const CigarOperation::Flag op) noexcept -{ - switch (op) { - case CigarOperation::Flag::alignmentMatch: - case CigarOperation::Flag::sequenceMatch: - case CigarOperation::Flag::substitution: return true; - default: return false; - } -} - template auto make_optional(bool b, T&& value) { @@ -282,89 +269,57 @@ auto overlapped_bins(Container& bins, const M& mappable) void LocalReassembler::do_add_read(const SampleName& sample, const AlignedRead& read) { - active_region_generator_.add(sample, read); read_buffer_[sample].insert(read); } -void LocalReassembler::do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) +void LocalReassembler::do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) { - active_region_generator_.add(sample, first, last); read_buffer_[sample].insert(first, last); } -void LocalReassembler::do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) +void LocalReassembler::do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) { - active_region_generator_.add(sample, first, last); read_buffer_[sample].insert(first, last); } -template -void remove_nonoverlapping(Container& candidates, const GenomicRegion& region) +void remove_nonoverlapping(std::vector& candidates, std::vector& active_regions) { const auto it = std::remove_if(std::begin(candidates), std::end(candidates), - [®ion] (const Variant& candidate) { - return !overlaps(candidate, region); + [&] (const Variant& candidate) { + return !has_overlapped(active_regions, candidate); }); candidates.erase(it, std::end(candidates)); } -auto extract_unique(std::deque&& variants) +void remove_duplicates(std::deque& variants) { - using std::make_move_iterator; - std::vector result {make_move_iterator(std::begin(variants)), make_move_iterator(std::end(variants))}; - std::sort(std::begin(result), std::end(result)); - result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); - return result; + std::sort(std::begin(variants), std::end(variants)); + variants.erase(std::unique(std::begin(variants), std::end(variants)), std::end(variants)); } -void remove_oversized(std::vector& variants, const Variant::MappingDomain::Size max_size) +void remove_larger_than(std::deque& variants, const Variant::MappingDomain::Size max_size) { variants.erase(std::remove_if(std::begin(variants), std::end(variants), - [max_size] (const auto& variant) { - return region_size(variant) > max_size; - }), + [max_size] (const auto& variant) { return region_size(variant) > max_size; }), std::end(variants)); } -auto extract_final(std::deque&& variants, const GenomicRegion& extract_region, - const Variant::MappingDomain::Size max_size) +std::vector LocalReassembler::do_generate(const RegionSet& regions) const { - auto result = extract_unique(std::move(variants)); - remove_oversized(result, max_size); - remove_nonoverlapping(result, extract_region); // as we expanded original region - return result; -} - -namespace debug { - -template -void log_active_regions(const Range& regions, boost::optional& log) -{ - if (log) { - auto log_stream = stream(*log); - log_stream << "Assembler active regions are: "; - for (const auto& region : regions) log_stream << region << ' '; - } -} - -} // namespace debug - -std::vector LocalReassembler::do_generate_variants(const GenomicRegion& region) -{ - const auto active_regions = active_region_generator_.generate(region); - debug::log_active_regions(active_regions, debug_log_); - for (const auto& active_region : active_regions) { - prepare_bins(active_region); + BinList bins {}; + SequenceBuffer masked_sequence_buffer {}; + for (const auto& region : regions) { + prepare_bins(region, bins); for (const auto& p : read_buffer_) { - for (const auto& read : overlap_range(p.second, active_region)) { - auto active_bins = overlapped_bins(bins_, read); + for (const auto& read : overlap_range(p.second, region)) { + auto active_bins = overlapped_bins(bins, read); assert(!active_bins.empty()); if (requires_masking(read, mask_threshold_)) { auto masked_sequence = mask(read, mask_threshold_, reference_); if (masked_sequence) { - masked_sequence_buffer_.emplace_back(std::move(*masked_sequence)); + masked_sequence_buffer.emplace_back(std::move(*masked_sequence)); for (auto& bin : active_bins) { - bin.add(mapped_region(read), std::cref(masked_sequence_buffer_.back())); + bin.add(mapped_region(read), std::cref(masked_sequence_buffer.back())); } } } else { @@ -373,14 +328,11 @@ std::vector LocalReassembler::do_generate_variants(const GenomicRegion& } } } - read_buffer_.clear(); - finalise_bins(); - if (bins_.empty()) return {}; - const auto active_bins = overlapped_bins(bins_, region); - const auto num_bins = size(active_bins); + finalise_bins(bins, regions); + if (bins.empty()) return {}; std::deque candidates {}; - if (execution_policy_ == ExecutionPolicy::seq || num_bins < 2) { - for (auto& bin : active_bins) { + if (execution_policy_ == ExecutionPolicy::seq || bins.size() < 2) { + for (auto& bin : bins) { if (debug_log_) { stream(*debug_log_) << "Assembling " << bin.read_sequences.size() << " reads in bin " << mapped_region(bin); @@ -393,8 +345,8 @@ std::vector LocalReassembler::do_generate_variants(const GenomicRegion& } } else { const std::size_t num_threads {4}; - std::vector>> bin_futures(std::min(num_bins, num_threads)); - for (auto first_bin = std::begin(active_bins), last_bin = std::end(active_bins); first_bin != last_bin; ) { + std::vector>> bin_futures(std::min(bins.size(), num_threads)); + for (auto first_bin = std::begin(bins), last_bin = std::end(bins); first_bin != last_bin; ) { const auto batch_size = std::min(num_threads, static_cast(std::distance(first_bin, last_bin))); const auto next_bin = std::next(first_bin, batch_size); auto last_future = std::transform(first_bin, next_bin, std::begin(bin_futures), [&] (Bin& bin) { @@ -416,19 +368,14 @@ std::vector LocalReassembler::do_generate_variants(const GenomicRegion& first_bin = next_bin; } } - bins_.clear(); - bins_.shrink_to_fit(); - return extract_final(std::move(candidates), region, max_variant_size_); + remove_duplicates(candidates); + remove_larger_than(candidates, max_variant_size_); + return {std::make_move_iterator(std::begin(candidates)), std::make_move_iterator(std::end(candidates))}; } void LocalReassembler::do_clear() noexcept { read_buffer_.clear(); - masked_sequence_buffer_.clear(); - masked_sequence_buffer_.shrink_to_fit(); - bins_.clear(); - bins_.shrink_to_fit(); - active_region_generator_.clear(); } std::string LocalReassembler::name() const @@ -438,20 +385,42 @@ std::string LocalReassembler::name() const // private methods -void LocalReassembler::prepare_bins(const GenomicRegion& region) +template +auto decompose(const MappableTp& mappable, const GenomicRegion::Position n, + const GenomicRegion::Size overlap = 0) +{ + if (overlap >= n) { + throw std::runtime_error {"decompose: overlap must be less than n"}; + } + std::vector result {}; + if (n == 0) return result; + const auto num_elements = region_size(mappable) / (n - overlap); + if (num_elements == 0) return result; + result.reserve(num_elements); + const auto& contig = contig_name(mappable); + auto curr = mapped_begin(mappable); + std::generate_n(std::back_inserter(result), num_elements, [&contig, &curr, n, overlap] () { + auto tmp = curr; + curr += (n - overlap); + return GenomicRegion {contig, tmp, tmp + n}; + }); + return result; +} + +void LocalReassembler::prepare_bins(const GenomicRegion& region, BinList& bins) const { - assert(bins_.empty() || is_after(region, bins_.back())); + assert(bins.empty() || is_after(region, bins.back())); if (size(region) > max_bin_size_) { auto bin_region = expand_rhs(head_region(region), max_bin_size_); while (ends_before(bin_region, region)) { - bins_.push_back(bin_region); + bins.push_back(bin_region); bin_region = shift(bin_region, max_bin_overlap_); } if (overlap_size(region, bin_region) > 0) { - bins_.push_back(*overlapped_region(region, bin_region)); + bins.push_back(*overlapped_region(region, bin_region)); } } else { - bins_.push_back(region); + bins.push_back(region); } } @@ -460,23 +429,22 @@ bool LocalReassembler::should_assemble_bin(const Bin& bin) const return !bin.empty(); } -void LocalReassembler::finalise_bins() +void LocalReassembler::finalise_bins(BinList& bins, const RegionSet& active_regions) const { - auto itr = std::remove_if(std::begin(bins_), std::end(bins_), - [this] (const Bin& bin) { return !should_assemble_bin(bin); }); - bins_.erase(itr, std::end(bins_)); - for (auto& bin : bins_) { + bins.erase(std::remove_if(std::begin(bins), std::end(bins), + [this] (const Bin& bin) { return !should_assemble_bin(bin); }), + std::end(bins)); + for (auto& bin : bins) { if (bin.read_region) { bin.region = GenomicRegion {bin.region.contig_name(), *bin.read_region}; } } // unique in reverse order as we want to keep bigger bins, which // are sorted after smaller bins with the same starting point - itr = std::unique(std::rbegin(bins_), std::rend(bins_), - [] (const Bin& lhs, const Bin& rhs) noexcept { - return begins_equal(lhs, rhs); - }).base(); - bins_.erase(std::begin(bins_), itr); + bins.erase(std::begin(bins), std::unique(std::rbegin(bins), std::rend(bins), + [] (const Bin& lhs, const Bin& rhs) { + return begins_equal(lhs, rhs); + }).base()); } namespace { @@ -531,7 +499,9 @@ void LocalReassembler::try_assemble_with_fallbacks(const Bin& bin, std::deque 5) { - assemble_bin((prev_k + k) / 2, bin, result); + const auto gap = k - prev_k; + assemble_bin(k - gap / 2, bin, result); + assemble_bin(k + gap / 2, bin, result); } return; case AssemblerStatus::partial_success: @@ -662,11 +632,16 @@ struct Repeat : public Mappable { ContigRegion region; unsigned period; + Assembler::NucleotideSequence::const_iterator begin_itr, end_itr; const auto& mapped_region() const noexcept { return region; } + auto begin() const noexcept { return begin_itr; } + auto end() const noexcept { return end_itr; } Repeat() = default; - Repeat(const tandem::Repeat& repeat) noexcept + Repeat(const tandem::Repeat& repeat, const Assembler::NucleotideSequence& sequence) noexcept : region {repeat.pos, repeat.pos + repeat.length} , period {repeat.period} + , begin_itr {std::next(std::cbegin(sequence), repeat.pos)} + , end_itr {std::next(begin_itr, repeat.length)} {} }; @@ -677,7 +652,7 @@ auto find_repeats(Assembler::NucleotideSequence& sequence, const unsigned max_pe sequence.pop_back(); std::vector result(repeats.size()); std::transform(std::cbegin(repeats), std::cend(repeats), std::begin(result), - [] (auto repeat) { return Repeat {repeat}; }); + [&] (auto repeat) { return Repeat {repeat, sequence}; }); std::sort(std::begin(result), std::end(result)); return result; } @@ -700,8 +675,18 @@ struct VariantReference : public Mappable bool matches_rhs(const Repeat& repeat, const Assembler::NucleotideSequence& sequence) noexcept { - if (sequence.size() < 2 * repeat.period) return false; - return utils::is_tandem_repeat(sequence, repeat.period); + if (sequence.size() < repeat.period) return false; + if (sequence.size() == repeat.period) { + return std::equal(std::cbegin(sequence), std::cend(sequence), std::cbegin(repeat)); + } else if (utils::is_tandem_repeat(sequence, repeat.period)) { + assert(std::distance(std::cbegin(repeat), std::cend(repeat)) >= 2 * repeat.period); + const auto repeat_match_end_itr = std::next(std::cbegin(repeat), 2 * repeat.period); + auto match_itr = std::search(std::cbegin(repeat), repeat_match_end_itr, + std::cbegin(sequence), std::next(std::cbegin(sequence), repeat.period)); + return match_itr != repeat_match_end_itr; + } else { + return false; + } } template @@ -757,13 +742,13 @@ std::vector try_to_split_repeats(Assembler::Variant& v, cons complete_partial_ref_repeat(v, ref_repeat); } else { auto alt_repeat_ritr = std::make_reverse_iterator(ref_repeat_itr); - auto alt_repeat_match_ritr = alt_repeat_ritr; + auto alt_repeat_match_ritr = std::crend(ref_repeats); for (; alt_repeat_ritr != std::crend(ref_repeats); ++alt_repeat_ritr) { if (is_before(*alt_repeat_ritr, ref_repeat)) break; if (matches_rhs(*alt_repeat_ritr, v.alt)) alt_repeat_match_ritr = alt_repeat_ritr; } if (alt_repeat_match_ritr == std::crend(ref_repeats)) return {}; - if (v.alt.size() < 2 * alt_repeat_match_ritr->period) return {}; + if (v.alt.size() < alt_repeat_match_ritr->period) return {}; complete_partial_alt_repeat(v, *alt_repeat_match_ritr); } Assembler::Variant deletion {v.begin_pos, std::move(v.ref), ""}, insertion {v.begin_pos, "", std::move(v.alt)}; diff --git a/src/core/tools/vargen/local_reassembler.hpp b/src/core/tools/vargen/local_reassembler.hpp index d2d9bb603..3f1141978 100644 --- a/src/core/tools/vargen/local_reassembler.hpp +++ b/src/core/tools/vargen/local_reassembler.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef local_reassembler_hpp @@ -22,8 +22,6 @@ #include "variant_generator.hpp" #include "utils/assembler.hpp" -#include "utils/assembler_active_region_generator.hpp" - namespace octopus { class ReferenceGenome; @@ -60,27 +58,28 @@ class LocalReassembler : public VariantGenerator ~LocalReassembler() override = default; private: - using VariantGenerator::VectorIterator; - using VariantGenerator::FlatSetIterator; + using VariantGenerator::ReadVectorIterator; + using VariantGenerator::ReadFlatSetIterator; std::unique_ptr do_clone() const override; bool do_requires_reads() const noexcept override; void do_add_read(const SampleName& sample, const AlignedRead& read) override; - void do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) override; - void do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) override; + void do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) override; + void do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) override; - std::vector do_generate_variants(const GenomicRegion& region) override; + std::vector do_generate(const RegionSet& regions) const override; void do_clear() noexcept override; std::string name() const override; using NucleotideSequence = AlignedRead::NucleotideSequence; - - using ReadBuffer = MappableFlatMultiSet>; - using ReadBufferMap = std::map; + using SequenceBuffer = std::deque; + using ReadReference = MappableReferenceWrapper; + using ReadBuffer = MappableFlatMultiSet; + using ReadBufferMap = std::map; struct Bin : public Mappable { @@ -99,31 +98,24 @@ class LocalReassembler : public VariantGenerator std::deque> read_sequences; }; + using BinList = std::deque; + enum class AssemblerStatus { success, partial_success, failed }; ExecutionPolicy execution_policy_; - std::reference_wrapper reference_; - std::vector default_kmer_sizes_, fallback_kmer_sizes_; - ReadBufferMap read_buffer_; - GenomicRegion::Size max_bin_size_, max_bin_overlap_; - std::deque bins_; - std::deque masked_sequence_buffer_; - AlignedRead::BaseQuality mask_threshold_; unsigned min_kmer_observations_; unsigned max_bubbles_; double min_bubble_score_; Variant::MappingDomain::Size max_variant_size_; - AssemblerActiveRegionGenerator active_region_generator_; - - void prepare_bins(const GenomicRegion& active_region); + void prepare_bins(const GenomicRegion& active_region, BinList& bins) const; bool should_assemble_bin(const Bin& bin) const; - void finalise_bins(); + void finalise_bins(BinList& bins, const RegionSet& active_regions) const; unsigned try_assemble_with_defaults(const Bin& bin, std::deque& result) const; void try_assemble_with_fallbacks(const Bin& bin, std::deque& result) const; GenomicRegion propose_assembler_region(const GenomicRegion& input_region, unsigned kmer_size) const; diff --git a/src/core/tools/vargen/randomiser.cpp b/src/core/tools/vargen/randomiser.cpp index 37df43816..9367c8642 100644 --- a/src/core/tools/vargen/randomiser.cpp +++ b/src/core/tools/vargen/randomiser.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "randomiser.hpp" @@ -24,44 +24,35 @@ std::unique_ptr Randomiser::do_clone() const return std::make_unique(*this); } -void Randomiser::do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) +void Randomiser::do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) { max_read_size_ = region_size(*largest_mappable(first, last)); } -void Randomiser::do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) +void Randomiser::do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) { max_read_size_ = region_size(*largest_mappable(first, last)); } -std::vector Randomiser::do_generate_variants(const GenomicRegion& region) +std::vector Randomiser::do_generate(const RegionSet& regions) const { - auto num_positions = region_size(region); - std::vector result {}; - - if (num_positions == 0) return result; - - static const auto seed = std::chrono::system_clock::now().time_since_epoch().count(); - - static std::default_random_engine generator {static_cast(seed)}; - - using T = Variant::MappingDomain::Size; - - std::uniform_int_distribution uniform {0, std::min(num_positions, max_read_size_)}; - - auto positions = decompose(region); - - for (auto p = uniform(generator); p < num_positions; p += max_read_size_) { - auto position = positions[p]; - - auto reference_allele = make_reference_allele(position, reference_); - - Allele mutation {position, utils::reverse_complement_copy(reference_allele.sequence())}; - - result.emplace_back(reference_allele, mutation); + for (const auto& region : regions) { + auto num_positions = region_size(region); + std::vector result {}; + if (num_positions == 0) return result; + static const auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + static std::default_random_engine generator {static_cast(seed)}; + using T = Variant::MappingDomain::Size; + std::uniform_int_distribution uniform {0, std::min(num_positions, max_read_size_)}; + auto positions = decompose(region); + for (auto p = uniform(generator); p < num_positions; p += max_read_size_) { + auto position = positions[p]; + auto reference_allele = make_reference_allele(position, reference_); + Allele mutation {position, utils::reverse_complement_copy(reference_allele.sequence())}; + result.emplace_back(reference_allele, mutation); + } } - return result; } diff --git a/src/core/tools/vargen/randomiser.hpp b/src/core/tools/vargen/randomiser.hpp index bb6b6cabf..b70a5ddff 100644 --- a/src/core/tools/vargen/randomiser.hpp +++ b/src/core/tools/vargen/randomiser.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef randomiser_hpp @@ -40,16 +40,13 @@ class Randomiser : public VariantGenerator ~Randomiser() override = default; private: - using VariantGenerator::VectorIterator; - using VariantGenerator::FlatSetIterator; + using VariantGenerator::ReadVectorIterator; + using VariantGenerator::ReadFlatSetIterator; std::unique_ptr do_clone() const override; - - void do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) override; - void do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) override; - - std::vector do_generate_variants(const GenomicRegion& region) override; - + void do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) override; + void do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) override; + std::vector do_generate(const RegionSet& regions) const override; std::string name() const override; std::reference_wrapper reference_; diff --git a/src/core/tools/vargen/utils/assembler.cpp b/src/core/tools/vargen/utils/assembler.cpp index 3445c8a9a..c4db89b4c 100644 --- a/src/core/tools/vargen/utils/assembler.cpp +++ b/src/core/tools/vargen/utils/assembler.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "assembler.hpp" @@ -960,10 +960,10 @@ void Assembler::remove_all_nonreference_cycles(const bool break_chains) bad_kmers.insert(cycle_origin); cycle_origin = *boost::inv_adjacent_vertices(cycle_origin, graph_).first; } - bool refererence_origin {false}; + bool is_reference_origin {false}; if (is_reference(cycle_origin)) { reference_origins.insert(cycle_origin); - refererence_origin = true; + is_reference_origin = true; } else { bad_kmers.insert(cycle_origin); } @@ -974,7 +974,7 @@ void Assembler::remove_all_nonreference_cycles(const bool break_chains) } if (is_reference(cycle_sink)) { reference_sinks.insert(cycle_sink); - if (refererence_origin) { + if (is_reference_origin) { cyclic_reference_segments.emplace_back(cycle_sink, cycle_origin); } else if (boost::out_degree(cycle_origin, graph_) > 1) { const auto p = boost::out_edges(cycle_origin, graph_); @@ -994,21 +994,24 @@ void Assembler::remove_all_nonreference_cycles(const bool break_chains) } }); if (!reference_tails.empty()) { + Vertex cycle_tail; if (reference_tails.size() == 1) { - const auto& cycle_tail = reference_tails.front(); - Edge e; bool present; - std::tie(e, present) = boost::edge(cycle_origin, cycle_tail, graph_); - if (!present) { - cyclic_reference_segments.emplace_back(cycle_sink, cycle_tail); - } else { - cyclic_reference_segments.emplace_back(cycle_tail, cycle_sink); - } + cycle_tail = reference_tails.front(); } else { // Just add the rightmost reference vertex auto itr = std::find_first_of(std::crbegin(reference_vertices_), std::crend(reference_vertices_), std::cbegin(reference_tails), std::cend(reference_tails)); assert(itr != std::crend(reference_vertices_)); - cyclic_reference_segments.emplace_back(cycle_sink, *itr); + cycle_tail = *itr; + } + const std::array cycle_vertices {cycle_sink, cycle_tail}; + auto itr = std::find_first_of(std::cbegin(reference_vertices_), std::cend(reference_vertices_), + std::cbegin(cycle_vertices), std::cend(cycle_vertices)); + assert(itr != std::cend(reference_vertices_)); + if (*itr == cycle_vertices.front()) { + cyclic_reference_segments.emplace_back(cycle_sink, cycle_tail); + } else { + cyclic_reference_segments.emplace_back(cycle_tail, cycle_sink); } } } diff --git a/src/core/tools/vargen/utils/assembler.hpp b/src/core/tools/vargen/utils/assembler.hpp index f629166f4..dcdd1f672 100644 --- a/src/core/tools/vargen/utils/assembler.hpp +++ b/src/core/tools/vargen/utils/assembler.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef assembler_hpp diff --git a/src/core/tools/vargen/utils/assembler_active_region_generator.cpp b/src/core/tools/vargen/utils/assembler_active_region_generator.cpp index d84946022..bf7879af6 100644 --- a/src/core/tools/vargen/utils/assembler_active_region_generator.cpp +++ b/src/core/tools/vargen/utils/assembler_active_region_generator.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "assembler_active_region_generator.hpp" @@ -167,20 +167,6 @@ auto compute_base_deletion_probabilities(const std::vector& coverages, return result; } -auto get_regions(const std::vector& good_bases, const GenomicRegion& region) -{ - std::vector result {}; - auto itr = std::find(std::cbegin(good_bases), std::cend(good_bases), true); - for (; itr != std::cend(good_bases);) { - const auto itr2 = std::find(itr, std::cend(good_bases), false); - const auto begin = region.begin() + std::distance(std::cbegin(good_bases), itr); - const auto end = begin + std::distance(itr, itr2); - result.emplace_back(region.contig_name(), begin, end); - itr = std::find(itr2, std::cend(good_bases), true); - } - return result; -} - template auto expand_each(const Container& regions, const GenomicRegion::Distance n) { @@ -193,14 +179,14 @@ auto expand_each(const Container& regions, const GenomicRegion::Distance n) auto get_deletion_hotspots(const GenomicRegion& region, const CoverageTracker& tracker) { - const auto coverages = tracker.coverage(region); - const auto mean_coverage = tracker.mean_coverage(region); - const auto stdev_coverage = tracker.stdev_coverage(region); + const auto coverages = tracker.get(region); + const auto mean_coverage = tracker.mean(region); + const auto stdev_coverage = tracker.stdev(region); const auto deletion_base_probs = compute_base_deletion_probabilities(coverages, mean_coverage, stdev_coverage); std::vector deletion_bases(deletion_base_probs.size()); std::transform(std::cbegin(deletion_base_probs), std::cend(deletion_base_probs), std::begin(deletion_bases), [] (const auto p) { return p > 0.5; }); - return extract_covered_regions(expand_each(get_regions(deletion_bases, region), 50)); + return extract_covered_regions(expand_each(select_regions(region, deletion_bases), 50)); } auto get_interesting_hotspots(const GenomicRegion& region, @@ -219,15 +205,15 @@ auto get_interesting_hotspots(const GenomicRegion& region, return 10 * interesting_coverage >= coverage; } }); - return get_regions(interesting_bases, region); + return select_regions(region, interesting_bases); } auto get_interesting_hotspots(const GenomicRegion& region, const CoverageTracker& interesting_read_tracker, const CoverageTracker& tracker) { - const auto interesting_coverages = interesting_read_tracker.coverage(region); - const auto coverages = tracker.coverage(region); + const auto interesting_coverages = interesting_read_tracker.get(region); + const auto coverages = tracker.get(region); return get_interesting_hotspots(region, interesting_coverages, coverages); } @@ -255,8 +241,8 @@ get_interesting_hotspots(const GenomicRegion& region, std::vector best_sample_interesting_coverage(n), best_sample_coverage(n); for (const auto& p : interesting_read_tracker) { assert(tracker.count(p.first) == 1); - const auto sample_coverage = tracker.at(p.first).coverage(region); - const auto sample_interesting_coverage = p.second.coverage(region); + const auto sample_coverage = tracker.at(p.first).get(region); + const auto sample_interesting_coverage = p.second.get(region); assert(sample_coverage.size() == n); assert(sample_interesting_coverage.size() == n); for (std::size_t i {0}; i < sample_coverage.size(); ++i) { diff --git a/src/core/tools/vargen/utils/assembler_active_region_generator.hpp b/src/core/tools/vargen/utils/assembler_active_region_generator.hpp index 3e68206c6..f33c504e1 100644 --- a/src/core/tools/vargen/utils/assembler_active_region_generator.hpp +++ b/src/core/tools/vargen/utils/assembler_active_region_generator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef assembler_active_region_generator_hpp diff --git a/src/core/tools/vargen/utils/global_aligner.cpp b/src/core/tools/vargen/utils/global_aligner.cpp index 120e758b0..9f600fc75 100644 --- a/src/core/tools/vargen/utils/global_aligner.cpp +++ b/src/core/tools/vargen/utils/global_aligner.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "global_aligner.hpp" diff --git a/src/core/tools/vargen/utils/global_aligner.hpp b/src/core/tools/vargen/utils/global_aligner.hpp index a3bc8a1f9..dd7730f23 100644 --- a/src/core/tools/vargen/utils/global_aligner.hpp +++ b/src/core/tools/vargen/utils/global_aligner.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef global_aligner_hpp diff --git a/src/core/tools/vargen/utils/misaligned_reads_detector.cpp b/src/core/tools/vargen/utils/misaligned_reads_detector.cpp new file mode 100644 index 000000000..a0e48157c --- /dev/null +++ b/src/core/tools/vargen/utils/misaligned_reads_detector.cpp @@ -0,0 +1,181 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "misaligned_reads_detector.hpp" + +#include +#include +#include + +#include +#include + +#include "basics/cigar_string.hpp" +#include "utils/mappable_algorithms.hpp" +#include "utils/maths.hpp" +#include "utils/append.hpp" + +namespace octopus { namespace coretools { + +MisalignedReadsDetector::MisalignedReadsDetector(const ReferenceGenome& reference) : reference_ {reference} {} + +MisalignedReadsDetector::MisalignedReadsDetector(const ReferenceGenome& reference, Options options) +: reference_ {reference} +, options_ {options} +{} + +void MisalignedReadsDetector::add(const SampleName& sample, const AlignedRead& read) +{ + coverage_tracker_[sample].add(read); + if (is_likely_misaligned(read)) { + likely_misaligned_coverage_tracker_[sample].add(read); + } +} + +std::vector MisalignedReadsDetector::generate(const GenomicRegion& region) const +{ + if (likely_misaligned_coverage_tracker_.empty()) return {}; + if (coverage_tracker_.size() == 1) { + const auto& sample = std::cbegin(coverage_tracker_)->first; + const auto total_coverages = coverage_tracker_.at(sample).get(region); + const auto misaligned_coverages = likely_misaligned_coverage_tracker_.at(sample).get(region); + assert(total_coverages.size() == misaligned_coverages.size()); + std::vector likely_misaligned_base_mask(total_coverages.size()); + std::transform(std::cbegin(total_coverages), std::cend(total_coverages), std::cbegin(misaligned_coverages), + std::begin(likely_misaligned_base_mask), + [] (auto depth, auto misaligned_depth) -> bool { + return misaligned_depth > depth / 2; + }); + return join(select_regions(region, likely_misaligned_base_mask), 30); + } + return {}; +} + +void MisalignedReadsDetector::clear() noexcept +{ + coverage_tracker_.clear(); + likely_misaligned_coverage_tracker_.clear(); +} + +namespace { + +using NucleotideSequenceIterator = AlignedRead::NucleotideSequence::const_iterator; +using BaseQualityVectorIterator = AlignedRead::BaseQualityVector::const_iterator; + +bool count_snvs_in_match_range(const NucleotideSequenceIterator first_ref, const NucleotideSequenceIterator last_ref, + const NucleotideSequenceIterator first_base, const BaseQualityVectorIterator first_quality, + const AlignedRead::BaseQuality trigger) +{ + using boost::make_zip_iterator; + using Tuple = boost::tuple; + const auto num_bases = std::distance(first_ref, last_ref); + const auto last_base = std::next(first_base, num_bases); + const auto last_quality = std::next(first_quality, num_bases); + return std::count_if(make_zip_iterator(boost::make_tuple(first_ref, first_base, first_quality)), + make_zip_iterator(boost::make_tuple(last_ref, last_base, last_quality)), + [trigger](const Tuple& t) { + const char ref_base{t.get<0>()}, read_base{t.get<1>()}; + return ref_base != read_base && ref_base != 'N' && read_base != 'N' && t.get<2>() >= trigger; + }); +} + +double ln_probability_read_correctly_aligned(const double misalign_penalty, const AlignedRead& read, + const double max_expected_mutation_rate) +{ + const auto k = static_cast(std::floor(misalign_penalty)); + if (k == 0) { + return 0; + } else { + const auto ln_prob_missmapped = -maths::constants::ln10Div10<> * read.mapping_quality(); + const auto ln_prob_mapped = std::log(1.0 - std::exp(ln_prob_missmapped)); + const auto mu = max_expected_mutation_rate * region_size(read); + auto ln_prob_given_mapped = maths::log_poisson_sf(k, mu); + return ln_prob_mapped + ln_prob_given_mapped; + } +} + +} // namespace + +bool MisalignedReadsDetector::is_likely_misaligned(const AlignedRead& read) const +{ + using std::cbegin; using std::next; using std::move; + using Flag = CigarOperation::Flag; + const auto& read_sequence = read.sequence(); + auto sequence_itr = cbegin(read_sequence); + auto base_quality_itr = cbegin(read.base_qualities()); + auto ref_index = mapped_begin(read); + std::size_t read_index {0}; + double misalignment_penalty {0}; + for (const auto& cigar_operation : read.cigar()) { + const auto op_size = cigar_operation.size(); + switch (cigar_operation.flag()) { + case Flag::alignmentMatch: + { + const GenomicRegion region {contig_name(read), ref_index, ref_index + op_size}; + const auto ref_segment = reference_.get().fetch_sequence(region); + auto num_snvs = count_snvs_in_match_range(std::cbegin(ref_segment), std::cend(ref_segment), + next(sequence_itr, read_index), + next(base_quality_itr, read_index), + options_.snv_threshold); + misalignment_penalty += num_snvs * options_.snv_penalty; + read_index += op_size; + ref_index += op_size; + break; + } + case Flag::sequenceMatch: + read_index += op_size; + ref_index += op_size; + break; + case Flag::substitution: + { + auto num_snvs = std::count_if(next(base_quality_itr, read_index), next(base_quality_itr, read_index + op_size), + [this] (const AlignedRead::BaseQuality quality) { return quality >= options_.snv_threshold; }); + misalignment_penalty += num_snvs * options_.snv_penalty; + read_index += op_size; + ref_index += op_size; + break; + } + case Flag::insertion: + { + read_index += op_size; + misalignment_penalty += options_.indel_penalty; + break; + } + case Flag::deletion: + { + ref_index += op_size; + misalignment_penalty += options_.indel_penalty; + break; + } + case Flag::softClipped: + { + if (op_size > options_.max_unpenalised_clip_size) { + misalignment_penalty += options_.clip_penalty; + } + read_index += op_size; + ref_index += op_size; + break; + } + case Flag::hardClipped: + { + if (op_size > options_.max_unpenalised_clip_size) { + misalignment_penalty += options_.clip_penalty; + } + break; + } + case Flag::padding: + ref_index += op_size; + break; + case Flag::skipped: + ref_index += op_size; + break; + } + } + auto mu = options_.max_expected_mutation_rate; + auto ln_prob_misaligned = ln_probability_read_correctly_aligned(misalignment_penalty, read, mu); + auto min_ln_prob_misaligned = options_.min_ln_prob_correctly_aligned; + return ln_prob_misaligned < min_ln_prob_misaligned; +} + +} // namespace coretools +} // namespace octopus diff --git a/src/core/tools/vargen/utils/misaligned_reads_detector.hpp b/src/core/tools/vargen/utils/misaligned_reads_detector.hpp new file mode 100644 index 000000000..196a2354d --- /dev/null +++ b/src/core/tools/vargen/utils/misaligned_reads_detector.hpp @@ -0,0 +1,78 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef misaligned_reads_detector_hpp +#define misaligned_reads_detector_hpp + +#include +#include +#include +#include + +#include "config/common.hpp" +#include "basics/genomic_region.hpp" +#include "basics/aligned_read.hpp" +#include "utils/coverage_tracker.hpp" +#include "io/reference/reference_genome.hpp" + +namespace octopus { namespace coretools { + +class MisalignedReadsDetector +{ +public: + struct Options + { + AlignedRead::BaseQuality snv_threshold; + double snv_penalty = 1, indel_penalty = 1, clip_penalty = 1; + double max_expected_mutation_rate = 1e-3; + double min_ln_prob_correctly_aligned = std::log(0.0001); + unsigned max_unpenalised_clip_size = 3; + }; + + MisalignedReadsDetector() = delete; + + MisalignedReadsDetector(const ReferenceGenome& reference); + MisalignedReadsDetector(const ReferenceGenome& reference, Options options); + + MisalignedReadsDetector(const MisalignedReadsDetector&) = default; + MisalignedReadsDetector& operator=(const MisalignedReadsDetector&) = default; + MisalignedReadsDetector(MisalignedReadsDetector&&) = default; + MisalignedReadsDetector& operator=(MisalignedReadsDetector&&) = default; + + ~MisalignedReadsDetector() = default; + + void add(const SampleName& sample, const AlignedRead& read); + template + void add(const SampleName& sample, ForwardIterator first_read, ForwardIterator last_read); + + std::vector generate(const GenomicRegion& region) const; + + void clear() noexcept; + +private: + using CoverageTrackerMap = std::unordered_map>; + + std::reference_wrapper reference_; + Options options_; + CoverageTrackerMap coverage_tracker_, likely_misaligned_coverage_tracker_; + + bool is_likely_misaligned(const AlignedRead& read) const; +}; + +template +void MisalignedReadsDetector::add(const SampleName& sample, ForwardIterator first_read, ForwardIterator last_read) +{ + auto& coverage_tracker = coverage_tracker_[sample]; + auto& likely_misaligned_coverage_tracker = likely_misaligned_coverage_tracker_[sample]; + std::for_each(first_read, last_read, [&] (const AlignedRead& read) { + coverage_tracker.add(read); + if (is_likely_misaligned(read)) { + likely_misaligned_coverage_tracker.add(read); + } + }); +} + +} // namespace coretools +} // namespace octopus + +#endif diff --git a/src/core/tools/vargen/variant_generator.cpp b/src/core/tools/vargen/variant_generator.cpp index 295a4153b..81af2e67d 100644 --- a/src/core/tools/vargen/variant_generator.cpp +++ b/src/core/tools/vargen/variant_generator.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "variant_generator.hpp" @@ -6,38 +6,46 @@ #include #include -#include "config/common.hpp" -#include "basics/genomic_region.hpp" +#include "utils/append.hpp" namespace octopus { namespace coretools { VariantGenerator::VariantGenerator() : debug_log_ {logging::get_debug_log()} , trace_log_ {logging::get_trace_log()} +, active_region_generator_ {} +{} + +VariantGenerator::VariantGenerator(ActiveRegionGenerator region_generator) +: debug_log_ {logging::get_debug_log()} +, trace_log_ {logging::get_trace_log()} +, active_region_generator_ {std::move(region_generator)} {} VariantGenerator::VariantGenerator(const VariantGenerator& other) { - generators_.reserve(other.generators_.size()); - for (const auto& generator : other.generators_) { - generators_.push_back(generator->clone()); + variant_generators_.reserve(other.variant_generators_.size()); + for (const auto& generator : other.variant_generators_) { + variant_generators_.push_back(generator->clone()); } + active_region_generator_ = other.active_region_generator_; } VariantGenerator& VariantGenerator::operator=(VariantGenerator other) { - std::swap(generators_, other.generators_); + std::swap(variant_generators_, other.variant_generators_); return *this; } void VariantGenerator::add(std::unique_ptr generator) { - generators_.push_back(std::move(generator)); + if (active_region_generator_) active_region_generator_->add_generator(generator->name()); + variant_generators_.push_back(std::move(generator)); } unsigned VariantGenerator::num_generators() const noexcept { - return static_cast(generators_.size()); + return static_cast(variant_generators_.size()); } std::unique_ptr VariantGenerator::clone() const @@ -46,34 +54,46 @@ std::unique_ptr VariantGenerator::clone() const } namespace debug { - template - void print_generated_candidates(S&& stream, const Container& candidates, - const std::string& generator_name) - { + +void log_active_regions(const std::vector& regions, const std::string& generator, + boost::optional& log) +{ + if (log) { + auto log_stream = stream(*log); + log_stream << generator << " active regions: "; + for (const auto& region : regions) log_stream << region << ' '; + } +} + +void log_candidates(const std::vector& candidates, const std::string& generator, + boost::optional& log) +{ + if (log) { + auto log_stream = stream(*log); if (candidates.empty()) { - stream << "No candidates generated from " << generator_name << '\n'; + log_stream << "No candidates generated from " << generator << '\n'; } else { - stream << "Generated " << candidates.size(); - stream << " candidate"; - if (candidates.size() > 1) stream << "s"; - stream << " from " << generator_name << ":\n"; - for (const auto& c : candidates) stream << c << '\n'; + log_stream << "Generated " << candidates.size(); + log_stream << " candidate"; + if (candidates.size() > 1) log_stream << "s"; + log_stream << " from " << generator << ":\n"; + for (const auto& c : candidates) log_stream << c << '\n'; } } +} + } // namespace debug -std::vector VariantGenerator::generate(const GenomicRegion& region) +std::vector VariantGenerator::generate(const GenomicRegion& region) const { std::vector result {}; - for (auto& generator : generators_) { - auto generator_result = generator->do_generate_variants(region); + for (auto& generator : variant_generators_) { + const auto active_regions = generate_active_regions(region, *generator); + debug::log_active_regions(active_regions, generator->name(), debug_log_); + auto generator_result = generator->do_generate(active_regions); + debug::log_candidates(generator_result, generator->name(), debug_log_); assert(std::is_sorted(std::cbegin(generator_result), std::cend(generator_result))); - if (debug_log_) { - debug::print_generated_candidates(stream(*debug_log_), generator_result, generator->name()); - } - auto itr = result.insert(std::end(result), - std::make_move_iterator(std::begin(generator_result)), - std::make_move_iterator(std::end(generator_result))); + auto itr = utils::append(std::move(generator_result), result); std::inplace_merge(std::begin(result), itr, std::end(result)); } // Each generator is guaranteed to return unique variants, but two generators can still @@ -84,18 +104,20 @@ std::vector VariantGenerator::generate(const GenomicRegion& region) bool VariantGenerator::requires_reads() const noexcept { - return std::any_of(std::cbegin(generators_), std::cend(generators_), + return std::any_of(std::cbegin(variant_generators_), std::cend(variant_generators_), [] (const auto& generator) { return generator->do_requires_reads(); }); } void VariantGenerator::add_read(const SampleName& sample, const AlignedRead& read) { - for (auto& generator : generators_) generator->do_add_read(sample, read); + if (active_region_generator_) active_region_generator_->add_read(sample, read); + for (auto& generator : variant_generators_) generator->do_add_read(sample, read); } void VariantGenerator::clear() noexcept { - for (auto& generator : generators_) generator->do_clear(); + if (active_region_generator_) active_region_generator_->clear(); + for (auto& generator : variant_generators_) generator->do_clear(); } std::unique_ptr VariantGenerator::do_clone() const @@ -103,6 +125,15 @@ std::unique_ptr VariantGenerator::do_clone() const return std::make_unique(*this); } +VariantGenerator::RegionSet +VariantGenerator::generate_active_regions(const GenomicRegion& region, const VariantGenerator& generator) const +{ + if (active_region_generator_) { + return active_region_generator_->generate(region, generator.name()); + } else { + return {region}; + } +} } // namespace coretools } // namespace octopus diff --git a/src/core/tools/vargen/variant_generator.hpp b/src/core/tools/vargen/variant_generator.hpp index 6e7f45f83..5490bcf08 100644 --- a/src/core/tools/vargen/variant_generator.hpp +++ b/src/core/tools/vargen/variant_generator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variant_generator_hpp @@ -16,20 +16,19 @@ #include "config/common.hpp" #include "logging/logging.hpp" +#include "basics/genomic_region.hpp" #include "basics/aligned_read.hpp" #include "core/types/variant.hpp" #include "containers/mappable_flat_multi_set.hpp" +#include "active_region_generator.hpp" -namespace octopus { - -class GenomicRegion; - -namespace coretools { +namespace octopus { namespace coretools { class VariantGenerator { public: VariantGenerator(); + VariantGenerator(ActiveRegionGenerator region_generator); VariantGenerator(const VariantGenerator&); VariantGenerator& operator=(VariantGenerator); @@ -44,53 +43,53 @@ class VariantGenerator std::unique_ptr clone() const; - std::vector generate(const GenomicRegion& region); + std::vector generate(const GenomicRegion& region) const; bool requires_reads() const noexcept; void add_read(const SampleName& sample, const AlignedRead& read); - template void add_reads(const SampleName& sample, InputIt first, InputIt last); void clear() noexcept; protected: - using VectorIterator = std::vector::const_iterator; - using FlatSetIterator = MappableFlatMultiSet::const_iterator; + using ReadVectorIterator = std::vector::const_iterator; + using ReadFlatSetIterator = MappableFlatMultiSet::const_iterator; + + using RegionSet = std::vector; mutable boost::optional debug_log_; mutable boost::optional trace_log_; private: - std::vector> generators_; + std::vector> variant_generators_; + boost::optional active_region_generator_; virtual std::unique_ptr do_clone() const; - - virtual std::vector do_generate_variants(const GenomicRegion& region) { return {}; }; - + virtual std::vector do_generate(const RegionSet& regions) const { return {}; }; virtual bool do_requires_reads() const noexcept { return false; }; - virtual void do_add_read(const SampleName& sample, const AlignedRead& read) {}; - // add_reads is not strictly necessary as the effect of calling add_reads must be the same as // calling add_read for each read. However, there may be performance benefits // to having an add_reads method to avoid many virtual dispatches. // Ideally add_reads would be a template to accept any InputIterator, but it is not possible // to have template virtual methods. The best solution is therefore to just overload add_reads // for common container iterators, more can easily be added if needed. - virtual void do_add_reads(const SampleName& sample, VectorIterator first, VectorIterator last) {}; - virtual void do_add_reads(const SampleName& sample, FlatSetIterator first, FlatSetIterator last) {}; - + virtual void do_add_reads(const SampleName& sample, ReadVectorIterator first, ReadVectorIterator last) {}; + virtual void do_add_reads(const SampleName& sample, ReadFlatSetIterator first, ReadFlatSetIterator last) {}; virtual void do_clear() noexcept {}; virtual std::string name() const { return "VariantGenerator"; } + + RegionSet generate_active_regions(const GenomicRegion& region, const VariantGenerator& generator) const; }; template void VariantGenerator::add_reads(const SampleName& sample, InputIt first, InputIt last) { - for (auto& generator : generators_) generator->do_add_reads(sample, first, last); + if (active_region_generator_) active_region_generator_->add_reads(sample, first, last); + for (auto& generator : variant_generators_) generator->do_add_reads(sample, first, last); } // non-member methods @@ -100,12 +99,13 @@ namespace detail { template void add_reads(const Container& reads, G& generator, std::true_type) { - generator.add_reads("octopus-sample", std::cbegin(reads), std::cend(reads)); + generator.add_reads("octopus", std::cbegin(reads), std::cend(reads)); } template void add_reads(const ReadMap& reads, G& generator, std::false_type) { + for (const auto& p : reads) { generator.add_reads(p.first, std::cbegin(p.second), std::cend(p.second)); } diff --git a/src/core/tools/vargen/variant_generator_builder.cpp b/src/core/tools/vargen/variant_generator_builder.cpp index 95702d601..5dbf40536 100644 --- a/src/core/tools/vargen/variant_generator_builder.cpp +++ b/src/core/tools/vargen/variant_generator_builder.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "variant_generator_builder.hpp" @@ -45,10 +45,17 @@ VariantGeneratorBuilder::add_randomiser(Randomiser::Options options) return *this; } +VariantGeneratorBuilder& +VariantGeneratorBuilder::set_active_region_generator(ActiveRegionGenerator::Options options) +{ + active_region_generator_ = std::move(options); + return *this; +} + VariantGenerator VariantGeneratorBuilder::build(const ReferenceGenome& reference) const { - VariantGenerator result {}; + VariantGenerator result {ActiveRegionGenerator {reference, active_region_generator_}}; if (cigar_scanner_) { result.add(std::make_unique(reference, *cigar_scanner_)); } @@ -64,7 +71,6 @@ VariantGenerator VariantGeneratorBuilder::build(const ReferenceGenome& reference for (auto options : randomisers_) { result.add(std::make_unique(reference, options)); } - return result; } diff --git a/src/core/tools/vargen/variant_generator_builder.hpp b/src/core/tools/vargen/variant_generator_builder.hpp index 9d7eaf210..0624ce3bd 100644 --- a/src/core/tools/vargen/variant_generator_builder.hpp +++ b/src/core/tools/vargen/variant_generator_builder.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variant_generator_builder_hpp @@ -16,6 +16,7 @@ #include "randomiser.hpp" #include "io/reference/reference_genome.hpp" #include "io/variant/vcf_reader.hpp" +#include "active_region_generator.hpp" namespace octopus { namespace coretools { @@ -40,6 +41,7 @@ class VariantGeneratorBuilder VcfExtractor::Options options = VcfExtractor::Options {}); VariantGeneratorBuilder& add_downloader(Downloader::Options options = Downloader::Options {}); VariantGeneratorBuilder& add_randomiser(Randomiser::Options options = Randomiser::Options {}); + VariantGeneratorBuilder& set_active_region_generator(ActiveRegionGenerator::Options options = ActiveRegionGenerator::Options {}); VariantGenerator build(const ReferenceGenome& reference) const; @@ -55,6 +57,7 @@ class VariantGeneratorBuilder std::deque vcf_extractors_; std::deque downloaders_; std::deque randomisers_; + ActiveRegionGenerator::Options active_region_generator_; }; } // namespace coretools diff --git a/src/core/tools/vargen/vcf_extractor.cpp b/src/core/tools/vargen/vcf_extractor.cpp index 78910c9ab..6ad3d561d 100644 --- a/src/core/tools/vargen/vcf_extractor.cpp +++ b/src/core/tools/vargen/vcf_extractor.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_extractor.hpp" @@ -11,6 +11,7 @@ #include "io/variant/vcf_spec.hpp" #include "io/variant/vcf_record.hpp" #include "utils/sequence_utils.hpp" +#include "utils/append.hpp" namespace octopus { namespace coretools { @@ -37,11 +38,6 @@ static bool is_canonical(const VcfRecord::NucleotideSequence& allele) [](const auto base) { return base == vcfspec::deletedBase; }); } -bool is_good_quality(const VcfRecord& record, boost::optional min_quality) noexcept -{ - return !min_quality || (record.qual() && *record.qual() >= *min_quality); -} - template auto make_allele(const Iterator first_base, const Iterator last_base) { @@ -89,34 +85,41 @@ void extract_variants(const VcfRecord& record, Container& result) } } -std::vector fetch_variants(const GenomicRegion& region, const VcfReader& reader, - const boost::optional min_quality) +} // namespace + +std::vector VcfExtractor::do_generate(const RegionSet& regions) const { - std::deque variants{}; // Use deque to prevent reallocating - auto p = reader.iterate(region, VcfReader::UnpackPolicy::sites); - std::for_each(std::move(p.first), std::move(p.second), - [&variants, min_quality](const auto& record) { - if (is_good_quality(record, min_quality)) { - extract_variants(record, variants); - } - }); - std::vector result{std::make_move_iterator(std::begin(variants)), - std::make_move_iterator(std::end(variants))}; - std::sort(std::begin(result), std::end(result)); - result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); + std::vector result {}; + for (const auto& region : regions) { + utils::append(fetch_variants(region), result); + } return result; } -} // namespace +std::string VcfExtractor::name() const +{ + return "VCF extraction"; +} -std::vector VcfExtractor::do_generate_variants(const GenomicRegion& region) +std::vector VcfExtractor::fetch_variants(const GenomicRegion& region) const { - return fetch_variants(region, *reader_, options_.min_quality); + std::deque variants {}; + for (auto p = reader_->iterate(region, VcfReader::UnpackPolicy::sites); p.first != p.second; ++p.first) { + if (is_good(*p.first)) { + extract_variants(*p.first, variants); + } + } + std::vector result {std::make_move_iterator(std::begin(variants)), + std::make_move_iterator(std::end(variants))}; + std::sort(std::begin(result), std::end(result)); + result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); + return result; } -std::string VcfExtractor::name() const +bool VcfExtractor::is_good(const VcfRecord& record) const { - return "VCF extraction"; + if (!options_.extract_filtered && is_filtered(record)) return false; + return !options_.min_quality || (record.qual() && *record.qual() >= *options_.min_quality); } } // namespace coretools diff --git a/src/core/tools/vargen/vcf_extractor.hpp b/src/core/tools/vargen/vcf_extractor.hpp index 68e5c8de5..691dc09b0 100644 --- a/src/core/tools/vargen/vcf_extractor.hpp +++ b/src/core/tools/vargen/vcf_extractor.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_extractor_hpp @@ -26,6 +26,7 @@ class VcfExtractor : public VariantGenerator struct Options { Variant::MappingDomain::Size max_variant_size = 100; + bool extract_filtered = false; boost::optional min_quality = boost::none; }; @@ -43,13 +44,14 @@ class VcfExtractor : public VariantGenerator private: std::unique_ptr do_clone() const override; - - std::vector do_generate_variants(const GenomicRegion& region) override; - + std::vector do_generate(const RegionSet& regions) const override; std::string name() const override; std::shared_ptr reader_; Options options_; + + std::vector fetch_variants(const GenomicRegion& region) const; + bool is_good(const VcfRecord& record) const; }; } // namespace coretools diff --git a/src/core/tools/vcf_header_factory.cpp b/src/core/tools/vcf_header_factory.cpp index ec6c1e4fb..07c8fd95d 100644 --- a/src/core/tools/vcf_header_factory.cpp +++ b/src/core/tools/vcf_header_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_header_factory.hpp" @@ -21,7 +21,7 @@ VcfHeaderFactory::AnnotatorMap VcfHeaderFactory::annotators_ = }}, {std::type_index(typeid(SomaticCall)), [] (auto& hb) { hb.add_info("SOMATIC", "0", "Flag", "Indicates that the record is a somatic mutation, for cancer genomics"); - hb.add_format("SCR", "2", "Float", "99% credible region of the somatic allele frequency"); + hb.add_format("VAF_CR", "2", "Float", "Credible region for the Variant Allele Frequency"); hb.add_info("MP", "1", "Float", "Model posterior"); }}, {std::type_index(typeid(DenovoCall)), [] (auto& hb) { diff --git a/src/core/tools/vcf_header_factory.hpp b/src/core/tools/vcf_header_factory.hpp index 18ceb0ac5..33b6e6b3f 100644 --- a/src/core/tools/vcf_header_factory.hpp +++ b/src/core/tools/vcf_header_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_header_factory_hpp diff --git a/src/core/tools/vcf_record_factory.cpp b/src/core/tools/vcf_record_factory.cpp index bcc10d3d0..8bba1a45c 100644 --- a/src/core/tools/vcf_record_factory.cpp +++ b/src/core/tools/vcf_record_factory.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_record_factory.hpp" @@ -29,6 +29,7 @@ #include "utils/append.hpp" #include "exceptions/program_error.hpp" #include "io/variant/vcf_spec.hpp" +#include "config/octopus_vcf.hpp" #define _unused(x) ((void)(x)) @@ -167,61 +168,73 @@ void resolve_indel_genotypes(std::vector& calls, const std::vector< } } +bool is_modified_phase_boundary(const CallWrapper& call, const std::vector& samples) +{ + return std::none_of(std::cbegin(samples), std::cend(samples), + [&call] (const auto& sample) { + const auto& old_phase = call->get_genotype_call(sample).phase; + return old_phase && begins_before(call, old_phase->region()); + }); +} + +template +void resolve_phase(CallWrapper& call, const SampleName& sample, + const Iterator first_phase_boundary_itr, const Iterator last_phase_boundary_itr) +{ + const auto& phase = call->get_genotype_call(sample).phase; + if (phase) { + auto overlapped = overlap_range(first_phase_boundary_itr, last_phase_boundary_itr, phase->region()); + if (overlapped.empty()) { + overlapped = overlap_range(first_phase_boundary_itr, last_phase_boundary_itr, expand_lhs(phase->region(), 1)); + if (!overlapped.empty() && begin_distance(overlapped.front(), phase->region()) != 1) { + overlapped.advance_begin(1); + } + } + if (!overlapped.empty() && overlapped.front() != call) { + const auto& old_phase = call->get_genotype_call(sample).phase; + auto new_phase_region = encompassing_region(overlapped.front(), old_phase->region()); + Call::PhaseCall new_phase {std::move(new_phase_region), old_phase->score()}; + call->set_phase(sample, std::move(new_phase)); + } + } +} + void pad_indels(std::vector& calls, const std::vector& samples) { - using std::begin; using std::end; using std::move; - const auto first_modified = std::stable_partition(begin(calls), end(calls), - [] (const auto& call) { return !call->parsimonise(dummy_base); }); - if (first_modified != end(calls)) { - const auto last = end(calls); - const auto first_phase_adjusted = std::partition(first_modified, last, - [&samples] (const auto& call) { - return std::none_of(begin(samples), cend(samples), - [&call] (const auto& sample) { - const auto& old_phase = call->get_genotype_call(sample).phase; - return old_phase && begins_before(mapped_region(call), old_phase->region()); - }); - }); - if (first_phase_adjusted != last) { - std::sort(first_phase_adjusted, last); - for_each(first_phase_adjusted, last, - [&samples] (auto& call) { - for (const auto& sample : samples) { - const auto& old_phase = call->get_genotype_call(sample).phase; - if (old_phase && begins_before(mapped_region(call), old_phase->region())) { - auto new_phase_region = expand_lhs(old_phase->region(), 1); - Call::PhaseCall new_phase {move(new_phase_region), old_phase->score()}; - call->set_phase(sample, move(new_phase)); - } - } - }); - for_each(begin(calls), first_phase_adjusted, - [&samples, first_phase_adjusted, last] (auto& call) { - for (const auto& sample : samples) { - const auto& phase = call->get_genotype_call(sample).phase; - if (phase) { - auto overlapped = overlap_range(first_phase_adjusted, last, phase->region()); - if (overlapped.empty()) { - overlapped = overlap_range(first_phase_adjusted, last, expand_lhs(phase->region(), 1)); - if (!overlapped.empty()) { - if (begin_distance(overlapped.front(), phase->region()) != 1) { - overlapped.advance_begin(1); - } - } - } - if (!overlapped.empty() && overlapped.front() != call) { - const auto& old_phase = call->get_genotype_call(sample).phase; - auto new_phase_region = encompassing_region(overlapped.front(), old_phase->region()); - Call::PhaseCall new_phase {move(new_phase_region), old_phase->score()}; - call->set_phase(sample, move(new_phase)); - } - } - } - }); + using std::begin; using std::end; + const auto first_modified_itr = std::stable_partition(begin(calls), end(calls), + [] (const auto& call) { return !call->parsimonise(dummy_base); }); + if (first_modified_itr != end(calls)) { + const auto last_call_itr = end(calls); + const auto first_phase_adjusted_itr = std::partition(first_modified_itr, last_call_itr, + [&samples] (const auto& call) { + return is_modified_phase_boundary(call, samples); }); + if (first_phase_adjusted_itr != last_call_itr) { + std::sort(first_phase_adjusted_itr, last_call_itr); + for (auto call_itr = first_phase_adjusted_itr; call_itr != last_call_itr; ++call_itr) { + auto& call = *call_itr; + for (const auto& sample : samples) { + const auto& old_phase = call->get_genotype_call(sample).phase; + if (old_phase) { + if (begins_before(call, old_phase->region())) { + auto new_phase_region = expand_lhs(old_phase->region(), 1); + Call::PhaseCall new_phase {std::move(new_phase_region), old_phase->score()}; + call->set_phase(sample, std::move(new_phase)); + } else { + resolve_phase(call, sample, first_phase_adjusted_itr, call_itr); + } + } + } + } + for_each(begin(calls), first_phase_adjusted_itr, [&samples, first_phase_adjusted_itr, last_call_itr] (auto& call) { + for (const auto& sample : samples) { + resolve_phase(call, sample, first_phase_adjusted_itr, last_call_itr); + } + }); } - std::sort(first_modified, first_phase_adjusted); - std::inplace_merge(first_modified, first_phase_adjusted, last); - std::inplace_merge(begin(calls), first_modified, last); + std::sort(first_modified_itr, first_phase_adjusted_itr); + std::inplace_merge(first_modified_itr, first_phase_adjusted_itr, last_call_itr); + std::inplace_merge(begin(calls), first_modified_itr, last_call_itr); } } @@ -240,7 +253,7 @@ std::vector VcfRecordFactory::make(std::vector&& calls) const auto block_begin_itr = adjacent_overlap_find(call_itr, end(calls)); transform(std::make_move_iterator(call_itr), std::make_move_iterator(block_begin_itr), std::back_inserter(result), [this] (CallWrapper&& call) { - call->replace(dummy_base, reference_.fetch_sequence(head_position(call->mapped_region())).front()); + call->replace(dummy_base, reference_.fetch_sequence(head_position(call)).front()); // We may still have uncalled genotyped alleles here if the called genotype // did not have a high posterior call->replace_uncalled_genotype_alleles(Allele {call->mapped_region(), vcfspec::missingValue}, 'N'); @@ -259,8 +272,8 @@ std::vector VcfRecordFactory::make(std::vector&& calls) [] (const auto& call) { return call->reference().sequence().front() == dummy_base; }); - boost::optional base; - if (alt_itr != block_head_end_itr) base = alt_itr; + boost::optional base {}; + if (alt_itr != block_head_end_itr) base = alt_itr; std::deque duplicates {}; for_each(block_begin_itr, block_head_end_itr, [this, base, &duplicates] (auto& call) { assert(!call->reference().sequence().empty()); @@ -281,11 +294,16 @@ std::vector VcfRecordFactory::make(std::vector&& calls) if (old_genotype[i].sequence().front() == dummy_base) { auto new_sequence = old_genotype[i].sequence(); if (base) { - const auto& base_sequence = (**base)->get_genotype_call(sample).genotype[i].sequence(); - if (!base_sequence.empty()) { - new_sequence.front() = base_sequence.front(); + const auto& base_genotype = (**base)->get_genotype_call(sample).genotype; + if (base_genotype.ploidy() == ploidy) { + const auto& base_sequence = base_genotype[i].sequence(); + if (!base_sequence.empty()) { + new_sequence.front() = base_sequence.front(); + } else { + new_sequence = vcfspec::missingValue; + } } else { - new_sequence = vcfspec::missingValue; + new_sequence.front() = actual_reference_base; } } else { new_sequence.front() = actual_reference_base; @@ -376,7 +394,8 @@ std::vector VcfRecordFactory::make(std::vector&& calls) Allele new_allele {mapped_region(curr_call), move(new_sequence)}; new_genotype.emplace(move(new_allele)); } else if (old_genotype[i].sequence().front() == dummy_base) { - if (prev_represented[s][i] && begins_before(*prev_represented[s][i], curr_call)) { + if (prev_represented[s].size() > i && prev_represented[s][i] + && begins_before(*prev_represented[s][i], curr_call)) { const auto& prev_represented_genotype = prev_represented[s][i]->get_genotype_call(sample); if (are_in_phase(genotype_call, prev_represented_genotype)) { const auto& prev_allele = prev_represented_genotype.genotype[i]; @@ -446,6 +465,9 @@ std::vector VcfRecordFactory::make(std::vector&& calls) const auto& seq = new_genotype[i].sequence(); if (std::find(std::cbegin(seq), std::cend(seq), vcfspec::deletedBase) == std::cend(seq) && block_head_end_itr->call->is_represented(new_genotype[i])) { + if (prev_represented[s].size() <= i) { + prev_represented[s].resize(i + 1, nullptr); + } prev_represented[s][i] = std::addressof(*block_head_end_itr->call); } } @@ -530,7 +552,7 @@ void set_vcf_genotype(const SampleName& sample, const Call::GenotypeCall& call, auto genotyped_alleles = extract_allele_sequences(call.genotype); if (replace_missing_with_non_ref) { std::replace(std::begin(genotyped_alleles), std::end(genotyped_alleles), - std::string {vcfspec::missingValue}, std::string {""}); + std::string {vcfspec::missingValue}, std::string {vcf::spec::allele::nonref}); } record.set_genotype(sample, std::move(genotyped_alleles), VcfRecord::Builder::Phasing::phased); } @@ -592,10 +614,10 @@ VcfRecord VcfRecordFactory::make(std::unique_ptr call) const bool has_non_ref {false}; auto alts = extract_genotyped_alt_alleles(call.get(), samples_); if (alts.empty()) { - alts.push_back(""); + alts.push_back(vcf::spec::allele::nonref); has_non_ref = true; } else { - has_non_ref = std::find(std::cbegin(alts), std::cend(alts), "") != std::cend(alts); + has_non_ref = std::find(std::cbegin(alts), std::cend(alts), vcf::spec::allele::nonref) != std::cend(alts); } result.set_chrom(contig_name(region)); @@ -606,8 +628,6 @@ VcfRecord VcfRecordFactory::make(std::unique_ptr call) const const auto call_reads = copy_overlapped(reads_, region); result.set_info("NS", count_samples_with_coverage(call_reads)); result.set_info("DP", sum_max_coverages(call_reads)); - result.set_info("SB", utils::to_string(strand_bias(call_reads), 2)); - result.set_info("BQ", static_cast(rmq_base_quality(call_reads))); result.set_info("MQ", static_cast(rmq_mapping_quality(call_reads))); result.set_info("MQ0", count_mapq_zero(call_reads)); set_allele_counts(*call, samples_, result); @@ -618,9 +638,9 @@ VcfRecord VcfRecordFactory::make(std::unique_ptr call) const if (!sites_only_) { if (call->all_phased()) { - result.set_format({"GT", "GQ", "DP", "BQ", "MQ", "PS", "PQ"}); + result.set_format({"GT", "GQ", "DP", "MQ", "PS", "PQ"}); } else { - result.set_format({"GT", "GQ", "DP", "BQ", "MQ"}); + result.set_format({"GT", "GQ", "DP", "MQ"}); } for (const auto& sample : samples_) { const auto& genotype_call = call->get_genotype_call(sample); @@ -628,7 +648,6 @@ VcfRecord VcfRecordFactory::make(std::unique_ptr call) const set_vcf_genotype(sample, genotype_call, result, has_non_ref); result.set_format(sample, "GQ", std::to_string(gq)); result.set_format(sample, "DP", max_coverage(call_reads.at(sample))); - result.set_format(sample, "BQ", static_cast(rmq_base_quality(call_reads.at(sample)))); result.set_format(sample, "MQ", static_cast(rmq_mapping_quality(call_reads.at(sample)))); if (call->is_phased(sample)) { const auto& phase = *genotype_call.phase; @@ -744,10 +763,10 @@ VcfRecord VcfRecordFactory::make_segment(std::vector>&& ca alt_alleles.erase(itr, std::end(alt_alleles)); bool has_non_ref {false}; if (alt_alleles.empty()) { - alt_alleles.push_back(""); + alt_alleles.push_back(vcf::spec::allele::nonref); has_non_ref = true; } else { - has_non_ref = std::find(std::cbegin(alt_alleles), std::cend(alt_alleles), "") != std::cend(alt_alleles); + has_non_ref = std::find(std::cbegin(alt_alleles), std::cend(alt_alleles), vcf::spec::allele::nonref) != std::cend(alt_alleles); } set_allele_counts(alt_alleles, resolved_genotypes, result); result.set_alt(std::move(alt_alleles)); @@ -756,8 +775,6 @@ VcfRecord VcfRecordFactory::make_segment(std::vector>&& ca result.set_qual(std::min(max_qual, maths::round(q->get()->quality().score(), 2))); result.set_info("NS", count_samples_with_coverage(reads_, region)); result.set_info("DP", sum_max_coverages(reads_, region)); - result.set_info("SB", utils::to_string(strand_bias(reads_, region), 2)); - result.set_info("BQ", static_cast(rmq_base_quality(reads_, region))); result.set_info("MQ", static_cast(rmq_mapping_quality(reads_, region))); result.set_info("MQ0", count_mapq_zero(reads_, region)); @@ -767,9 +784,9 @@ VcfRecord VcfRecordFactory::make_segment(std::vector>&& ca } if (!sites_only_) { if (calls.front()->all_phased()) { - result.set_format({"GT", "GQ", "DP", "BQ", "MQ", "PS", "PQ"}); + result.set_format({"GT", "GQ", "DP", "MQ", "PS", "PQ"}); } else { - result.set_format({"GT", "GQ", "DP", "BQ", "MQ"}); + result.set_format({"GT", "GQ", "DP", "MQ"}); } auto sample_itr = std::begin(resolved_genotypes); @@ -779,12 +796,11 @@ VcfRecord VcfRecordFactory::make_segment(std::vector>&& ca auto& genotype_call = *sample_itr++; if (has_non_ref) { std::replace(std::begin(genotype_call), std::end(genotype_call), - std::string {vcfspec::missingValue}, std::string {""}); + std::string {vcfspec::missingValue}, std::string {vcf::spec::allele::nonref}); } result.set_genotype(sample, genotype_call, VcfRecord::Builder::Phasing::phased); result.set_format(sample, "GQ", std::to_string(gq)); result.set_format(sample, "DP", max_coverage(reads_.at(sample), region)); - result.set_format(sample, "BQ", static_cast(rmq_base_quality(reads_.at(sample), region))); result.set_format(sample, "MQ", static_cast(rmq_mapping_quality(reads_.at(sample), region))); if (calls.front()->is_phased(sample)) { const auto phase = *calls.front()->get_genotype_call(sample).phase; diff --git a/src/core/tools/vcf_record_factory.hpp b/src/core/tools/vcf_record_factory.hpp index cd957752a..c7d604256 100644 --- a/src/core/tools/vcf_record_factory.hpp +++ b/src/core/tools/vcf_record_factory.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_record_factory_hpp diff --git a/src/core/types/allele.cpp b/src/core/types/allele.cpp index a0f834488..b3072a0d0 100644 --- a/src/core/types/allele.cpp +++ b/src/core/types/allele.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "allele.hpp" diff --git a/src/core/types/allele.hpp b/src/core/types/allele.hpp index 111a6d074..3b6981193 100644 --- a/src/core/types/allele.hpp +++ b/src/core/types/allele.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef allele_hpp @@ -154,13 +154,15 @@ auto copy_sequence(const BasicAllele& allele, const RegionTp& region) if (mapped_region(allele) == region) return sequence; const auto region_offset = static_cast(begin_distance(allele, region)); auto first_base_itr = std::cbegin(sequence), last_base_itr = std::cend(sequence); - const auto region_size = static_cast(size(region)); if (is_deletion(allele)) { if (!is_sequence_empty(allele)) { - const auto base_offset = std::min(region_offset, sequence_size(allele)); - first_base_itr = std::next(std::cbegin(allele.sequence()), base_offset); - const auto num_remaining_bases = std::min(region_size, sequence_size(allele) - base_offset); - last_base_itr = std::next(first_base_itr, num_remaining_bases); + first_base_itr = std::cbegin(allele.sequence()); + const auto num_deleted_based = region_size(allele) - sequence_size(allele); + if (size(region) <= num_deleted_based) { + last_base_itr = first_base_itr; + } else { + last_base_itr = std::next(first_base_itr, size(region) - num_deleted_based); + } } } else { first_base_itr = std::next(std::cbegin(allele.sequence()), region_offset); @@ -169,7 +171,7 @@ auto copy_sequence(const BasicAllele& allele, const RegionTp& region) const auto num_subsequence_bases = sequence_size(allele) - region_offset - num_trailing_bases; last_base_itr = std::next(first_base_itr, num_subsequence_bases); } else { - last_base_itr = std::next(first_base_itr, region_size); + last_base_itr = std::next(first_base_itr, size(region)); } } assert(first_base_itr <= last_base_itr); diff --git a/src/core/types/calls/call.cpp b/src/core/types/calls/call.cpp index 4c8ff84bf..4415e18c9 100644 --- a/src/core/types/calls/call.cpp +++ b/src/core/types/calls/call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "call.hpp" @@ -7,6 +7,10 @@ namespace octopus { +Call::Call(Phred quality) +: quality_ {quality} +{} + std::unique_ptr Call::clone() const { return do_clone(); diff --git a/src/core/types/calls/call.hpp b/src/core/types/calls/call.hpp index 260ba0ab9..a62dbb1b0 100644 --- a/src/core/types/calls/call.hpp +++ b/src/core/types/calls/call.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef call_hpp diff --git a/src/core/types/calls/call_types.hpp b/src/core/types/calls/call_types.hpp index 701e9e8b4..ccc25a9f5 100644 --- a/src/core/types/calls/call_types.hpp +++ b/src/core/types/calls/call_types.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef call_types_hpp diff --git a/src/core/types/calls/call_wrapper.cpp b/src/core/types/calls/call_wrapper.cpp index 7cc5381b0..cacc0bcd9 100644 --- a/src/core/types/calls/call_wrapper.cpp +++ b/src/core/types/calls/call_wrapper.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "call_wrapper.hpp" diff --git a/src/core/types/calls/call_wrapper.hpp b/src/core/types/calls/call_wrapper.hpp index 3d53a4211..2b2c71bb7 100644 --- a/src/core/types/calls/call_wrapper.hpp +++ b/src/core/types/calls/call_wrapper.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef call_wrapper_hpp diff --git a/src/core/types/calls/denovo_call.cpp b/src/core/types/calls/denovo_call.cpp index 4de74cf7e..ea0e23dd9 100644 --- a/src/core/types/calls/denovo_call.cpp +++ b/src/core/types/calls/denovo_call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "denovo_call.hpp" diff --git a/src/core/types/calls/denovo_call.hpp b/src/core/types/calls/denovo_call.hpp index cca883a7e..22ed5bbe3 100644 --- a/src/core/types/calls/denovo_call.hpp +++ b/src/core/types/calls/denovo_call.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef denovo_call_hpp diff --git a/src/core/types/calls/denovo_reference_reversion_call.cpp b/src/core/types/calls/denovo_reference_reversion_call.cpp index a7412a4ef..6f69b0c9f 100644 --- a/src/core/types/calls/denovo_reference_reversion_call.cpp +++ b/src/core/types/calls/denovo_reference_reversion_call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "denovo_reference_reversion_call.hpp" diff --git a/src/core/types/calls/denovo_reference_reversion_call.hpp b/src/core/types/calls/denovo_reference_reversion_call.hpp index 49498848b..92303686c 100644 --- a/src/core/types/calls/denovo_reference_reversion_call.hpp +++ b/src/core/types/calls/denovo_reference_reversion_call.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef denovo_reference_reversion_call_hpp diff --git a/src/core/types/calls/germline_variant_call.cpp b/src/core/types/calls/germline_variant_call.cpp index ebce505a5..42222212f 100644 --- a/src/core/types/calls/germline_variant_call.cpp +++ b/src/core/types/calls/germline_variant_call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "germline_variant_call.hpp" diff --git a/src/core/types/calls/germline_variant_call.hpp b/src/core/types/calls/germline_variant_call.hpp index ccebca92b..7787e7ec8 100644 --- a/src/core/types/calls/germline_variant_call.hpp +++ b/src/core/types/calls/germline_variant_call.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef germline_variant_call_hpp diff --git a/src/core/types/calls/reference_call.cpp b/src/core/types/calls/reference_call.cpp index e6ec897be..1c5803f32 100644 --- a/src/core/types/calls/reference_call.cpp +++ b/src/core/types/calls/reference_call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "reference_call.hpp" diff --git a/src/core/types/calls/reference_call.hpp b/src/core/types/calls/reference_call.hpp index 5d302d4da..11c862daf 100644 --- a/src/core/types/calls/reference_call.hpp +++ b/src/core/types/calls/reference_call.hpp @@ -1,9 +1,10 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef reference_call_hpp #define reference_call_hpp +#include #include #include "call.hpp" @@ -15,8 +16,14 @@ class ReferenceCall : public Call public: ReferenceCall() = delete; + struct GenotypeCall + { + unsigned ploidy; + Phred posterior; + }; + template - ReferenceCall(A&& reference, Phred quality); + ReferenceCall(A&& reference, Phred quality, std::map genotypes); ReferenceCall(const ReferenceCall&) = default; ReferenceCall& operator=(const ReferenceCall&) = default; @@ -44,10 +51,17 @@ class ReferenceCall : public Call }; template -ReferenceCall::ReferenceCall(A&& reference, Phred quality) +ReferenceCall::ReferenceCall(A&& reference, Phred quality, std::map genotypes) : Call {quality} , reference_ {std::forward(reference)} -{} +{ + genotype_calls_.reserve(genotypes.size()); + for (const auto& p : genotypes) { + genotype_calls_.emplace(std::piecewise_construct, + std::forward_as_tuple(p.first), + std::forward_as_tuple(Genotype {p.second.ploidy, reference_}, p.second.posterior)); + } +} } // namespace octopus diff --git a/src/core/types/calls/somatic_call.cpp b/src/core/types/calls/somatic_call.cpp index c765488cc..fb059c124 100644 --- a/src/core/types/calls/somatic_call.cpp +++ b/src/core/types/calls/somatic_call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "somatic_call.hpp" @@ -10,17 +10,13 @@ namespace octopus { void SomaticCall::decorate(VcfRecord::Builder& record) const { record.set_somatic(); - record.add_format("SCR"); - + record.add_format("VAF_CR"); for (const auto& p : credible_regions_) { if (p.second.somatic) { using utils::to_string; - record.set_format(p.first, "SCR", { - to_string(p.second.somatic->first, 2), - to_string(p.second.somatic->second, 2) - }); + record.set_format(p.first, "VAF_CR", {to_string(p.second.somatic->first), to_string(p.second.somatic->second)}); } else { - record.set_format(p.first, "SCR", {"0", "0"}); + record.set_format_missing(p.first, "VAF_CR"); } } } diff --git a/src/core/types/calls/somatic_call.hpp b/src/core/types/calls/somatic_call.hpp index 499e57275..29b254bd1 100644 --- a/src/core/types/calls/somatic_call.hpp +++ b/src/core/types/calls/somatic_call.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef somatic_call_hpp @@ -76,7 +76,7 @@ SomaticCall::SomaticCall(V&& variant, if (p.second.somatic) { genotype_calls_.emplace(p.first, GenotypeCall {demote(genotype_call), genotype_posterior}); } else { - genotype_calls_.emplace(p.first, GenotypeCall {genotype_call.germline_genotype(), genotype_posterior}); + genotype_calls_.emplace(p.first, GenotypeCall {genotype_call.germline(), genotype_posterior}); } } } diff --git a/src/core/types/calls/variant_call.cpp b/src/core/types/calls/variant_call.cpp index 07635874a..f0e271e0f 100644 --- a/src/core/types/calls/variant_call.cpp +++ b/src/core/types/calls/variant_call.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "variant_call.hpp" diff --git a/src/core/types/calls/variant_call.hpp b/src/core/types/calls/variant_call.hpp index 176be3fe3..4d632d758 100644 --- a/src/core/types/calls/variant_call.hpp +++ b/src/core/types/calls/variant_call.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variant_call_hpp diff --git a/src/core/types/cancer_genotype.cpp b/src/core/types/cancer_genotype.cpp index 32b5dcf33..822b5d263 100644 --- a/src/core/types/cancer_genotype.cpp +++ b/src/core/types/cancer_genotype.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "cancer_genotype.hpp" @@ -10,12 +10,12 @@ namespace octopus { bool contains(const CancerGenotype& genotype, const Allele& allele) { - return contains(genotype.germline_genotype(), allele) || genotype.somatic_element().contains(allele); + return contains(genotype.germline(), allele) || contains(genotype.somatic(), allele); } bool includes(const CancerGenotype& genotype, const Allele& allele) { - return includes(genotype.germline_genotype(), allele) || genotype.somatic_element().includes(allele); + return includes(genotype.germline(), allele) || includes(genotype.somatic(), allele); } namespace { @@ -28,19 +28,58 @@ auto make_all_shared(const std::vector& elements) return result; } +bool is_proper_somatic_genotype(const Genotype& genotype) +{ + return genotype.zygosity() == genotype.ploidy(); +} + +std::vector> generate_somatic_genotypes(const std::vector& haplotypes, + const unsigned ploidy) +{ + auto result = generate_all_genotypes(haplotypes, ploidy); + if (ploidy > 1) { + auto itr = std::remove_if(std::begin(result), std::end(result), is_proper_somatic_genotype); + result.erase(itr, std::end(result)); + } + return result; +} + +std::vector> generate_somatic_genotypes(const std::vector& haplotypes, + const unsigned ploidy, + std::vector& germline_genotype_indices) +{ + std::vector all_genotype_indices {}; + auto all_genotypes = generate_all_genotypes(haplotypes, ploidy, all_genotype_indices); + if (ploidy > 1) { + std::vector> result {}; + result.reserve(all_genotypes.size()); + for (std::size_t i {0}; i < all_genotypes.size(); ++i) { + if (is_proper_somatic_genotype(all_genotypes[i])) { + result.push_back(std::move(all_genotypes[i])); + germline_genotype_indices.push_back(std::move(all_genotype_indices[i])); + } + } + return result; + } else { + germline_genotype_indices = std::move(all_genotype_indices); + return all_genotypes; + } +} + } // namespace std::vector> generate_all_cancer_genotypes(const std::vector>& germline_genotypes, - const std::vector& somatic_haplotypes) + const std::vector& somatic_haplotypes, + const unsigned somatic_ploidy, const bool allow_shared) { - const auto haplotype_ptrs = make_all_shared(somatic_haplotypes); + const auto somatic_genotypes = generate_somatic_genotypes(somatic_haplotypes, somatic_ploidy); std::vector> result {}; - result.reserve(germline_genotypes.size() * somatic_haplotypes.size()); - for (const auto& genotype : germline_genotypes) { - for (const auto& haplotype : haplotype_ptrs) { - if (!contains(genotype, *haplotype)) { - result.emplace_back(genotype, haplotype); + result.reserve(germline_genotypes.size() * somatic_genotypes.size()); + for (const auto& germline : germline_genotypes) { + for (const auto& somatic : somatic_genotypes) { + if (allow_shared || !have_shared(germline, somatic)) { + result.emplace_back(germline, somatic); } } } @@ -49,21 +88,23 @@ generate_all_cancer_genotypes(const std::vector>& germline_g std::vector> generate_all_cancer_genotypes(const std::vector>& germline_genotypes, - const std::vector>& germline_genotype_indices, + const std::vector& germline_genotype_indices, const std::vector& somatic_haplotypes, - std::vector, unsigned>>& cancer_genotype_indices) + std::vector& cancer_genotype_indices, + const unsigned somatic_ploidy, const bool allow_shared) { assert(germline_genotypes.size() == germline_genotype_indices.size()); - const auto haplotype_ptrs = make_all_shared(somatic_haplotypes); + std::vector somatic_genotype_indices {}; + const auto somatic_genotypes = generate_somatic_genotypes(somatic_haplotypes, somatic_ploidy, somatic_genotype_indices); std::vector> result {}; - const auto max_cancer_genotypes = germline_genotypes.size() * somatic_haplotypes.size(); + const auto max_cancer_genotypes = germline_genotypes.size() * somatic_genotypes.size(); result.reserve(max_cancer_genotypes); cancer_genotype_indices.reserve(max_cancer_genotypes); for (std::size_t g {0}; g < germline_genotypes.size(); ++g) { - for (unsigned h {0}; h < somatic_haplotypes.size(); ++h) { - if (!contains(germline_genotypes[g], *haplotype_ptrs[h])) { - result.emplace_back(germline_genotypes[g], haplotype_ptrs[h]); - cancer_genotype_indices.emplace_back(germline_genotype_indices[g], h); + for (std::size_t h {0}; h < somatic_genotypes.size(); ++h) { + if (allow_shared || !have_shared(germline_genotypes[g], somatic_genotypes[h])) { + result.emplace_back(germline_genotypes[g], somatic_genotypes[h]); + cancer_genotype_indices.push_back({germline_genotype_indices[g], somatic_genotype_indices[h]}); } } } diff --git a/src/core/types/cancer_genotype.hpp b/src/core/types/cancer_genotype.hpp index 4b991ca77..0c662f12e 100644 --- a/src/core/types/cancer_genotype.hpp +++ b/src/core/types/cancer_genotype.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef cancer_genotype_hpp @@ -8,10 +8,13 @@ #include #include #include +#include +#include #include #include "concepts/equitable.hpp" +#include "utils/append.hpp" #include "genotype.hpp" namespace octopus { @@ -25,16 +28,16 @@ class CancerGenotype CancerGenotype() = default; - CancerGenotype(std::initializer_list normal_elements, - const MappableType& somatic_element); - CancerGenotype(std::initializer_list normal_elements, - MappableType&& somatic_element); - + CancerGenotype(std::initializer_list germline, + std::initializer_list somatic); + CancerGenotype(std::initializer_list germline, + const MappableType& somatic); + CancerGenotype(std::initializer_list germline, + MappableType&& somatic); + template + CancerGenotype(G&& germline, S&& somatic); template - CancerGenotype(G&& germline_genotype, const std::shared_ptr& somatic_element); - - template - CancerGenotype(G&& germline_genotype, C&& somatic_element); + CancerGenotype(G&& germline, const std::shared_ptr& somatic); CancerGenotype(const CancerGenotype&) = default; CancerGenotype& operator=(const CancerGenotype&) = default; @@ -47,116 +50,133 @@ class CancerGenotype const MappableType& operator[](unsigned n) const; - const Genotype& germline_genotype() const; - - const MappableType& somatic_element() const; + const Genotype& germline() const; + const Genotype& somatic() const; + unsigned germline_ploidy() const noexcept; + unsigned somatic_ploidy() const noexcept; unsigned ploidy() const noexcept; - bool contains(const MappableType& element) const; - unsigned count(const MappableType& element) const; - bool is_homozygous() const; - unsigned zygosity() const; std::vector copy_unique() const; private: - Genotype germline_genotype_; - std::shared_ptr somatic_element_; + Genotype germline_, somatic_; }; template -CancerGenotype::CancerGenotype(std::initializer_list germline_elements, - const MappableType& somatic_element) -: germline_genotype_ {germline_elements} -, somatic_element_ {std::make_shared(somatic_element)} +CancerGenotype::CancerGenotype(std::initializer_list germline, std::initializer_list somatic) +: germline_ {germline} +, somatic_ {somatic} {} template -CancerGenotype::CancerGenotype(std::initializer_list germline_elements, - MappableType&& somatic_element) -: germline_genotype_ {germline_elements} -, somatic_element_ {std::make_shared(std::move(somatic_element))} +CancerGenotype::CancerGenotype(std::initializer_list germline, const MappableType& somatic) +: germline_ {germline} +, somatic_ {std::make_shared(somatic)} +{} + +template +CancerGenotype::CancerGenotype(std::initializer_list germline, MappableType&& somatic) +: germline_ {germline} +, somatic_ {std::make_shared(std::move(somatic))} {} template template -CancerGenotype::CancerGenotype(G&& germline_genotype, - const std::shared_ptr& somatic_element) -: germline_genotype_ {std::forward(germline_genotype)} -, somatic_element_ {somatic_element} +CancerGenotype::CancerGenotype(G&& germline, const std::shared_ptr& somatic) +: germline_ {std::forward(germline)} +, somatic_ {somatic} {} template -template -CancerGenotype::CancerGenotype(G&& germline_genotype, C&& somatic_element) -: germline_genotype_ {std::forward(germline_genotype)} -, somatic_element_ {std::make_shared(std::forward(somatic_element))} +template +CancerGenotype::CancerGenotype(G&& germline, S&& somatic) +: germline_ {std::forward(germline)} +, somatic_ {std::forward(somatic)} {} template const GenomicRegion& CancerGenotype::mapped_region() const noexcept { using octopus::mapped_region; - return mapped_region(*somatic_element_); + return mapped_region(germline_); } template const MappableType& CancerGenotype::operator[](unsigned n) const { - return (n < ploidy()) ? germline_genotype_[n] : *somatic_element_; + return (n < germline_ploidy()) ? germline_[n] : somatic_[n]; } template -const Genotype& CancerGenotype::germline_genotype() const +const Genotype& CancerGenotype::germline() const { - return germline_genotype_; + return germline_; } template -const MappableType& CancerGenotype::somatic_element() const +const Genotype& CancerGenotype::somatic() const { - return *somatic_element_; + return somatic_; +} + +template +unsigned CancerGenotype::germline_ploidy() const noexcept +{ + return germline_.ploidy(); +} + +template +unsigned CancerGenotype::somatic_ploidy() const noexcept +{ + return somatic_.ploidy(); } template unsigned CancerGenotype::ploidy() const noexcept { - return germline_genotype_.ploidy(); + return germline_ploidy() + somatic_ploidy(); } template bool CancerGenotype::contains(const MappableType& element) const { - return germline_genotype_.contains(element) || somatic_element_ == element; + return germline_.contains(element) || somatic_.contains(element); } template unsigned CancerGenotype::count(const MappableType& element) const { - return germline_genotype_.count(element) + ((*somatic_element_ == element) ? 1 : 0); + return germline_.count(element) + somatic_.count(element); } template bool CancerGenotype::is_homozygous() const { - return germline_genotype_.is_homozygous() && somatic_element_ == germline_genotype_[0]; + return germline_.is_homozygous() && somatic_.count(germline_[0]) == somatic_.ploidy(); } template unsigned CancerGenotype::zygosity() const { - return germline_genotype_.zygosity() + ((germline_genotype_.contains(somatic_element_)) ? 0 : 1); + if (somatic_.ploidy() == 1) { + return germline_.zygosity() + ((germline_.contains(somatic_)) ? 0 : 1); + } else { + return copy_unique().size(); + } } template std::vector CancerGenotype::copy_unique() const { - auto result = germline_genotype_.get_unique(); - if (!germline_genotype_.contains(somatic_element_)) result.push_back(somatic_element_); + auto result = germline_.copy_unique(); + auto itr = utils::append(somatic_.copy_unique(), result); + std::inplace_merge(std::begin(result), itr, std::end(result)); + result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); return result; } @@ -168,10 +188,8 @@ bool includes(const CancerGenotype& genotype, const Allele& allele); template CancerGenotype copy(const CancerGenotype& genotype, const GenomicRegion& region) { - return CancerGenotype { - copy(genotype.germline_genotype(), region), - copy(genotype.somatic_element(), region) - }; + return CancerGenotype {copy(genotype.germline(), region), + copy(genotype.somatic(), region)}; } template @@ -180,37 +198,46 @@ bool contains(const CancerGenotype& lhs, const CancerGenotype(lhs, rhs.mapped_region()) == rhs; } +struct CancerGenotypeIndex +{ + GenotypeIndex germline, somatic; +}; + std::vector> generate_all_cancer_genotypes(const std::vector>& germline_genotypes, - const std::vector& somatic_haplotypes); + const std::vector& somatic_haplotypes, + unsigned somatic_ploidy = 1, bool allow_shared = false); std::vector> generate_all_cancer_genotypes(const std::vector>& germline_genotypes, - const std::vector>& germline_genotype_indices, + const std::vector& germline_genotype_indices, const std::vector& somatic_haplotypes, - std::vector, unsigned>>& cancer_genotype_indices); + std::vector& cancer_genotype_indices, + unsigned somatic_ploidy = 1, bool allow_shared = false); template Genotype demote(const CancerGenotype& genotype) { - Genotype result {genotype.ploidy() + 1}; - for (const auto& e : genotype.germline_genotype()) { + Genotype result {genotype.ploidy()}; + for (const auto& e : genotype.germline()) { + result.emplace(e); + } + for (const auto& e : genotype.somatic()) { result.emplace(e); } - result.emplace(genotype.somatic_element()); return result; } template bool operator==(const CancerGenotype& lhs, const CancerGenotype& rhs) { - return lhs.somatic_element() == rhs.somatic_element() && lhs.germline_genotype() == rhs.germline_genotype(); + return lhs.germline() == rhs.germline() && lhs.somatic() == rhs.somatic(); } template std::ostream& operator<<(std::ostream& os, const CancerGenotype& genotype) { - os << genotype.germline_genotype() << "," << genotype.somatic_element() << "(cancer)"; + os << genotype.germline() << "," << genotype.somatic() << "(cancer)"; return os; } @@ -221,8 +248,8 @@ struct CancerGenotypeHash { using boost::hash_combine; size_t result {}; - hash_combine(result, std::hash>()(genotype.germline_genotype())); - hash_combine(result, std::hash()(genotype.somatic_element())); + hash_combine(result, std::hash>()(genotype.germline())); + hash_combine(result, std::hash>()(genotype.somatic())); return result; } }; @@ -232,9 +259,9 @@ namespace debug { template void print_alleles(S&& stream, const CancerGenotype& genotype) { - print_alleles(stream, genotype.germline_genotype()); + print_alleles(stream, genotype.germline()); stream << " + "; - print_alleles(stream, genotype.somatic_element()); + print_alleles(stream, genotype.somatic()); } void print_alleles(const CancerGenotype& genotype); @@ -242,9 +269,9 @@ void print_alleles(const CancerGenotype& genotype); template void print_variant_alleles(S&& stream, const CancerGenotype& genotype) { - print_variant_alleles(stream, genotype.germline_genotype()); + print_variant_alleles(stream, genotype.germline()); stream << " + "; - print_variant_alleles(stream, genotype.somatic_element()); + print_variant_alleles(stream, genotype.somatic()); } void print_variant_alleles(const CancerGenotype& genotype); diff --git a/src/core/types/genotype.cpp b/src/core/types/genotype.cpp index cbc91317b..ee30457b1 100644 --- a/src/core/types/genotype.cpp +++ b/src/core/types/genotype.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "genotype.hpp" @@ -7,6 +7,8 @@ #include +#include "utils/maths.hpp" + namespace octopus { Genotype::Genotype(const unsigned ploidy) @@ -71,19 +73,16 @@ bool Genotype::is_homozygous() const unsigned Genotype::zygosity() const { unsigned result {0}; - for (auto it = std::cbegin(haplotypes_), last = std::cend(haplotypes_); it != last; ++result) { // naive algorithm faster in practice than binary searching it = std::find_if_not(std::next(it), last, [it] (const auto& x) { return *x == **it; }); } - return result; } bool Genotype::contains(const Haplotype& haplotype) const { - return std::binary_search(std::cbegin(haplotypes_), std::cend(haplotypes_), haplotype, - HaplotypePtrLess {}); + return std::binary_search(std::cbegin(haplotypes_), std::cend(haplotypes_), haplotype, HaplotypePtrLess {}); } unsigned Genotype::count(const Haplotype& haplotype) const @@ -97,16 +96,11 @@ std::vector Genotype::copy_unique() const { std::vector> ptr_copy {}; ptr_copy.reserve(ploidy()); - - std::unique_copy(std::cbegin(haplotypes_), std::cend(haplotypes_), std::back_inserter(ptr_copy), - HaplotypePtrEqual {}); - + std::unique_copy(std::cbegin(haplotypes_), std::cend(haplotypes_), std::back_inserter(ptr_copy), HaplotypePtrEqual {}); std::vector result {}; result.reserve(ptr_copy.size()); - std::transform(std::cbegin(ptr_copy), std::cend(ptr_copy), std::back_inserter(result), [] (const auto& ptr) { return *ptr.get(); }); - return result; } @@ -114,47 +108,50 @@ std::vector> Genotype::copy_u { std::vector> result {}; result.reserve(ploidy()); - std::transform(std::cbegin(haplotypes_), std::cend(haplotypes_), std::back_inserter(result), [] (const HaplotypePtr& haplotype) { return std::cref(*haplotype); }); - result.erase(std::unique(std::begin(result), std::end(result)), std::end(result)); - return result; } -bool Genotype::HaplotypePtrLess::operator()(const HaplotypePtr& lhs, - const HaplotypePtr& rhs) const +std::vector Genotype::unique_counts() const +{ + std::vector result {}; + result.reserve(haplotypes_.size()); + for (auto itr = std::cbegin(haplotypes_), last = std::cend(haplotypes_); itr != last;) { + auto next = std::find_if_not(std::next(itr), last, [itr] (const auto& x) { return *x == **itr; }); + result.push_back(std::distance(itr, next)); + itr = next; + } + return result; +} + +bool Genotype::HaplotypePtrLess::operator()(const HaplotypePtr& lhs, const HaplotypePtr& rhs) const { return *lhs < *rhs; } -bool Genotype::HaplotypePtrLess::operator()(const Haplotype& lhs, - const HaplotypePtr& rhs) const +bool Genotype::HaplotypePtrLess::operator()(const Haplotype& lhs, const HaplotypePtr& rhs) const { return lhs < *rhs; } -bool Genotype::HaplotypePtrLess::operator()(const HaplotypePtr& lhs, - const Haplotype& rhs) const +bool Genotype::HaplotypePtrLess::operator()(const HaplotypePtr& lhs, const Haplotype& rhs) const { return *lhs < rhs; } -bool Genotype::HaplotypePtrEqual::operator()(const HaplotypePtr& lhs, - const HaplotypePtr& rhs) const +bool Genotype::HaplotypePtrEqual::operator()(const HaplotypePtr& lhs, const HaplotypePtr& rhs) const { return *lhs == *rhs; } -bool Genotype::HaplotypePtrEqual::operator()(const Haplotype& lhs, - const HaplotypePtr& rhs) const +bool Genotype::HaplotypePtrEqual::operator()(const Haplotype& lhs, const HaplotypePtr& rhs) const { return lhs == *rhs; } -bool Genotype::HaplotypePtrEqual::operator()(const HaplotypePtr& lhs, - const Haplotype& rhs) const +bool Genotype::HaplotypePtrEqual::operator()(const HaplotypePtr& lhs, const Haplotype& rhs) const { return *lhs == rhs; } @@ -207,8 +204,22 @@ bool is_homozygous(const Genotype& genotype, const Allele& allele) std::size_t num_genotypes(const unsigned num_elements, const unsigned ploidy) { - return static_cast(boost::math::binomial_coefficient(num_elements + ploidy - 1, - num_elements - 1)); + return boost::math::binomial_coefficient(num_elements + ploidy - 1, num_elements - 1); +} + +std::size_t max_num_elements(const std::size_t num_genotypes, const unsigned ploidy) +{ + if (num_genotypes == 0 || ploidy == 0) return 0; + auto y = maths::factorial(ploidy); + if (y >= num_genotypes) return 1; + const auto t = num_genotypes * y; + unsigned j {1}; + for (; j < num_genotypes; ++j) { + y /= j; + y *= j + ploidy; + if (y >= t) break; + } + return j + 1; } std::size_t element_cardinality_in_genotypes(const unsigned num_elements, const unsigned ploidy) @@ -223,14 +234,17 @@ generate_all_genotypes(const std::vector>& haplotypes } namespace debug { - void print_alleles(const Genotype& genotype) - { - print_alleles(std::cout, genotype); - } - void print_variant_alleles(const Genotype& genotype) - { - print_variant_alleles(std::cout, genotype); - } +void print_alleles(const Genotype& genotype) +{ + print_alleles(std::cout, genotype); +} + +void print_variant_alleles(const Genotype& genotype) +{ + print_variant_alleles(std::cout, genotype); +} + } // namespace debug + } // namespace octopus diff --git a/src/core/types/genotype.hpp b/src/core/types/genotype.hpp index af96f7cb8..1ea0e0312 100644 --- a/src/core/types/genotype.hpp +++ b/src/core/types/genotype.hpp @@ -1,10 +1,11 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef genotype_hpp #define genotype_hpp #include +#include #include #include #include @@ -122,6 +123,7 @@ class Genotype : public Equitable>, public Mappab std::vector copy_unique() const; std::vector> copy_unique_ref() const; + std::vector unique_counts() const; private: using HaplotypePtr = std::shared_ptr; @@ -203,7 +205,7 @@ const MappableType& Genotype::operator[](const unsigned n) const template unsigned Genotype::ploidy() const noexcept { - return static_cast(elements_.size()); + return elements_.size(); } template @@ -221,7 +223,7 @@ unsigned Genotype::zygosity() const } else if (ploidy() == 2) { return 2; } - return static_cast(copy_unique().size()); + return copy_unique().size(); } template @@ -233,7 +235,7 @@ bool Genotype::contains(const MappableType& element) const template unsigned Genotype::count(const MappableType& element) const { - return static_cast(std::count(std::cbegin(elements_), std::cend(elements_), element)); + return std::count(std::cbegin(elements_), std::cend(elements_), element); } template @@ -559,8 +561,38 @@ struct GenotypeHash }; std::size_t num_genotypes(unsigned num_elements, unsigned ploidy); +std::size_t max_num_elements(std::size_t num_genotypes, unsigned ploidy); std::size_t element_cardinality_in_genotypes(unsigned num_elements, unsigned ploidy); +template +unsigned count_shared(const Genotype& lhs, const Genotype& rhs) +{ + if (lhs.ploidy() <= rhs.ploidy()) { + if (lhs.ploidy() < 2 || (lhs.ploidy() == 2 && !lhs.is_homozygous())) { + return std::count_if(std::cbegin(lhs), std::cend(lhs), + [&rhs] (const auto& element) { return rhs.contains(element); }); + } else { + const auto& lhs_unique = lhs.copy_unique_ref(); + return std::count_if(std::cbegin(lhs_unique), std::cend(lhs_unique), + [&rhs] (const auto& element) { return rhs.contains(element); }); + } + } else { + return count_shared(rhs, lhs); + } +} + +template +bool have_shared(const Genotype& lhs, const Genotype& rhs) +{ + if (lhs.ploidy() <= rhs.ploidy()) { + return std::any_of(std::cbegin(lhs), std::cend(lhs), [&rhs] (const auto& element) { return rhs.contains(element); }); + } else { + return have_shared(rhs, lhs); + } +} + +using GenotypeIndex = std::vector; + namespace detail { namespace { @@ -654,7 +686,7 @@ auto generate_all_triploid_biallelic_genotypes(const Container& elements) } template -auto generate_genotype(const Container& elements, const std::vector& element_indicies) +auto generate_genotype(const Container& elements, const GenotypeIndex& element_indicies) { GenotypeType result{static_cast(element_indicies.size())}; for (const auto i : element_indicies) { @@ -716,7 +748,7 @@ auto do_generate_all_genotypes(const Container& elements, const unsigned ploidy) template auto do_generate_all_genotypes(const Container& elements, const unsigned ploidy, - std::vector>& indices) + std::vector& indices) { using GenotypeTp = GenotypeType; using ResultType = std::vector; @@ -728,7 +760,7 @@ auto do_generate_all_genotypes(const Container& elements, const unsigned ploidy, const auto result_size = num_genotypes(num_elements, ploidy); result.reserve(result_size); indices.reserve(result_size); - std::vector element_indicies(ploidy, 0); + GenotypeIndex element_indicies(ploidy, 0); while (true) { if (element_indicies[0] == num_elements) { unsigned i {0}; @@ -744,6 +776,54 @@ auto do_generate_all_genotypes(const Container& elements, const unsigned ploidy, return result; } +template +auto do_generate_all_genotypes(const Container& elements, const unsigned ploidy, + UnaryPredicate pred, OutputIterator result_itr) +{ + if (ploidy == 0 || elements.empty()) return result_itr; + const auto num_elements = static_cast(elements.size()); + std::vector element_indicies(ploidy, 0); + while (true) { + if (element_indicies[0] == num_elements) { + unsigned i {0}; + while (++i < ploidy && element_indicies[i] == num_elements - 1); + if (i == ploidy) break; + ++element_indicies[i]; + std::fill_n(std::begin(element_indicies), i + 1, element_indicies[i]); + } + auto genotype = detail::generate_genotype(elements, element_indicies); + if (pred(genotype)) *result_itr++ = std::move(genotype); + ++element_indicies[0]; + } + return result_itr; +} + +template +auto do_generate_all_genotypes(const Container& elements, const unsigned ploidy, + UnaryPredicate pred, OutputIterator result_itr, + std::vector& indices) +{ + if (ploidy == 0 || elements.empty()) return result_itr; + const auto num_elements = static_cast(elements.size()); + std::vector element_indicies(ploidy, 0); + while (true) { + if (element_indicies[0] == num_elements) { + unsigned i {0}; + while (++i < ploidy && element_indicies[i] == num_elements - 1); + if (i == ploidy) break; + ++element_indicies[i]; + std::fill_n(std::begin(element_indicies), i + 1, element_indicies[i]); + } + auto genotype = detail::generate_genotype(elements, element_indicies); + if (pred(genotype)) { + *result_itr++ = std::move(genotype); + indices.push_back(element_indicies); + } + ++element_indicies[0]; + } + return result_itr; +} + template struct RequiresSharedMemory : public std::is_same {}; @@ -753,9 +833,7 @@ auto generate_all_genotypes(const std::vector& elements, const uns { std::vector> temp_pointers(elements.size()); std::transform(std::cbegin(elements), std::cend(elements), std::begin(temp_pointers), - [] (const auto& element) { - return std::make_shared(element); - }); + [] (const auto& element) { return std::make_shared(element); }); return do_generate_all_genotypes(temp_pointers, ploidy); } @@ -766,21 +844,17 @@ auto generate_all_genotypes(const std::vector> temp_pointers(elements.size()); std::transform(std::cbegin(elements), std::cend(elements), std::begin(temp_pointers), - [] (const auto& element) { - return std::make_shared(element.get()); - }); + [] (const auto& element) { return std::make_shared(element.get()); }); return do_generate_all_genotypes(temp_pointers, ploidy); } template auto generate_all_genotypes(const std::vector& elements, const unsigned ploidy, - std::vector>& indices, std::true_type) + std::vector& indices, std::true_type) { std::vector> temp_pointers(elements.size()); std::transform(std::cbegin(elements), std::cend(elements), std::begin(temp_pointers), - [] (const auto& element) { - return std::make_shared(element); - }); + [] (const auto& element) { return std::make_shared(element); }); return do_generate_all_genotypes(temp_pointers, ploidy, indices); } @@ -793,11 +867,53 @@ auto generate_all_genotypes(const std::vector& elements, const uns template auto generate_all_genotypes(const std::vector& elements, const unsigned ploidy, - std::vector>& indices, std::false_type) + std::vector& indices, std::false_type) { return do_generate_all_genotypes(elements, ploidy, indices); } +template +OutputIterator +generate_all_genotypes(const std::vector& elements, const unsigned ploidy, + UnaryPredicate selector, OutputIterator result_itr, + std::true_type) +{ + std::vector> temp_pointers(elements.size()); + std::transform(std::cbegin(elements), std::cend(elements), std::begin(temp_pointers), + [] (const auto& element) { return std::make_shared(element); }); + return do_generate_all_genotypes(temp_pointers, ploidy, selector, result_itr); +} + +template +OutputIterator +generate_all_genotypes(const std::vector& elements, const unsigned ploidy, + UnaryPredicate selector, OutputIterator result_itr, std::vector& indices, + std::true_type) +{ + std::vector> temp_pointers(elements.size()); + std::transform(std::cbegin(elements), std::cend(elements), std::begin(temp_pointers), + [] (const auto& element) { return std::make_shared(element); }); + return do_generate_all_genotypes(temp_pointers, ploidy, selector, result_itr, indices); +} + +template +OutputIterator +generate_all_genotypes(const std::vector& elements, const unsigned ploidy, + UnaryPredicate selector, OutputIterator result_itr, + std::false_type) +{ + return do_generate_all_genotypes(elements, ploidy, selector, result_itr); +} + +template +OutputIterator +generate_all_genotypes(const std::vector& elements, const unsigned ploidy, + UnaryPredicate selector, OutputIterator result_itr, std::vector& indices, + std::false_type) +{ + return do_generate_all_genotypes(elements, ploidy, selector, result_itr, indices); +} + } // namespace detail template @@ -810,11 +926,27 @@ generate_all_genotypes(const std::vector& elements, const unsigned template std::vector> generate_all_genotypes(const std::vector& elements, const unsigned ploidy, - std::vector>& indices) + std::vector& indices) { return detail::generate_all_genotypes(elements, ploidy, indices, detail::RequiresSharedMemory {}); } +template +OutputIterator +generate_all_genotypes(const std::vector& elements, const unsigned ploidy, + UnaryPredicate selector, OutputIterator result_itr) +{ + return detail::generate_all_genotypes(elements, ploidy, selector, result_itr, detail::RequiresSharedMemory {}); +} + +template +OutputIterator +generate_all_genotypes(const std::vector& elements, const unsigned ploidy, + UnaryPredicate selector, OutputIterator result_itr, std::vector& indices) +{ + return detail::generate_all_genotypes(elements, ploidy, selector, result_itr, indices, detail::RequiresSharedMemory {}); +} + template std::vector> generate_all_genotypes(const std::vector>& elements, @@ -826,6 +958,29 @@ generate_all_genotypes(const std::vector> generate_all_genotypes(const std::vector>& haplotypes, unsigned ploidy); +template +std::vector> +generate_all_full_rank_genotypes(const std::vector& elements, const unsigned ploidy) +{ + if (elements.size() < ploidy) return {}; + std::deque> tmp {}; + generate_all_genotypes(elements, ploidy, [ploidy] (const auto& genotype) { return genotype.zygosity() == ploidy; }, + std::back_inserter(tmp)); + return {std::make_move_iterator(std::begin(tmp)), std::make_move_iterator(std::end(tmp))}; +} + +template +std::vector> +generate_all_full_rank_genotypes(const std::vector& elements, const unsigned ploidy, + std::vector& indices) +{ + if (elements.size() < ploidy) return {}; + std::deque> tmp {}; + generate_all_genotypes(elements, ploidy, [ploidy] (const auto& genotype) { return genotype.zygosity() == ploidy; }, + std::back_inserter(tmp), indices); + return {std::make_move_iterator(std::begin(tmp)), std::make_move_iterator(std::end(tmp))}; +} + namespace detail { inline std::size_t estimate_num_elements(const std::size_t num_genotypes) diff --git a/src/core/types/haplotype.cpp b/src/core/types/haplotype.cpp index 5700e8e7d..fa5c61400 100644 --- a/src/core/types/haplotype.cpp +++ b/src/core/types/haplotype.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "haplotype.hpp" @@ -37,18 +37,19 @@ const GenomicRegion& Haplotype::mapped_region() const } namespace { - template - BidirIt binary_find(BidirIt first, BidirIt last, const T& value) - { - const auto it = std::lower_bound(first, last, value); - return (it != last && *it == value) ? it : last; - } + +template +BidirIt binary_find(BidirIt first, BidirIt last, const T& value) +{ + const auto itr = std::lower_bound(first, last, value); + return (itr != last && *itr == value) ? itr : last; } +} // namespace + bool Haplotype::contains(const ContigAllele& allele) const { using octopus::contains; using std::cbegin; using std::cend; - if (contains(region_.contig_region(), allele)) { if (begins_before(allele, explicit_allele_region_)) { if (is_before(allele, explicit_allele_region_)) { @@ -68,10 +69,10 @@ bool Haplotype::contains(const ContigAllele& allele) const return false; } } - const auto it = binary_find(cbegin(explicit_alleles_), cend(explicit_alleles_), allele.mapped_region()); - if (it != cend(explicit_alleles_)) { - if (*it == allele) return true; - if (is_same_region(*it, allele)) { + const auto match_itr = binary_find(cbegin(explicit_alleles_), cend(explicit_alleles_), allele.mapped_region()); + if (match_itr != cend(explicit_alleles_)) { + if (*match_itr == allele) return true; + if (is_same_region(*match_itr, allele)) { // If the allele is not explcitly contained but the region is then it must be a different // allele, unless it is an insertion, in which case we must check the sequence if (is_insertion(allele)) { @@ -87,7 +88,6 @@ bool Haplotype::contains(const ContigAllele& allele) const } return sequence(allele.mapped_region()) == allele.sequence(); } - return false; } @@ -100,18 +100,24 @@ bool Haplotype::contains(const Allele& allele) const bool Haplotype::includes(const ContigAllele& allele) const { using octopus::contains; - const auto& this_region = region_.contig_region(); - if (!contains(this_region, allele)) { - return false; - } - if (contains(explicit_allele_region_, allele)) { - return std::binary_search(std::cbegin(explicit_alleles_), std::cend(explicit_alleles_), allele); - } - if (overlaps(explicit_allele_region_, allele) || is_indel(allele)) { + if (!contains(region_.contig_region(), allele)) { return false; + } else if (!explicit_alleles_.empty()) { + if (contains(explicit_allele_region_, allele)) { + return std::binary_search(std::cbegin(explicit_alleles_), std::cend(explicit_alleles_), allele); + } else if (overlaps(explicit_allele_region_, allele)) { + return false; + } else if (is_after(allele, explicit_allele_region_)) { + if (is_indel(allele)) return false; + const auto ref_ritr = std::next(std::crbegin(sequence_), end_distance(allele, region_.contig_region())); + assert(static_cast(std::distance(ref_ritr, std::crend(sequence_))) >= allele.sequence().size()); + return std::equal(std::crbegin(allele.sequence()), std::crend(allele.sequence()), ref_ritr); + } } - return std::equal(std::cbegin(allele.sequence()), std::cend(allele.sequence()), - std::next(std::cbegin(sequence_), begin_distance(this_region, allele))); + if (is_indel(allele)) return false; + const auto ref_itr = std::next(std::cbegin(sequence_), begin_distance(region_.contig_region(), allele)); + assert(static_cast(std::distance(ref_itr, std::cend(sequence_))) >= allele.sequence().size()); + return std::equal(std::cbegin(allele.sequence()), std::cend(allele.sequence()), ref_itr); } bool Haplotype::includes(const Allele& allele) const @@ -130,9 +136,9 @@ bool is_in_reference_flank(const ContigRegion& region, const ContigRegion& expli return true; } if (begins_before(region, explicit_allele_region_)) { - return !is_insertion(explicit_alleles.front()); + return !is_simple_insertion(explicit_alleles.front()); } - return !is_insertion(explicit_alleles.back()); + return !is_simple_insertion(explicit_alleles.back()); } Haplotype::NucleotideSequence Haplotype::sequence(const ContigRegion& region) const @@ -246,37 +252,17 @@ CigarString Haplotype::cigar() const } auto allele_op_flag = curr_op_flag; CigarOperation::Size allele_op_size {0}; - if (is_insertion(allele)) { - if (is_empty_region(allele)) { + if (is_indel(allele)) { + if (is_simple_insertion(allele)) { allele_op_flag = Flag::insertion; allele_op_size += allele.sequence().size(); - } else { - const auto insertion_size = allele.sequence().size() - region_size(allele); - if (curr_op_flag == Flag::insertion) { - curr_op_size += insertion_size; - result.emplace_back(curr_op_size, curr_op_flag); - curr_op_flag = Flag::deletion; - curr_op_size = region_size(allele); - } else if (curr_op_flag == Flag::deletion) { - curr_op_size += region_size(allele); - result.emplace_back(curr_op_size, curr_op_flag); - curr_op_flag = Flag::insertion; - curr_op_size = insertion_size; - } else { - result.emplace_back(curr_op_size, curr_op_flag); - result.emplace_back(region_size(allele), Flag::deletion); - curr_op_flag = Flag::insertion; - curr_op_size = insertion_size; - } - } - } else if (is_deletion(allele)) { - if (is_sequence_empty(allele)) { + } else if (is_simple_deletion(allele)) { allele_op_flag = Flag::deletion; allele_op_size += region_size(allele); } else { - const auto deletion_size = region_size(allele) - allele.sequence().size(); + // all complex indels are treated as replacements if (curr_op_flag == Flag::deletion) { - curr_op_size += deletion_size; + curr_op_size += region_size(allele); result.emplace_back(curr_op_size, curr_op_flag); curr_op_flag = Flag::insertion; curr_op_size = allele.sequence().size(); @@ -284,12 +270,14 @@ CigarString Haplotype::cigar() const curr_op_size += allele.sequence().size(); result.emplace_back(curr_op_size, curr_op_flag); curr_op_flag = Flag::deletion; - curr_op_size = deletion_size; + curr_op_size = region_size(allele); } else { - result.emplace_back(curr_op_size, curr_op_flag); - result.emplace_back(allele.sequence().size(), Flag::insertion); - curr_op_flag = Flag::deletion; - curr_op_size = deletion_size; + if (curr_op_size > 0) { + result.emplace_back(curr_op_size, curr_op_flag); + } + result.emplace_back(region_size(allele), Flag::deletion); + curr_op_flag = Flag::insertion; + curr_op_size = allele.sequence().size(); } } } else if (!is_empty_region(allele)) { @@ -665,7 +653,7 @@ bool have_same_alleles(const Haplotype& lhs, const Haplotype& rhs) IsLessComplex::IsLessComplex(boost::optional reference) : reference_ {std::move(reference)} {} -bool IsLessComplex::operator()(const Haplotype& lhs, const Haplotype& rhs) const noexcept +bool IsLessComplex::operator()(const Haplotype& lhs, const Haplotype& rhs) const { if (lhs.explicit_alleles_.size() != rhs.explicit_alleles_.size()) { return lhs.explicit_alleles_.size() < rhs.explicit_alleles_.size(); @@ -674,10 +662,8 @@ bool IsLessComplex::operator()(const Haplotype& lhs, const Haplotype& rhs) const return lhs.difference(*reference_).size() < rhs.difference(*reference_).size(); } // otherwise prefer the sequence with the least amount of indels - auto score = std::inner_product(std::cbegin(lhs.explicit_alleles_), - std::cend(lhs.explicit_alleles_), - std::cbegin(rhs.explicit_alleles_), 0, - std::plus {}, + auto score = std::inner_product(std::cbegin(lhs.explicit_alleles_), std::cend(lhs.explicit_alleles_), + std::cbegin(rhs.explicit_alleles_), 0, std::plus<> {}, [] (const auto& lhs, const auto& rhs) { if (lhs == rhs) { return 0; @@ -696,6 +682,17 @@ bool IsLessComplex::operator()(const Haplotype& lhs, const Haplotype& rhs) const return score >= 0; } +unsigned remove_duplicates(std::vector& haplotypes) +{ + return remove_duplicates(haplotypes, IsLessComplex {}); +} + +unsigned remove_duplicates(std::vector& haplotypes, Haplotype reference) +{ + IsLessComplex cmp {std::move(reference)}; + return remove_duplicates(haplotypes, cmp); +} + unsigned unique_least_complex(std::vector& haplotypes, boost::optional reference) { using std::begin; using std::end; diff --git a/src/core/types/haplotype.hpp b/src/core/types/haplotype.hpp index 22a319c06..885cfb774 100644 --- a/src/core/types/haplotype.hpp +++ b/src/core/types/haplotype.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef haplotype_hpp @@ -300,15 +300,43 @@ struct IsLessComplex { IsLessComplex() = default; explicit IsLessComplex(boost::optional reference); - bool operator()(const Haplotype& lhs, const Haplotype& rhs) const noexcept; + bool operator()(const Haplotype& lhs, const Haplotype& rhs) const; private: boost::optional reference_; }; -// Erases all duplicates haplotypes (w.r.t operator==) keeping the duplicate which is -// considered least complex w.r.t IsLessComplex. -// The optional Haplotype arguement can be used as a basis for the complexity comparison. -unsigned unique_least_complex(std::vector& haplotypes, boost::optional = boost::none); +// Removes all duplicates haplotypes (w.r.t operator==) keeping the duplicate which is considered least complex w.r.t cmp. +template +unsigned remove_duplicates(std::vector& haplotypes, const Cmp& cmp) +{ + using std::begin; using std::end; + std::sort(begin(haplotypes), end(haplotypes)); + auto first_dup_itr = begin(haplotypes); + const auto last_itr = end(haplotypes); + while (true) { + first_dup_itr = std::adjacent_find(first_dup_itr, last_itr); + if (first_dup_itr == last_itr) break; + auto dup_keep_itr = (cmp(*first_dup_itr, *std::next(first_dup_itr))) ? first_dup_itr : std::next(first_dup_itr); + auto last_dup_itr = std::next(first_dup_itr, 2); + for (; last_dup_itr != last_itr; ++last_dup_itr) { + if (*last_dup_itr != *first_dup_itr) { + break; + } + if (cmp(*last_dup_itr, *dup_keep_itr)) { + dup_keep_itr = last_dup_itr; + } + } + std::iter_swap(first_dup_itr, dup_keep_itr); + first_dup_itr = last_dup_itr; + } + const auto last_keep_itr = std::unique(begin(haplotypes), end(haplotypes)); + const auto result = std::distance(last_keep_itr, end(haplotypes)); + haplotypes.erase(last_keep_itr, last_itr); + return static_cast(result); +} + +unsigned remove_duplicates(std::vector& haplotypes); +unsigned remove_duplicates(std::vector& haplotypes, Haplotype reference); std::ostream& operator<<(std::ostream& os, const Haplotype& haplotype); diff --git a/src/core/types/variant.cpp b/src/core/types/variant.cpp index 9a28c1aad..78ce5b116 100644 --- a/src/core/types/variant.cpp +++ b/src/core/types/variant.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "variant.hpp" @@ -497,6 +497,12 @@ bool is_transversion(const Variant& variant) noexcept return is_snv(variant) && !is_transition(variant); } +Variant::NucleotideSequence::size_type indel_size(const Variant& variant) noexcept +{ + const auto p = std::minmax({ref_sequence_size(variant), alt_sequence_size(variant)}); + return p.second - p.first; +} + std::vector extract_alt_allele_sequences(const std::vector& variants) { std::vector result {}; diff --git a/src/core/types/variant.hpp b/src/core/types/variant.hpp index bdd997455..20b380764 100644 --- a/src/core/types/variant.hpp +++ b/src/core/types/variant.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef variant_hpp @@ -263,9 +263,9 @@ bool is_simple_insertion(const Variant& variant) noexcept; bool is_simple_deletion(const Variant& variant) noexcept; bool is_simple_indel(const Variant& variant) noexcept; bool are_same_type(const Variant& lhs, const Variant& rhs) noexcept; - bool is_transition(const Variant& variant) noexcept; bool is_transversion(const Variant& variant) noexcept; +Variant::NucleotideSequence::size_type indel_size(const Variant& variant) noexcept; std::vector extract_alt_allele_sequences(const std::vector& variants); diff --git a/src/exceptions/error.cpp b/src/exceptions/error.cpp index 4b4a96593..83bb5ed22 100644 --- a/src/exceptions/error.cpp +++ b/src/exceptions/error.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "error.hpp" diff --git a/src/exceptions/error.hpp b/src/exceptions/error.hpp index 7318e19bf..ca04e3fe1 100644 --- a/src/exceptions/error.hpp +++ b/src/exceptions/error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef error_hpp diff --git a/src/exceptions/malformed_file_error.cpp b/src/exceptions/malformed_file_error.cpp index a15c719df..d7a78cf99 100644 --- a/src/exceptions/malformed_file_error.cpp +++ b/src/exceptions/malformed_file_error.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "malformed_file_error.hpp" diff --git a/src/exceptions/malformed_file_error.hpp b/src/exceptions/malformed_file_error.hpp index 44ad8df91..739159a0c 100644 --- a/src/exceptions/malformed_file_error.hpp +++ b/src/exceptions/malformed_file_error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef malformed_file_error_hpp diff --git a/src/exceptions/missing_file_error.cpp b/src/exceptions/missing_file_error.cpp index c7167ba9e..2feebb66c 100644 --- a/src/exceptions/missing_file_error.cpp +++ b/src/exceptions/missing_file_error.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "missing_file_error.hpp" diff --git a/src/exceptions/missing_file_error.hpp b/src/exceptions/missing_file_error.hpp index f5e5e3585..ee9db146b 100644 --- a/src/exceptions/missing_file_error.hpp +++ b/src/exceptions/missing_file_error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef missing_file_error_hpp diff --git a/src/exceptions/missing_index_error.cpp b/src/exceptions/missing_index_error.cpp index be4f0829d..f019adeea 100644 --- a/src/exceptions/missing_index_error.cpp +++ b/src/exceptions/missing_index_error.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "missing_index_error.hpp" diff --git a/src/exceptions/missing_index_error.hpp b/src/exceptions/missing_index_error.hpp index 41fac2505..8d06633f2 100644 --- a/src/exceptions/missing_index_error.hpp +++ b/src/exceptions/missing_index_error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef missing_index_error_hpp diff --git a/src/exceptions/program_error.hpp b/src/exceptions/program_error.hpp index 46c43de75..078e7b2eb 100644 --- a/src/exceptions/program_error.hpp +++ b/src/exceptions/program_error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include diff --git a/src/exceptions/system_error.hpp b/src/exceptions/system_error.hpp index 448b365c0..2f975c9d7 100644 --- a/src/exceptions/system_error.hpp +++ b/src/exceptions/system_error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef system_error_hpp diff --git a/src/exceptions/unimplemented_feature_error.cpp b/src/exceptions/unimplemented_feature_error.cpp new file mode 100644 index 000000000..eeba27c8d --- /dev/null +++ b/src/exceptions/unimplemented_feature_error.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "unimplemented_feature_error.hpp" + +#include + +namespace octopus { + +UnimplementedFeatureError::UnimplementedFeatureError(std::string feature, std::string where) +: feature_ {std::move(feature)} +, where_ {std::move(where)} +{} + +std::string UnimplementedFeatureError::do_why() const +{ + return feature_ + " is not currently implemented"; +} + +std::string UnimplementedFeatureError::do_help() const +{ + return "submit a feature request"; +} + +std::string UnimplementedFeatureError::do_where() const +{ + return where_; +} + +} // namespace octopus diff --git a/src/exceptions/unimplemented_feature_error.hpp b/src/exceptions/unimplemented_feature_error.hpp new file mode 100644 index 000000000..37e25a268 --- /dev/null +++ b/src/exceptions/unimplemented_feature_error.hpp @@ -0,0 +1,31 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef unimplemented_feature_error_hpp +#define unimplemented_feature_error_hpp + +#include + +#include "program_error.hpp" + +namespace octopus { + +class UnimplementedFeatureError : public ProgramError +{ +public: + UnimplementedFeatureError() = delete; + UnimplementedFeatureError(std::string feature, std::string where); + + virtual ~UnimplementedFeatureError() override = default; + +private: + virtual std::string do_why() const override; + virtual std::string do_help() const override; + virtual std::string do_where() const override; + + std::string feature_, where_; +}; + +} // namespace octopus + +#endif diff --git a/src/exceptions/unwritable_file_error.cpp b/src/exceptions/unwritable_file_error.cpp index 4b4ef4ac8..de8b86c1a 100644 --- a/src/exceptions/unwritable_file_error.cpp +++ b/src/exceptions/unwritable_file_error.cpp @@ -1,10 +1,35 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "unwritable_file_error.hpp" +#include +#include + namespace octopus { +UnwritableFileError::UnwritableFileError(Path file) : file_ {std::move(file)} {} + +UnwritableFileError::UnwritableFileError(Path file, std::string type) +: file_ {std::move(file)} +, type_ {std::move(type)} +{} + +std::string UnwritableFileError::do_why() const +{ + std::ostringstream ss {}; + ss << "the "; + if (type_) { + ss << *type_ << ' '; + } + ss << "file you specified " << file_ << ' '; + ss << "is not writable"; + return ss.str(); +} +std::string UnwritableFileError::do_help() const +{ + return "ensure the specified path is correct and the location is writable (check permissions)"; +} } // namespace octopus diff --git a/src/exceptions/unwritable_file_error.hpp b/src/exceptions/unwritable_file_error.hpp index c31963ff3..1c3e8af7f 100644 --- a/src/exceptions/unwritable_file_error.hpp +++ b/src/exceptions/unwritable_file_error.hpp @@ -1,20 +1,37 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. +#ifndef unwritable_file_error_hpp +#define unwritable_file_error_hpp + #include #include +#include #include "user_error.hpp" -#ifndef unwritable_file_error_hpp -#define unwritable_file_error_hpp - namespace octopus { class UnwritableFileError : public UserError { +public: + using Path = boost::filesystem::path; + + UnwritableFileError() = delete; + + UnwritableFileError(Path file); + + UnwritableFileError(Path file, std::string type); + + virtual ~UnwritableFileError() override = default; + +private: + virtual std::string do_why() const override; + virtual std::string do_help() const override; + Path file_; + boost::optional type_; }; } // namespace octopus diff --git a/src/exceptions/user_error.hpp b/src/exceptions/user_error.hpp index fc5f4b78e..e89959d90 100644 --- a/src/exceptions/user_error.hpp +++ b/src/exceptions/user_error.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef user_error_hpp diff --git a/src/io/pedigree/pedigree_reader.cpp b/src/io/pedigree/pedigree_reader.cpp index 61b1983f7..6ef563ef5 100644 --- a/src/io/pedigree/pedigree_reader.cpp +++ b/src/io/pedigree/pedigree_reader.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "pedigree_reader.hpp" diff --git a/src/io/pedigree/pedigree_reader.hpp b/src/io/pedigree/pedigree_reader.hpp index 4b86ed18d..a8230773e 100644 --- a/src/io/pedigree/pedigree_reader.hpp +++ b/src/io/pedigree/pedigree_reader.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef pedigree_reader_hpp diff --git a/src/io/read/htslib_sam_facade.cpp b/src/io/read/htslib_sam_facade.cpp index 57e6e3216..4bdd29ab3 100644 --- a/src/io/read/htslib_sam_facade.cpp +++ b/src/io/read/htslib_sam_facade.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "htslib_sam_facade.hpp" @@ -12,6 +12,7 @@ #include #include +#include #include "basics/cigar_string.hpp" #include "basics/genomic_region.hpp" @@ -19,6 +20,7 @@ #include "exceptions/missing_file_error.hpp" #include "exceptions/missing_index_error.hpp" #include "exceptions/malformed_file_error.hpp" +#include "exceptions/unwritable_file_error.hpp" namespace octopus { namespace io { @@ -115,6 +117,13 @@ class MalformedCRAMHeader : public MalformedFileError MalformedCRAMHeader(boost::filesystem::path file) : MalformedFileError {std::move(file)} {} }; +class UnwritableBAM : public UnwritableFileError +{ + std::string do_where() const override { return "HtslibSamFacade"; } +public: + UnwritableBAM(boost::filesystem::path file) : UnwritableFileError {std::move(file), "bam"} {} +}; + class InvalidBamRecord : public std::runtime_error { public: @@ -207,6 +216,41 @@ HtslibSamFacade::HtslibSamFacade(Path file_path) std::sort(std::begin(samples_), std::end(samples_)); } +auto open_hts_writable_file(const boost::filesystem::path& path) +{ + std::string mode {"[w]"}; + const auto extension = path.extension(); + if (extension == ".bam") { + mode += "b"; + } + return sam_open(path.c_str(), mode.c_str()); +} + +HtslibSamFacade::HtslibSamFacade(Path sam_out, Path sam_template) +: HtslibSamFacade {std::move(sam_template)} +{ + file_path_ = std::move(sam_out); + hts_file_.reset(open_hts_writable_file(file_path_)); + if (!hts_file_) { + throw UnwritableBAM {std::move(file_path_)}; + } + hts_index_ = nullptr; + if (sam_hdr_write(hts_file_.get(), hts_header_.get()) < 0) { + throw UnwritableBAM {std::move(file_path_)}; + } +} + +HtslibSamFacade::~HtslibSamFacade() +{ + if (!hts_index_) { + hts_header_.reset(nullptr); + hts_file_.reset(nullptr); + if (sam_index_build(file_path_.c_str(), 0) < 0) { + return; + } + } +} + bool HtslibSamFacade::is_open() const noexcept { return hts_file_ != nullptr && hts_header_ != nullptr && hts_index_ != nullptr; @@ -215,7 +259,6 @@ bool HtslibSamFacade::is_open() const noexcept void HtslibSamFacade::open() { hts_file_.reset(sam_open(file_path_.string().c_str(), "r")); - if (hts_file_) { hts_header_.reset(sam_hdr_read(hts_file_.get())); hts_index_.reset(sam_index_load(hts_file_.get(), file_path_.c_str())); @@ -445,6 +488,24 @@ HtslibSamFacade::extract_read_positions(const std::vector& samples, // fetch_reads +namespace { + +template +bool try_reserve(Container& c, const std::size_t max, const std::size_t min) +{ + assert(max >= min); + if (max == 0) return true; + for (auto curr = max; curr >= min; curr /= 2) { + try { + c.reserve(curr); + return true; + } catch (const std::bad_alloc& e) {} + } + return false; +} + +} // namespace + HtslibSamFacade::SampleReadMap HtslibSamFacade::fetch_reads(const GenomicRegion& region) const { SampleReadMap result {samples_.size()}; @@ -454,7 +515,7 @@ HtslibSamFacade::SampleReadMap HtslibSamFacade::fetch_reads(const GenomicRegion& HtslibIterator it {*this, region}; for (const auto& sample : samples_) { auto p = result.emplace(std::piecewise_construct, std::forward_as_tuple(sample), std::forward_as_tuple()); - p.first->second.reserve(defaultReserve_); + try_reserve(p.first->second, defaultReserve_, defaultReserve_ / 10); } while (++it) { try { @@ -475,7 +536,7 @@ HtslibSamFacade::ReadContainer HtslibSamFacade::fetch_reads(const SampleName& sa if (samples_.size() == 1) return fetch_all_reads(region); HtslibIterator it {*this, region}; ReadContainer result {}; - result.reserve(defaultReserve_); + try_reserve(result, defaultReserve_, defaultReserve_ / 10); while (++it) { if (sample_names_.at(it.read_group()) == sample) { try { @@ -504,7 +565,7 @@ HtslibSamFacade::SampleReadMap HtslibSamFacade::fetch_reads(const std::vectorsecond.reserve(defaultReserve_); + try_reserve(p.first->second, defaultReserve_, defaultReserve_ / 10); } } if (result.empty()) return result; // no matching samples @@ -548,13 +609,28 @@ boost::optional> HtslibSamFacade::mapped_ return result; } +void HtslibSamFacade::write(const AlignedRead& read) +{ + if (!hts_file_ || !hts_header_) { + throw UnwritableBAM {file_path_}; + } + std::unique_ptr record {bam_init1(), HtsBam1Deleter {}}; + if (!record) { + throw UnwritableBAM {file_path_}; + } + write(read, record.get()); + if (sam_write1(hts_file_.get(), hts_header_.get(), record.get()) < 0) { + throw UnwritableBAM {file_path_}; + } +} + // private methods HtslibSamFacade::ReadContainer HtslibSamFacade::fetch_all_reads(const GenomicRegion& region) const { HtslibIterator it {*this, region}; ReadContainer result {}; - result.reserve(defaultReserve_); + try_reserve(result, defaultReserve_, defaultReserve_ / 10); while (++it) { try { result.emplace_back(*it); @@ -708,8 +784,7 @@ AlignedRead::NucleotideSequence extract_sequence(const bam1_t* b) const auto hts_sequence = bam_get_seq(b); NucleotideSequence result(sequence_length, 'N'); std::uint32_t i {0}; - std::generate_n(std::begin(result), sequence_length, - [&i, &hts_sequence] () { return extract_base(hts_sequence, i++); }); + std::generate_n(std::begin(result), sequence_length, [&] () { return extract_base(hts_sequence, i++); }); return result; } @@ -732,10 +807,8 @@ CigarString extract_cigar_string(const bam1_t* b) CigarString result(cigar_length); std::transform(cigar_operations, cigar_operations + cigar_length, std::begin(result), [] (const auto op) noexcept { - return CigarOperation { - static_cast(bam_cigar_oplen(op)), - static_cast(bam_cigar_opchr(op)) - }; + return CigarOperation {static_cast(bam_cigar_oplen(op)), + static_cast(bam_cigar_opchr(op))}; }); return result; } @@ -761,12 +834,12 @@ auto mapping_quality(const bam1_core_t& c) noexcept { return static_cast(c.qual); } - + bool has_multiple_segments(const bam1_core_t& c) noexcept { return c.mtid != -1; } - + auto next_segment_position(const bam1_core_t& c) noexcept { return static_cast(c.mpos); @@ -776,7 +849,7 @@ auto template_length(const bam1_core_t& c) noexcept { return static_cast(std::abs(c.isize)); } - + auto extract_next_segment_flags(const bam1_core_t& c) noexcept { AlignedRead::Segment::Flags result {}; @@ -788,25 +861,20 @@ auto extract_next_segment_flags(const bam1_core_t& c) noexcept AlignedRead HtslibSamFacade::HtslibIterator::operator*() const { using std::begin; using std::end; using std::next; using std::move; - auto qualities = extract_qualities(hts_bam1_.get()); - if (qualities.empty() || qualities[0] == 0xff) { throw InvalidBamRecord {hts_facade_.file_path_, extract_read_name(hts_bam1_.get()), "corrupt sequence data"}; } - auto cigar = extract_cigar_string(hts_bam1_.get()); const auto& info = hts_bam1_->core; auto read_begin_tmp = clipped_begin(cigar, info.pos); auto sequence = extract_sequence(hts_bam1_.get()); - if (read_begin_tmp < 0) { // Then the read hangs off the left of the contig, and we must remove bases, base_qualities, and // adjust the cigar string as we cannot have a negative begin position const auto overhang_size = static_cast(std::abs(read_begin_tmp)); sequence.erase(begin(sequence), next(begin(sequence), overhang_size)); qualities.erase(begin(qualities), next(begin(qualities), overhang_size)); - auto soft_clip_size = cigar.front().size(); if (overhang_size == soft_clip_size) { cigar.erase(begin(cigar)); @@ -815,23 +883,18 @@ AlignedRead HtslibSamFacade::HtslibIterator::operator*() const } read_begin_tmp = 0; } - const auto read_begin = static_cast(read_begin_tmp); const auto& contig_name = hts_facade_.get_contig_name(info.tid); - if (has_multiple_segments(info)) { return AlignedRead { extract_read_name(hts_bam1_.get()), - GenomicRegion { - contig_name, - read_begin, - read_begin + octopus::reference_size(cigar) - }, + GenomicRegion {contig_name, read_begin, read_begin + octopus::reference_size(cigar)}, move(sequence), move(qualities), move(cigar), mapping_quality(info), extract_flags(info), + read_group(), hts_facade_.get_contig_name(info.mtid), next_segment_position(info), template_length(info), @@ -840,16 +903,13 @@ AlignedRead HtslibSamFacade::HtslibIterator::operator*() const } else { return AlignedRead { extract_read_name(hts_bam1_.get()), - GenomicRegion { - contig_name, - read_begin, - read_begin + octopus::reference_size(cigar) - }, + GenomicRegion {contig_name, read_begin, read_begin + octopus::reference_size(cigar)}, move(sequence), move(qualities), move(cigar), mapping_quality(info), - extract_flags(info) + extract_flags(info), + read_group() }; } } @@ -875,9 +935,7 @@ bool HtslibSamFacade::HtslibIterator::is_good() const noexcept if (cigar_length == 0) return false; const auto cigar_operations = bam_get_cigar(hts_bam1_.get()); return std::all_of(cigar_operations, cigar_operations + cigar_length, - [] (const auto op) { - return bam_cigar_oplen(op) > 0; - }); + [] (const auto op) { return bam_cigar_oplen(op) > 0; }); } std::size_t HtslibSamFacade::HtslibIterator::begin() const noexcept @@ -885,5 +943,177 @@ std::size_t HtslibSamFacade::HtslibIterator::begin() const noexcept return hts_bam1_->core.pos; } +namespace { + +void set_contig(const std::int32_t tid, bam1_t* result) noexcept +{ + result->core.tid = tid; +} + +void set_pos(const AlignedRead& read, bam1_t* result) noexcept +{ + result->core.pos = mapped_begin(read); + if (is_front_soft_clipped(read)) { + result->core.pos += get_soft_clipped_sizes(read).first; + } +} + +void set_mapping_quality(const AlignedRead& read, bam1_t* result) noexcept +{ + result->core.qual = read.mapping_quality(); +} + +void set_flag(bool set, std::uint16_t mask, std::uint16_t& result) noexcept +{ + constexpr std::uint16_t zeros {0}, ones = -1; + result |= (set ? ones : zeros) & mask; +} + +void set_flags(const AlignedRead& read, bam1_t* result) noexcept +{ + const auto flags = read.flags(); + auto& bitset = result->core.flag; + set_flag(flags.multiple_segment_template, BAM_FPAIRED, bitset); + set_flag(flags.all_segments_in_read_aligned, BAM_FPROPER_PAIR, bitset); + set_flag(flags.unmapped, BAM_FUNMAP, bitset); + set_flag(flags.reverse_mapped, BAM_FREVERSE, bitset); + set_flag(flags.secondary_alignment, BAM_FSECONDARY, bitset); + set_flag(flags.qc_fail, BAM_FQCFAIL, bitset); + set_flag(flags.duplicate, BAM_FDUP, bitset); + set_flag(flags.supplementary_alignment, BAM_FSUPPLEMENTARY, bitset); + set_flag(flags.first_template_segment, BAM_FREAD1, bitset); + set_flag(flags.last_template_segment, BAM_FREAD2, bitset); +} + +void set_segment(const AlignedRead& read, const std::int32_t tid, bam1_t* result) +{ + const auto& segment = read.next_segment(); + result->core.mtid = tid; + result->core.mpos = segment.begin(); + result->core.isize = static_cast(segment.inferred_template_length()); + if (segment.begin() < mapped_begin(read)) { + result->core.isize *= -1; + } + set_flag(segment.is_marked_unmapped(), BAM_FMUNMAP, result->core.flag); + set_flag(segment.is_marked_reverse_mapped(), BAM_FMREVERSE, result->core.flag); +} + +void init_variable_length_data(const AlignedRead& read, bam1_t* result) +{ + result->m_data = 0; + result->m_data += read.name().size() + 1 + read.name().size() % 4; + result->m_data += (read.sequence().size() + 1) / 2; // 4 bits per base + result->m_data += read.base_qualities().size(); + result->m_data += 4 * read.cigar().size(); + if (!read.read_group().empty()) { + result->m_data += read.read_group().size() + readGroupTag.size() + 2; // 1 for tag, 1 for '\0' + } + result->l_data = static_cast(result->m_data); + result->data = (std::uint8_t*) std::realloc(result->data, result->m_data); + std::fill_n(result->data, result->m_data, 0); +} + +void set_name(const AlignedRead& read, bam1_t* result) +{ + const auto& name = read.name(); + std::copy(std::cbegin(name), std::cend(name), result->data); + result->core.l_extranul = name.size() % 4; + result->core.l_qname = name.size() + result->core.l_extranul + 1; +} + +void set_cigar(const AlignedRead& read, bam1_t* result) noexcept +{ + const auto& cigar = read.cigar(); + result->core.n_cigar = cigar.size(); + std::transform(std::cbegin(cigar), std::cend(cigar), bam_get_cigar(result), + [] (const CigarOperation& op) noexcept -> std::uint32_t { + // Lower 4 bits for CIGAR operation and the higher 28 bits for size + std::uint32_t result = op.size(); + result <<= BAM_CIGAR_SHIFT; + using Flag = CigarOperation::Flag; + switch (op.flag()) { + case Flag::alignmentMatch: result |= BAM_CMATCH; break; + case Flag::insertion: result |= BAM_CINS; break; + case Flag::deletion: result |= BAM_CDEL; break; + case Flag::skipped: result |= BAM_CREF_SKIP; break; + case Flag::softClipped: result |= BAM_CSOFT_CLIP; break; + case Flag::hardClipped: result |= BAM_CHARD_CLIP; break; + case Flag::padding: result |= BAM_CPAD; break; + case Flag::sequenceMatch: result |= BAM_CEQUAL; break; + case Flag::substitution: result |= BAM_CDIFF; break; + } + return result; + }); +} + +static constexpr std::array sam_bases +{ +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 15,0, +0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +}; + +void set_read_sequence(const AlignedRead& read, bam1_t* result) noexcept +{ + const auto& sequence = read.sequence(); + result->core.l_qseq = sequence.size(); + // Each base is encoded in 4 bits: 1 for A, 2 for C, 4 for G, + // 8 for T and 15 for N. Two bases are packed in one byte with the base + // at the higher 4 bits having smaller coordinate on the read. + auto bam_seq_itr = bam_get_seq(result); + for (std::size_t i {1}; i < sequence.size(); i += 2, ++bam_seq_itr) { + *bam_seq_itr = sam_bases[sequence[i - 1]] << 4 | sam_bases[sequence[i]]; + } + if (sequence.size() % 2 == 1) { + *bam_seq_itr = sam_bases[sequence.back()] << 4; + } +} + +void set_base_qualities(const AlignedRead& read, bam1_t* result) noexcept +{ + std::copy(std::cbegin(read.base_qualities()), std::cend(read.base_qualities()), bam_get_qual(result)); +} + +void set_aux(const AlignedRead& read, bam1_t* result) noexcept +{ + const auto& rg = read.read_group(); + if (!rg.empty()) { + auto aux_itr = std::copy(std::cbegin(readGroupTag), std::cend(readGroupTag), bam_get_aux(result)); + *aux_itr++ = 'Z'; + std::copy(std::cbegin(rg), std::cend(rg), aux_itr); + } +} + +void set_variable_length_data(const AlignedRead& read, bam1_t* result) +{ + init_variable_length_data(read, result); + set_name(read, result); + set_cigar(read, result); + set_read_sequence(read, result); + set_base_qualities(read, result); + set_aux(read, result); +} + +} // namespace + +void HtslibSamFacade::write(const AlignedRead& read, bam1_t* result) const +{ + set_contig(hts_targets_.at(contig_name(read)), result); + set_pos(read, result); + set_mapping_quality(read, result); + set_flags(read, result); + if (read.has_other_segment()) { + set_segment(read, hts_targets_.at(read.next_segment().contig_name()), result); + } else { + result->core.mtid = '*'; + } + set_variable_length_data(read, result); +} + } // namespace io } // namespace octopus diff --git a/src/io/read/htslib_sam_facade.hpp b/src/io/read/htslib_sam_facade.hpp index 324ed7507..4358defa7 100644 --- a/src/io/read/htslib_sam_facade.hpp +++ b/src/io/read/htslib_sam_facade.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef htslib_sam_facade_hpp @@ -43,14 +43,16 @@ class HtslibSamFacade : public IReadReaderImpl using ReadGroupIdType = std::string; HtslibSamFacade() = delete; + HtslibSamFacade(Path file_path); + HtslibSamFacade(Path sam_out, Path sam_template); HtslibSamFacade(const HtslibSamFacade&) = delete; HtslibSamFacade& operator=(const HtslibSamFacade&) = delete; HtslibSamFacade(HtslibSamFacade&&) = default; HtslibSamFacade& operator=(HtslibSamFacade&&) = default; - ~HtslibSamFacade() override = default; + ~HtslibSamFacade() override; bool is_open() const noexcept override; void open() override; @@ -90,11 +92,30 @@ class HtslibSamFacade : public IReadReaderImpl std::vector reference_contigs() const override; boost::optional> mapped_contigs() const override; + void write(const AlignedRead& read); + private: using HtsTid = std::int32_t; static constexpr std::size_t defaultReserve_ {10'000'000}; + struct HtsFileDeleter + { + void operator()(htsFile* file) const { hts_close(file); } + }; + struct HtsHeaderDeleter + { + void operator()(bam_hdr_t* header) const { bam_hdr_destroy(header); } + }; + struct HtsIndexDeleter + { + void operator()(hts_idx_t* index) const { hts_idx_destroy(index); } + }; + struct HtsBam1Deleter + { + void operator()(bam1_t* b) const { bam_destroy1(b); } + }; + class HtslibIterator { public: @@ -115,7 +136,7 @@ class HtslibSamFacade : public IReadReaderImpl bool is_good() const noexcept; std::size_t begin() const noexcept; - + private: struct HtsIteratorDeleter { @@ -132,19 +153,6 @@ class HtslibSamFacade : public IReadReaderImpl std::unique_ptr hts_bam1_; }; - struct HtsFileDeleter - { - void operator()(htsFile* file) const { hts_close(file); } - }; - struct HtsHeaderDeleter - { - void operator()(bam_hdr_t* header) const { bam_hdr_destroy(header); } - }; - struct HtsIndexDeleter - { - void operator()(hts_idx_t* index) const { hts_idx_destroy(index); } - }; - Path file_path_; std::unique_ptr hts_file_; @@ -162,6 +170,7 @@ class HtslibSamFacade : public IReadReaderImpl const GenomicRegion::ContigName& get_contig_name(HtsTid target) const; std::uint64_t get_num_mapped_reads(const GenomicRegion::ContigName& contig) const; ReadContainer fetch_all_reads(const GenomicRegion& region) const; + void write(const AlignedRead& read, bam1_t* result) const; }; } // namespace io diff --git a/src/io/read/read_manager.cpp b/src/io/read/read_manager.cpp index 0269defeb..01539af32 100644 --- a/src/io/read/read_manager.cpp +++ b/src/io/read/read_manager.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_manager.hpp" @@ -44,15 +44,16 @@ ReadManager::ReadManager(std::initializer_list read_file_paths) {} ReadManager::ReadManager(ReadManager&& other) -: -num_files_ {std::move(other.num_files_)} { std::lock_guard lock {other.mutex_}; - closed_readers_ = std::move(other.closed_readers_); - open_readers_ = std::move(other.open_readers_); - reader_paths_containing_sample_ = std::move(other.reader_paths_containing_sample_); - possible_regions_in_readers_ = std::move(other.possible_regions_in_readers_); - samples_ = std::move(other.samples_); + using std::move; + max_open_files_ = move(other.max_open_files_); + num_files_ = move(other.num_files_); + closed_readers_ = move(other.closed_readers_); + open_readers_ = move(other.open_readers_); + reader_paths_containing_sample_ = move(other.reader_paths_containing_sample_); + possible_regions_in_readers_ = move(other.possible_regions_in_readers_); + samples_ = move(other.samples_); } void swap(ReadManager& lhs, ReadManager& rhs) noexcept @@ -61,11 +62,19 @@ void swap(ReadManager& lhs, ReadManager& rhs) noexcept std::lock(lhs.mutex_, rhs.mutex_); std::lock_guard lock_lhs {lhs.mutex_, std::adopt_lock}, lock_rhs {rhs.mutex_, std::adopt_lock}; using std::swap; + swap(lhs.max_open_files_, rhs.max_open_files_); + swap(lhs.num_files_, rhs.num_files_); swap(lhs.closed_readers_, rhs.closed_readers_); swap(lhs.open_readers_, rhs.open_readers_); swap(lhs.reader_paths_containing_sample_, rhs.reader_paths_containing_sample_); swap(lhs.possible_regions_in_readers_, rhs.possible_regions_in_readers_); - swap(lhs.samples_, rhs.samples_); + swap(lhs.samples_, rhs.samples_); +} + +void ReadManager::close() const noexcept +{ + std::lock_guard lock {mutex_}; + close_readers(num_files_); } bool ReadManager::good() const noexcept @@ -79,6 +88,21 @@ unsigned ReadManager::num_files() const noexcept return static_cast(closed_readers_.size() + open_readers_.size()); } +std::vector ReadManager::paths() const +{ + std::vector result {}; + result.reserve(num_files_); + std::lock_guard lock {mutex_}; + for (const auto& path : closed_readers_) { + result.push_back(path); + } + for (const auto& p : open_readers_) { + result.push_back(p.first); + } + std::sort(std::begin(result), std::end(result)); + return result; +} + unsigned ReadManager::num_samples() const noexcept { return static_cast(samples_.size()); @@ -277,7 +301,7 @@ auto max_head_region(const CoverageTracker& position_tracker, if (position_tracker.num_tracked() <= max_coverage) return region; const auto max_region = max_head_region(position_tracker, region); if (size(max_region) <= 1) return max_region; - auto position_coverage = position_tracker.coverage(max_region.contig_region()); + auto position_coverage = position_tracker.get(max_region.contig_region()); std::partial_sum(std::begin(position_coverage), std::end(position_coverage), std::begin(position_coverage)); const auto last_position = std::upper_bound(std::cbegin(position_coverage), std::cend(position_coverage), max_coverage); return expand_rhs(head_region(region), std::distance(std::cbegin(position_coverage), last_position)); @@ -359,7 +383,7 @@ ReadManager::ReadContainer ReadManager::fetch_reads(const SampleName& sample, co ReadManager::SampleReadMap ReadManager::fetch_reads(const std::vector& samples, const GenomicRegion& region) const { SampleReadMap result {samples.size()}; - // Populate here so we can do unchcked access + // Populate here so we can make unchecked access for (const auto& sample : samples) { result.emplace(std::piecewise_construct, std::forward_as_tuple(sample), std::forward_as_tuple()); } diff --git a/src/io/read/read_manager.hpp b/src/io/read/read_manager.hpp index e8356861d..c531d79c1 100644 --- a/src/io/read/read_manager.hpp +++ b/src/io/read/read_manager.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_manager_hpp @@ -50,9 +50,10 @@ class ReadManager friend void swap(ReadManager& lhs, ReadManager& rhs) noexcept; + void close() const noexcept; // close all readers bool good() const noexcept; unsigned num_files() const noexcept; - + std::vector paths() const; // Managed files unsigned num_samples() const noexcept; const std::vector& samples() const; unsigned drop_samples(std::vector samples); @@ -84,7 +85,7 @@ class ReadManager }; using OpenReaderMap = std::map; - using ClosedReaders = std::unordered_set; + using ClosedReaderSet = std::unordered_set; using SampleIdToReaderPathMap = std::unordered_map>; using ContigMap = MappableMap; using ReaderRegionsMap = std::unordered_map; @@ -92,13 +93,11 @@ class ReadManager unsigned max_open_files_ = 200; unsigned num_files_; - mutable ClosedReaders closed_readers_; + mutable ClosedReaderSet closed_readers_; mutable OpenReaderMap open_readers_; SampleIdToReaderPathMap reader_paths_containing_sample_; - ReaderRegionsMap possible_regions_in_readers_; - std::vector samples_; mutable std::mutex mutex_; diff --git a/src/io/read/read_reader.cpp b/src/io/read/read_reader.cpp index 4b3c2c6a3..ecf015bec 100644 --- a/src/io/read/read_reader.cpp +++ b/src/io/read/read_reader.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_reader.hpp" diff --git a/src/io/read/read_reader.hpp b/src/io/read/read_reader.hpp index d4c9d2c2a..4dcf3a502 100644 --- a/src/io/read/read_reader.hpp +++ b/src/io/read/read_reader.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_reader_hpp diff --git a/src/io/read/read_reader_impl.hpp b/src/io/read/read_reader_impl.hpp index 193aecf6d..c48bc45df 100644 --- a/src/io/read/read_reader_impl.hpp +++ b/src/io/read/read_reader_impl.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_reader_impl_hpp diff --git a/src/io/read/read_writer.cpp b/src/io/read/read_writer.cpp new file mode 100644 index 000000000..fc376dace --- /dev/null +++ b/src/io/read/read_writer.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2017 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "read_writer.hpp" + +#include + +namespace octopus { namespace io { + +ReadWriter::ReadWriter(Path bam_out, Path bam_template) +: path_ {std::move(bam_out)} +, impl_ {std::make_unique(path_, std::move(bam_template))} +{} + +ReadWriter::ReadWriter(ReadWriter&& other) +{ + std::lock_guard lock {other.mutex_}; + path_ = std::move(other.path_); + impl_ = std::move(other.impl_); +} + +void swap(ReadWriter& lhs, ReadWriter& rhs) noexcept +{ + if (&lhs == &rhs) return; + std::lock(lhs.mutex_, rhs.mutex_); + std::lock_guard lock_lhs {lhs.mutex_, std::adopt_lock}, lock_rhs {rhs.mutex_, std::adopt_lock}; + using std::swap; + swap(lhs.path_, rhs.path_); + swap(lhs.impl_, rhs.impl_); +} + +const ReadWriter::Path& ReadWriter::path() const noexcept +{ + return path_; +} + +void ReadWriter::write(const AlignedRead& read) +{ + std::lock_guard lock {mutex_}; + impl_->write(read); +} + +ReadWriter& operator<<(ReadWriter& dst, const AlignedRead& read) +{ + dst.write(read); + return dst; +} + +} // namespace io +} // namespace octopus diff --git a/src/io/read/read_writer.hpp b/src/io/read/read_writer.hpp new file mode 100644 index 000000000..fe30f8b2f --- /dev/null +++ b/src/io/read/read_writer.hpp @@ -0,0 +1,69 @@ +// Copyright (c) 2017 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef read_writer_hpp +#define read_writer_hpp + +#include +#include + +#include + +#include "concepts/equitable.hpp" +#include "htslib_sam_facade.hpp" + +namespace octopus { + +class AlignedRead; + +namespace io { + +class ReadWriter +{ +public: + using Path = boost::filesystem::path; + + ReadWriter() = delete; + + ReadWriter(Path bam_out, Path bam_template); + + ReadWriter(const ReadWriter&) = delete; + ReadWriter& operator=(const ReadWriter&) = delete; + ReadWriter(ReadWriter&&); + ReadWriter& operator=(ReadWriter&&) = delete; + + ~ReadWriter() = default; + + friend void swap(ReadWriter& lhs, ReadWriter& rhs) noexcept; + + const Path& path() const noexcept; + + void write(const AlignedRead& read); + +private: + Path path_; + std::unique_ptr impl_; + mutable std::mutex mutex_; +}; + +ReadWriter& operator<<(ReadWriter& dst, const AlignedRead& read); + +template +void write(const Container& reads, ReadWriter& dst) +{ + for (const auto& read : reads) { + dst << read; + } +} + +template +ReadWriter& operator<<(ReadWriter& dst, const Container& reads) +{ + write(reads, dst); + return dst; +} + +} // namespace io +} // namespace octopus + +#endif diff --git a/src/io/reference/caching_fasta.cpp b/src/io/reference/caching_fasta.cpp index c1c043daa..0e00bed06 100644 --- a/src/io/reference/caching_fasta.cpp +++ b/src/io/reference/caching_fasta.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "caching_fasta.hpp" diff --git a/src/io/reference/caching_fasta.hpp b/src/io/reference/caching_fasta.hpp index 8a0aea04b..1e09f7f14 100644 --- a/src/io/reference/caching_fasta.hpp +++ b/src/io/reference/caching_fasta.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef caching_fasta_hpp diff --git a/src/io/reference/fasta.cpp b/src/io/reference/fasta.cpp index c30de3f2c..68e588210 100644 --- a/src/io/reference/fasta.cpp +++ b/src/io/reference/fasta.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "fasta.hpp" diff --git a/src/io/reference/fasta.hpp b/src/io/reference/fasta.hpp index 6bdae0800..e286e5c99 100644 --- a/src/io/reference/fasta.hpp +++ b/src/io/reference/fasta.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef fasta_hpp diff --git a/src/io/reference/reference_genome.cpp b/src/io/reference/reference_genome.cpp index d64c31a62..9052f666d 100644 --- a/src/io/reference/reference_genome.cpp +++ b/src/io/reference/reference_genome.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "reference_genome.hpp" @@ -95,7 +95,7 @@ ReferenceGenome::GeneticSequence ReferenceGenome::fetch_sequence(const GenomicRe // non-member functions ReferenceGenome make_reference(boost::filesystem::path reference_path, - const std::size_t max_cached_bases, + const MemoryFootprint max_cache_size, const bool is_threaded, bool capitalise_bases) { @@ -111,10 +111,10 @@ ReferenceGenome make_reference(boost::filesystem::path reference_path, } else { impl_ = std::make_unique(std::move(reference_path), options); } - if (max_cached_bases > 0) { + if (max_cache_size.num_bytes() > 0) { double locality_bias {0.99}, forward_bias {0.99}; if (is_threaded) locality_bias = 0.25; - return ReferenceGenome {std::make_unique(std::move(impl_), max_cached_bases, + return ReferenceGenome {std::make_unique(std::move(impl_), max_cache_size.num_bytes(), locality_bias, forward_bias)}; } else { return ReferenceGenome {std::move(impl_)}; diff --git a/src/io/reference/reference_genome.hpp b/src/io/reference/reference_genome.hpp index a5726686b..cfc3cd7e8 100644 --- a/src/io/reference/reference_genome.hpp +++ b/src/io/reference/reference_genome.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef reference_genome_hpp @@ -13,6 +13,7 @@ #include #include "basics/genomic_region.hpp" +#include "utils/memory_footprint.hpp" #include "reference_reader.hpp" namespace octopus { @@ -37,13 +38,9 @@ class ReferenceGenome const std::string& name() const; bool has_contig(const ContigName& contig) const noexcept; - std::size_t num_contigs() const noexcept; - std::vector contig_names() const; - ContigRegion::Size contig_size(const ContigName& contig) const; - GenomicRegion contig_region(const ContigName& contig) const; bool contains(const GenomicRegion& region) const noexcept; @@ -52,18 +49,15 @@ class ReferenceGenome private: std::unique_ptr impl_; - std::string name_; - std::unordered_map contig_sizes_; - std::vector ordered_contigs_; }; // non-member functions ReferenceGenome make_reference(boost::filesystem::path reference_path, - std::size_t max_cached_bases = 0, + MemoryFootprint max_cache_size = 0, bool is_threaded = false, bool capitalise_bases = true); diff --git a/src/io/reference/reference_reader.hpp b/src/io/reference/reference_reader.hpp index 4dd953a6e..179f61a11 100644 --- a/src/io/reference/reference_reader.hpp +++ b/src/io/reference/reference_reader.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef reference_reader_hpp diff --git a/src/io/reference/threadsafe_fasta.cpp b/src/io/reference/threadsafe_fasta.cpp index 5c2e1219a..1b1e4e86e 100644 --- a/src/io/reference/threadsafe_fasta.cpp +++ b/src/io/reference/threadsafe_fasta.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "threadsafe_fasta.hpp" diff --git a/src/io/reference/threadsafe_fasta.hpp b/src/io/reference/threadsafe_fasta.hpp index f238bf6c5..e5f403bd8 100644 --- a/src/io/reference/threadsafe_fasta.hpp +++ b/src/io/reference/threadsafe_fasta.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef threadsafe_fasta_hpp diff --git a/src/io/region/region_parser.cpp b/src/io/region/region_parser.cpp index 0d432822a..cd586ffb2 100644 --- a/src/io/region/region_parser.cpp +++ b/src/io/region/region_parser.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "region_parser.hpp" diff --git a/src/io/region/region_parser.hpp b/src/io/region/region_parser.hpp index c70e436a8..b461591b1 100644 --- a/src/io/region/region_parser.hpp +++ b/src/io/region/region_parser.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef region_parser_hpp diff --git a/src/io/variant/htslib_bcf_facade.cpp b/src/io/variant/htslib_bcf_facade.cpp index a73511a78..c599d9dc2 100644 --- a/src/io/variant/htslib_bcf_facade.cpp +++ b/src/io/variant/htslib_bcf_facade.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "htslib_bcf_facade.hpp" @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -29,7 +30,12 @@ namespace octopus { namespace { -static const std::string vcfMissingValue {vcfspec::missingValue}; +static const std::string bcf_missing_str {vcfspec::missingValue}; + +bool is_missing(const std::string& value) noexcept +{ + return value == bcf_missing_str; +} namespace bc = boost::container; @@ -584,7 +590,7 @@ void extract_info(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder& b values.reserve(nintinfo); std::transform(intinfo, intinfo + nintinfo, std::back_inserter(values), [] (auto v) { - return v != bcf_int32_missing ? std::to_string(v) : vcfMissingValue; + return v != bcf_int32_missing ? std::to_string(v) : bcf_missing_str; }); } break; @@ -593,7 +599,7 @@ void extract_info(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder& b values.reserve(nfloatinfo); std::transform(floatinfo, floatinfo + nfloatinfo, std::back_inserter(values), [] (auto v) { - return v != bcf_float_missing ? std::to_string(v) : vcfMissingValue; + return v != bcf_float_missing ? std::to_string(v) : bcf_missing_str; }); } break; @@ -621,22 +627,25 @@ void extract_info(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder& b if (flaginfo != nullptr) std::free(flaginfo); } +float get_bcf_float_missing() noexcept +{ + float result; + bcf_float_set_missing(result); + return result; +} + void set_info(const bcf_hdr_t* header, bcf1_t* dest, const VcfRecord& source) { for (const auto& key : source.info_keys()) { const auto& values = source.info_value(key); const auto num_values = static_cast(values.size()); - static constexpr std::size_t defaultBufferCapacity {100}; - switch (bcf_hdr_id2type(header, BCF_HL_INFO, bcf_hdr_id2int(header, BCF_DT_ID, key.c_str()))) { case BCF_HT_INT: { bc::small_vector vals(num_values); std::transform(std::cbegin(values), std::cend(values), std::begin(vals), - [] (const auto& v) { - return v != vcfMissingValue ? std::stoi(v) : bcf_int32_missing; - }); + [] (const auto& v) { return !is_missing(v) ? std::stoi(v) : bcf_int32_missing; }); bcf_update_info_int32(header, dest, key.c_str(), vals.data(), num_values); break; } @@ -644,9 +653,7 @@ void set_info(const bcf_hdr_t* header, bcf1_t* dest, const VcfRecord& source) { bc::small_vector vals(num_values); std::transform(std::cbegin(values), std::cend(values), std::begin(vals), - [] (const auto& v) { - return v != vcfMissingValue ? std::stof(v) : bcf_float_missing; - }); + [] (const auto& v) { return !is_missing(v) ? std::stof(v) : get_bcf_float_missing(); }); bcf_update_info_float(header, dest, key.c_str(), vals.data(), num_values); break; } @@ -675,7 +682,6 @@ auto extract_format(const bcf_hdr_t* header, const bcf1_t* record) { std::vector result {}; result.reserve(record->n_fmt); - for (unsigned i {0}; i < record->n_fmt; ++i) { const auto key_id = record->d.fmt[i].id; if (key_id >= header->n[BCF_DT_ID]) { @@ -683,7 +689,6 @@ auto extract_format(const bcf_hdr_t* header, const bcf1_t* record) } result.emplace_back(header->id[BCF_DT_ID][key_id].key); } - return result; } @@ -692,54 +697,47 @@ void extract_samples(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder auto format = extract_format(header, record); const auto num_samples = record->n_sample; builder.reserve_samples(num_samples); - + auto first_format = std::cbegin(format); if (format.front() == vcfspec::format::genotype) { // the first key must be GT if present int ngt {}, g {}; int* gt {nullptr}; bcf_get_genotypes(header, record, >, &ngt); // mallocs gt const auto max_ploidy = static_cast(record->d.fmt->n); - for (unsigned sample {0}, i {0}; sample < num_samples; ++sample, i += max_ploidy) { std::vector alleles {}; alleles.reserve(max_ploidy); - for (unsigned p {0}; p < max_ploidy; ++p) { g = gt[i + p]; if (g == bcf_int32_vector_end) { alleles.shrink_to_fit(); break; } else if (bcf_gt_is_missing(g)) { - alleles.push_back(vcfMissingValue); + alleles.push_back(bcf_missing_str); } else { const auto idx = bcf_gt_allele(g); if (idx < record->n_allele) { alleles.emplace_back(record->d.allele[idx]); } else { - alleles.push_back(vcfMissingValue); + alleles.push_back(bcf_missing_str); } } } - using Phasing = VcfRecord::Builder::Phasing; builder.set_genotype(header->samples[sample], std::move(alleles), bcf_gt_is_phased(g) ? Phasing::phased : Phasing::unphased); } - std::free(gt); + ++first_format; } - int nintformat {}; int* intformat {nullptr}; int nfloatformat {}; float* floatformat {nullptr}; int nstringformat {}; char** stringformat {nullptr}; - - for (auto it = std::next(std::cbegin(format)), end = std::cend(format); it != end; ++it) { - const auto& key = *it; - + for (auto itr = first_format, end = std::cend(format); itr != end; ++itr) { + const auto& key = *itr; std::vector> values(num_samples, std::vector {}); - switch (bcf_hdr_id2type(header, BCF_HL_FMT, bcf_hdr_id2int(header, BCF_DT_ID, key.c_str()))) { case BCF_HT_INT: if (bcf_get_format_int32(header, record, key.c_str(), &intformat, &nintformat) > 0) { @@ -749,7 +747,7 @@ void extract_samples(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder values[sample].reserve(num_values_per_sample); std::transform(ptr, ptr + num_values_per_sample, std::back_inserter(values[sample]), [] (auto v) { - return v != bcf_int32_missing ? std::to_string(v) : vcfMissingValue; + return v != bcf_int32_missing ? std::to_string(v) : bcf_missing_str; }); } } @@ -760,10 +758,9 @@ void extract_samples(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder auto ptr = floatformat; for (unsigned sample {0}; sample < num_samples; ++sample, ptr += num_values_per_sample) { values[sample].reserve(num_values_per_sample); - std::transform(ptr, ptr + num_samples, std::back_inserter(values[sample]), [] (auto v) { - return v != bcf_float_missing ? std::to_string(v) : vcfMissingValue; + return v != bcf_float_missing ? std::to_string(v) : bcf_missing_str; }); } } @@ -779,14 +776,11 @@ void extract_samples(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder } break; } - for (unsigned sample {0}; sample < num_samples; ++sample) { builder.set_format(header->samples[sample], key, std::move(values[sample])); } } - builder.set_format(std::move(format)); - if (intformat != nullptr) std::free(intformat); if (floatformat != nullptr) std::free(floatformat); if (stringformat != nullptr) { @@ -799,7 +793,7 @@ void extract_samples(const bcf_hdr_t* header, bcf1_t* record, VcfRecord::Builder template auto genotype_number(const T& allele, const Container& alleles, const bool is_phased) { - if (allele == vcfMissingValue) { + if (is_missing(allele)) { return (is_phased) ? bcf_gt_missing + 1 : bcf_gt_missing; } const auto it = std::find(std::cbegin(alleles), std::cend(alleles), allele); @@ -807,6 +801,22 @@ auto genotype_number(const T& allele, const Container& alleles, const bool is_ph return (is_phased) ? allele_num + 1 : allele_num; } +auto max_format_cardinality(const VcfRecord& record, const VcfRecord::KeyType& key, const std::vector& samples) +{ + std::size_t result {0}; + for (const auto& sample : samples) { + result = std::max(result, record.get_sample_value(sample, key).size()); + } + return result; +} + +float get_bcf_float_pad() noexcept +{ + float result; + bcf_float_set_vector_end(result); + return result; +} + void set_samples(const bcf_hdr_t* header, bcf1_t* dest, const VcfRecord& source, const std::vector& samples) { @@ -814,7 +824,6 @@ void set_samples(const bcf_hdr_t* header, bcf1_t* dest, const VcfRecord& source, const auto num_samples = static_cast(source.num_samples()); const auto& format = source.format(); if (format.empty()) return; - auto first_format = std::cbegin(format); if (*first_format == vcfspec::format::genotype) { const auto& alt_alleles = source.alt(); @@ -822,75 +831,90 @@ void set_samples(const bcf_hdr_t* header, bcf1_t* dest, const VcfRecord& source, alleles.reserve(alt_alleles.size() + 1); alleles.push_back(source.ref()); alleles.insert(std::end(alleles), std::cbegin(alt_alleles), std::cend(alt_alleles)); - unsigned max_ploidy {}; for (const auto& sample : samples) { const auto p = source.ploidy(sample); if (p > max_ploidy) max_ploidy = p; } - const auto ngt = num_samples * static_cast(max_ploidy); - bc::small_vector genotype(ngt); - auto it = std::begin(genotype); - + bc::small_vector genotype(ngt); + auto genotype_itr = std::begin(genotype); for (const auto& sample : samples) { const bool is_phased {source.is_sample_phased(sample)}; const auto& genotype = source.get_sample_value(sample, vcfspec::format::genotype); const auto ploidy = static_cast(genotype.size()); - - it = std::transform(std::cbegin(genotype), std::cend(genotype), it, - [is_phased, &alleles] (const auto& allele) { - return genotype_number(allele, alleles, is_phased); - }); - it = std::fill_n(it, max_ploidy - ploidy, bcf_int32_vector_end); + genotype_itr = std::transform(std::cbegin(genotype), std::cend(genotype), genotype_itr, + [is_phased, &alleles] (const auto& allele) { + return genotype_number(allele, alleles, is_phased); + }); + genotype_itr = std::fill_n(genotype_itr, max_ploidy - ploidy, bcf_int32_vector_end); } - bcf_update_genotypes(header, dest, genotype.data(), ngt); ++first_format; } - + std::vector str_buffer {}; std::for_each(first_format, std::cend(format), [&] (const auto& key) { - const auto num_values = num_samples * static_cast(source.format_cardinality(key)); - - static constexpr std::size_t defaultValueCapacity {100}; - + const auto key_cardinality = source.format_cardinality(key); + int num_values {}; + if (key_cardinality) { + num_values = *key_cardinality * num_samples; + } else { + num_values = max_format_cardinality(source, key, samples) * num_samples; + } + const auto num_values_per_sample = static_cast(num_values / num_samples); + static constexpr std::size_t defaultValueCapacity {1'000}; switch (bcf_hdr_id2type(header, BCF_HL_FMT, bcf_hdr_id2int(header, BCF_DT_ID, key.c_str()))) { case BCF_HT_INT: { + static const int pad {bcf_int32_vector_end}; bc::small_vector typed_values(num_values); - auto it = std::begin(typed_values); + auto value_itr = std::begin(typed_values); for (const auto& sample : samples) { const auto& values = source.get_sample_value(sample, key); - it = std::transform(std::cbegin(values), std::cend(values), it, - [] (const auto& v) { - return v != vcfMissingValue ? std::stoi(v) : bcf_int32_missing; - }); + value_itr = std::transform(std::cbegin(values), std::cend(values), value_itr, + [] (const auto& v) { return !is_missing(v) ? std::stoi(v) : bcf_int32_missing; }); + assert(values.size() <= num_values_per_sample); + value_itr = std::fill_n(value_itr, num_values_per_sample - values.size(), pad); } bcf_update_format_int32(header, dest, key.c_str(), typed_values.data(), num_values); break; } case BCF_HT_REAL: { + static const float pad {get_bcf_float_pad()}; bc::small_vector typed_values(num_values); - auto it = std::begin(typed_values); + auto value_itr = std::begin(typed_values); for (const auto& sample : samples) { const auto& values = source.get_sample_value(sample, key); - it = std::transform(std::cbegin(values), std::cend(values), it, - [] (const auto& v) { - return v != vcfMissingValue ? std::stof(v) : bcf_float_missing; - }); + value_itr = std::transform(std::cbegin(values), std::cend(values), value_itr, + [] (const auto& v) { return !is_missing(v) ? std::stof(v) : get_bcf_float_missing(); }); + assert(values.size() <= num_values_per_sample); + value_itr = std::fill_n(value_itr, num_values_per_sample - values.size(), pad); } bcf_update_format_float(header, dest, key.c_str(), typed_values.data(), num_values); break; } case BCF_HT_STR: { - bc::small_vector typed_values(num_values); - auto it = std::begin(typed_values); - for (const auto& sample : samples) { - const auto& values = source.get_sample_value(sample, key); - it = std::transform(std::cbegin(values), std::cend(values), it, - [] (const auto& value) { return value.c_str(); }); + bc::small_vector typed_values; + if (key_cardinality && *key_cardinality <= 1) { + typed_values.resize(num_values); + auto value_itr = std::begin(typed_values); + for (const auto& sample : samples) { + const auto& values = source.get_sample_value(sample, key); + value_itr = std::transform(std::cbegin(values), std::cend(values), value_itr, + [] (const auto& value) { return value.c_str(); }); + } + } else { + str_buffer.clear(); + str_buffer.reserve(num_samples); + for (const auto& sample : samples) { + str_buffer.push_back(utils::join(source.get_sample_value(sample, key), vcfspec::format::valueSeperator)); + } + num_values = num_samples; + typed_values.resize(num_values); + std::transform(std::cbegin(str_buffer), std::cend(str_buffer), std::begin(typed_values), + [] (const auto& value) { return value.c_str(); }); } bcf_update_format_string(header, dest, key.c_str(), typed_values.data(), num_values); break; diff --git a/src/io/variant/htslib_bcf_facade.hpp b/src/io/variant/htslib_bcf_facade.hpp index eebfa7e15..9750a732e 100644 --- a/src/io/variant/htslib_bcf_facade.hpp +++ b/src/io/variant/htslib_bcf_facade.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef htslib_bcf_facade_hpp diff --git a/src/io/variant/vcf.hpp b/src/io/variant/vcf.hpp index f2d11aa92..0fceeb438 100644 --- a/src/io/variant/vcf.hpp +++ b/src/io/variant/vcf.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_hpp diff --git a/src/io/variant/vcf_header.cpp b/src/io/variant/vcf_header.cpp index e99fa914d..35f527192 100644 --- a/src/io/variant/vcf_header.cpp +++ b/src/io/variant/vcf_header.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_header.hpp" @@ -418,10 +418,6 @@ VcfHeader::Builder get_default_header_builder() result.add_format("BQ", "1", "Integer", "RMS base quality at this position"); result.add_filter("PASS", "All filters passed"); - result.add_filter("MQ", "Root-mean-square mapping quality across calling region is low"); - result.add_filter("q10", "Variant quality is below 10"); - result.add_filter("SB", "One of the alternative alleles has strand bias"); - result.add_filter("KL", "High Kullback–Leibler divergence between REF and ALT mapping quality distributions"); return result; } diff --git a/src/io/variant/vcf_header.hpp b/src/io/variant/vcf_header.hpp index d454bb60d..e8ef35a59 100644 --- a/src/io/variant/vcf_header.hpp +++ b/src/io/variant/vcf_header.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_header_hpp diff --git a/src/io/variant/vcf_parser.cpp b/src/io/variant/vcf_parser.cpp index 98fb44b25..1dc32eabf 100644 --- a/src/io/variant/vcf_parser.cpp +++ b/src/io/variant/vcf_parser.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_parser.hpp" diff --git a/src/io/variant/vcf_parser.hpp b/src/io/variant/vcf_parser.hpp index 252040020..ec975d5f4 100644 --- a/src/io/variant/vcf_parser.hpp +++ b/src/io/variant/vcf_parser.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_parser_hpp diff --git a/src/io/variant/vcf_reader.cpp b/src/io/variant/vcf_reader.cpp index abaac40a1..6834ba12f 100644 --- a/src/io/variant/vcf_reader.cpp +++ b/src/io/variant/vcf_reader.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_reader.hpp" diff --git a/src/io/variant/vcf_reader.hpp b/src/io/variant/vcf_reader.hpp index dda3cf056..f10930821 100644 --- a/src/io/variant/vcf_reader.hpp +++ b/src/io/variant/vcf_reader.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_reader_hpp diff --git a/src/io/variant/vcf_reader_impl.hpp b/src/io/variant/vcf_reader_impl.hpp index 1c1a8170c..4329857c8 100644 --- a/src/io/variant/vcf_reader_impl.hpp +++ b/src/io/variant/vcf_reader_impl.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_reader_impl_hpp diff --git a/src/io/variant/vcf_record.cpp b/src/io/variant/vcf_record.cpp index 941709f2a..7e98da1b2 100644 --- a/src/io/variant/vcf_record.cpp +++ b/src/io/variant/vcf_record.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_record.hpp" @@ -59,7 +59,7 @@ bool VcfRecord::has_filter(const KeyType& filter) const noexcept return std::find(std::cbegin(filter_), std::cend(filter_), filter) != std::cend(filter_); } -const std::vector VcfRecord::filter() const noexcept +const std::vector& VcfRecord::filter() const noexcept { return filter_; } @@ -91,9 +91,20 @@ bool VcfRecord::has_format(const KeyType& key) const noexcept return std::find(std::cbegin(format_), std::cend(format_), key) != std::cend(format_); } -unsigned VcfRecord::format_cardinality(const KeyType& key) const noexcept +boost::optional VcfRecord::format_cardinality(const KeyType& key) const noexcept { - return (has_format(key)) ? static_cast(std::cbegin(samples_)->second.at(key).size()) : 0; + boost::optional result {}; + if (has_format(key)) { + for (const auto& p : samples_) { + const auto sample_format_cardinality = p.second.at(key).size(); + if (result) { + if (*result != sample_format_cardinality) return boost::none; + } else { + result = sample_format_cardinality; + } + } + } + return result; } const std::vector& VcfRecord::format() const noexcept @@ -310,6 +321,12 @@ std::vector get_genotype(const VcfRecord& record, return record.get_sample_value(sample, vcfspec::format::genotype); } +bool is_filtered(const VcfRecord& record) noexcept +{ + const auto& filters = record.filter(); + return !filters.empty() && !(filters[0] == vcfspec::filter::pass || filters[0] == vcfspec::missingValue); +} + bool is_dbsnp_member(const VcfRecord& record) noexcept { return record.has_info(vcfspec::info::dbSNPMember); @@ -382,7 +399,7 @@ std::ostream& operator<<(std::ostream& os, const VcfRecord& record) if (record.qual_) { os << static_cast(*record.qual_) << "\t"; } else { - os << '.' << "\t"; + os << vcfspec::missingValue << "\t"; } os << record.filter_ << "\t"; record.print_info(os); @@ -466,7 +483,7 @@ VcfRecord::Builder& VcfRecord::Builder::set_qual(QualityType quality) VcfRecord::Builder& VcfRecord::Builder::set_passed() { - filter_.assign({"PASS"}); + filter_.assign({vcfspec::filter::pass}); return *this; } @@ -527,6 +544,11 @@ VcfRecord::Builder& VcfRecord::Builder::set_info_flag(KeyType key) return this->set_info(std::move(key), {}); } +VcfRecord::Builder& VcfRecord::Builder::set_info_missing(const KeyType& key) +{ + return this->set_info(key, {vcfspec::missingValue}); +} + VcfRecord::Builder& VcfRecord::Builder::clear_info() noexcept { info_.clear(); @@ -563,75 +585,112 @@ VcfRecord::Builder& VcfRecord::Builder::reserve_samples(unsigned n) return *this; } -VcfRecord::Builder&VcfRecord::Builder:: set_homozygous_ref_genotype(const SampleName& sample, - unsigned ploidy) +VcfRecord::Builder&VcfRecord::Builder:: set_homozygous_ref_genotype(const SampleName& sample, unsigned ploidy) { std::vector tmp(ploidy, ref_); return set_genotype(sample, tmp, Phasing::phased); } -VcfRecord::Builder& VcfRecord::Builder::set_genotype(const SampleName& sample, - std::vector alleles, +VcfRecord::Builder& VcfRecord::Builder::set_genotype(const SampleName& sample, std::vector alleles, Phasing phasing) { genotypes_[sample] = std::make_pair(std::move(alleles), phasing == Phasing::phased); return *this; } -VcfRecord::Builder& VcfRecord::Builder::set_genotype(const SampleName& sample, - const std::vector>& alleles, +VcfRecord::Builder& VcfRecord::Builder::set_genotype(const SampleName& sample, const std::vector>& alleles, Phasing phasing) { std::vector tmp {}; tmp.reserve(alleles.size()); - std::transform(std::cbegin(alleles), std::cend(alleles), std::back_inserter(tmp), [this] (const auto& allele) -> NucleotideSequence { if (allele) { return (*allele == 0) ? ref_ : alt_[*allele - 1]; } else { - return "."; + return vcfspec::missingValue; } }); - return set_genotype(sample, tmp, phasing); } -VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, - const KeyType& key, - const ValueType& value) +VcfRecord::Builder& VcfRecord::Builder::clear_genotype(const SampleName& sample) noexcept +{ + genotypes_.erase(sample); + return *this; +} + +VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, const KeyType& key, const ValueType& value) { return this->set_format(sample, key, std::vector {value}); } -VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, - const KeyType& key, - std::vector values) +VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, const KeyType& key, std::vector values) { samples_[sample][key] = std::move(values); return *this; } -VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, - const KeyType& key, - std::initializer_list values) +VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, const KeyType& key, std::initializer_list values) { return this->set_format(sample, key, std::vector {values}); } -VcfRecord::Builder& VcfRecord::Builder::set_format_missing(const SampleName& sample, - const KeyType& key) +VcfRecord::Builder& VcfRecord::Builder::set_format_missing(const SampleName& sample, const KeyType& key) { - return this->set_format(sample, key, std::string {"."}); + return this->set_format(sample, key, std::string {vcfspec::missingValue}); } VcfRecord::Builder& VcfRecord::Builder::clear_format() noexcept { format_.clear(); + samples_.clear(); genotypes_.clear(); return *this; } +VcfRecord::Builder& VcfRecord::Builder::clear_format(const SampleName& sample) noexcept +{ + samples_.erase(sample); + genotypes_.erase(sample); + return *this; +} + +VcfRecord::Builder& VcfRecord::Builder::clear_format(const SampleName& sample, const KeyType& key) noexcept +{ + const auto sample_itr = samples_.find(sample); + if (sample_itr != std::cend(samples_)) { + sample_itr->second.erase(key); + } + return *this; +} + +VcfRecord::Builder& VcfRecord::Builder::set_passed(const SampleName& sample) +{ + return this->set_format(sample, vcfspec::format::filter, {vcfspec::filter::pass}); +} + +VcfRecord::Builder& VcfRecord::Builder::set_filter(const SampleName& sample, std::vector filter) +{ + return this->set_format(sample, vcfspec::format::filter, std::move(filter)); +} + +VcfRecord::Builder& VcfRecord::Builder::set_filter(const SampleName& sample, std::initializer_list filter) +{ + return this->set_format(sample, vcfspec::format::filter, std::move(filter)); +} + +VcfRecord::Builder& VcfRecord::Builder::add_filter(const SampleName& sample, KeyType filter) +{ + samples_[sample][vcfspec::format::filter].push_back(std::move(filter)); + return *this; +} + +VcfRecord::Builder& VcfRecord::Builder::clear_filter(const SampleName& sample) noexcept +{ + return this->clear_format(sample, vcfspec::format::filter); +} + VcfRecord::Builder& VcfRecord::Builder::set_refcall() { return set_alt(""); @@ -660,13 +719,9 @@ GenomicRegion::Position VcfRecord::Builder::pos() const noexcept VcfRecord VcfRecord::Builder::build() const { if (genotypes_.empty() && samples_.empty()) { - return VcfRecord { - chrom_, pos_, id_, ref_, alt_, qual_, filter_, info_ - }; + return VcfRecord {chrom_, pos_, id_, ref_, alt_, qual_, filter_, info_}; } else { - return VcfRecord { - chrom_, pos_, id_, ref_, alt_, qual_, filter_, info_, format_, genotypes_, samples_ - }; + return VcfRecord {chrom_, pos_, id_, ref_, alt_, qual_, filter_, info_, format_, genotypes_, samples_}; } } @@ -674,11 +729,11 @@ VcfRecord VcfRecord::Builder::build_once() noexcept { if (genotypes_.empty() && samples_.empty()) { return VcfRecord {std::move(chrom_), pos_, std::move(id_), std::move(ref_), - std::move(alt_), qual_, std::move(filter_), std::move(info_)}; + std::move(alt_), qual_, std::move(filter_), std::move(info_)}; } else { return VcfRecord {std::move(chrom_), pos_, std::move(id_), std::move(ref_), - std::move(alt_), qual_, std::move(filter_), std::move(info_), - std::move(format_), std::move(genotypes_), std::move(samples_)}; + std::move(alt_), qual_, std::move(filter_), std::move(info_), + std::move(format_), std::move(genotypes_), std::move(samples_)}; } } diff --git a/src/io/variant/vcf_record.hpp b/src/io/variant/vcf_record.hpp index bfdc910c8..14f87b16b 100644 --- a/src/io/variant/vcf_record.hpp +++ b/src/io/variant/vcf_record.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_record_hpp @@ -74,7 +74,7 @@ class VcfRecord : public Comparable, public Mappable const std::vector& alt() const noexcept; boost::optional qual() const noexcept; bool has_filter(const KeyType& filter) const noexcept; - const std::vector filter() const noexcept; + const std::vector& filter() const noexcept; bool has_info(const KeyType& key) const noexcept; std::vector info_keys() const; const std::vector& info_value(const KeyType& key) const; @@ -83,7 +83,7 @@ class VcfRecord : public Comparable, public Mappable // Sample releated functions // bool has_format(const KeyType& key) const noexcept; - unsigned format_cardinality(const KeyType& key) const noexcept; + boost::optional format_cardinality(const KeyType& key) const noexcept; const std::vector& format() const noexcept; unsigned num_samples() const noexcept; bool has_genotypes() const noexcept; @@ -133,6 +133,8 @@ std::vector get_genotype(const VcfRecord& record, VcfRecord::NucleotideSequence get_ancestral_allele(const VcfRecord& record); std::vector get_allele_count(const VcfRecord& record); std::vector get_allele_frequency(const VcfRecord& record); + +bool is_filtered(const VcfRecord& record) noexcept; bool is_dbsnp_member(const VcfRecord& record) noexcept; bool is_hapmap2_member(const VcfRecord& record) noexcept; bool is_hapmap3_member(const VcfRecord& record) noexcept; @@ -164,18 +166,23 @@ class VcfRecord::Builder Builder& set_chrom(std::string name); Builder& set_pos(GenomicRegion::Position pos); + Builder& set_id(std::string id); + Builder& set_ref(const char allele); Builder& set_ref(NucleotideSequence allele); Builder& set_alt(const char allele); // if just one Builder& set_alt(NucleotideSequence allele); // if just one Builder& set_alt(std::vector alleles); + Builder& set_qual(QualityType quality); + Builder& set_passed(); Builder& set_filter(std::vector filter); Builder& set_filter(std::initializer_list filter); Builder& add_filter(KeyType filter); Builder& clear_filter() noexcept; + Builder& reserve_info(unsigned n); Builder& add_info(const KeyType& key); // flags Builder& set_info(const KeyType& key, const ValueType& value); @@ -183,8 +190,10 @@ class VcfRecord::Builder Builder& set_info(const KeyType& key, std::vector values); Builder& set_info(const KeyType& key, std::initializer_list values); Builder& set_info_flag(KeyType key); + Builder& set_info_missing(const KeyType& key); Builder& clear_info() noexcept; Builder& clear_info(const KeyType& key); + Builder& set_format(std::vector format); Builder& set_format(std::initializer_list format); Builder& add_format(KeyType key); @@ -192,6 +201,7 @@ class VcfRecord::Builder Builder& reserve_samples(unsigned n); Builder& set_genotype(const SampleName& sample, std::vector alleles, Phasing phasing); Builder& set_genotype(const SampleName& sample, const std::vector>& alleles, Phasing is_phased); + Builder& clear_genotype(const SampleName& sample) noexcept; Builder& set_format(const SampleName& sample, const KeyType& key, const ValueType& value); template Builder& set_format(const SampleName& sample, const KeyType& key, const T& value); // calls std::to_string @@ -199,6 +209,13 @@ class VcfRecord::Builder Builder& set_format(const SampleName& sample, const KeyType& key, std::initializer_list values); Builder& set_format_missing(const SampleName& sample, const KeyType& key); Builder& clear_format() noexcept; + Builder& clear_format(const SampleName& sample) noexcept; + Builder& clear_format(const SampleName& sample, const KeyType& key) noexcept; + Builder& set_passed(const SampleName& sample); + Builder& set_filter(const SampleName& sample, std::vector filter); + Builder& set_filter(const SampleName& sample, std::initializer_list filter); + Builder& add_filter(const SampleName& sample, KeyType filter); + Builder& clear_filter(const SampleName& sample) noexcept; Builder& set_refcall(); Builder& set_somatic(); @@ -228,10 +245,7 @@ template VcfRecord::VcfRecord(String1&& chrom, GenomicRegion::Position pos, String2&& id, Sequence1&& ref, Sequence2&& alt, boost::optional qual, Filters&& filters, Info&& info) -: region_ { - std::forward(chrom), - pos - 1, - pos + static_cast(utils::length(ref)) - 1} +: region_ {std::forward(chrom), pos - 1, pos + static_cast(utils::length(ref)) - 1} , id_ {std::forward(id)} , ref_ {std::forward(ref)} , alt_ {std::forward(alt)} @@ -248,10 +262,7 @@ typename Filters, typename Info, typename Format, typename Genotypes, typename S VcfRecord::VcfRecord(String1&& chrom, GenomicRegion::Position pos, String2&& id, Sequence1&& ref, Sequence2&& alt, boost::optional qual, Filters&& filters, Info&& info, Format&& format, Genotypes&& genotypes, Samples&& samples) -: region_ { - std::forward(chrom), - pos - 1, - pos + static_cast(utils::length(ref)) - 1} +: region_ {std::forward(chrom), pos - 1, pos + static_cast(utils::length(ref)) - 1} , id_ {std::forward(id)} , ref_ {std::forward(ref)} , alt_ {std::forward(alt)} @@ -266,26 +277,29 @@ VcfRecord::VcfRecord(String1&& chrom, GenomicRegion::Position pos, String2&& id, template VcfRecord::Builder& VcfRecord::Builder::set_info(const KeyType& key, const T& value) { - return set_info(key, std::to_string(value)); + using std::to_string; + return set_info(key, to_string(value)); } template -VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, const KeyType& key, - const T& value) +VcfRecord::Builder& VcfRecord::Builder::set_format(const SampleName& sample, const KeyType& key, const T& value) { - return set_format(sample, key, std::to_string(value)); + using std::to_string; + return set_format(sample, key, to_string(value)); } } // namespace octopus namespace std { - template <> struct hash + +template <> struct hash +{ + size_t operator()(const octopus::VcfRecord& record) const { - size_t operator()(const octopus::VcfRecord& record) const - { - return hash()(record.id()); - } - }; + return hash()(record.id()); + } +}; + } // namespace std #endif diff --git a/src/io/variant/vcf_spec.hpp b/src/io/variant/vcf_spec.hpp index a4aecf872..2ce0ca941 100644 --- a/src/io/variant/vcf_spec.hpp +++ b/src/io/variant/vcf_spec.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_spec_hpp @@ -69,6 +69,25 @@ static constexpr std::array order { } // namespace struc +namespace type { + +VCF_SPEC_CONSTANT string {"String"}; +VCF_SPEC_CONSTANT floating {"Float"}; +VCF_SPEC_CONSTANT flag {"Flag"}; +VCF_SPEC_CONSTANT integer {"Integer"}; +VCF_SPEC_CONSTANT character {"Character"}; + +} // namespace type + +namespace number { + +VCF_SPEC_CONSTANT per_alt_allele {"A"}; +VCF_SPEC_CONSTANT per_allele {"R"}; +VCF_SPEC_CONSTANT per_genotype {"G"}; +VCF_SPEC_CONSTANT unknown {"."}; + +} // namespace number + VCF_SPEC_CONSTANT vcfVersion {"fileformat"}; } // namespace meta diff --git a/src/io/variant/vcf_type.cpp b/src/io/variant/vcf_type.cpp index fd7f9f5ab..a1db183be 100644 --- a/src/io/variant/vcf_type.cpp +++ b/src/io/variant/vcf_type.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_type.hpp" diff --git a/src/io/variant/vcf_type.hpp b/src/io/variant/vcf_type.hpp index 629fc3f27..9494c4d13 100644 --- a/src/io/variant/vcf_type.hpp +++ b/src/io/variant/vcf_type.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_type_hpp diff --git a/src/io/variant/vcf_utils.cpp b/src/io/variant/vcf_utils.cpp index ae3db91c0..d6521c85b 100644 --- a/src/io/variant/vcf_utils.cpp +++ b/src/io/variant/vcf_utils.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_utils.hpp" diff --git a/src/io/variant/vcf_utils.hpp b/src/io/variant/vcf_utils.hpp index bcf736964..45c7cee43 100644 --- a/src/io/variant/vcf_utils.hpp +++ b/src/io/variant/vcf_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_utils_hpp diff --git a/src/io/variant/vcf_writer.cpp b/src/io/variant/vcf_writer.cpp index 96d5518e6..feeb9ff97 100644 --- a/src/io/variant/vcf_writer.cpp +++ b/src/io/variant/vcf_writer.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "vcf_writer.hpp" diff --git a/src/io/variant/vcf_writer.hpp b/src/io/variant/vcf_writer.hpp index 3f9c5e78c..273aab923 100644 --- a/src/io/variant/vcf_writer.hpp +++ b/src/io/variant/vcf_writer.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef vcf_writer_hpp diff --git a/src/logging/error_handler.cpp b/src/logging/error_handler.cpp index b073fded5..9260167ca 100644 --- a/src/logging/error_handler.cpp +++ b/src/logging/error_handler.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "error_handler.hpp" @@ -8,6 +8,7 @@ #include #include +#include "exceptions/system_error.hpp" #include "config/config.hpp" #include "utils/string_utils.hpp" #include "logging.hpp" @@ -97,6 +98,22 @@ void log_error(const Error& error) log_error_help(error, log); } +class BadAlloc : public SystemError +{ + std::string do_where() const override { return "unknown"; } + std::string do_why() const override { return "system could not satisfy memory request"; } + std::string do_help() const override + { + return "ensure the system sufficient resources or submit an error report"; + } +}; + +void log_error(const std::bad_alloc& error) +{ + const BadAlloc e {}; + log_error(e); +} + class UnclassifiedError : public Error { std::string do_type() const override { return "unclassified"; } diff --git a/src/logging/error_handler.hpp b/src/logging/error_handler.hpp index 619ed5d2f..21295f1bf 100644 --- a/src/logging/error_handler.hpp +++ b/src/logging/error_handler.hpp @@ -1,17 +1,18 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef error_handler_hpp #define error_handler_hpp #include +#include #include "exceptions/error.hpp" namespace octopus { void log_error(const Error& error); - +void log_error(const std::bad_alloc& error); void log_error(const std::exception& error); void log_unknown_error(); diff --git a/src/logging/logging.cpp b/src/logging/logging.cpp index 4cab51c5b..8178f37fa 100644 --- a/src/logging/logging.cpp +++ b/src/logging/logging.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "logging.hpp" diff --git a/src/logging/logging.hpp b/src/logging/logging.hpp index 3eef9aedf..3a1c7ebf9 100644 --- a/src/logging/logging.hpp +++ b/src/logging/logging.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef logging_hpp diff --git a/src/logging/main_logging.cpp b/src/logging/main_logging.cpp index a82f2cc69..93ec12172 100644 --- a/src/logging/main_logging.cpp +++ b/src/logging/main_logging.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "main_logging.hpp" diff --git a/src/logging/main_logging.hpp b/src/logging/main_logging.hpp index 96f59bc9c..82ea80ec3 100644 --- a/src/logging/main_logging.hpp +++ b/src/logging/main_logging.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef main_logging_hpp diff --git a/src/logging/progress_meter.cpp b/src/logging/progress_meter.cpp index 87288b9a7..20bf6a918 100644 --- a/src/logging/progress_meter.cpp +++ b/src/logging/progress_meter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "progress_meter.hpp" @@ -227,14 +227,10 @@ ProgressMeter::~ProgressMeter() if (!done_ && !target_regions_.empty() && num_bp_completed_ > 0) { const TimeInterval duration {start_, std::chrono::system_clock::now()}; const auto time_taken = to_string(duration); - stream(log_) << std::string(position_tab_length_ - 4, ' ') - << "-" - << completed_pad("100%") - << "100%" - << time_taken_pad(time_taken) - << time_taken - << ttc_pad("-") - << "-"; + stream(log_) << std::string(position_tab_length_ - 4, ' ') << "-" + << completed_pad("100%", position_tab_length_ - 3) << "100%" + << time_taken_pad(time_taken) << time_taken + << ttc_pad("-") << "-"; } } @@ -271,14 +267,10 @@ void ProgressMeter::stop() if (!done_ && !target_regions_.empty()) { const TimeInterval duration {start_, std::chrono::system_clock::now()}; const auto time_taken = to_string(duration); - stream(log_) << std::string(position_tab_length_ - 4, ' ') - << "-" - << completed_pad("100%") - << "100%" - << time_taken_pad(time_taken) - << time_taken - << ttc_pad("-") - << "-"; + stream(log_) << std::string(position_tab_length_ - 4, ' ') << "-" + << completed_pad("100%", position_tab_length_ - 3) << "100%" + << time_taken_pad(time_taken) << time_taken + << ttc_pad("-") << "-"; } done_ = true; } @@ -426,14 +418,12 @@ void ProgressMeter::output_log(const GenomicRegion& region) } } const auto percent_completed = percent_completed_str(num_bp_completed_, num_bp_to_search_); - stream(log_) << position_pad(region) - << region.contig_name() << ':' << region.end() - << completed_pad(percent_completed) - << percent_completed - << time_taken_pad(time_taken) - << time_taken - << ttc_pad(ttc) - << ttc; + const auto position_tick = region.contig_name() + ":" + std::to_string(region.end()); + const auto position_str = position_pad(region) + position_tick; + stream(log_) << position_str + << completed_pad(percent_completed, position_str.size()) << percent_completed + << time_taken_pad(time_taken) << time_taken + << ttc_pad(ttc) << ttc; tick_durations_.emplace_back(now - last_tick_); last_tick_ = now; percent_until_tick_ = curr_tick_size_; @@ -453,22 +443,24 @@ std::string ProgressMeter::position_pad(const GenomicRegion& completed_region) c return ""; } -std::string ProgressMeter::completed_pad(const std::string& percent_completed) const +std::string ProgressMeter::completed_pad(const std::string& percent_completed, const std::size_t position_tick_size) const { - if (percent_completed.size() >= 17) return {}; - return std::string(std::size_t {17} - percent_completed.size(), ' '); + std::size_t pad {1}; + pad += position_tab_length_ - position_tick_size; + if (percent_completed.size() < 13) pad += 13 - percent_completed.size(); + return std::string(pad, ' '); } std::string ProgressMeter::time_taken_pad(const std::string& time_taken) const { - if (time_taken.size() >= 16) return {}; - return std::string(16 - time_taken.size(), ' '); + if (time_taken.size() >= 17) return {}; + return std::string(17 - time_taken.size(), ' '); } std::string ProgressMeter::ttc_pad(const std::string& ttc) const { - if (ttc.size() >= 16) return {}; - return std::string(16 - ttc.size(), ' '); + if (ttc.size() >= 18) return {}; + return std::string(18 - ttc.size(), ' '); } void ProgressMeter::update_tick_size() diff --git a/src/logging/progress_meter.hpp b/src/logging/progress_meter.hpp index b45e84cb2..54d9f50a2 100644 --- a/src/logging/progress_meter.hpp +++ b/src/logging/progress_meter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef progress_meter_hpp @@ -69,7 +69,7 @@ class ProgressMeter void output_log(const GenomicRegion& region); std::string position_pad(const GenomicRegion& completed_region) const; - std::string completed_pad(const std::string& percent_completed) const; + std::string completed_pad(const std::string& percent_completed, std::size_t position_tick_size) const; std::string time_taken_pad(const std::string& time_taken) const; std::string ttc_pad(const std::string& ttc) const; void update_tick_size(); diff --git a/src/main.cpp b/src/main.cpp index 1644f20e3..f40bbc095 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include @@ -14,6 +14,7 @@ #include "config/option_collation.hpp" #include "core/octopus.hpp" #include "utils/timing.hpp" +#include "utils/system_utils.hpp" #include "utils/string_utils.hpp" #include "exceptions/error.hpp" #include "logging/error_handler.hpp" @@ -52,6 +53,19 @@ std::string to_string(const int argc, const char** argv) return utils::join(arguements, ' '); } +bool could_exceed_open_file_limit(const OptionMap& options) +{ + return options::estimate_max_open_files(options) >= get_max_open_files(); +} + +void sanity_check(const OptionMap& options) +{ + logging::WarningLogger warn_log {}; + if (could_exceed_open_file_limit(options)) { + warn_log << "Detected potential to exceed open file limit. Consult your OS documentation if errors occur!"; + } +} + } // namespace int main(const int argc, const char** argv) @@ -68,13 +82,13 @@ int main(const int argc, const char** argv) log_program_end(); return EXIT_FAILURE; } - if (is_run_command(options)) { try { init_common(options); log_program_startup(); logging::InfoLogger info_log {}; const auto start = std::chrono::system_clock::now(); + sanity_check(options); auto components = collate_genome_calling_components(options); auto end = std::chrono::system_clock::now(); using utils::TimeInterval; @@ -94,6 +108,5 @@ int main(const int argc, const char** argv) return EXIT_FAILURE; } } - return EXIT_SUCCESS; } diff --git a/src/readpipe/buffered_read_pipe.cpp b/src/readpipe/buffered_read_pipe.cpp index ab316a2e9..a463f22d9 100644 --- a/src/readpipe/buffered_read_pipe.cpp +++ b/src/readpipe/buffered_read_pipe.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "buffered_read_pipe.hpp" diff --git a/src/readpipe/buffered_read_pipe.hpp b/src/readpipe/buffered_read_pipe.hpp index c881f28d7..18ea12b37 100644 --- a/src/readpipe/buffered_read_pipe.hpp +++ b/src/readpipe/buffered_read_pipe.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef buffered_read_pipe_hpp diff --git a/src/readpipe/downsampling/downsampler.cpp b/src/readpipe/downsampling/downsampler.cpp index 219f0c16a..d90c8da56 100644 --- a/src/readpipe/downsampling/downsampler.cpp +++ b/src/readpipe/downsampling/downsampler.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "downsampler.hpp" diff --git a/src/readpipe/downsampling/downsampler.hpp b/src/readpipe/downsampling/downsampler.hpp index 488fc1b78..79b6ae666 100644 --- a/src/readpipe/downsampling/downsampler.hpp +++ b/src/readpipe/downsampling/downsampler.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef downsampler_hpp diff --git a/src/readpipe/filtering/read_filter.cpp b/src/readpipe/filtering/read_filter.cpp index 9375003ee..020075b8b 100644 --- a/src/readpipe/filtering/read_filter.cpp +++ b/src/readpipe/filtering/read_filter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_filter.hpp" diff --git a/src/readpipe/filtering/read_filter.hpp b/src/readpipe/filtering/read_filter.hpp index 4db9d8e9c..91758c5c3 100644 --- a/src/readpipe/filtering/read_filter.hpp +++ b/src/readpipe/filtering/read_filter.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_filter_hpp diff --git a/src/readpipe/filtering/read_filterer.hpp b/src/readpipe/filtering/read_filterer.hpp index f06cba9a0..4a76cbed6 100644 --- a/src/readpipe/filtering/read_filterer.hpp +++ b/src/readpipe/filtering/read_filterer.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_filterer_hpp diff --git a/src/readpipe/read_pipe.cpp b/src/readpipe/read_pipe.cpp index d8ab47f54..0798c5c1b 100644 --- a/src/readpipe/read_pipe.cpp +++ b/src/readpipe/read_pipe.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_pipe.hpp" diff --git a/src/readpipe/read_pipe.hpp b/src/readpipe/read_pipe.hpp index b6af97d9c..176f6cd5c 100644 --- a/src/readpipe/read_pipe.hpp +++ b/src/readpipe/read_pipe.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_pipe_hpp diff --git a/src/readpipe/read_pipe_fwd.hpp b/src/readpipe/read_pipe_fwd.hpp index a3a2ff658..18433cdb8 100644 --- a/src/readpipe/read_pipe_fwd.hpp +++ b/src/readpipe/read_pipe_fwd.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_pipe_fwd_hpp diff --git a/src/readpipe/transformers/read_transform.cpp b/src/readpipe/transformers/read_transform.cpp index 20b53ce93..eb8fc3be6 100644 --- a/src/readpipe/transformers/read_transform.cpp +++ b/src/readpipe/transformers/read_transform.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_transform.hpp" diff --git a/src/readpipe/transformers/read_transform.hpp b/src/readpipe/transformers/read_transform.hpp index 806c3e4a7..fd532683e 100644 --- a/src/readpipe/transformers/read_transform.hpp +++ b/src/readpipe/transformers/read_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_transform_hpp diff --git a/src/readpipe/transformers/read_transformer.cpp b/src/readpipe/transformers/read_transformer.cpp index d97e6ff4b..00cbc1c39 100644 --- a/src/readpipe/transformers/read_transformer.cpp +++ b/src/readpipe/transformers/read_transformer.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_transformer.hpp" diff --git a/src/readpipe/transformers/read_transformer.hpp b/src/readpipe/transformers/read_transformer.hpp index 5e07dc86e..151c12340 100644 --- a/src/readpipe/transformers/read_transformer.hpp +++ b/src/readpipe/transformers/read_transformer.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_transformer_hpp diff --git a/src/timers.cpp b/src/timers.cpp index 707ccaa05..28a4d412f 100644 --- a/src/timers.cpp +++ b/src/timers.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "timers.hpp" diff --git a/src/timers.hpp b/src/timers.hpp index 1509c0d00..477d25ce9 100644 --- a/src/timers.hpp +++ b/src/timers.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef timers_hpp diff --git a/src/utils/append.hpp b/src/utils/append.hpp index 873f364a5..54d2e48cc 100644 --- a/src/utils/append.hpp +++ b/src/utils/append.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef append_hpp diff --git a/src/utils/beta_distribution.hpp b/src/utils/beta_distribution.hpp index 782b28844..c7950b410 100644 --- a/src/utils/beta_distribution.hpp +++ b/src/utils/beta_distribution.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. // Based of https://stackoverflow.com/a/15166623/2970186 diff --git a/src/utils/compression.cpp b/src/utils/compression.cpp index c05b7d204..a7843cef3 100644 --- a/src/utils/compression.cpp +++ b/src/utils/compression.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "compression.hpp" diff --git a/src/utils/compression.hpp b/src/utils/compression.hpp index 88bf8943d..d06e5c639 100644 --- a/src/utils/compression.hpp +++ b/src/utils/compression.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef compression_hpp diff --git a/src/utils/concat.hpp b/src/utils/concat.hpp new file mode 100644 index 000000000..8d896938f --- /dev/null +++ b/src/utils/concat.hpp @@ -0,0 +1,44 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef concat_hpp +#define concat_hpp + +#include +#include + +template +std::vector concat(const std::vector& lhs, const std::vector& rhs) +{ + if (lhs.empty()) return rhs; + if (rhs.empty()) return lhs; + std::vector result {}; + result.reserve(lhs.size() + rhs.size()); + result.insert(result.cend(), lhs.cbegin(), lhs.cend()); + result.insert(result.cend(), rhs.cbegin(), rhs.cend()); + return result; +} + +template +std::vector concat(std::vector&& lhs, const std::vector& rhs) +{ + lhs.insert(lhs.cend(), rhs.cbegin(), rhs.cend()); + return std::move(lhs); +} + +template +std::vector concat(const std::vector& lhs, std::vector&& rhs) +{ + rhs.insert(rhs.cbegin(), lhs.cbegin(), lhs.cend()); + return std::move(rhs); +} + +template +std::vector concat(std::vector&& lhs, std::vector&& rhs) +{ + if (lhs.empty()) return std::move(rhs); + lhs.insert(lhs.cend(), std::make_move_iterator(rhs.begin()), std::make_move_iterator(rhs.end())); + return std::move(lhs); +} + +#endif diff --git a/src/utils/coverage_tracker.hpp b/src/utils/coverage_tracker.hpp index 426d18039..7a069e78a 100644 --- a/src/utils/coverage_tracker.hpp +++ b/src/utils/coverage_tracker.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef coverage_tracker_hpp @@ -25,10 +25,13 @@ namespace octopus { CoverageTracker provides an efficient method for tracking coverage statistics over a range of Mappable objects without having to store the entire collection. */ -template +template class CoverageTracker { public: + using RegionType = Region; + using DepthType = T; + CoverageTracker() = default; CoverageTracker(const CoverageTracker&) = default; @@ -41,78 +44,116 @@ class CoverageTracker template void add(const MappableType& mappable); - std::size_t total_coverage() const noexcept; - std::size_t total_coverage(const Region& region) const noexcept; + bool any() const noexcept; + bool any(const Region& region) const noexcept; - unsigned max_coverage() const noexcept; - unsigned max_coverage(const Region& region) const noexcept; + std::size_t sum() const noexcept; + std::size_t sum(const Region& region) const noexcept; - unsigned min_coverage() const noexcept; - unsigned min_coverage(const Region& region) const noexcept; + DepthType max() const noexcept; + DepthType max(const Region& region) const noexcept; - double mean_coverage() const noexcept; - double mean_coverage(const Region& region) const noexcept; + DepthType min() const noexcept; + DepthType min(const Region& region) const noexcept; - double stdev_coverage() const noexcept; - double stdev_coverage(const Region& region) const noexcept; + double mean() const noexcept; + double mean(const Region& region) const noexcept; - double median_coverage(const Region& region) const; + double stdev() const noexcept; + double stdev(const Region& region) const noexcept; - std::vector coverage(const Region& region) const; + double median() const; + double median(const Region& region) const; - boost::optional encompassing_region() const; + template + OutputIt get(const Region& region, OutputIt result) const; + std::vector get(const Region& region) const; + boost::optional encompassing_region() const; bool is_empty() const noexcept; - std::size_t num_tracked() const noexcept; - void clear() noexcept; private: - std::deque coverage_ = {}; + std::deque coverage_ = {}; Region encompassing_region_; std::size_t num_tracked_ = 0; - using Iterator = typename decltype(coverage_)::const_iterator; + using Iterator = typename decltype(coverage_)::const_iterator; + using IteratorPair = std::pair; void do_add(const Region& region); - std::pair range(const Region& region) const; + IteratorPair range(const Region& region) const; }; -// public methods +// non-member methods template +std::vector get_covered_regions(const CoverageTracker& tracker, const Region& region) +{ + const auto depths = tracker.coverage(region); + return select_regions(region, depths, [] (unsigned depth) { return depth > 0; }); +} + +template +std::vector get_covered_regions(const CoverageTracker& tracker) +{ + const auto tracker_region = tracker.encompassing_region(); + if (tracker_region) { + return get_covered_regions(tracker, *tracker_region); + } else { + return {}; + } +} + +// public methods + +template template -void CoverageTracker::add(const MappableType& mappable) +void CoverageTracker::add(const MappableType& mappable) { static_assert(is_region_or_mappable, "MappableType not Mappable"); do_add(mapped_region(mappable)); } -template -std::size_t CoverageTracker::total_coverage() const noexcept +template +bool CoverageTracker::any() const noexcept +{ + return std::find_if(std::cbegin(coverage_), std::cend(coverage_), + [] (auto depth) noexcept { return depth > 0; }) != std::cend(coverage_); +} + +template +bool CoverageTracker::any(const Region& region) const noexcept +{ + if (octopus::is_empty(region)) return false; + const auto p = range(region); + return std::find_if(p.first, p.second, [] (auto depth) noexcept { return depth > 0; }) != p.second; +} + +template +std::size_t CoverageTracker::sum() const noexcept { return std::accumulate(std::cbegin(coverage_), std::cend(coverage_), std::size_t {0}); } -template -std::size_t CoverageTracker::total_coverage(const Region& region) const noexcept +template +std::size_t CoverageTracker::sum(const Region& region) const noexcept { if (octopus::is_empty(region)) return 0; const auto p = range(region); - if (p.first == p.second) return 0; return std::accumulate(p.first, p.second, std::size_t {0}); } -template -unsigned CoverageTracker::max_coverage() const noexcept +template +T CoverageTracker::max() const noexcept { if (coverage_.empty()) return 0; return *std::max_element(std::cbegin(coverage_), std::cend(coverage_)); } -template -unsigned CoverageTracker::max_coverage(const Region& region) const noexcept +template +T CoverageTracker::max(const Region& region) const noexcept { if (octopus::is_empty(region)) return 0; const auto p = range(region); @@ -120,15 +161,15 @@ unsigned CoverageTracker::max_coverage(const Region& region) const noexc return *std::max_element(p.first, p.second); } -template -unsigned CoverageTracker::min_coverage() const noexcept +template +T CoverageTracker::min() const noexcept { if (coverage_.empty()) return 0; return *std::min_element(std::cbegin(coverage_), std::cend(coverage_)); } -template -unsigned CoverageTracker::min_coverage(const Region& region) const noexcept +template +T CoverageTracker::min(const Region& region) const noexcept { if (octopus::is_empty(region)) return 0; const auto p = range(region); @@ -136,63 +177,81 @@ unsigned CoverageTracker::min_coverage(const Region& region) const noexc return *std::min_element(p.first, p.second); } -template -double CoverageTracker::mean_coverage() const noexcept +template +double CoverageTracker::mean() const noexcept { if (coverage_.empty()) return 0; return maths::mean(coverage_); } -template -double CoverageTracker::mean_coverage(const Region& region) const noexcept +template +double CoverageTracker::mean(const Region& region) const noexcept { if (octopus::is_empty(region)) return 0; const auto p = range(region); return maths::mean(p.first, p.second); } -template -double CoverageTracker::stdev_coverage() const noexcept +template +double CoverageTracker::stdev() const noexcept { if (coverage_.empty()) return 0; return maths::stdev(coverage_); } -template -double CoverageTracker::stdev_coverage(const Region& region) const noexcept +template +double CoverageTracker::stdev(const Region& region) const noexcept { if (octopus::is_empty(region)) return 0; const auto p = range(region); return maths::stdev(p.first, p.second); } -template -double CoverageTracker::median_coverage(const Region& region) const +template +double CoverageTracker::median() const +{ + if (coverage_.empty()) return 0; + return maths::median(coverage_); +} + +template +double CoverageTracker::median(const Region& region) const { - auto range_coverage = coverage(region); + auto range_coverage = this->get(region); if (range_coverage.empty()) return 0; - const auto first = std::begin(range_coverage); - const auto nth = std::next(first, range_coverage.size() / 2); - std::nth_element(first, nth, std::end(range_coverage)); - return *nth; + return maths::median(range_coverage); } -template -std::vector CoverageTracker::coverage(const Region& region) const +template +template +OutputIt CoverageTracker::get(const Region& region, OutputIt result) const { - if (coverage_.empty()) return std::vector(size(region), 0); + if (coverage_.empty()) { + return std::fill_n(result, size(region), 0); + } const auto p = range(region); - if (!contains(encompassing_region_, region)) { - std::vector result(size(region), 0); - const auto d = std::max(begin_distance(region, encompassing_region_), GenomicRegion::Distance {0}); - std::copy(p.first, p.second, std::next(std::begin(result), d)); - return result; + if (contains(encompassing_region_, region)) { + return std::copy(p.first, p.second, result); + } else { + using D = typename Region::Distance; + const auto lhs_pad = std::max(begin_distance(region, encompassing_region_), D {0}); + result = std::fill_n(result, lhs_pad, 0); + result = std::copy(p.first, p.second, result); + const auto rhs_pad = std::max(end_distance(encompassing_region_, region), D {0}); + return std::fill_n(result, rhs_pad, 0); } - return std::vector {p.first, p.second}; } -template -boost::optional CoverageTracker::encompassing_region() const +template +std::vector CoverageTracker::get(const Region& region) const +{ + std::vector result(size(region)); + this->get(region, std::begin(result)); + return result; +} + +template +boost::optional CoverageTracker::encompassing_region() const { if (!is_empty()) { return encompassing_region_; @@ -201,20 +260,20 @@ boost::optional CoverageTracker::encompassing_region() const } } -template -bool CoverageTracker::is_empty() const noexcept +template +bool CoverageTracker::is_empty() const noexcept { return num_tracked_ == 0; } -template -std::size_t CoverageTracker::num_tracked() const noexcept +template +std::size_t CoverageTracker::num_tracked() const noexcept { return num_tracked_; } -template -void CoverageTracker::clear() noexcept +template +void CoverageTracker::clear() noexcept { coverage_.clear(); coverage_.shrink_to_fit(); @@ -237,8 +296,8 @@ inline bool is_same_contig_helper(const GenomicRegion& lhs, const GenomicRegion& } // namespace detail -template -void CoverageTracker::do_add(const Region& region) +template +void CoverageTracker::do_add(const Region& region) { if (octopus::is_empty(region)) return; if (num_tracked_ == 0) { @@ -263,23 +322,22 @@ void CoverageTracker::do_add(const Region& region) const auto first = std::next(std::begin(coverage_), begin_distance(encompassing_region_, region)); assert(first < std::end(coverage_)); assert(std::next(first, size(region)) <= std::end(coverage_)); - std::transform(first, std::next(first, size(region)), first, [] (auto count) { return count + 1; }); + std::transform(first, std::next(first, size(region)), first, [] (auto count) noexcept { return count + 1; }); } ++num_tracked_; } -template -std::pair::Iterator, typename CoverageTracker::Iterator> -CoverageTracker::range(const Region& region) const +template +typename CoverageTracker::IteratorPair CoverageTracker::range(const Region& region) const { if (coverage_.empty() || !overlaps(region, encompassing_region_)) { return {std::end(coverage_), std::end(coverage_)}; } - auto first = std::begin(coverage_); + auto range_start_itr = std::begin(coverage_); if (begins_before(encompassing_region_, region)) { - std::advance(first, begin_distance(encompassing_region_, region)); + std::advance(range_start_itr, begin_distance(encompassing_region_, region)); } - return {first, std::next(first, overlap_size(region, encompassing_region_))}; + return {range_start_itr, std::next(range_start_itr, overlap_size(region, encompassing_region_))}; } } // namespace octopus diff --git a/src/utils/emplace_iterator.hpp b/src/utils/emplace_iterator.hpp index b6fb13bca..b2f15608f 100644 --- a/src/utils/emplace_iterator.hpp +++ b/src/utils/emplace_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef emplace_iterator_hpp diff --git a/src/utils/genotype_reader.cpp b/src/utils/genotype_reader.cpp index 82e66979e..ff44781b7 100644 --- a/src/utils/genotype_reader.cpp +++ b/src/utils/genotype_reader.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "genotype_reader.hpp" @@ -88,7 +88,26 @@ bool has_indel(const VcfRecord& call) noexcept [&] (const auto& allele) { return allele.size() != call.ref().size(); }); } -boost::optional make_allele(const VcfRecord& call, VcfRecord::NucleotideSequence allele_sequence, const int ref_pad) +bool has_simple_indel(const VcfRecord& call) noexcept +{ + return std::any_of(std::cbegin(call.alt()), std::cend(call.alt()), + [&] (const auto& allele) { + return allele.size() != call.ref().size() && (allele.size() == 1 || call.ref().size() == 1); + }); +} + +bool has_non_complex_indel(const VcfRecord& call) noexcept +{ + assert(!call.ref().empty()); + return std::any_of(std::cbegin(call.alt()), std::cend(call.alt()), + [&] (const auto& allele) { + assert(!allele.empty()); + return allele.size() != call.ref().size() && allele.front() == call.ref().front(); + }); +} + +boost::optional +make_allele(const VcfRecord& call, VcfRecord::NucleotideSequence allele_sequence, const int max_ref_pad) { if (is_missing(allele_sequence)) { return boost::none; @@ -107,10 +126,11 @@ boost::optional make_allele(const VcfRecord& call, VcfRecord::Nucl auto delete_mask_len = std::distance(std::cbegin(allele_sequence), first_base_itr); allele_sequence.erase(std::cbegin(allele_sequence), first_base_itr); region = expand_lhs(region, -delete_mask_len); - } else if (ref_pad > 0) { - assert(static_cast(ref_pad) <= allele_sequence.size()); - allele_sequence.erase(std::cbegin(allele_sequence), std::next(std::cbegin(allele_sequence), ref_pad)); - region = expand_lhs(region, -ref_pad); + } else if (max_ref_pad > 0) { + auto p = std::mismatch(std::cbegin(call.ref()), std::next(std::cbegin(call.ref()), max_ref_pad), + std::cbegin(allele_sequence), std::cend(allele_sequence)); + allele_sequence.erase(std::cbegin(allele_sequence), p.second); + region = expand_lhs(region, std::distance(p.second, std::cbegin(allele_sequence))); } return ContigAllele {region, std::move(allele_sequence)}; } @@ -118,7 +138,7 @@ boost::optional make_allele(const VcfRecord& call, VcfRecord::Nucl auto extract_genotype(const VcfRecord& call, const SampleName& sample) { auto genotype = get_genotype(call, sample); - boost::optional min_ref_pad {}; + boost::optional max_ref_pad {}; std::vector unknown_pad_indices {}; const auto ploidy = genotype.size(); std::vector> result(ploidy, boost::none); @@ -126,21 +146,21 @@ auto extract_genotype(const VcfRecord& call, const SampleName& sample) auto& allele = genotype[i]; if (is_ref_pad_size_known(allele, call)) { const auto allele_pad = num_matching_lhs_bases(call.ref(), allele); - if (min_ref_pad) { - min_ref_pad = std::min(*min_ref_pad, allele_pad); + if (max_ref_pad) { + max_ref_pad = std::max(*max_ref_pad, allele_pad); } else { - min_ref_pad = allele_pad; + max_ref_pad = allele_pad; } result[i] = make_allele(call, std::move(allele), allele_pad); } else { unknown_pad_indices.push_back(i); } } - if (!min_ref_pad && has_indel(call)) { - min_ref_pad = 1; + if (!max_ref_pad) { + max_ref_pad = has_non_complex_indel(call) ? 1 : 0; } for (auto idx : unknown_pad_indices) { - result[idx] = make_allele(call, std::move(genotype[idx]), *min_ref_pad); + result[idx] = make_allele(call, std::move(genotype[idx]), *max_ref_pad); } return result; } @@ -170,7 +190,7 @@ get_called_alleles(const VcfRecord& call, const VcfRecord::SampleName& sample, c has_ref = true; } std::vector unknwown_pad_allele_indices {}; - boost::optional min_ref_pad {}; + boost::optional max_ref_pad {}; auto allele_idx = std::distance(std::begin(genotype), first_itr); std::for_each(first_itr, std::end(genotype), [&] (auto& allele) { if (is_ref_pad_size_known(allele, call)) { @@ -178,31 +198,33 @@ get_called_alleles(const VcfRecord& call, const VcfRecord::SampleName& sample, c allele.erase(std::cbegin(allele), std::next(std::cbegin(allele), pad_size)); auto allele_region = expand_lhs(call_region, -pad_size); result.emplace_back(std::move(allele_region), std::move(allele)); - if (min_ref_pad) { - min_ref_pad = std::min(*min_ref_pad, pad_size); + if (max_ref_pad) { + max_ref_pad = std::max(*max_ref_pad, pad_size); } else { - min_ref_pad = pad_size; + max_ref_pad = pad_size; } } else { unknwown_pad_allele_indices.push_back(allele_idx); } ++allele_idx; }); - if (!min_ref_pad && has_indel(call)) { - min_ref_pad = 1; + if (!max_ref_pad) { + max_ref_pad = has_non_complex_indel(call) ? 1 : 0; } if (has_ref) { auto& ref = genotype.front(); - ref.erase(std::cbegin(ref), std::next(std::cbegin(ref), *min_ref_pad)); - auto allele_region = expand_lhs(call_region, -*min_ref_pad); + ref.erase(std::cbegin(ref), std::next(std::cbegin(ref), *max_ref_pad)); + auto allele_region = expand_lhs(call_region, -*max_ref_pad); result.emplace_back(std::move(allele_region), std::move(ref)); std::rotate(std::rbegin(result), std::next(std::rbegin(result)), std::rend(result)); } - if (*min_ref_pad > 0 && !unknwown_pad_allele_indices.empty()) { + if (!unknwown_pad_allele_indices.empty()) { for (auto idx : unknwown_pad_allele_indices) { auto& allele = genotype[idx]; - allele.erase(std::cbegin(allele), std::next(std::cbegin(allele), *min_ref_pad)); - auto allele_region = expand_lhs(call_region, -*min_ref_pad); + auto p = std::mismatch(std::cbegin(call.ref()), std::next(std::cbegin(call.ref()), *max_ref_pad), + std::cbegin(allele), std::cend(allele)); + allele.erase(std::cbegin(allele), p.second); + auto allele_region = expand_lhs(call_region, std::distance(p.second, std::cbegin(allele))); result.emplace_back(std::move(allele_region), std::move(allele)); } auto alt_alleles_begin_itr = std::begin(result); @@ -240,6 +262,7 @@ struct CallWrapper : public Mappable std::reference_wrapper call; GenomicRegion phase_region; const GenomicRegion& mapped_region() const noexcept { return phase_region; } + const VcfRecord& get() const noexcept { return call.get(); } }; auto wrap_calls(const std::vector& calls, const SampleName& sample) @@ -252,10 +275,13 @@ auto wrap_calls(const std::vector& calls, const SampleName& sample) return result; } -auto get_ploidy(const std::vector& phased_calls, const SampleName& sample) +auto get_max_ploidy(const std::vector& calls, const SampleName& sample) { - assert(!phased_calls.empty()); - return get_genotype(phased_calls.front().call, sample).size(); + unsigned result {0}; + for (const auto& call : calls) { + result = std::max(result, call.get().ploidy(sample)); + } + return result; } auto make_genotype(std::vector&& haplotypes) @@ -274,12 +300,12 @@ Genotype extract_genotype(const std::vector& phased_call { assert(!phased_calls.empty()); assert(contains(region, encompassing_region(phased_calls))); - const auto ploidy = get_ploidy(phased_calls, sample); - std::vector haplotypes(ploidy, Haplotype::Builder {region, reference}); + const auto max_ploidy = get_max_ploidy(phased_calls, sample); + std::vector haplotypes(max_ploidy, Haplotype::Builder {region, reference}); for (const auto& call : phased_calls) { auto genotype = extract_genotype(call.call, sample); - assert(genotype.size() == ploidy); - for (unsigned i {0}; i < ploidy; ++i) { + assert(genotype.size() <= max_ploidy); + for (unsigned i {0}; i < genotype.size(); ++i) { if (genotype[i] && haplotypes[i].can_push_back(*genotype[i])) { haplotypes[i].push_back(std::move(*genotype[i])); } diff --git a/src/utils/genotype_reader.hpp b/src/utils/genotype_reader.hpp index e29950344..2d1f293f3 100644 --- a/src/utils/genotype_reader.hpp +++ b/src/utils/genotype_reader.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef genotype_reader_hpp diff --git a/src/utils/hash_functions.hpp b/src/utils/hash_functions.hpp index fc564bfa9..772113019 100644 --- a/src/utils/hash_functions.hpp +++ b/src/utils/hash_functions.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef hash_functions_hpp diff --git a/src/utils/input_reads_profiler.cpp b/src/utils/input_reads_profiler.cpp new file mode 100644 index 000000000..19cc9a772 --- /dev/null +++ b/src/utils/input_reads_profiler.cpp @@ -0,0 +1,223 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "input_reads_profiler.hpp" + +#include +#include +#include +#include +#include +#include + +#include "mappable_algorithms.hpp" +#include "maths.hpp" +#include "append.hpp" + +namespace octopus { + +namespace { + +template +ForwardIt random_select(ForwardIt first, ForwardIt last, RandomGenerator& g) +{ + if (first == last) return first; + const auto max = static_cast(std::distance(first, last)); + std::uniform_int_distribution dist {0, max - 1}; + std::advance(first, dist(g)); + return first; +} + +template +ForwardIt random_select(ForwardIt first, ForwardIt last) +{ + static std::default_random_engine gen {}; + return random_select(first, last, gen); +} + +auto estimate_dynamic_size(const AlignedRead& read) noexcept +{ + return read.name().size() * sizeof(char) + + read.read_group().size() * sizeof(char) + + sequence_size(read) * sizeof(char) + + sequence_size(read) * sizeof(AlignedRead::BaseQuality) + + read.cigar().size() * sizeof(CigarOperation) + + contig_name(read).size() * sizeof(char) + + (read.has_other_segment() ? sizeof(AlignedRead::Segment) : 0); +} + +auto estimate_read_size(const AlignedRead& read) noexcept +{ + return sizeof(AlignedRead) + estimate_dynamic_size(read); +} + +auto get_covered_sample_regions(const std::vector& samples, const InputRegionMap& input_regions, + const ReadManager& read_manager) +{ + InputRegionMap result {}; + result.reserve(input_regions.size()); + for (const auto& p : input_regions) { + InputRegionMap::mapped_type contig_regions {}; + std::copy_if(std::cbegin(p.second), std::cend(p.second), + std::inserter(contig_regions, std::begin(contig_regions)), + [&] (const auto& region) { return read_manager.has_reads(samples, region); }); + if (!contig_regions.empty()) { + result.emplace(p.first, std::move(contig_regions)); + } + } + return result; +} + +auto choose_sample_region(const GenomicRegion& from, GenomicRegion::Size max_size) +{ + if (size(from) <= max_size) return from; + const auto max_begin = from.end() - max_size; + static std::default_random_engine gen {}; + std::uniform_int_distribution dist {from.begin(), max_begin}; + return GenomicRegion {from.contig_name(), dist(gen), from.end()}; +} + +auto draw_sample(const SampleName& sample, const InputRegionMap& regions, + const ReadManager& source, const ReadSetProfileConfig& config) +{ + const auto contig_itr = random_select(std::cbegin(regions), std::cend(regions)); + assert(!contig_itr->second.empty()); + const auto region_itr = random_select(std::cbegin(contig_itr->second), std::cend(contig_itr->second)); + const auto sample_region = choose_sample_region(*region_itr, config.max_sample_size); + auto test_region = source.find_covered_subregion(sample, sample_region, config.max_sample_size); + if (is_empty(test_region)) { + test_region = expand_rhs(test_region, 1); + } + return source.fetch_reads(sample, test_region); +} + +auto draw_sample_from_begin(const SampleName& sample, const InputRegionMap& regions, + const ReadManager& source, const ReadSetProfileConfig& config) +{ + const auto contig_itr = random_select(std::cbegin(regions), std::cend(regions)); + assert(!contig_itr->second.empty()); + const auto region_itr = random_select(std::cbegin(contig_itr->second), std::cend(contig_itr->second)); + auto test_region = source.find_covered_subregion(sample, *region_itr, config.max_sample_size); + if (is_empty(test_region)) { + test_region = expand_rhs(test_region, 1); + } + return source.fetch_reads(sample, test_region); +} + +using ReadSetSamples = std::vector; + +bool all_empty(const ReadSetSamples& samples) +{ + return std::all_of(std::cbegin(samples), std::cend(samples), [] (const auto& reads) { return reads.empty(); }); +} + +auto draw_samples(const SampleName& sample, const InputRegionMap& regions, + const ReadManager& source, const ReadSetProfileConfig& config) +{ + ReadSetSamples result {}; + result.reserve(config.max_samples_per_sample); + std::generate_n(std::back_inserter(result), config.max_samples_per_sample, + [&] () { return draw_sample(sample, regions, source, config); }); + if (all_empty(result)) { + result.back() = draw_sample_from_begin(sample, regions, source, config); + } + return result; +} + +auto draw_samples(const std::vector& samples, const InputRegionMap& regions, + const ReadManager& source, const ReadSetProfileConfig& config) +{ + std::vector result {}; + result.reserve(samples.size()); + for (const auto& sample : samples) { + result.push_back(draw_samples(sample, regions, source, config)); + } + return result; +} + +auto get_read_bytes(const std::vector& read_sets) +{ + std::deque result {}; + for (const auto& set : read_sets) { + for (const auto& reads : set) { + std::transform(std::cbegin(reads), std::cend(reads), std::back_inserter(result), estimate_read_size); + } + } + return result; +} + +} // namespace + +boost::optional profile_reads(const std::vector& samples, + const InputRegionMap& input_regions, + const ReadManager& source, + ReadSetProfileConfig config) +{ + if (input_regions.empty()) return boost::none; + const auto sampling_regions = get_covered_sample_regions(samples, input_regions, source); + if (sampling_regions.empty()) return boost::none; + const auto read_sets = draw_samples(samples, sampling_regions, source, config); + if (read_sets.empty()) return boost::none; + const auto bytes = get_read_bytes(read_sets); + if (bytes.empty()) return boost::none; + ReadSetProfile result {}; + result.mean_read_bytes = maths::mean(bytes); + result.read_bytes_stdev = maths::stdev(bytes); + result.sample_mean_depth.resize(samples.size()); + result.sample_depth_stdev.resize(samples.size()); + std::deque depths {}; + for (std::size_t s {0}; s < samples.size(); ++s) { + std::deque sample_depths {}; + for (const auto& reads : read_sets[s]) { + if (!reads.empty()) { + utils::append(calculate_positional_coverage(reads), sample_depths); + } + } + if (!sample_depths.empty()) { + result.sample_mean_depth[s] = maths::mean(sample_depths); + result.sample_depth_stdev[s] = maths::stdev(sample_depths); + } else { + result.sample_mean_depth[s] = 0; + result.sample_depth_stdev[s] = 0; + } + utils::append(std::move(sample_depths), depths); + } + assert(!depths.empty()); + result.mean_depth = maths::mean(depths); + result.depth_stdev = maths::stdev(depths); + return result; +} + +boost::optional estimate_mean_read_size(const std::vector& samples, + const InputRegionMap& input_regions, + ReadManager& read_manager, + const unsigned max_sample_size) +{ + if (input_regions.empty()) return boost::none; + const auto sample_regions = get_covered_sample_regions(samples, input_regions, read_manager); + if (sample_regions.empty()) return boost::none; + const auto num_samples_per_sample = max_sample_size / samples.size(); + std::deque read_size_samples {}; + // take read samples from each sample seperatly to ensure we cover each + for (const auto& sample : samples) { + const auto it = random_select(std::cbegin(sample_regions), std::cend(sample_regions)); + assert(!it->second.empty()); + const auto it2 = random_select(std::cbegin(it->second), std::cend(it->second)); + auto test_region = read_manager.find_covered_subregion(sample, *it2, num_samples_per_sample); + if (is_empty(test_region)) { + test_region = expand_rhs(test_region, 1); + } + const auto reads = read_manager.fetch_reads(sample, test_region); + std::transform(std::cbegin(reads), std::cend(reads), std::back_inserter(read_size_samples), + estimate_read_size); + } + if (read_size_samples.empty()) return boost::none; + return static_cast(maths::mean(read_size_samples) + maths::stdev(read_size_samples)); +} + +std::size_t default_read_size_estimate() noexcept +{ + return sizeof(AlignedRead) + 300; +} + +} // namespace octopus diff --git a/src/utils/input_reads_profiler.hpp b/src/utils/input_reads_profiler.hpp new file mode 100644 index 000000000..bf8926898 --- /dev/null +++ b/src/utils/input_reads_profiler.hpp @@ -0,0 +1,46 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef input_reads_profiler_hpp +#define input_reads_profiler_hpp + +#include +#include + +#include + +#include "config/common.hpp" +#include "io/read/read_manager.hpp" +#include "readpipe/read_pipe.hpp" + +namespace octopus { + +struct ReadSetProfileConfig +{ + unsigned max_samples_per_sample = 10; + unsigned max_sample_size = 1000; +}; + +struct ReadSetProfile +{ + std::size_t mean_read_bytes, read_bytes_stdev; + std::size_t mean_depth, depth_stdev; + std::vector sample_mean_depth; + std::vector sample_depth_stdev; +}; + +boost::optional profile_reads(const std::vector& samples, + const InputRegionMap& input_regions, + const ReadManager& source, + ReadSetProfileConfig config = ReadSetProfileConfig {}); + +boost::optional estimate_mean_read_size(const std::vector& samples, + const InputRegionMap& input_regions, + ReadManager& read_manager, + unsigned max_sample_size = 1000); + +std::size_t default_read_size_estimate() noexcept; + +} // namespace octopus + +#endif diff --git a/src/utils/kmer_mapper.cpp b/src/utils/kmer_mapper.cpp index 7a9e4d90a..71f0db4b0 100644 --- a/src/utils/kmer_mapper.cpp +++ b/src/utils/kmer_mapper.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "kmer_mapper.hpp" diff --git a/src/utils/kmer_mapper.hpp b/src/utils/kmer_mapper.hpp index d5579d399..550c966bb 100644 --- a/src/utils/kmer_mapper.hpp +++ b/src/utils/kmer_mapper.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef kmer_mapping_hpp diff --git a/src/utils/map_utils.hpp b/src/utils/map_utils.hpp index ca3793d48..589d15958 100644 --- a/src/utils/map_utils.hpp +++ b/src/utils/map_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef map_utils_hpp @@ -45,25 +45,35 @@ auto extract_values(const MapType& map) return result; } +template +auto extract_sorted_keys(const std::map& map) +{ + return extract_keys(map); +} + +template +auto extract_sorted_keys(const MapType& map) +{ + static_assert(is_map, "MapType must be a map type"); + auto result = extract_keys(map); + std::sort(std::begin(result), std::end(result)); + return result; +} + template auto extract_value_sorted_keys(const MapType& map) { static_assert(is_map, "MapType must be a map type"); - std::vector> pairs {}; pairs.reserve(map.size()); - std::transform(std::cbegin(map), std::cend(map), std::back_inserter(pairs), [] (const auto& p) { return std::make_pair(p.second, p.first); }); std::sort(std::begin(pairs), std::end(pairs), [] (const auto& lhs, const auto& rhs) { return lhs.first > rhs.first; }); - std::vector result {}; result.reserve(pairs.size()); - std::transform(std::cbegin(pairs), std::cend(pairs), std::back_inserter(result), [] (const auto& p) { return p.second; }); - return result; } diff --git a/src/utils/mappable_algorithms.hpp b/src/utils/mappable_algorithms.hpp index 6305e3f82..4c22769fb 100644 --- a/src/utils/mappable_algorithms.hpp +++ b/src/utils/mappable_algorithms.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef mappable_algorithms_hpp @@ -1675,10 +1675,12 @@ template auto extract_intervening_regions(ForwardIt first, ForwardIt last, const MappableTp& mappable) { using MappableTp2 = typename std::iterator_traits::value_type; - static_assert(is_region_or_mappable && is_region_or_mappable, - "Mappable required"); + static_assert(is_region_or_mappable && is_region_or_mappable, "Mappable required"); std::vector> result {}; - if (first == last) return result; + if (first == last) { + result.assign({mapped_region(mappable)}); + return result; + } result.reserve(std::distance(first, last) + 1); if (begins_before(mappable, *first)) { result.push_back(left_overhang_region(mappable, *first)); @@ -1879,7 +1881,7 @@ auto calculate_positional_coverage(ForwardIt first, ForwardIt last) } template > + typename = EnableIfRegionOrMappable> auto calculate_positional_coverage(const Range& mappables) { return calculate_positional_coverage(std::cbegin(mappables), std::cend(mappables)); @@ -2030,7 +2032,75 @@ auto join(const Range& regions, const GenomicRegion::Distance n) { return join(std::cbegin(regions), std::cend(regions), n); } - + +namespace detail { + +inline void append(const ContigRegion& base, ContigRegion::Position begin, ContigRegion::Position end, + std::vector& result) +{ + result.emplace_back(begin, end); +} + +inline void append(const GenomicRegion& base, GenomicRegion::Position begin, GenomicRegion::Position end, + std::vector& result) +{ + result.emplace_back(base.contig_name(), begin, end); +} + +} // namespace detail + +// select_regions: returns minimal subset of regions defined by each element in a range. + +template ::value_type, bool>::value>> +std::vector select_regions(const Region& region, const ForwardIt first, const ForwardIt last) +{ + static_assert(is_region, "must be ContigRegion or GenomicRegion"); + assert(static_cast(std::distance(first, last)) == size(region)); + std::vector result {}; + result.reserve(std::distance(first, last) / 2); // max possible + auto itr = std::find(first, last, true); + for (; itr != last;) { + const auto itr2 = std::find(itr, last, false); + const auto begin = region.begin() + std::distance(first, itr); + const auto end = begin + std::distance(itr, itr2); + detail::append(region, begin, end, result); + itr = std::find(itr2, last, true); + } + return result; +} + +template ::value>> +auto select_regions(const Region& region, const Range& selections) +{ + return select_regions(region, std::cbegin(selections), std::cend(selections)); +} + +template +std::vector select_regions(const Region& region, const ForwardIt first, const ForwardIt last, UnaryPredicate pred) +{ + static_assert(is_region, "must be ContigRegion or GenomicRegion"); + assert(std::distance(first, last) == size(region)); + std::vector result {}; + result.reserve(std::distance(first, last) / 2); // max possible + auto itr = std::find_if(first, last, pred); + for (; itr != last;) { + const auto itr2 = std::find_if_not(itr, last, pred); + const auto begin = region.begin() + std::distance(first, itr); + const auto end = begin + std::distance(itr, itr2); + detail::append(region, begin, end, result); + itr = std::find_if(itr2, last, pred); + } + return result; +} + +template +auto select_regions(const Region& region, const Range& values, UnaryPredicate pred) +{ + return select_regions(region, std::cbegin(values), std::cend(values), std::move(pred)); +} + } // namespace octopus #endif diff --git a/src/utils/maths.hpp b/src/utils/maths.hpp index 539c8802b..6ea6aab0a 100644 --- a/src/utils/maths.hpp +++ b/src/utils/maths.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef maths_hpp @@ -23,9 +23,11 @@ #include #include #include +#include +#include + +namespace octopus { namespace maths { -namespace octopus { namespace maths -{ namespace constants { template @@ -40,6 +42,15 @@ RealType round(const RealType val, const unsigned precision = 2) return std::round(val * factor) / factor; } +template ::value>> +T round_sf(const T x, const int n) +{ + // https://stackoverflow.com/a/13094362/2970186 + if (x == 0.0) return 0; + auto factor = std::pow(10.0, n - std::ceil(std::log10(std::abs(x)))); + return std::round(x * factor) / factor; +} + template ::value>> bool almost_equal(const T lhs, T rhs, const int ulp = 1) { @@ -58,14 +69,22 @@ bool almost_one(const T x, const int ulp = 1) return almost_equal(x, T {1}, ulp); } +template ::value>> +int count_leading_zeros(const T x) +{ + if (x == 0.0) return 0; + return -std::ceil(std::log10(std::abs(x - std::numeric_limits::epsilon()))); +} template -constexpr RealType exp_maclaurin(const RealType x) { +constexpr RealType exp_maclaurin(const RealType x) +{ return (6 + x * (6 + x * (3 + x))) * 0.16666666; } template -constexpr RealType mercator(const RealType x) { +constexpr RealType mercator(const RealType x) +{ return x - x * x / 2 + x * x * x / 3; } @@ -102,6 +121,70 @@ auto mean(const Container& values, UnaryOperation unary_op) return mean(std::cbegin(values), std::cend(values), unary_op); } +namespace detail { + +template +T median_unsorted(ForwardIt first, ForwardIt last) +{ + const auto n = std::distance(first, last); + assert(n > 0); + if (n == 1) return *first; + if (n == 2) return static_cast(*first + *std::next(first)) / 2; + const auto middle = std::next(first, n / 2); + std::nth_element(first, middle, last); + if (n % 2 == 1) { + return *middle; + } else { + auto prev_middle_itr = std::max_element(first, middle); + return static_cast(*prev_middle_itr + *middle) / 2; + } +} + +template +auto median_sorted(ForwardIt first, ForwardIt last) +{ + const auto n = std::distance(first, last); + assert(n > 0); + if (n == 1) return *first; + const auto middle = std::next(first, n / 2); + if (n % 2 == 1) { + return *middle; + } else { + return static_cast(*std::prev(middle) + *middle) / 2; + } +} + +template +auto median_const(ForwardIt first, ForwardIt last) +{ + if (std::is_sorted(first, last)) { + return median_sorted(first, last); + } else { + std::vector::value_type> tmp {first, last}; + return median_unsorted(std::begin(tmp), std::end(tmp)); + } +} + +} // namespace detail + +template +auto median(ForwardIt first, ForwardIt last) +{ + return detail::median_unsorted(first, last); +} + +template +auto median(Range& values) +{ + return median(std::begin(values), std::end(values)); +} + +template +auto median(const Range& values) +{ + return detail::median_const(std::cbegin(values), std::cend(values)); +} + template double stdev(InputIt first, InputIt last, UnaryOperation unary_op) { @@ -190,6 +273,13 @@ auto log_sum_exp(const Container& values) return log_sum_exp(std::cbegin(values), std::cend(values)); } +template ::value>> +T factorial(const IntegerType x) +{ + return boost::math::factorial(x); +} + template ::value>, typename = std::enable_if_t::value>> @@ -211,12 +301,34 @@ RealType log_factorial(IntegerType x) } } +template ::value>, + typename = std::enable_if_t::value>> +RealType geometric_pdf(const IntegerType k, const RealType p) +{ + boost::math::geometric_distribution dist {p}; + return boost::math::pdf(dist, k); +} + +template ::value>, + typename = std::enable_if_t::value>> +RealType binomial_pdf(const IntegerType k, const IntegerType n, const RealType p) +{ + boost::math::binomial_distribution dist {static_cast(n), p}; + return boost::math::pdf(dist, k); +} + template ::value>, typename = std::enable_if_t::value>> RealType log_poisson_pmf(const IntegerType k, const RealType mu) { - return k * std::log(mu) - boost::math::lgamma(k) - mu; + if (k > 0) { + return k * std::log(mu) - std::lgamma(k) - mu; + } else { + return -mu; + } } template ::value_type>; static_assert(std::is_floating_point::value, "log_beta is only defined for floating point types."); - return std::accumulate(first, last, T {0}, - [] (const auto curr, const auto x) { - return curr + boost::math::lgamma(x); - }) - boost::math::lgamma(std::accumulate(first, last, T {0})); + return std::accumulate(first, last, T {0}, [] (const auto curr, const auto x) { return curr + std::lgamma(x); }) + - std::lgamma(std::accumulate(first, last, T {0})); } template @@ -275,7 +385,7 @@ auto log_dirichlet(ForwardIt1 firstalpha, ForwardIt1 lastalpha, ForwardIt2 first using T = std::decay_t::value_type>; static_assert(std::is_floating_point::value, "log_dirichlet is only defined for floating point types."); - return std::inner_product(firstalpha, lastalpha, firstpi, T {0}, std::plus {}, + return std::inner_product(firstalpha, lastalpha, firstpi, T {0}, std::plus<> {}, [] (const auto a, const auto p) { return (a - 1) * std::log(p); }) - log_beta(firstalpha, lastalpha); } @@ -286,6 +396,60 @@ auto log_dirichlet(const Container1& alpha, const Container2& pi) return log_dirichlet(std::cbegin(alpha), std::cend(alpha), std::cbegin(pi)); } +template +auto dirichlet_expectation(ForwardIt first_alpha, ForwardIt last_alpha) +{ + using T = std::decay_t::value_type>; + static_assert(std::is_floating_point::value, + "log_dirichlet is only defined for floating point types."); + const auto K = static_cast(std::distance(first_alpha, last_alpha)); + const auto a0 = std::accumulate(first_alpha, last_alpha, T {0}); + std::vector result(K); + std::transform(first_alpha, last_alpha, std::begin(result), [a0] (auto a) { return a / a0; }); + return result; +} + +template +auto dirichlet_expectation(const Range& values) +{ + return dirichlet_expectation(std::cbegin(values), std::cend(values)); +} + +template +auto dirichlet_expectation(const unsigned i, ForwardIt first_alpha, ForwardIt last_alpha) +{ + using T = std::decay_t::value_type>; + static_assert(std::is_floating_point::value, + "log_dirichlet is only defined for floating point types."); + assert(i < static_cast(std::distance(first_alpha, last_alpha))); + return *std::next(first_alpha, i) / std::accumulate(first_alpha, last_alpha, T {0}); +} + +template +auto dirichlet_expectation(const unsigned i, const Range& values) +{ + return dirichlet_expectation(i, std::cbegin(values), std::cend(values)); +} + +template +auto dirichlet_entropy(ForwardIt first_alpha, ForwardIt last_alpha) +{ + using T = std::decay_t::value_type>; + static_assert(std::is_floating_point::value, + "log_dirichlet is only defined for floating point types."); + const auto K = static_cast(std::distance(first_alpha, last_alpha)); + const auto a0 = std::accumulate(first_alpha, last_alpha, T {0}); + using boost::math::digamma; + return log_beta(first_alpha, last_alpha) - (K - a0) * digamma(a0) + - std::accumulate(first_alpha, last_alpha, T {0}, [] (auto curr, auto a) { return curr + (a - 1) * digamma(a); }); +} + +template +auto dirichlet_entropy(const Range& values) +{ + return dirichlet_entropy(std::cbegin(values), std::cend(values)); +} + template inline RealType log_multinomial_coefficient(std::initializer_list il) { @@ -330,14 +494,34 @@ inline IntegerType multinomial_coefficient(const Container& values) return multinomial_coefficient(std::cbegin(values), std::cend(values)); } +template +inline RealType multinomial_pdf(ForwardIt1 first_z, ForwardIt1 last_z, ForwardIt2 first_p) +{ + auto r = std::inner_product(first_z, last_z, first_p, RealType {0}, std::multiplies<> {}, + [] (auto z_i, auto p_i) { return std::pow(p_i, z_i); }); + return multinomial_coefficient(first_z, last_z) * r; +} + template inline RealType multinomial_pdf(const std::vector& z, const std::vector& p) { - RealType r {1}; - for (std::size_t i {0}; i < z.size(); ++i) { - r *= std::pow(p[i], z[i]); - } - return multinomial_coefficient(std::cbegin(z), std::cend(z)) * r; + assert(z.size() == p.size()); + return multinomial_pdf(std::cbegin(z), std::cend(z), std::cbegin(p)); +} + +template +inline RealType log_multinomial_pdf(ForwardIt1 first_z, ForwardIt1 last_z, ForwardIt2 first_p) +{ + auto r = std::inner_product(first_z, last_z, first_p, RealType {0}, std::plus<> {}, + [] (auto z_i, auto p_i) { return z_i > 0 ? z_i * std::log(p_i) : 0.0; }); + return log_multinomial_coefficient(first_z, last_z) + r; +} + +template +inline RealType log_multinomial_pdf(const std::vector& z, const std::vector& p) +{ + assert(z.size() == p.size()); + return log_multinomial_pdf(std::cbegin(z), std::cend(z), std::cbegin(p)); } // Returns approximate y such that digamma(y) = x @@ -353,16 +537,38 @@ inline RealType digamma_inv(const RealType x, const RealType epsilon = 10e-8) return y; } +namespace detail { + +template +T ifactorial(RealType x, std::true_type) +{ + return factorial(x); +} + +template +T ifactorial(RealType x, std::false_type) +{ + return factorial(x); +} + +template +T ifactorial(RealType x) +{ + return ifactorial(x, std::is_floating_point {}); +} + +} // namespace detail + template RealType dirichlet_multinomial(const RealType z1, const RealType z2, const RealType a1, const RealType a2) { auto z_0 = z1 + z2; auto a_0 = a1 + a2; - auto z_m = boost::math::factorial(z1) * boost::math::factorial(z2); - return (boost::math::factorial(z_0) / z_m) * - (boost::math::tgamma(a_0) / boost::math::tgamma(z_0 + a_0)) * - (boost::math::tgamma(z1 + a1) * boost::math::tgamma(z2 + a2)) / - (boost::math::tgamma(a1) + boost::math::tgamma(a2)); + using detail::ifactorial; + auto z_m = ifactorial(z1) * ifactorial(z2); + return (ifactorial(z_0) / z_m) * + (std::tgamma(a_0) / std::tgamma(z_0 + a_0)) * + (std::tgamma(z1 + a1) * std::tgamma(z2 + a2)) / (std::tgamma(a1) + std::tgamma(a2)); } template @@ -371,13 +577,12 @@ RealType dirichlet_multinomial(const RealType z1, const RealType z2, const RealT { auto z_0 = z1 + z2 + z3; auto a_0 = a1 + a2 + a3; - auto z_m = boost::math::factorial(z1) * boost::math::factorial(z2) * - boost::math::factorial(z3); - return (boost::math::factorial(z_0) / z_m) * - (boost::math::tgamma(a_0) / boost::math::tgamma(z_0 + a_0)) * - (boost::math::tgamma(z1 + a1) * boost::math::tgamma(z2 + a2) * - boost::math::tgamma(z3 + a3)) / - (boost::math::tgamma(a1) + boost::math::tgamma(a2) + boost::math::tgamma(a3)); + using detail::ifactorial; + auto z_m = ifactorial(z1) * ifactorial(z2) * ifactorial(z3); + return (ifactorial(z_0) / z_m) * + (std::tgamma(a_0) / std::tgamma(z_0 + a_0)) * + (std::tgamma(z1 + a1) * std::tgamma(z2 + a2) * + std::tgamma(z3 + a3)) / (std::tgamma(a1) + std::tgamma(a2) + std::tgamma(a3)); } template @@ -386,16 +591,15 @@ RealType dirichlet_multinomial(const std::vector& z, const std::vector auto z_0 = std::accumulate(std::cbegin(z), std::cend(z), RealType {0}); auto a_0 = std::accumulate(std::cbegin(a), std::cend(a), RealType {0}); RealType z_m {1}; + using detail::ifactorial; for (auto z_i : z) { - z_m *= boost::math::factorial(z_i); + z_m *= ifactorial(z_i); } RealType g {1}; for (std::size_t i {0}; i < z.size(); ++i) { - g *= boost::math::tgamma(z[i] + a[i]) / boost::math::tgamma(a[i]); + g *= std::tgamma(z[i] + a[i]) / std::tgamma(a[i]); } - - return (boost::math::factorial(z_0) / z_m) * - (boost::math::tgamma(a_0) / boost::math::tgamma(z_0 + a_0)) * g; + return (ifactorial(z_0) / z_m) * (std::tgamma(a_0) / std::tgamma(z_0 + a_0)) * g; } template @@ -405,17 +609,19 @@ RealType beta_binomial(const RealType k, const RealType n, const RealType alpha, } namespace detail { - template - bool is_mldp_converged(std::vector& lhs, const std::vector& rhs, - const RealType epsilon) - { - std::transform(std::cbegin(lhs), std::cend(lhs), std::cbegin(rhs), std::begin(lhs), - [] (const auto a, const auto b) { return std::abs(a - b); }); - return std::all_of(std::cbegin(lhs), std::cend(lhs), - [epsilon] (const auto x) { return x < epsilon; }); - } + +template +bool is_mldp_converged(std::vector& lhs, const std::vector& rhs, + const RealType epsilon) +{ + std::transform(std::cbegin(lhs), std::cend(lhs), std::cbegin(rhs), std::begin(lhs), + [] (const auto a, const auto b) { return std::abs(a - b); }); + return std::all_of(std::cbegin(lhs), std::cend(lhs), + [epsilon] (const auto x) { return x < epsilon; }); } +} // namespace detail + template std::vector dirichlet_mle(std::vector pi, const RealType precision, @@ -423,11 +629,9 @@ dirichlet_mle(std::vector pi, const RealType precision, { std::transform(std::cbegin(pi), std::cend(pi), std::begin(pi), [] (const auto p) { return std::log(p); }); - const auto l = pi.size(); const RealType u {RealType {1} / l}; std::vector result(l, u), curr_result(l, u), means(l, u); - for (unsigned n {0}; n < max_iterations; ++n) { RealType v {0}; for (std::size_t j {0}; j < l; ++j) { @@ -435,16 +639,13 @@ dirichlet_mle(std::vector pi, const RealType precision, } for (std::size_t k {0}; k < l; ++k) { curr_result[k] = digamma_inv(pi[k] - v); - means[k] = curr_result[k] / std::accumulate(std::cbegin(curr_result), - std::cend(curr_result), - RealType {0}); + means[k] = curr_result[k] / std::accumulate(std::cbegin(curr_result), std::cend(curr_result), RealType {0}); } if (detail::is_mldp_converged(result, curr_result, epsilon)) { return curr_result; } result = curr_result; } - return result; } @@ -452,7 +653,7 @@ template ::value>> NumericType probability_to_phred(const RealType p) { - return static_cast(-10.0 * std::log10(std::max(1.0 - p, std::numeric_limits::epsilon()))); + return -10.0 * std::log10(std::max(1.0 - p, std::numeric_limits::epsilon())); } template -RealType beta_cdf_complement(const RealType a, const RealType b, const RealType x) +RealType beta_sf(const RealType a, const RealType b, const RealType x) { const boost::math::beta_distribution<> beta_dist {a, b}; return boost::math::cdf(boost::math::complement(beta_dist, x)); @@ -588,8 +789,7 @@ std::pair beta_hdi_skewed(const RealType a, const RealType b, const RealType mass) { const auto c = (RealType {1} - mass) / 2; - return std::make_pair(boost::math::ibeta_inv(a, b, c), - boost::math::ibeta_inv(a, b, c + mass)); + return std::make_pair(boost::math::ibeta_inv(a, b, c), boost::math::ibeta_inv(a, b, c + mass)); } } // namespace detail @@ -598,16 +798,13 @@ template std::pair beta_hdi(RealType a, RealType b, const RealType mass = 0.99) { - static_assert(std::is_floating_point::value, - "beta_hdi only works for floating point types"); - + static_assert(std::is_floating_point::value, "beta_hdi only works for floating point types"); if (mass < RealType {0} || mass > RealType {1}) { throw std::domain_error {"beta_hdi: given mass not in range [0, 1]"}; } if (a <= RealType {0} || b <= RealType {0}) { throw std::domain_error {"beta_hdi: given non-positive parameter"}; } - if (mass == RealType {0}) { const auto mean = a / (a + b); return std::make_pair(mean, mean); @@ -631,6 +828,20 @@ beta_hdi(RealType a, RealType b, const RealType mass = 0.99) return detail::beta_hdi_skewed(a, b, mass); } +template +RealType dirichlet_marginal_cdf(const std::vector& alphas, const std::size_t k, const RealType x) +{ + const auto a_0 = std::accumulate(std::cbegin(alphas), std::cend(alphas), RealType {}); + return beta_cdf(alphas[k], a_0 - alphas[k], x); +} + +template +RealType dirichlet_marginal_sf(const std::vector& alphas, const std::size_t k, const RealType x) +{ + const auto a_0 = std::accumulate(std::cbegin(alphas), std::cend(alphas), RealType {}); + return beta_sf(alphas[k], a_0 - alphas[k], x); +} + template void log_each(Container& values) { @@ -658,6 +869,7 @@ auto normalise_exp(Container& logs) for (auto& p : logs) p = std::exp(p -= norm); return norm; } + } // namespace maths } // namespace octopus diff --git a/src/utils/memory_footprint.cpp b/src/utils/memory_footprint.cpp index 8a96266f3..f2f48aa15 100644 --- a/src/utils/memory_footprint.cpp +++ b/src/utils/memory_footprint.cpp @@ -1,57 +1,56 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "memory_footprint.hpp" -#include +#include #include #include #include #include +#include +#include +#include #include "string_utils.hpp" namespace octopus { - -MemoryFootprint::MemoryFootprint(std::size_t num_bytes) noexcept -: num_bytes_ {num_bytes} -{} - -std::size_t MemoryFootprint::num_bytes() const noexcept -{ - return num_bytes_; -} -std::ostream& operator<<(std::ostream& os, MemoryFootprint footprint) +bool operator==(const MemoryFootprint& lhs, const MemoryFootprint& rhs) noexcept { - os << footprint.num_bytes(); - return os; + return lhs.num_bytes() == rhs.num_bytes(); } -std::istream& operator>>(std::istream& is, MemoryFootprint& result) +bool operator<(const MemoryFootprint& lhs, const MemoryFootprint& rhs) noexcept { - if (is.good()) { - std::string input; - std::getline(is, input, ' '); - auto footprint = parse_footprint(input); - if (footprint) result = *footprint; - } - return is; + return lhs.num_bytes() < rhs.num_bytes(); } -enum class MemoryUnit { kB, KiB, MB, MiB, GB, GiB, TB, TiB, PB, PiB, EB, EiB, ZB, ZiB, YB, YiB, }; +namespace { + +enum class MemoryUnit { B, kB, KiB, MB, MiB, GB, GiB, TB, TiB, PB, PiB, EB, EiB, ZB, ZiB, YB, YiB, }; boost::optional parse_units(std::string& units_str) { - static const std::unordered_map units { - {"K", MemoryUnit::kB}, {"KB", MemoryUnit::kB}, - {"M", MemoryUnit::kB}, {"MB", MemoryUnit::MB}, - {"G", MemoryUnit::kB}, {"GB", MemoryUnit::GB}, - {"T", MemoryUnit::kB}, {"TB", MemoryUnit::TB}, - {"P", MemoryUnit::kB}, {"PB", MemoryUnit::PB}, - {"E", MemoryUnit::kB}, {"EB", MemoryUnit::EB}, - {"Z", MemoryUnit::kB}, {"ZB", MemoryUnit::ZB}, - {"Y", MemoryUnit::kB}, {"YB", MemoryUnit::YB}, + static const std::unordered_map units + { + {"B", MemoryUnit::B}, + {"K", MemoryUnit::kB}, + {"KB", MemoryUnit::kB}, + {"M", MemoryUnit::MB}, + {"MB", MemoryUnit::MB}, + {"G", MemoryUnit::GB}, + {"GB", MemoryUnit::GB}, + {"T", MemoryUnit::TB}, + {"TB", MemoryUnit::TB}, + {"P", MemoryUnit::PB}, + {"PB", MemoryUnit::PB}, + {"E", MemoryUnit::EB}, + {"EB", MemoryUnit::EB}, + {"Z", MemoryUnit::ZB}, + {"ZB", MemoryUnit::ZB}, + {"Y", MemoryUnit::YB}, + {"YB", MemoryUnit::YB}, {"KIB", MemoryUnit::KiB}, {"MIB", MemoryUnit::MiB}, {"GIB", MemoryUnit::GiB}, @@ -61,10 +60,10 @@ boost::optional parse_units(std::string& units_str) {"ZIB", MemoryUnit::ZiB}, {"YIB", MemoryUnit::YiB} }; - utils::capitalise(units_str); - const auto iter = units.find(units_str); - if (iter != std::cend(units)) { - return iter->second; + utils::capitalise(units_str); + const auto itr = units.find(units_str); + if (itr != std::cend(units)) { + return itr->second; } else { return boost::none; } @@ -89,45 +88,140 @@ struct MemoryUnitHash } }; -std::size_t get_multiplier(const MemoryUnit units) +constexpr std::size_t get_multiplier(const MemoryUnit unit) noexcept { - static const std::unordered_map multiplier { - {MemoryUnit::kB, ipow(1000, 1)}, - {MemoryUnit::MB, ipow(1000, 2)}, - {MemoryUnit::GB, ipow(1000, 3)}, - {MemoryUnit::TB, ipow(1000, 4)}, - {MemoryUnit::PB, ipow(1000, 5)}, - {MemoryUnit::EB, ipow(1000, 6)}, - {MemoryUnit::ZB, ipow(1000, 7)}, - {MemoryUnit::YB, ipow(1000, 8)}, - {MemoryUnit::KiB, ipow(1024, 1)}, - {MemoryUnit::MiB, ipow(1024, 2)}, - {MemoryUnit::GiB, ipow(1024, 3)}, - {MemoryUnit::TiB, ipow(1024, 4)}, - {MemoryUnit::PiB, ipow(1024, 5)}, - {MemoryUnit::EiB, ipow(1024, 6)}, - {MemoryUnit::ZiB, ipow(1024, 7)}, - {MemoryUnit::YiB, ipow(1024, 8)} - }; - return multiplier.at(units); + switch (unit) { + case MemoryUnit::B: return ipow(1000, 0); + case MemoryUnit::kB: return ipow(1000, 1); + case MemoryUnit::MB: return ipow(1000, 2); + case MemoryUnit::GB: return ipow(1000, 3); + case MemoryUnit::TB: return ipow(1000, 4); + case MemoryUnit::PB: return ipow(1000, 5); + case MemoryUnit::EB: return ipow(1000, 6); + case MemoryUnit::ZB: return ipow(1000, 7); + case MemoryUnit::YB: return ipow(1000, 8); + case MemoryUnit::KiB: return ipow(1024, 1); + case MemoryUnit::MiB: return ipow(1024, 2); + case MemoryUnit::GiB: return ipow(1024, 3); + case MemoryUnit::TiB: return ipow(1024, 4); + case MemoryUnit::PiB: return ipow(1024, 5); + case MemoryUnit::EiB: return ipow(1024, 6); + case MemoryUnit::ZiB: return ipow(1024, 7); + case MemoryUnit::YiB: return ipow(1024, 8); + default: return ipow(1000, 0); + } +} + +std::string to_string(const MemoryUnit unit) +{ + switch (unit) { + case MemoryUnit::B: return "B"; + case MemoryUnit::kB: return "kB"; + case MemoryUnit::MB: return "MB"; + case MemoryUnit::GB: return "GB"; + case MemoryUnit::TB: return "TB"; + case MemoryUnit::PB: return "PB"; + case MemoryUnit::EB: return "EB"; + case MemoryUnit::ZB: return "ZB"; + case MemoryUnit::YB: return "YB"; + case MemoryUnit::KiB: return "KiB"; + case MemoryUnit::MiB: return "MiB"; + case MemoryUnit::GiB: return "GiB"; + case MemoryUnit::TiB: return "TiB"; + case MemoryUnit::PiB: return "PiB"; + case MemoryUnit::EiB: return "EiB"; + case MemoryUnit::ZiB: return "ZiB"; + case MemoryUnit::YiB: return "YiB"; + default: return "B"; + } +} + +auto get_human_format_units(std::size_t bytes) noexcept +{ + using MU = MemoryUnit; + if (bytes >= 1000) { + if (bytes % 10 == 0) { + static constexpr std::array units { MU::kB, MU::MB, MU::GB, MU::TB, MU::PB, MU::EB, MU::ZB, MU::YB }; + std::size_t unit_idx {0}; + bytes /= 1000; + while (bytes >= 1000 && bytes % 10 == 0) { + bytes /= 1000; + ++unit_idx; + } + return units[std::min(unit_idx, units.size() - 1)]; + } else { + static constexpr std::array units { MU::B, MU::KiB, MU::MiB, MU::GiB, MU::TiB, MU::PiB, MU::EiB, MU::ZiB, MU::YiB }; + const auto unit_idx = static_cast(std::floor(std::log(bytes) / std::log(1024))); + const auto dv = std::lldiv(bytes, std::pow(1024, unit_idx)); + if (dv.rem == 0 || (dv.rem % 2 == 0 && std::log2(dv.rem) > 6)) { + return units[std::min(unit_idx, units.size() - 1)]; + } + } + } + return MU::B; } +auto get_human_format_units(const MemoryFootprint& footprint) noexcept +{ + return get_human_format_units(footprint.num_bytes()); +} + +} // namespace + boost::optional parse_footprint(std::string footprint_str) { using std::cbegin; using std::cend; - const auto first_digit_itr = std::find_if_not(cbegin(footprint_str), cend(footprint_str), - [] (char c) { return std::isdigit(c); }); - if (first_digit_itr == cbegin(footprint_str)) return boost::none; - const auto unit_begin_itr = std::find_if_not(first_digit_itr, cend(footprint_str), [] (char c) { return c == ' '; }); + static const auto is_digit = [] (char c) { return std::isdigit(c); }; + auto last_digit_itr = std::find_if_not(cbegin(footprint_str), cend(footprint_str), is_digit); + if (last_digit_itr == cend(footprint_str)) { + MemoryFootprint {static_cast(std::stoll(footprint_str))}; + } + bool is_float {false}; + if (*last_digit_itr == '.') { + last_digit_itr = std::find_if_not(std::next(last_digit_itr), cend(footprint_str), is_digit); + is_float = true; + } + if (last_digit_itr == cbegin(footprint_str)) return boost::none; + const auto unit_begin_itr = std::find_if_not(last_digit_itr, cend(footprint_str), [] (char c) { return c == ' '; }); std::string unit_part {unit_begin_itr, cend(footprint_str)}; std::size_t multiplier {1}; if (!unit_part.empty()) { - footprint_str.erase(first_digit_itr, cend(footprint_str)); + footprint_str.erase(last_digit_itr, cend(footprint_str)); const auto units = parse_units(unit_part); if (!units) return boost::none; multiplier = get_multiplier(*units); } - return MemoryFootprint {multiplier * std::stoll(footprint_str)}; + std::size_t bytes {}; + if (is_float) { + bytes = multiplier * std::stod(footprint_str); + } else { + bytes = multiplier * std::stoll(footprint_str); + } + return MemoryFootprint {bytes}; +} + +std::ostream& operator<<(std::ostream& os, MemoryFootprint footprint) +{ + const auto units = get_human_format_units(footprint); + const auto multiplier = get_multiplier(units); + if (footprint.num_bytes() % multiplier == 0) { + os << footprint.num_bytes() / multiplier; + } else { + os << static_cast(footprint.num_bytes()) / multiplier; + } + os << to_string(units); + return os; +} + +std::istream& operator>>(std::istream& is, MemoryFootprint& result) +{ + if (is.good()) { + std::string input; + std::getline(is, input, ' '); + auto footprint = parse_footprint(input); + if (footprint) result = *footprint; + } + return is; } } // namespace octopus diff --git a/src/utils/memory_footprint.hpp b/src/utils/memory_footprint.hpp index 1a9f74ee9..680d74d95 100644 --- a/src/utils/memory_footprint.hpp +++ b/src/utils/memory_footprint.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef memory_footprint_hpp @@ -11,14 +11,16 @@ #include +#include "concepts/comparable.hpp" + namespace octopus { -class MemoryFootprint +class MemoryFootprint : public Comparable { public: MemoryFootprint() = default; - MemoryFootprint(std::size_t num_bytes) noexcept; + constexpr MemoryFootprint(std::size_t num_bytes) noexcept : num_bytes_ {num_bytes} {} MemoryFootprint(const MemoryFootprint&) = default; MemoryFootprint& operator=(const MemoryFootprint&) = default; @@ -26,13 +28,16 @@ class MemoryFootprint MemoryFootprint& operator=(MemoryFootprint&&) = default; ~MemoryFootprint() = default; - - std::size_t num_bytes() const noexcept; + + constexpr std::size_t num_bytes() const noexcept { return num_bytes_; } private: std::size_t num_bytes_; }; +bool operator==(const MemoryFootprint& lhs, const MemoryFootprint& rhs) noexcept; +bool operator<(const MemoryFootprint& lhs, const MemoryFootprint& rhs) noexcept; + std::ostream& operator<<(std::ostream& os, MemoryFootprint footprint); std::istream& operator>>(std::istream& is, MemoryFootprint& result); diff --git a/src/utils/merge_transform.hpp b/src/utils/merge_transform.hpp index 63ec68140..54b9a277a 100644 --- a/src/utils/merge_transform.hpp +++ b/src/utils/merge_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef merge_transform_hpp diff --git a/src/utils/parallel_transform.hpp b/src/utils/parallel_transform.hpp index 69c6895ba..e34571d4b 100644 --- a/src/utils/parallel_transform.hpp +++ b/src/utils/parallel_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef parallel_transform_hpp diff --git a/src/utils/path_utils.cpp b/src/utils/path_utils.cpp index e5bc483ca..4aed5fad3 100644 --- a/src/utils/path_utils.cpp +++ b/src/utils/path_utils.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "path_utils.hpp" diff --git a/src/utils/path_utils.hpp b/src/utils/path_utils.hpp index 23e363919..63b9dacd5 100644 --- a/src/utils/path_utils.hpp +++ b/src/utils/path_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef path_utils_hpp diff --git a/src/utils/read_algorithms.hpp b/src/utils/read_algorithms.hpp index 78415a0f6..4489ce3a8 100644 --- a/src/utils/read_algorithms.hpp +++ b/src/utils/read_algorithms.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_algorithms_hpp diff --git a/src/utils/read_size_estimator.cpp b/src/utils/read_size_estimator.cpp deleted file mode 100644 index c180b30cf..000000000 --- a/src/utils/read_size_estimator.cpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2017 Daniel Cooke -// Use of this source code is governed by the MIT license that can be found in the LICENSE file. - -#include "read_size_estimator.hpp" - -#include -#include -#include -#include -#include -#include - -#include "maths.hpp" - -namespace octopus { - -namespace { - -template -ForwardIt random_select(ForwardIt first, ForwardIt last, RandomGenerator& g) { - std::uniform_int_distribution dis(0, std::distance(first, last) - 1); - std::advance(first, dis(g)); - return first; -} - -template -ForwardIt random_select(ForwardIt first, ForwardIt last) { - static std::default_random_engine gen {}; - return random_select(first, last, gen); -} - -auto estimate_dynamic_size(const AlignedRead& read) noexcept -{ - return read.name().size() * sizeof(char) - + sequence_size(read) * sizeof(char) - + sequence_size(read) * sizeof(AlignedRead::BaseQuality) - + read.cigar().size() * sizeof(CigarOperation) - + contig_name(read).size() - + (read.has_other_segment() ? sizeof(AlignedRead::Segment) : 0); -} - -auto estimate_read_size(const AlignedRead& read) noexcept -{ - return sizeof(AlignedRead) + estimate_dynamic_size(read); -} - -auto get_covered_sample_regions(const std::vector& samples, const InputRegionMap& input_regions, - ReadManager& read_manager) -{ - InputRegionMap result {}; - result.reserve(input_regions.size()); - for (const auto& p : input_regions) { - InputRegionMap::mapped_type contig_regions {}; - std::copy_if(std::cbegin(p.second), std::cend(p.second), - std::inserter(contig_regions, std::begin(contig_regions)), - [&] (const auto& region) { return read_manager.has_reads(samples, region); }); - if (!contig_regions.empty()) { - result.emplace(p.first, std::move(contig_regions)); - } - } - return result; -} - -} // namespace - -boost::optional estimate_mean_read_size(const std::vector& samples, - const InputRegionMap& input_regions, - ReadManager& read_manager, - const unsigned max_sample_size) -{ - if (input_regions.empty()) return boost::none; - const auto sample_regions = get_covered_sample_regions(samples, input_regions, read_manager); - if (sample_regions.empty()) return boost::none; - const auto num_samples_per_sample = max_sample_size / samples.size(); - std::deque read_size_samples {}; - // take read samples from each sample seperatly to ensure we cover each - for (const auto& sample : samples) { - const auto it = random_select(std::cbegin(sample_regions), std::cend(sample_regions)); - assert(!it->second.empty()); - const auto it2 = random_select(std::cbegin(it->second), std::cend(it->second)); - auto test_region = read_manager.find_covered_subregion(sample, *it2, num_samples_per_sample); - if (is_empty(test_region)) { - test_region = expand_rhs(test_region, 1); - } - const auto reads = read_manager.fetch_reads(sample, test_region); - std::transform(std::cbegin(reads), std::cend(reads), std::back_inserter(read_size_samples), - estimate_read_size); - } - if (read_size_samples.empty()) return boost::none; - return static_cast(maths::mean(read_size_samples) + maths::stdev(read_size_samples)); -} - -std::size_t default_read_size_estimate() noexcept -{ - return sizeof(AlignedRead) + 300; -} - -} // namespace octopus diff --git a/src/utils/read_size_estimator.hpp b/src/utils/read_size_estimator.hpp deleted file mode 100644 index 975b6ddcc..000000000 --- a/src/utils/read_size_estimator.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2017 Daniel Cooke -// Use of this source code is governed by the MIT license that can be found in the LICENSE file. - -#ifndef read_size_estimator_hpp -#define read_size_estimator_hpp - -#include -#include - -#include - -#include "config/common.hpp" -#include "io/read/read_manager.hpp" - -namespace octopus { - -boost::optional estimate_mean_read_size(const std::vector& samples, - const InputRegionMap& input_regions, - ReadManager& read_manager, - unsigned max_sample_size = 1000); - -std::size_t default_read_size_estimate() noexcept; - -} // namespace octopus - -#endif diff --git a/src/utils/read_stats.cpp b/src/utils/read_stats.cpp index fa8f6b1ea..f5245def0 100644 --- a/src/utils/read_stats.cpp +++ b/src/utils/read_stats.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "read_stats.hpp" diff --git a/src/utils/read_stats.hpp b/src/utils/read_stats.hpp index b5c20f0a3..79ad7f92c 100644 --- a/src/utils/read_stats.hpp +++ b/src/utils/read_stats.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef read_stats_hpp @@ -53,6 +53,12 @@ struct IsMappingQualityZero } }; +template +bool all_empty(const Map& map) noexcept +{ + return std::all_of(std::cbegin(map), std::cend(map), [] (const auto& p) { return p.second.empty(); }); +} + template bool has_coverage(const T& reads, NonMapTag) { @@ -102,6 +108,13 @@ double mean_coverage(const T& reads, const GenomicRegion& region, NonMapTag) return maths::mean(calculate_positional_coverage(reads, region)); } +template +double mean_coverage(const T& reads, NonMapTag) +{ + if (reads.empty()) return 0; + return mean_coverage(reads, encompassing_region(reads), NonMapTag {}); +} + template double stdev_coverage(const T& reads, const GenomicRegion& region, NonMapTag) { @@ -109,6 +122,13 @@ double stdev_coverage(const T& reads, const GenomicRegion& region, NonMapTag) return maths::stdev(calculate_positional_coverage(reads, region)); } +template +double stdev_coverage(const T& reads, NonMapTag) +{ + if (reads.empty()) return 0; + return stdev_coverage(reads, encompassing_region(reads), NonMapTag {}); +} + template std::size_t count_reads(const T& reads, NonMapTag) { @@ -229,6 +249,48 @@ std::size_t count_mapq_zero(const T& reads, const GenomicRegion& region, NonMapT return std::count_if(std::cbegin(overlapped), std::cend(overlapped), IsMappingQualityZero {}); } +struct MappingQualityLess +{ + bool operator()(const AlignedRead& lhs, const AlignedRead& rhs) const noexcept + { + return lhs.mapping_quality() < rhs.mapping_quality(); + } +}; + +template +auto min_mapping_quality(const T& reads, NonMapTag) +{ + static_assert(is_aligned_read_container, "T must be a container of AlignedReads"); + auto result_itr = std::min_element(std::cbegin(reads), std::cend(reads), MappingQualityLess {}); + return result_itr->mapping_quality(); +} + +template +auto min_mapping_quality(const T& reads, const GenomicRegion& region, NonMapTag) +{ + static_assert(is_aligned_read_container, "T must be a container of AlignedReads"); + const auto overlapped = overlap_range(reads, region); + auto result_itr = std::min_element(std::cbegin(overlapped), std::cend(overlapped), MappingQualityLess {}); + return result_itr->mapping_quality(); +} + +template +auto max_mapping_quality(const T& reads, NonMapTag) +{ + static_assert(is_aligned_read_container, "T must be a container of AlignedReads"); + auto result_itr = std::max_element(std::cbegin(reads), std::cend(reads), MappingQualityLess {}); + return result_itr->mapping_quality(); +} + +template +auto max_mapping_quality(const T& reads, const GenomicRegion& region, NonMapTag) +{ + static_assert(is_aligned_read_container, "T must be a container of AlignedReads"); + const auto overlapped = overlap_range(reads, region); + auto result_itr = std::max_element(std::cbegin(overlapped), std::cend(overlapped), MappingQualityLess {}); + return result_itr->mapping_quality(); +} + template double rmq_mapping_quality(const T& reads, NonMapTag) { @@ -346,51 +408,48 @@ unsigned max_coverage(const ReadMap& reads, const GenomicRegion& region, MapTag) } template -double mean_coverage(const T& reads, MapTag) +auto sum_positional_coverages(const T& reads, const GenomicRegion& region) { - if (reads.empty()) return 0.0; - std::vector sample_mean_coverages(reads.size(), 0.0); - std::transform(std::cbegin(reads), std::cend(reads), std::begin(sample_mean_coverages), - [] (const auto& sample_reads) { - return mean_coverage(sample_reads.second, NonMapTag {}); - }); - return maths::mean(sample_mean_coverages); + const auto num_bases = size(region); + std::vector result(num_bases); + for (const auto& p : reads) { + const auto sample_coverages = calculate_positional_coverage(p.second, region); + assert(sample_coverages.size() == num_bases); + for (std::size_t i {0}; i < num_bases; ++i) { + result[i] += sample_coverages[i]; + } + } + return result; } template double mean_coverage(const T& reads, const GenomicRegion& region, MapTag) { - if (reads.empty()) return 0.0; - std::vector sample_mean_coverages(reads.size(), 0.0); - std::transform(std::cbegin(reads), std::cend(reads), std::begin(sample_mean_coverages), - [®ion] (const auto& sample_reads) { - return mean_coverage(sample_reads.second, region, NonMapTag {}); - }); - return maths::mean(sample_mean_coverages); + if (reads.empty() || is_empty_region(region)) return 0.0; + if (reads.size() == 1) return mean_coverage(std::cbegin(reads)->second, region, NonMapTag {}); + return maths::mean(sum_positional_coverages(reads, region)); } template -double stdev_coverage(const T& reads, MapTag) +double mean_coverage(const T& reads, MapTag) { - if (reads.empty()) return 0.0; - std::vector sample_stdev_coverages(reads.size(), 0.0); - std::transform(std::cbegin(reads), std::cend(reads), std::begin(sample_stdev_coverages), - [] (const auto& sample_reads) { - return stdev_coverage(sample_reads.second, NonMapTag {}); - }); - return maths::stdev(sample_stdev_coverages); + if (reads.empty() || all_empty(reads)) return 0.0; + return mean_coverage(reads, encompassing_region(reads), MapTag {}); } template double stdev_coverage(const T& reads, const GenomicRegion& region, MapTag) { - if (reads.empty()) return 0.0; - std::vector sample_stdev_coverages(reads.size(), 0.0); - std::transform(std::cbegin(reads), std::cend(reads), std::begin(sample_stdev_coverages), - [®ion] (const auto& sample_reads) { - return stdev_coverage(sample_reads.second, region, NonMapTag {}); - }); - return maths::stdev(sample_stdev_coverages); + if (reads.empty() || is_empty_region(region)) return 0.0; + if (reads.size() == 1) return stdev_coverage(std::cbegin(reads)->second, region, NonMapTag {}); + return maths::stdev(sum_positional_coverages(reads, region)); +} + +template +double stdev_coverage(const T& reads, MapTag) +{ + if (reads.empty() || all_empty(reads)) return 0.0; + return stdev_coverage(reads, encompassing_region(reads), MapTag {}); } template @@ -516,6 +575,46 @@ std::size_t count_mapq_zero(const T& reads, const GenomicRegion& region, MapTag) }); } +template +auto min_mapping_quality(const T& reads, MapTag) +{ + auto result = min_mapping_quality(std::cbegin(reads)->second, NonMapTag {}); + std::for_each(std::next(std::cbegin(reads)), std::cend(reads), [&] (const auto& p) { + result = std::min(result, min_mapping_quality(p.second, NonMapTag {})); + }); + return result; +} + +template +auto min_mapping_quality(const T& reads, const GenomicRegion& region, MapTag) +{ + auto result = min_mapping_quality(std::cbegin(reads)->second, region, NonMapTag {}); + std::for_each(std::next(std::cbegin(reads)), std::cend(reads), [&] (const auto& p) { + result = std::min(result, min_mapping_quality(p.second, region, NonMapTag {})); + }); + return result; +} + +template +auto max_mapping_quality(const T& reads, MapTag) +{ + auto result = max_mapping_quality(std::cbegin(reads)->second, NonMapTag {}); + std::for_each(std::next(std::cbegin(reads)), std::cend(reads), [&] (const auto& p) { + result = std::max(result, max_mapping_quality(p.second, NonMapTag {})); + }); + return result; +} + +template +auto max_mapping_quality(const T& reads, const GenomicRegion& region, MapTag) +{ + auto result = max_mapping_quality(std::cbegin(reads)->second, region, NonMapTag {}); + std::for_each(std::next(std::cbegin(reads)), std::cend(reads), [&] (const auto& p) { + result = std::max(result, max_mapping_quality(p.second, region, NonMapTag {})); + }); + return result; +} + template double rmq_mapping_quality(const T& reads, MapTag) { @@ -741,6 +840,30 @@ std::size_t count_mapq_zero(const T& reads, const GenomicRegion& region) return detail::count_mapq_zero(reads, region, MapTagType {}); } +template +auto min_mapping_quality(const T& reads) +{ + return detail::min_mapping_quality(reads, MapTagType {}); +} + +template +auto min_mapping_quality(const T& reads, const GenomicRegion& region) +{ + return detail::min_mapping_quality(reads, region, MapTagType {}); +} + +template +auto max_mapping_quality(const T& reads) +{ + return detail::max_mapping_quality(reads, MapTagType {}); +} + +template +auto max_mapping_quality(const T& reads, const GenomicRegion& region) +{ + return detail::max_mapping_quality(reads, region, MapTagType {}); +} + template double rmq_mapping_quality(const T& reads) { diff --git a/src/utils/repeat_finder.cpp b/src/utils/repeat_finder.cpp index d0b045375..cb1199f7f 100644 --- a/src/utils/repeat_finder.cpp +++ b/src/utils/repeat_finder.cpp @@ -1,16 +1,62 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "repeat_finder.hpp" namespace octopus { +std::vector +find_exact_tandem_repeats(const ReferenceGenome& reference, const GenomicRegion& region, unsigned max_period) +{ + auto sequence = reference.fetch_sequence(region); + return find_exact_tandem_repeats(sequence, region, 1, max_period); +} + +bool is_good_seed(const TandemRepeat& repeat, const InexactRepeatDefinition& repeat_def) noexcept +{ + const auto repeat_length = region_size(repeat); + return repeat_length >= repeat_def.min_exact_repeat_seed_length + && repeat_length / repeat.period >= repeat_def.min_exact_repeat_seed_periods; +} + +std::vector +find_repeat_regions(const std::vector& repeats, const GenomicRegion& region, + const InexactRepeatDefinition repeat_def) +{ + assert(std::is_sorted(std::cbegin(repeats), std::cend(repeats))); + std::vector seeds {}; + seeds.reserve(repeats.size()); + std::copy_if(std::cbegin(repeats), std::cend(repeats), std::back_inserter(seeds), + [&] (const auto& repeat) { return is_good_seed(repeat, repeat_def); }); + if (seeds.size() < repeat_def.min_exact_seeds) { + return {}; + } + auto repeat_begin_itr = std::cbegin(repeats); + std::vector hits {}; + hits.reserve(repeats.size()); + for (const auto& seed : seeds) { + const auto expanded_seed_region = expand(seed.region, repeat_def.max_seed_join_distance); + const auto overlapped_repeats = overlap_range(repeat_begin_itr, std::cend(repeats), expanded_seed_region); + for (const auto& repeat : overlapped_repeats) { + hits.push_back(repeat.region); + } + repeat_begin_itr = overlapped_repeats.begin().base(); + } + auto result = join(extract_covered_regions(hits), repeat_def.max_seed_join_distance); + result.erase(std::remove_if(std::begin(result), std::end(result), + [&] (const auto& region) { + return size(region) < repeat_def.min_joined_repeat_length; + }), std::end(result)); + return result; +} + std::vector find_repeat_regions(const ReferenceGenome& reference, const GenomicRegion& region, - const InexactRepeatDefinition repeat_definition) + const InexactRepeatDefinition repeat_def) { - const auto sequence = reference.fetch_sequence(region); - return find_repeat_regions(sequence, region, repeat_definition); + auto sequence = reference.fetch_sequence(region); + auto seeds = find_exact_tandem_repeats(sequence, region, 1, repeat_def.max_exact_repeat_seed_period); + return find_repeat_regions(seeds, region, repeat_def); } } // namespace octopus diff --git a/src/utils/repeat_finder.hpp b/src/utils/repeat_finder.hpp index 9fc3bc11e..2f254eeae 100644 --- a/src/utils/repeat_finder.hpp +++ b/src/utils/repeat_finder.hpp @@ -1,8 +1,8 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. -#ifndef repeat_finer_hpp -#define repeat_finer_hpp +#ifndef repeat_finder_hpp +#define repeat_finder_hpp #include #include @@ -33,19 +33,22 @@ struct TandemRepeat : public Mappable struct InexactRepeatDefinition { - unsigned max_exact_repeat_seed_period = 100; - unsigned min_exact_repeat_seed_length = 100; - unsigned max_seed_join_distance = 50; + unsigned max_exact_repeat_seed_period = 6; + unsigned min_exact_repeat_seed_length = 3; + unsigned min_exact_repeat_seed_periods = 3; + unsigned min_exact_seeds = 1; + unsigned max_seed_join_distance = 2; + unsigned min_joined_repeat_length = 10; }; template std::vector -find_exact_tandem_repeats(SequenceType sequence, const GenomicRegion& region, - GenomicRegion::Size min_period = 2, GenomicRegion::Size max_period = 10000) +find_exact_tandem_repeats(SequenceType& sequence, const GenomicRegion& region, + GenomicRegion::Size min_period, GenomicRegion::Size max_period) { - if (sequence.back() != 'N') { + if (sequence.back() != '$') { sequence.reserve(sequence.size() + 1); - sequence.push_back('N'); + sequence.push_back('$'); } auto n_shift_map = tandem::collapse(sequence, 'N'); auto maximal_repetitions = tandem::extract_exact_tandem_repeats(sequence , min_period, max_period); @@ -64,29 +67,24 @@ find_exact_tandem_repeats(SequenceType sequence, const GenomicRegion& region, } template -std::vector -find_repeat_regions(const SequenceType& sequence, const GenomicRegion& region, - const InexactRepeatDefinition repeat_definition) +std::vector +find_exact_tandem_repeats(const SequenceType& sequence, const GenomicRegion& region, + GenomicRegion::Size min_period, GenomicRegion::Size max_period) { - auto repeats = find_exact_tandem_repeats(sequence, region, 2, repeat_definition.max_exact_repeat_seed_period); - auto itr = std::partition(std::begin(repeats), std::end(repeats), - [&] (const auto& repeat) noexcept { - return region_size(repeat) >= repeat_definition.min_exact_repeat_seed_length; - }); - itr = std::remove_if(itr, std::end(repeats), - [&] (const auto& small_repeat) { - return std::none_of(std::begin(repeats), itr, [&] (const auto& seed) { - return std::abs(inner_distance(small_repeat, seed)) <= repeat_definition.max_seed_join_distance; - }); - }); - repeats.erase(itr, std::end(repeats)); - std::sort(std::begin(repeats), std::end(repeats)); - return join(extract_covered_regions(repeats), repeat_definition.max_seed_join_distance); + auto tmp = sequence; + return find_exact_tandem_repeats(tmp, region, min_period, max_period); } +std::vector +find_exact_tandem_repeats(const ReferenceGenome& reference, const GenomicRegion& region, unsigned max_period); + +std::vector +find_repeat_regions(const std::vector& repeats, const GenomicRegion& region, + const InexactRepeatDefinition repeat_def); + std::vector find_repeat_regions(const ReferenceGenome& reference, const GenomicRegion& region, - InexactRepeatDefinition repeat_definition = InexactRepeatDefinition {}); + InexactRepeatDefinition repeat_def = InexactRepeatDefinition {}); } // namespace octopus diff --git a/src/utils/select_top_k.hpp b/src/utils/select_top_k.hpp new file mode 100644 index 000000000..9506a501b --- /dev/null +++ b/src/utils/select_top_k.hpp @@ -0,0 +1,147 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef select_top_k_hpp +#define select_top_k_hpp + +#include +#include +#include +#include +#include +#include +#include + +namespace octopus { + +using Index = std::size_t; +using IndexTuple = std::vector; +using IndexTupleVector = std::vector; + +namespace detail { + +template +auto index(const std::vector& values) +{ + std::vector> result(values.size()); + for (std::size_t i {0}; i < values.size(); ++i) { + result[i] = std::make_pair(values[i], i); + } + return result; +} + +template +auto index_and_sort(const std::vector& values, const std::size_t k) +{ + auto result = index(values); + const auto kth = std::next(std::begin(result), std::min(k, result.size())); + std::partial_sort(std::begin(result), kth, std::end(result), std::greater<> {}); + result.erase(kth, std::end(result)); + return result; +} + +template +struct IndexTupleScorePair +{ + IndexTuple indices; + T score; +}; + +template +bool operator>(const IndexTupleScorePair& lhs, const IndexTupleScorePair& rhs) noexcept +{ + return lhs.score > rhs.score; +} + +template +using IndexTupleScorePairVector = std::vector>; + +using IndexPair = std::pair; + +template +std::vector +find_k_max_pairs(const std::vector& lhs, const std::vector& rhs, const std::size_t k) +{ + // Modified from https://leetcode.com/problems/find-k-pairs-with-smallest-sums/discuss/84607/Clean-16ms-C++-O(N)-Space-O(KlogN)-Time-Solution-using-Priority-queue + std::vector result {}; + if (lhs.empty() || rhs.empty() || k == 0) + return result; + auto cmp = [&lhs, &rhs] (const IndexPair& a, const IndexPair& b) { + return lhs[a.first] + rhs[a.second] > lhs[b.first] + rhs[b.second]; }; + std::priority_queue, decltype(cmp)> max_heap {cmp}; + max_heap.emplace(0, 0); + for (std::size_t i {0}; i < k && !max_heap.empty(); ++i) { + const auto idx_pair = max_heap.top(); max_heap.pop(); + result.push_back(idx_pair); + if (idx_pair.first + 1 < lhs.size()) + max_heap.emplace(idx_pair.first + 1, idx_pair.second); + if (idx_pair.first == 0 && idx_pair.second + 1 < rhs.size()) + max_heap.emplace(idx_pair.first, idx_pair.second + 1); + } + return result; +} + +template +void join(const std::vector>& values, + IndexTupleScorePairVector& result, + const std::size_t k, + IndexTupleScorePairVector& buffer) +{ + const auto n = std::min(k, values.size()); + if (n == 0) return; + if (result.empty()) { + std::transform(std::cbegin(values), std::next(std::cbegin(values), n), + std::back_inserter(result), [=] (const auto& p) -> IndexTupleScorePair { + return {{p.second}, p.first}; + }); + } else { + if (values.size() > 1) { + std::vector vals1(values.size()), vals2(result.size()); + std::transform(std::cbegin(values), std::cend(values), std::begin(vals1), [] (const auto& p) { return p.first; }); + std::transform(std::cbegin(result), std::cend(result), std::begin(vals2), [] (const auto& p) { return p.score; }); + const auto max_pairs = find_k_max_pairs(vals1, vals2, std::min(k, n * result.size())); + buffer.clear(); + buffer.reserve(max_pairs.size()); + for (const auto& p : max_pairs) { + buffer.push_back(result[p.second]); + buffer.back().indices.push_back(values[p.first].second); + buffer.back().score += values[p.first].first; + } + result = std::move(buffer); + std::sort(std::begin(result), std::end(result), std::greater<> {}); + } else { + assert(values.size() == 1); + for (auto& t : result) { + t.indices.push_back(values[0].second); + t.score += values[0].first; + } + } + } +} + +} // namespace detail + +template +IndexTupleVector +select_top_k_tuples(const std::vector>& values, const std::size_t k) +{ + // Implements a variant of the top-K selection algorithm + // See Henderson & Eliassi-Rad (http://eliassi.org/papers/henderson-llnltr09.pdf) + // for comparison. + // Returns results in descending score order + detail::IndexTupleScorePairVector joins {}, buffer {}; + joins.reserve(k); + for (const auto& v : values) { + detail::join(detail::index_and_sort(v, k), joins, k, buffer); + } + IndexTupleVector result {}; + result.reserve(k); + for (auto& p : joins) { + result.push_back(std::move(p.indices)); + } + return result; +} + +} // namespace octopus + +#endif diff --git a/src/utils/sequence_utils.hpp b/src/utils/sequence_utils.hpp index a08b46912..0b41944df 100644 --- a/src/utils/sequence_utils.hpp +++ b/src/utils/sequence_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef sequence_utils_hpp diff --git a/src/utils/string_utils.cpp b/src/utils/string_utils.cpp index b4c331bf8..b33c8073d 100644 --- a/src/utils/string_utils.cpp +++ b/src/utils/string_utils.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "string_utils.hpp" diff --git a/src/utils/string_utils.hpp b/src/utils/string_utils.hpp index e36a14a36..87b8587e5 100644 --- a/src/utils/string_utils.hpp +++ b/src/utils/string_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef string_utils_hpp diff --git a/src/utils/system_utils.cpp b/src/utils/system_utils.cpp new file mode 100644 index 000000000..067a04594 --- /dev/null +++ b/src/utils/system_utils.cpp @@ -0,0 +1,17 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#include "system_utils.hpp" + +#include + +namespace octopus { + +std::size_t get_max_open_files() +{ + struct rlimit lim; + getrlimit(RLIMIT_NOFILE, &lim); + return lim.rlim_cur; +} + +} // namespace octopus diff --git a/src/utils/system_utils.hpp b/src/utils/system_utils.hpp new file mode 100644 index 000000000..d90e02d71 --- /dev/null +++ b/src/utils/system_utils.hpp @@ -0,0 +1,15 @@ +// Copyright (c) 2015-2018 Daniel Cooke +// Use of this source code is governed by the MIT license that can be found in the LICENSE file. + +#ifndef system_utils_hpp +#define system_utils_hpp + +#include + +namespace octopus { + +std::size_t get_max_open_files(); + +} // namespace octopus + +#endif diff --git a/src/utils/thread_pool.cpp b/src/utils/thread_pool.cpp index 1d49ebfc3..a596e5f0b 100644 --- a/src/utils/thread_pool.cpp +++ b/src/utils/thread_pool.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #include "thread_pool.hpp" diff --git a/src/utils/thread_pool.hpp b/src/utils/thread_pool.hpp index e8a72bb3f..2f0f6787a 100644 --- a/src/utils/thread_pool.hpp +++ b/src/utils/thread_pool.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. // This thread pool implementation is mostly derived from https://github.com/progschj/ThreadPool diff --git a/src/utils/timing.hpp b/src/utils/timing.hpp index e80ded56d..620999845 100644 --- a/src/utils/timing.hpp +++ b/src/utils/timing.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef timing_hpp @@ -39,34 +39,39 @@ inline std::ostream& operator<<(std::ostream& os, const TimeInterval& interval) return os; } const auto duration_s = duration(interval); - if (duration_s.count() < 60) { - os << duration_s.count() << 's'; + const auto secs = duration_s.count(); + if (secs < 60) { + os << secs << 's'; return os; } const auto duration_m = duration(interval); - if (duration_m.count() < 60) { - os << duration_m.count() << '.' << std::setw(2) << std::setfill('0') - << ((100 * (duration_s.count() % 60)) / 60) << 'm'; + const auto mins = duration_m.count(); + if (mins < 60) { + os << mins << 'm'; + const auto remainder_secs = secs % 60; + if (remainder_secs > 0) os << ' ' << remainder_secs << 's'; } else { const auto duration_h = duration(interval); if (duration_h.count() <= 24) { - os << duration_h.count() << '.' << std::setw(2) << std::setfill('0') - << ((100 * (duration_m.count() % 60)) / 60) << 'h'; + const auto hours = duration_h.count(); + os << hours << 'h'; + const auto remainder_mins = mins % 60; + if (remainder_mins > 0) os << ' ' << remainder_mins << 'm'; } else { using H = std::chrono::hours::rep; constexpr H num_hours_in_day {24}; const auto days = std::div(duration_h.count(), num_hours_in_day); if (days.quot < 7) { - os << days.quot << '.' << std::setw(2) << std::setfill('0') - << ((100 * days.rem) / num_hours_in_day) << 'd'; + os << days.quot << 'd'; + if (days.rem > 0) os << ' ' << days.rem << 'h'; } else { constexpr H num_hours_in_week {7 * num_hours_in_day}; const auto weeks = std::div(duration_h.count(), num_hours_in_week); - os << weeks.quot << '.' << std::setw(2) << std::setfill('0') - << ((100 * weeks.rem) / num_hours_in_week) << 'w'; + os << weeks.quot << 'w'; + const auto rem_days = weeks.rem / num_hours_in_day; + if (rem_days > 0) os << ' ' << rem_days << 'd'; } } - } return os; } diff --git a/src/utils/type_tricks.hpp b/src/utils/type_tricks.hpp index 51e21d8c6..c67cd2ae5 100644 --- a/src/utils/type_tricks.hpp +++ b/src/utils/type_tricks.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Daniel Cooke +// Copyright (c) 2015-2018 Daniel Cooke // Use of this source code is governed by the MIT license that can be found in the LICENSE file. #ifndef type_tricks_hpp