All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.campagnelab.dl.somatic.learning.AbstractPredictMutations Maven / Gradle / Ivy

package org.campagnelab.dl.somatic.learning;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.ParameterException;
import it.unimi.dsi.logging.ProgressLogger;
import org.campagnelab.dl.somatic.utils.CalcCalibrator;
import org.campagnelab.dl.somatic.utils.ProtoPredictor;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.somatic.learning.calibrate.CalibratingModel;
import org.campagnelab.dl.varanalysis.protobuf.BaseInformationRecords;
import org.campagnelab.dl.framework.performance.AreaUnderTheROCCurve;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.PrintWriter;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;

/**
 * Created by fac2003 on 6/10/16.
 */
public abstract class AbstractPredictMutations {

    String header;
    String modelName;
    String modelDir;
    String version;
    String testSet;
    String type;
    int scoreN;
    boolean longReport;
    String s3Scores = "\tsample3Scores";
    String s3Counts = "\tsampleS3Counts";
    String formatted3 = "\tformatted3";

    protected PredictionArguments arguments;

    @Deprecated
    public AbstractPredictMutations(){
        System.out.println("deprecated predictor in use");
    }

    public AbstractPredictMutations(PredictionArguments arguments) {
        this.arguments=arguments;
        Path path = Paths.get(arguments.modelPath);
        modelName = path.getFileName().toString();
        modelDir = path.getParent().toString();
        testSet = arguments.testSet;
        version = arguments.modelVersion;
        type = arguments.type;
        longReport = arguments.longReport;
        scoreN = arguments.scoreN;
    }


    protected static PredictionArguments parseArguments(String[] args, String commandName) {
        PredictionArguments arguments = new PredictionArguments();
        JCommander commander = new JCommander(arguments);
        commander.setProgramName(commandName);
        try {
            commander.parse(args);
        } catch (ParameterException e) {

            commander.usage();
            System.out.flush();
            throw e;
        }
        return arguments;
    }
    @Deprecated
    protected void writeHeader(PrintWriter results) {
        System.out.println("deprectead header writer in use. old feature mapepr?");
    }


    protected void writeHeader(PrintWriter results, boolean isTrio) {
        if (!isTrio) {
            s3Counts = "";
            s3Scores = "";
            formatted3 = "";
        }
        header = "mutatedLabel\tProbabilityMut\tProbabilityUnmut\tcorrectness\tfrequency" +
                "\trefId\tposition\treferenceBase" +
                "\tsample1Scores\tsample2Scores" + s3Scores + "\tsumAllCounts";

        if (longReport){
            header = header + "\tmutatedBase\tsample1Counts\tsample2Counts" + s3Counts +
                    "\tformatted1\tformatted2" + formatted3;
        }
        results.append(header);
        if (cmodel!=null) {
            results.append("\tcalibratedP");
        }
        results.append("\n");
    }

    protected void writeRecordResult(MultiLayerNetwork model, PrintWriter results, FeatureMapper featureMapper, ProgressLogger pgReadWrite, BaseInformationRecords.BaseInformation record, AreaUnderTheROCCurve aucLossCalculator, boolean isTrio) {
        writeRecordResult(model, null, results, featureMapper, pgReadWrite, record, aucLossCalculator, null, isTrio);
    }
    protected void writeRecordResult(MultiLayerNetwork model, PrintWriter results, FeatureMapper featureMapper, ProgressLogger pgReadWrite, BaseInformationRecords.BaseInformation record, AreaUnderTheROCCurve aucLossCalculator) {
        writeRecordResult(model, null, results, featureMapper, pgReadWrite, record, aucLossCalculator, null, false);
    }

    CalibratingModel cmodel;

    protected void writeRecordResult(MultiLayerNetwork model, MultiLayerNetwork calibrationModel, PrintWriter results, FeatureMapper featureMapper, ProgressLogger pgReadWrite, BaseInformationRecords.BaseInformation record, AreaUnderTheROCCurve aucLossCalculator, CalcCalibrator calc, boolean isTrio) {
        INDArray testFeatures = Nd4j.zeros(1, featureMapper.numberOfFeatures());
        featureMapper.prepareToNormalize(record,0);
        featureMapper.mapFeatures(record, testFeatures, 0);
        String features = featuresToString(record,longReport);
        //boolean
        boolean mutated = record.getMutated();
        ProtoPredictor predictor = new ProtoPredictor(null,model, featureMapper);
        ProtoPredictor.Prediction prediction = predictor.mutPrediction(record);
        String formatted0 = longReport?"\t"+genFormattedString(record.getSamples(0)):"";
        String formatted1 = longReport?"\t"+genFormattedString(record.getSamples(1)):"";
        String formatted2 = longReport?isTrio?"\t"+genFormattedString(record.getSamples(2)):"":"";
        String correctness = (prediction.clas == mutated) ? "right" : "wrong";
        if (aucLossCalculator != null) {
            aucLossCalculator.observe(prediction.posProb, mutated ? 1 : -1);
        }


        results.append(String.format("%s\t%f\t%f\t%s\t%s%s%s%s",
                (mutated ? "1" : "0"),
                prediction.posProb, prediction.negProb,
                correctness, features,
                formatted0, formatted1, formatted2
        ));
        if (cmodel != null) {
            results.append(String.format("\t%f", cmodel.estimateCalibratedP(testFeatures)));
        }

        results.append("\n");
        pgReadWrite.update();

        //update tree sets
        calc.observe(prediction.posProb,mutated);

    }

    private static String genFormattedString(BaseInformationRecords.SampleInfo sample) {
        int a = sample.getCounts(0).getGenotypeCountReverseStrand() + sample.getCounts(0).getGenotypeCountForwardStrand();
        int t = sample.getCounts(1).getGenotypeCountReverseStrand() + sample.getCounts(1).getGenotypeCountForwardStrand();
        int c = sample.getCounts(2).getGenotypeCountReverseStrand() + sample.getCounts(2).getGenotypeCountForwardStrand();
        int g = sample.getCounts(3).getGenotypeCountReverseStrand() + sample.getCounts(3).getGenotypeCountForwardStrand();
        int n = sample.getCounts(4).getGenotypeCountReverseStrand() + sample.getCounts(4).getGenotypeCountForwardStrand();
        String fb = sample.getFormattedCounts().split(" ")[8];
        int numIndels = sample.getCountsCount() - 5;
        int[] indels = new int[numIndels];
        for (int i = 5; i < numIndels + 5 ; i++) {
            indels[i-5] = sample.getCounts(i).getGenotypeCountForwardStrand() + sample.getCounts(i).getGenotypeCountReverseStrand();
        }
        return String.format("counts A=%d T=%d C=%d G=%d N=%d %s indels:%s",a,t,c,g,n,fb, Arrays.toString(indels));
    }

    protected abstract String featuresToString(BaseInformationRecords.BaseInformation record,boolean longReport);
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy