All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.broadinstitute.hellbender.tools.spark.bwa.BwaSparkEngine Maven / Gradle / Ivy

The newest version!
package org.broadinstitute.hellbender.tools.spark.bwa;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import org.apache.spark.SparkFiles;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.bwa.*;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.SAMRecordToGATKReadAdapter;

import java.util.*;

/**
 * The BwaSparkEngine provides a simple interface for transforming a JavaRDD in which the reads are paired
 * and unaligned, into a JavaRDD of aligned reads, and does so lazily.
 * Use it like this:
 *     Make one, call the {@link #align} method for each of your input RDDs in a pipeline that runs some action, close it.
 *
 * The reason that the pipeline must culminate in some action, is because this class implements a lazy
 * transform, and nothing will happen otherwise.
 *
 * See {@link BwaSpark#runTool runTool} for an example.
 */
public final class BwaSparkEngine implements AutoCloseable {
    private static final String REFERENCE_INDEX_IMAGE_FILE_SUFFIX = ".img";
    private final JavaSparkContext ctx;
    private final String indexFileName;
    private final boolean resolveIndexFileName;
    private final Broadcast broadcastHeader;

    /**
     * @param ctx           the Spark context
     * @param referenceFile the path to the reference file named _prefix_.fa, which is used to find the image file with name _prefix_.fa.img.
     *                      Can be null if the indexFileName is provided.
     * @param indexFileName the index image file name that already exists, or null to have the image file automatically distributed.
     * @param inputHeader   the SAM file header to use for reads
     * @param refDictionary the sequence dictionary to use for reads if the SAM file header doesn't have one (or it's empty)
     */
    public BwaSparkEngine(final JavaSparkContext ctx,
                          final String referenceFile,
                          final String indexFileName,
                          SAMFileHeader inputHeader,
                          final SAMSequenceDictionary refDictionary) {
        Utils.nonNull(referenceFile);
        Utils.nonNull(inputHeader);
        this.ctx = ctx;
        if (indexFileName != null) {
            this.indexFileName = indexFileName;
            this.resolveIndexFileName = false;
        } else {
            String indexFile = referenceFile + REFERENCE_INDEX_IMAGE_FILE_SUFFIX;
            ctx.addFile(indexFile); // distribute index file to all executors
            this.indexFileName = IOUtils.getPath(indexFile).getFileName().toString();
            this.resolveIndexFileName = true;
        }

        if (inputHeader.getSequenceDictionary() == null || inputHeader.getSequenceDictionary().isEmpty()) {
            Utils.nonNull(refDictionary);
            inputHeader = inputHeader.clone();
            inputHeader.setSequenceDictionary(refDictionary);
        }
        broadcastHeader = ctx.broadcast(inputHeader);
    }

    public SAMFileHeader getHeader() { return broadcastHeader.getValue(); }

    /**
     * Performs pair-end alignment on a RDD.
     * @param unalignedReads the read-pairs to align.
     * @return never {@code null}.
     */
    public JavaRDD alignPaired(final JavaRDD unalignedReads) {
        return align(unalignedReads, true);
    }

    /**
     * Performs single-end alignment on a RDD.
     *
     * @param unalignedReads the reads to align.
     * @return never {@code null}.
     */
    public JavaRDD alignUnpaired(final JavaRDD unalignedReads) {
        return align(unalignedReads, false);
    }

    /**
     * Performs read alignment on a RDD.
     * @param unalignedReads the reads to align.
     * @param pairedAlignment whether it should perform pair-end alignment ({@code true}) or single-end alignment ({@code false}).
     * @return never {@code null}.
     */
    public JavaRDD align(final JavaRDD unalignedReads, final boolean pairedAlignment) {
        final Broadcast broadcastHeader = this.broadcastHeader;
        final String indexFileName = this.indexFileName;
        final boolean resolveIndexFileName = this.resolveIndexFileName;
        return unalignedReads.mapPartitions(itr ->
                new ReadAligner(resolveIndexFileName ? SparkFiles.get(indexFileName) : indexFileName, broadcastHeader.value(), pairedAlignment).apply(itr));
    }

    @Override
    public void close() {
        broadcastHeader.destroy();
        BwaMemIndexCache.closeAllDistributedInstances(ctx);
    }

    private static final class ReadAligner {
        private final BwaMemIndex bwaMemIndex;
        private final SAMFileHeader readsHeader;
        private final boolean alignsPairs;

        // assumes 128Mb partitions, with reads needing about 100bytes each when BAM compressed
        private static final int READS_PER_PARTITION_GUESS = 1500000;

        ReadAligner( final String indexFileName, final SAMFileHeader readsHeader, final boolean alignsPairs) {
            this.bwaMemIndex = BwaMemIndexCache.getInstance(indexFileName);
            this.readsHeader = readsHeader;
            this.alignsPairs = alignsPairs;
            if ( alignsPairs && readsHeader.getSortOrder() != SAMFileHeader.SortOrder.queryname ) {
                throw new UserException("Input must be queryname sorted unless you use single-ended alignment mode.");
            }
        }

        Iterator apply( final Iterator readItr ) {
            final List inputReads = new ArrayList<>(READS_PER_PARTITION_GUESS);
            while ( readItr.hasNext() ) {
                inputReads.add(readItr.next());
            }
            final int nReads = inputReads.size();
            if ( alignsPairs ) {
                if ( (nReads & 1) != 0 ) {
                    throw new GATKException("We're supposed to be aligning paired reads, but there are an odd number of them.");
                }
                for ( int idx = 0; idx != nReads; idx += 2 ) {
                    final String readName1 = inputReads.get(idx).getName();
                    final String readName2 = inputReads.get(idx+1).getName();
                    if ( !Objects.equals(readName1,readName2) ) {
                        throw new GATKException("Read pair has varying template name: "+readName1+" .vs "+readName2);
                    }
                }
            }
            final List> allAlignments;
            if ( nReads == 0 ) allAlignments = Collections.emptyList();
            else {
                final List seqs = new ArrayList<>(nReads);
                for (final GATKRead read : inputReads) {
                    seqs.add(read.getBases());
                }
                final BwaMemAligner aligner = new BwaMemAligner(bwaMemIndex);
                // we are dealing with interleaved, paired reads.  tell BWA that they're paired.
                if (alignsPairs) {
                    aligner.alignPairs();
                }
                allAlignments = aligner.alignSeqs(seqs);
            }
            final List refNames = bwaMemIndex.getReferenceContigNames();
            final List outputReads = new ArrayList<>(allAlignments.stream().mapToInt(List::size).sum());
            for ( int idx = 0; idx != nReads; ++idx ) {
                final GATKRead originalRead = inputReads.get(idx);
                final String readName = originalRead.getName();
                final byte[] bases = originalRead.getBases();
                final byte[] quals = originalRead.getBaseQualities();
                final String readGroup = originalRead.getReadGroup();
                final List alignments = allAlignments.get(idx);
                final Map saTagMap = BwaMemAlignmentUtils.createSATags(alignments,refNames);
                for ( final BwaMemAlignment alignment : alignments ) {
                    final SAMRecord samRecord =
                            BwaMemAlignmentUtils.applyAlignment(readName, bases, quals, readGroup,
                                                                alignment, refNames, readsHeader, false, true);
                    final GATKRead rec = SAMRecordToGATKReadAdapter.headerlessReadAdapter(samRecord);
                    final String saTag = saTagMap.get(alignment);
                    if ( saTag != null ) rec.setAttribute("SA", saTag);
                    outputReads.add(rec);
                }
            }
            return outputReads.iterator();
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy