Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speedup demultiplex barcode by parallelisation #198

Open
bpenaud opened this issue Feb 5, 2025 · 16 comments · Fixed by #200, #202 or #208
Open

speedup demultiplex barcode by parallelisation #198

bpenaud opened this issue Feb 5, 2025 · 16 comments · Fixed by #200, #202 or #208
Labels
enhancement New feature or request
Milestone

Comments

@bpenaud
Copy link

bpenaud commented Feb 5, 2025

Description of feature

Hello Pavel,

I work in the same lab as BELKHIR, and on my side I tried to speed up the demultiplexing step by parallelizing the rule with the GNU parallel tool.

My method uses cpu to split the Undetermined_[RS][12].fastq.gz files into a number of files equal to the number of cpu in the rule with the seqkit split2 tool.

Then I use GNU parallel to run the demuxGen1 binary on my sequence blocks (split).

Once the BX Tag has been added, I rebuild the files in the same way as the Undetermined files.

This method has the advantage of considerably speeding up the rule, on 2*2.5 billion reads the demultiplexing step in harpy run during 6d, 20h, 30min, 43 sec. With my method the demultiplexing step run during 8h, 54min, 13 sec with 60 cores. (time only for rule demultiplex_barcodes)

On the other hand, this method requires temporary files to be written, which is going to take up a lot of disk space and can be a problem.

I've modified the rule :

rule demultiplex_barcodes:
    input:
        collect(outdir + "/DATA_{IR}{ext}_001.fastq.gz", IR = ["R","I"], ext = [1,2]),
        collect(outdir + "/BC_{letter}.txt", letter = ["A","C","B","D"])
    output:
        temp(collect(outdir + "/demux_R{ext}_001.fastq.gz", ext = [1,2]))
    params:
        outdir
    container:
        None
    shell:
        """
        cd {params}
        demuxGen1 DATA_ demux
        mv demux*BC.log logs
        """

By

rule demultiplex_barcodes:
    input:
        collect(outdir + "/DATA_{IR}{ext}_001.fastq.gz", IR = ["R","I"], ext = [1,2]),
        collect(outdir + "/BC_{letter}.txt", letter = ["A","C","B","D"])
    output:
        demux=temp(collect(outdir + "/demux_R{ext}_001.fastq.gz", ext = [1,2])),
        demux_dir=temp(directory(outdir+"/demux_temp/"))
    params:
        outdir
    threads:
        100
    container:
        None
    benchmark:
        outdir +"/benchmark_demultiplex/demultiplex_barcode_benchmark.txt"
    shell:
        """
        r1=$(echo {input} | grep -o "\\w*_R1_\\w*.fastq.gz")
        r2=$(echo {input} | grep -o "\\w*_R2_\\w*.fastq.gz")
        i1=$(echo {input} | grep -o "\\w*_I1_\\w*.fastq.gz")
        i2=$(echo {input} | grep -o "\\w*_I2_\\w*.fastq.gz")

        seqkit split2 -1 {params}/$r1 -2 {params}/$r2 -p {threads} -O {output.demux_dir} -e .gz -j {threads}
        rename 's/DATA(_R[12]_001).part_(\\d+).fastq.gz/DATA$2$1.fastq.gz/' {output.demux_dir}/DATA_R[12]_*.fastq.gz

        seqkit split2 -1 {params}/$i1 -2 {params}/$i2 -p {threads} -O {output.demux_dir} -e .gz -j {threads}
        rename 's/DATA(_I[12]_001).part_(\\d+).fastq.gz/DATA$2$1.fastq.gz/' {output.demux_dir}/DATA_I[12]_*.fastq.gz

        cd {output.demux_dir}
        cp {params}/BC_*.txt {output.demux_dir}/
        parallel -j {threads} demuxGen1 DATA{{}}_ demux{{}}  ::: $(ls *_R1_*.fastq.gz | sed 's/DATA\\([0-9]\\+\\)_R1_001.fastq.gz/\\1/')

        rm DATA*_001.fastq.gz
        find -name "demux*_R1_*.fastq.gz" |sort -V | xargs cat > {params}/demux_R1_001.fastq.gz
        find -name "demux*_R2_*.fastq.gz" |sort -V | xargs cat > {params}/demux_R2_001.fastq.gz

        mkdir -p {params}/logs/demultiplex
        mv demux*BC.log {params}/logs/demultiplex/
        rm demux*.fastq.gz
        """

My rule needs GNU parallel and seqkit install in the conda environment.

conda install conda-forge::parallel
conda install bioconda::seqkit

Best regards,
Benjamin

@bpenaud bpenaud added the enhancement New feature or request label Feb 5, 2025
@pdimens
Copy link
Owner

pdimens commented Feb 5, 2025

Hi Benjamin, thanks for writing. The benchmarks you provide on real data (6 days) is kind of unacceptable, so thanks for bringing it to my attention.

I had a similar-ish line of thinking last night regarding a speedup that I would like to try today or this week. I would like to modify the new python demuxing script originally conceived by @BELKHIR in #190 to focus on a single sample at a time, and have that be parallelized through snakemake. If this works the way I'm hoping it will, it will also be flexible to changes in schema, such as the inclusion of a new sample or something (probably a rare case, but it would work regardless)

The idea is such (psuedocode):

rule demux:
  input:
    data_R1.fq
    data_R2.fq
    data_I1.fq
    dataI2.fq
    segments a,b,c,d
  output:
    pipe({sample}.R1.fq)
    pipe({sample}.R2.fq)
  params:
    sample = lambda wc: wc.get("sample")
    id_segment = lambda wc: sample_dict(wc.sample)
  script:
    scripts/demux_gen1.py {params} {input}

rule compress:
  input:
    {sample}.R{FR}.fq
  output:
    {sample.R{FR}.fq.gz
  shell:
    "gzip {input}"

@pdimens
Copy link
Owner

pdimens commented Feb 5, 2025

The work for that can be seen in the demux_parallelized branch here https://github.com/pdimens/harpy/tree/demux_parallelized

And this PR #200

@pdimens pdimens added this to the 2.0 milestone Feb 5, 2025
@pdimens
Copy link
Owner

pdimens commented Feb 6, 2025

On the [admittedly tiny] test data, the parallelized-demux-by-sample seems to be performant. Once all checks pass, would you be willing to try a dev build on your data to see how it performs in a real setting?

I'd also like to rope @BELKHIR into this, as I made some modifications to their python script that are worth noting.

Preamble

Given that the pythonic approach you provided essentially liberates us from the original Chan method, there is a lot of freedom to make it work in a way that seems more sensible for general use. With that in mind:

  1. Input files are no longer hardcoded into the script. However, to use the levenshtein package, a new conda environment for demuxing was required, and the script needed to be in Snakemake-script form, i.e. using snakemake.input... variables, etc. So it's a victory but at a slight cost.
  2. The script processes a single sample at a time, allowing snakemake to parallelize it arbitrarily
  3. I was unable to get a workable/performant solution that wrote gzipped fastqs directly, so the demuxing first writes uncompressed fasta. However, since gzipping is a separate step and also parallelized, snakemake should simultaneously start compressing outputs as they become available. That's better than all the fastqs being created and then compressed.
  4. The barcode log outputs are unified into a single file. Each sample gets their own file. To make that work, I renamed the columns:
Barcode    Total    Correct_Reads    Corrected_Reads

This format allows the "unclear" barcodes to sit at the bottom. You will be able to recognize if they are "unclear" by them having a number >0 for Total, but zeroes for Correct and Corrected _Reads:

Barcode    Total    Correct_Reads    Corrected_Reads
A41B00C00D82   4    0    0

@bpenaud
Copy link
Author

bpenaud commented Feb 6, 2025

Yes no problem to try on my big dataset once all checks pass.

Regards,
Benjamin

@pdimens pdimens reopened this Feb 6, 2025
@pdimens
Copy link
Owner

pdimens commented Feb 6, 2025

@bpenaud thanks for your willingness to test it. The dev version can be installed using the instructions provided here

@pdimens
Copy link
Owner

pdimens commented Feb 7, 2025

@bpenaud it's ready for testing off the main branch

@bpenaud
Copy link
Author

bpenaud commented Feb 10, 2025

Hi @pdimens ,

I launched the job on Friday afternoon and it's currently still running. I can see that the demultiplexing job isn't halfway through yet, so I don't think the modifications will save any computing time.

From what I can see of your modification, I think the problem is that the demultiplexing job is run as many times as there are samples. As a result, the entire Undetermined file is unzipped and read as many times as there are samples, which takes time.

BELKHIR's solution was to read the Undetermined file once and then write to the sample files directly. This solution could be done with a single job and therefore could be parrallelized as I proposed by first splitting the Undetermined file and running the python script on each group of the Undetermined file.

Don't hesitate if you want help to write the snakefile and python script.

Best Regards,
Benjamin

@pdimens
Copy link
Owner

pdimens commented Feb 10, 2025

That's really unfortunate that the current implementation is slow-- I was concerned about it for the very reasons you outlined. When I have time this week, I'll investigate your solution in better detail. Ideally, it would be best to split the rule you provided above into separate rules, one to split/chunk and the other to demux the chunk.

@pdimens
Copy link
Owner

pdimens commented Feb 12, 2025

@bpenaud the new divide-and-conquer approach has been merged into main. Would you be willing to test it?

@bpenaud
Copy link
Author

bpenaud commented Feb 13, 2025

Yes I can launch it before the week end.

Benjamin

@bpenaud
Copy link
Author

bpenaud commented Feb 17, 2025

Hi,

I try to run the new workflow, but the workflow has a problem to resolve the DAG when I set a big number of threads (i.e. dry run time) :

  • 1 threads
time snakemake --rerun-incomplete --show-failed-logs --rerun-triggers input mtime params --nolock --conda-prefix .environments --conda-cleanup-pkgs cache --apptainer-prefix .environments --directory . --software-deployment-method conda --cores 1 --snakefile /home/bpenaud/Results/Haplotagging/Demultiplex_V2/workflow/demultiplex_gen1.smk --configfile /home/bpenaud/Results/Haplotagging/Demultiplex_V2/workflow/config.yaml -n

real 0m8.204s
user 0m7.466s
sys 0m0.698s

  • 2 threads
time snakemake --rerun-incomplete --show-failed-logs --rerun-triggers input mtime params --nolock --conda-prefix .environments --conda-cleanup-pkgs cache --apptainer-prefix .environments --directory . --software-deployment-method conda --cores 1 --snakefile /home/bpenaud/Results/Haplotagging/Demultiplex_V2/workflow/demultiplex_gen1.smk --configfile /home/bpenaud/Results/Haplotagging/Demultiplex_V2/workflow/config.yaml -n

real 0m17.795s
user 0m16.983s
sys 0m0.769s

  • 3 threads
    real 0m35.016s
    user 0m33.948s
    sys 0m1.026s

  • 4 threads
    real 1m1.851s
    user 1m0.562s
    sys 0m1.251s

  • 5 threads
    real 1m39.477s
    user 1m37.365s
    sys 0m2.059s

-6 threads
real 2m27.727s
user 2m24.808s
sys 0m2.854s

-8 threads
real 4m34.012s
user 4m28.808s
sys 0m5.130s

  • 10 threads
    real 7m42.499s
    user 7m33.478s
    sys 0m8.920s

With 60 threads the DAG was never resolve during the week end.

For the moment, I don't find the reason of this behavior.

Regards,
Benjamin

@pdimens
Copy link
Owner

pdimens commented Feb 17, 2025

So two things are happening, if I understand correctly:

  1. The dag doesn't resolve
  2. Runtime increases with thread count? Or that's just dry run DAG resolving time?

@bpenaud
Copy link
Author

bpenaud commented Feb 17, 2025

All runtime above are to display the DAG (dry run). But since the DAG was never resolve with 60 cores the demultiplex workflow doesn't start.

So with a small amount of threads the DAG is resolve but by increasing it, the DAg is not resolve.

@pdimens
Copy link
Owner

pdimens commented Feb 17, 2025

That's so interesting. Thanks for letting me know, I'll look into it

Update: can reproduce the error on my system

@pdimens
Copy link
Owner

pdimens commented Feb 17, 2025

Update 2:
The issue (hopefully the only issue) is an infinite recursion in the sample wildcard, seen here when using --dry-run --debug-dag:

candidate job merge_partitions
    wildcards: sample=Sample_17.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001.001.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001.001.001.001.001, FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001.001.001.001.001.001, 
FR=1
candidate job merge_partitions
    wildcards: sample=Sample_17.001.001.001.001.001.001.001.001.001.001,
FR=1

I'll get this fixed.

@pdimens
Copy link
Owner

pdimens commented Feb 17, 2025

@bpenaud alright, I think I fixed the issue by setting proper wildcard_constraints. It seems to work (just about immediately) on my laptop up to --threads 999. When you have a chance, please replace the demultiplex_gen1.smk in OUTDIR/workflow/ with the one below and run

# OUTDIR being your output directory
harpy resume OUTDIR

demultiplex_gen1.smk

containerized: "docker://pdimens/harpy:latest"

import os
import logging

outdir = config["output_directory"]
envdir = os.path.join(os.getcwd(), outdir, "workflow", "envs")
samplefile = config["inputs"]["demultiplex_schema"]
skip_reports = config["reports"]["skip"]
keep_unknown = config["keep_unknown"]

onstart:
    logger.logger.addHandler(logging.FileHandler(config["snakemake_log"]))
    os.makedirs(f"{outdir}/reports/data", exist_ok = True)
onsuccess:
    os.remove(logger.logfile)
onerror:
    os.remove(logger.logfile)
wildcard_constraints:
    sample = r"[a-zA-Z0-9._-]+",
    FR = r"[12]",
    part = r"\d{3}"

def parse_schema(smpl, keep_unknown):
    d = {}
    with open(smpl, "r") as f:
        for i in f.readlines():
            # a casual way to ignore empty lines or lines with !=2 fields
            try:
                sample, bc = i.split()
                id_segment = bc[0]
                if sample not in d:
                    d[sample] = [bc]
                else:
                    d[sample].append(bc)
            except ValueError:
                continue
    if keep_unknown:
        d["_unknown_sample"] = f"{id_segment}00"
    return d

samples = parse_schema(samplefile, keep_unknown)
samplenames = [i for i in samples]
print(samplenames)
fastq_parts = [f"{i:03d}" for i in range(1, min(workflow.cores, 999) + 1)]

rule barcode_segments:
    output:
        collect(outdir + "/workflow/segment_{letter}.bc", letter = ["A","C","B","D"])
    params:
        f"{outdir}/workflow"
    container:
        None
    shell:
        "haplotag_acbd.py {params}"

rule partition_reads:
    input:
        r1 = config["inputs"]["R1"],
        r2 = config["inputs"]["R2"]       
    output:
        r1 = temp(f"{outdir}/reads.R1.fq.gz"),
        r2 = temp(f"{outdir}/reads.R2.fq.gz"),
        parts = temp(collect(outdir + "/reads_chunks/reads.R{FR}.part_{part}.fq.gz", part = fastq_parts, FR = [1,2]))
    log:
        outdir + "/logs/partition.reads.log"
    threads:
        workflow.cores
    params:
        chunks = min(workflow.cores, 999),
        outdir = f"{outdir}/reads_chunks"
    conda:
        f"{envdir}/demultiplex.yaml"
    shell:
        """
        ln -sr {input.r1} {output.r1}
        ln -sr {input.r2} {output.r2}
        seqkit split2 -f --quiet -1 {output.r1} -2 {output.r2} -p {params.chunks} -j {threads} -O {params.outdir} -e .gz 2> {log}
        """

use rule partition_reads as partition_index with:
    input:
        r1 = config["inputs"]["I1"],
        r2 = config["inputs"]["I2"]       
    output:
        r1 = temp(f"{outdir}/reads.I1.fq.gz"),
        r2 = temp(f"{outdir}/reads.I2.fq.gz"),
        parts = temp(collect(outdir + "/index_chunks/reads.I{FR}.part_{part}.fq.gz", part = fastq_parts, FR = [1,2]))
    log:
        outdir + "/logs/partition.index.log"
    params:
        chunks = min(workflow.cores, 999),
        outdir = f"{outdir}/index_chunks"

rule demultiplex:
    input:
        R1 = outdir + "/reads_chunks/reads.R1.part_{part}.fq.gz",
        R2 = outdir + "/reads_chunks/reads.R2.part_{part}.fq.gz",
        I1 = outdir + "/index_chunks/reads.I1.part_{part}.fq.gz",
        I2 = outdir + "/index_chunks/reads.I2.part_{part}.fq.gz",
        segment_a = f"{outdir}/workflow/segment_A.bc",
        segment_b = f"{outdir}/workflow/segment_B.bc",
        segment_c = f"{outdir}/workflow/segment_C.bc",
        segment_d = f"{outdir}/workflow/segment_D.bc",
        schema = samplefile
    output:
        temp(collect(outdir + "/{sample}.{{part}}.R{FR}.fq", sample = samplenames, FR = [1,2])),
        bx_info = temp(f"{outdir}/logs/part.{{part}}.barcodes")
    log:
        f"{outdir}/logs/demultiplex.{{part}}.log"
    params:
        outdir = outdir,
        qxrx = config["include_qx_rx_tags"],
        keep_unknown = keep_unknown,
        part = lambda wc: wc.get("part")
    conda:
        f"{envdir}/demultiplex.yaml"
    script:
        "scripts/demultiplex_gen1.py"

rule merge_partitions:
    input:
        collect(outdir + "/{{sample}}.{part}.R{{FR}}.fq", part = fastq_parts)
    output:
        outdir + "/{sample}.R{FR}.fq.gz"
    log:
        outdir + "/logs/{sample}.{FR}.concat.log"
    container:
        None
    shell:
        "cat {input} | gzip > {output} 2> {log}"

rule merge_barcode_logs:
    input:
        bc = collect(outdir + "/logs/part.{part}.barcodes", part = fastq_parts)
    output:
        log = f"{outdir}/logs/barcodes.log"
    run:
        bc_dict = {}
        for i in input.bc:
            with open(i, "r") as bc_log:
                # skip first row of column names
                _ = bc_log.readline()
                for line in bc_log:
                    barcode,total,correct,corrected = line.split()
                    bc_stats = [int(total), int(correct), int(corrected)]
                    if barcode not in bc_dict:
                        bc_dict[barcode] = bc_stats
                    else:
                        bc_dict[barcode] = list(map(lambda x,y: x+y, bc_stats, bc_dict[barcode]))
        with open(output.log, "w") as f:
            f.write("Barcode\tTotal_Reads\tCorrect_Reads\tCorrected_Reads\n")
            for k,v in bc_dict.items():
                f.write(k + "\t" + "\t".join([str(i) for i in v]) + "\n")

rule assess_quality:
    input:
        outdir + "/{sample}.R{FR}.fq.gz"
    output: 
        outdir + "/reports/data/{sample}.R{FR}.fastqc"
    log:
        outdir + "/logs/{sample}.R{FR}.qc.log"
    threads:
        1
    conda:
        f"{envdir}/qc.yaml"
    shell:
        """
        ( falco --quiet --threads {threads} -skip-report -skip-summary -data-filename {output} {input} ) > {log} 2>&1 ||
cat <<EOF > {output}
##Falco	1.2.4
>>Basic Statistics	fail
#Measure	Value
Filename	{wildcards.sample}.R{wildcards.FR}.fq.gz
File type	Conventional base calls
Encoding	Sanger / Illumina 1.9
Total Sequences	0
Sequences flagged as poor quality	0
Sequence length	0
%GC	0
>>END_MODULE
EOF      
        """

rule report_config:
    output:
        outdir + "/workflow/multiqc.yaml"
    run:
        import yaml
        configs = {
            "sp": {"fastqc/data": {"fn" : "*.fastqc"}},
            "table_sample_merge": {
                "R1": ".R1",
                "R2": ".R2"
            },
            "title": "Quality Assessment of Demultiplexed Samples",
            "subtitle": "This report aggregates the QA results created by falco",
            "report_comment": "Generated as part of the Harpy demultiplex workflow",
            "report_header_info": [
                {"Submit an issue": "https://github.com/pdimens/harpy/issues/new/choose"},
                {"Read the Docs": "https://pdimens.github.io/harpy/"},
                {"Project Homepage": "https://github.com/pdimens/harpy"}
            ]
        }
        with open(output[0], "w", encoding="utf-8") as yml:
            yaml.dump(configs, yml, default_flow_style= False, sort_keys=False, width=float('inf'))

rule quality_report:
    input:
        fqc = collect(outdir + "/reports/data/{sample}.R{FR}.fastqc", sample = samplenames, FR = [1,2]),
        mqc_yaml = outdir + "/workflow/multiqc.yaml"
    output:
        outdir + "/reports/demultiplex.QA.html"
    log:
        f"{outdir}/logs/multiqc.log"
    params:
        options = "--no-version-check --force --quiet --no-data-dir",
        module = " --module fastqc",
        logdir = outdir + "/reports/data/"
    conda:
        f"{envdir}/qc.yaml"
    shell:
        "multiqc --filename {output} --config {input.mqc_yaml} {params} 2> {log}"

rule workflow_summary:
    default_target: True
    input:
        fq = collect(outdir + "/{sample}.R{FR}.fq.gz", sample = samplenames, FR = [1,2]),
        barcode_logs = f"{outdir}/logs/barcodes.log",
        reports = outdir + "/reports/demultiplex.QA.html" if not skip_reports else []
    params:
        R1 = config["inputs"]["R1"],
        R2 = config["inputs"]["R2"],
        I1 = config["inputs"]["I1"],
        I2 = config["inputs"]["I2"]
    run:
        summary = ["The harpy demultiplex workflow ran using these parameters:"]
        summary.append("Linked Read Barcode Design: Generation I")
        inputs = "The multiplexed input files:\n"
        inputs += f"\tread 1: {params.R1}\n"
        inputs += f"\tread 2: {params.R2}\n"
        inputs += f"\tindex 1: {params.I1}\n"
        inputs += f"\tindex 2: {params.I2}"
        inputs += f"Sample demultiplexing schema: {samplefile}"
        summary.append(inputs)
        demux = "Samples were demultiplexed using:\n"
        demux += "\tworkflow/scripts/demultiplex_gen1.py"
        summary.append(demux)
        qc = "QC checks were performed on demultiplexed FASTQ files using:\n"
        qc += "\tfalco -skip-report -skip-summary -data-filename output input.fq.gz"
        summary.append(qc)
        sm = "The Snakemake workflow was called via command line:\n"
        sm += f"\t{config['workflow_call']}"
        summary.append(sm)
        with open(outdir + "/workflow/demux.gen1.summary", "w") as f:
            f.write("\n\n".join(summary))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
2 participants