All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.broadinstitute.hellbender.tools.walkers.vqsr.TrancheManager Maven / Gradle / Ivy

The newest version!
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager;

import org.broadinstitute.hellbender.exceptions.UserException;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class TrancheManager {

    protected final static Logger logger = LogManager.getLogger(TrancheManager.class);

    // ---------------------------------------------------------------------------------------------------------
    //
    // Code to determine FDR tranches for VariantDatum[]
    //
    // ---------------------------------------------------------------------------------------------------------

    public static abstract class SelectionMetric {
        String name = null;

        public SelectionMetric(String name) {
            this.name = name;
        }

        public String getName() { return name; }

        public abstract double getThreshold(double tranche);
        public abstract double getTarget();
        public abstract void calculateRunningMetric(List data);
        public abstract double getRunningMetric(int i);
        public abstract int datumValue(VariantDatum d);
    }

    public static class TruthSensitivityMetric extends SelectionMetric {
        double[] runningSensitivity;
        int nTrueSites = 0;

        public TruthSensitivityMetric(final int nTrueSites) {
            super("TruthSensitivity");
            this.nTrueSites = nTrueSites;
        }

        @Override
        public double getThreshold(final double tranche) {
            return 1.0 - tranche/100.0; // tranche of 1 => 99% sensitivity target
        }

        @Override
        public double getTarget() { return 1.0; }

        @Override
        public void calculateRunningMetric(final List data) {
            int nCalledAtTruth = 0;
            runningSensitivity = new double[data.size()];

            for ( int i = data.size() - 1; i >= 0; i-- ) {
                VariantDatum datum = data.get(i);
                nCalledAtTruth += datum.atTruthSite ? 1 : 0;
                runningSensitivity[i] = 1 - nCalledAtTruth / (1.0 * nTrueSites);
            }
        }

        @Override
        public double getRunningMetric(final int i) {
            return runningSensitivity[i];
        }

        @Override
        public int datumValue(final VariantDatum d) {
            return d.atTruthSite ? 1 : 0;
        }
    }

    public static List findTranches(final List data,
                                                             final List tranches,
                                                             final SelectionMetric metric,
                                                             final VariantRecalibratorArgumentCollection.Mode model ) {
        return findTranches( data, tranches, metric, model, null );
    }

    public static List findTranches(
            final List data,
            final List trancheThresholds,
            final SelectionMetric metric,
            final VariantRecalibratorArgumentCollection.Mode model,
            final File debugFile ) {
        logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.size(), data.size()));

        Collections.sort( data, VariantDatum.VariantDatumLODComparator );
        metric.calculateRunningMetric(data);

        if ( debugFile != null) {
            writeTranchesDebuggingInfo(debugFile, data, metric);
        }

        List tranches = new ArrayList<>();
        for ( double trancheThreshold : trancheThresholds ) {
            TruthSensitivityTranche t = findTranche(data, metric, trancheThreshold, model);

            if ( t == null ) {
                if ( tranches.size() == 0 ) {
                    throw new UserException(String.format(
                            "Couldn't find any tranche containing variants with a %s > %.2f. Are you sure the truth files contain unfiltered variants which overlap the input data?",
                            metric.getName(),
                            metric.getThreshold(trancheThreshold)));
                }
                break;
            }

            tranches.add(t);
        }

        return tranches;
    }

    public static List findVQSLODTranches(
            final List data,
            final List trancheThresholds,
            final SelectionMetric metric,
            final VariantRecalibratorArgumentCollection.Mode model) {
        logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.size(), data.size()));

        Collections.sort( data, VariantDatum.VariantDatumLODComparator );
        metric.calculateRunningMetric(data);

        final List tranches = new ArrayList<>();
        for ( double trancheThreshold : trancheThresholds ) {
            VQSLODTranche t = findVQSLODTranche(data, metric, trancheThreshold, model);

            if ( t == null ) {
                if ( tranches.size() == 0 ) {
                    throw new UserException(String.format(
                            "Couldn't find any tranche containing variants with a %s > %.2f. Are you sure the truth files contain unfiltered variants which overlap the input data?",
                            metric.getName(),
                            metric.getThreshold(trancheThreshold)));
                }
                break;
            }

            tranches.add(t);
        }

        return tranches;

    }

    private static void writeTranchesDebuggingInfo(final File f, final List tranchesData, final SelectionMetric metric ) {
        try {
            PrintStream out = new PrintStream(f);
            out.println("Qual metricValue runningValue");
            for ( int i = 0; i < tranchesData.size(); i++ ) {
                VariantDatum  d = tranchesData.get(i);
                int score = metric.datumValue(d);
                double runningValue = metric.getRunningMetric(i);
                out.printf("%.4f %d %.4f%n", d.lod, score, runningValue);
            }
            out.close();
        } catch (FileNotFoundException e) {
            throw new UserException.CouldNotCreateOutputFile(f, e);
        }
    }

    public static TruthSensitivityTranche findTranche(
            final List data,
            final SelectionMetric metric,
            final double trancheThreshold,
            final VariantRecalibratorArgumentCollection.Mode model ) {
        logger.info(String.format("  TruthSensitivityTranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold)));

        final double metricThreshold = metric.getThreshold(trancheThreshold);
        final int n = data.size();
        for ( int i = 0; i < n; i++ ) {
            if ( metric.getRunningMetric(i) >= metricThreshold ) {
                // we've found the largest group of variants with sensitivity >= our target truth sensitivity
                final TruthSensitivityTranche t = TruthSensitivityTranche.trancheOfVariants(data, i, trancheThreshold, model);
                logger.info(String.format("  Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ",
                        trancheThreshold, metricThreshold, i, metric.getRunningMetric(i)));
                logger.info(String.format("  TruthSensitivityTranche is %s", t));
                return t;
            }
        }

        return null;
    }

    public static VQSLODTranche findVQSLODTranche(
            final List data,
            final SelectionMetric metric,
            final double trancheThreshold,
            final VariantRecalibratorArgumentCollection.Mode model ) {
        logger.info(String.format("  VQSLODTranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold)));

        final double metricThreshold = metric.getThreshold(trancheThreshold);
        final int n = data.size();
        for ( int i = 0; i < n; i++ ) {
            if ( data.get(i).lod >= trancheThreshold ) {
                // we've found the largest group of variants with LOD >= our target LOD
                final VQSLODTranche t = VQSLODTranche.trancheOfVariants(data, i, trancheThreshold, model);
                logger.info(String.format("  Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ",
                        trancheThreshold, metricThreshold, i, metric.getRunningMetric(i)));
                logger.info(String.format("  VQSLODTranche is %s", t));
                return t;
            }
        }
        //If we don't have enough good data to meet the trancheThreshold, for VQSLODs we still want to return something
        logger.info(String.format("  Could not find tranche for %.3f: %.3f threshold; reporting empty tranche",
                trancheThreshold, metricThreshold));
        final VQSLODTranche t = VQSLODTranche.emptyTranche(data, n-1, trancheThreshold, model);
        return t;
    }

    public static double fdrToTiTv(final double desiredFDR, final double targetTiTv) {
        return (1.0 - desiredFDR / 100.0) * (targetTiTv - 0.5) + 0.5;
    }

    public static int countCallsAtTruth(final List data, double minLOD ) {
        int n = 0;
        for ( VariantDatum d : data ) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); }
        return n;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy