All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.campagnelab.dl.somatic.learning.PredictMutations Maven / Gradle / Ivy

package org.campagnelab.dl.somatic.learning;

import it.unimi.dsi.logging.ProgressLogger;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.models.ModelLoader;
import org.campagnelab.dl.framework.performance.AreaUnderTheROCCurve;
import org.campagnelab.dl.somatic.utils.BayesCalibrator;
import org.campagnelab.dl.somatic.utils.CalcCalibrator;
import org.campagnelab.dl.somatic.utils.ProtoPredictor;
import org.campagnelab.dl.somatic.mappers.QualityFeatures;
import org.campagnelab.dl.somatic.learning.calibrate.CalibratingModel;
import org.campagnelab.dl.varanalysis.protobuf.BaseInformationRecords;
import org.campagnelab.dl.somatic.storage.RecordReader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.stream.IntStream;

/**
 * Use a trained neural network model to predict mutations.
 * 

* Created by rct2002 on 6/8/16. * @deprecated Use Predict instead. * @author Remi Torracinta */ public class PredictMutations extends AbstractPredictMutations { static private Logger LOG = LoggerFactory.getLogger(PredictMutations.class); final boolean SKIP0COUNTS = true; //will be adjusted if model's loaded featuremapper is for trios. don't manually change boolean isTrio = false; private ModelLoader modelLoader; CalcCalibrator calculator; // String[] unmutFilenames = new String[]{"unmut_genotypes_test_proto_VN.parquet", "training_batch/genotypes_proto_" + version + "_randomized_mutated.parquet"}; // String[] mutFilenames = new String[]{"mutated-MHFC-13-CTL_B_NK.parquet", "training_batch/genotypes_proto_" + version + "_randomized_mutated.parquet"}; FeatureMapper featureMapper;// = new FeatureMapperV9(); public PredictMutations(PredictionArguments arguments) { super(arguments); } public static void main(String[] args) throws Exception { PredictionArguments arguments = parseArguments(args, "PredictMutations"); PredictMutations predictor = new PredictMutations(arguments); System.out.println("modelName: " + predictor.modelName); predictor.printPredictions(predictor.version, arguments.modelPath, predictor.testSet, "tests/" + predictor.modelName + "/"); System.out.println(arguments.modelPath); } //no starting tab, no ending tab public String featuresToString(BaseInformationRecords.BaseInformation pos, boolean longReport) { int[] s1Counts = new int[10]; int[] s2Counts = new int[10]; for (int i = 0; i < 5; i++) { s1Counts[i] = pos.getSamples(0).getCounts(i).getGenotypeCountForwardStrand(); s1Counts[i + 5] = pos.getSamples(0).getCounts(i).getGenotypeCountReverseStrand(); s2Counts[i] = pos.getSamples(1).getCounts(i).getGenotypeCountForwardStrand(); s2Counts[i + 5] = pos.getSamples(1).getCounts(i).getGenotypeCountReverseStrand(); } float[] s1Scores = new float[10]; float[] s2Scores = new float[10]; for (int i = 0; i < 5; i++) { s1Scores[i] = QualityFeatures.avgQuality(ProtoPredictor.expandFreq(pos.getSamples(0).getCounts(i).getQualityScoresForwardStrandList())); s1Scores[i + 5] = QualityFeatures.avgQuality(ProtoPredictor.expandFreq(pos.getSamples(0).getCounts(i).getQualityScoresReverseStrandList())); s2Scores[i] = QualityFeatures.avgQuality(ProtoPredictor.expandFreq(pos.getSamples(1).getCounts(i).getQualityScoresForwardStrandList())); s2Scores[i + 5] = QualityFeatures.avgQuality(ProtoPredictor.expandFreq(pos.getSamples(1).getCounts(i).getQualityScoresReverseStrandList())); } String s3CountsString = ""; String s3ScoresString = ""; int s3CountsSum = 0; if (isTrio) { int[] s3Counts = new int[10]; for (int i = 0; i < 5; i++) { s3Counts[i] = pos.getSamples(2).getCounts(i).getGenotypeCountForwardStrand(); s3Counts[i + 5] = pos.getSamples(2).getCounts(i).getGenotypeCountReverseStrand(); } float[] s3Scores = new float[10]; for (int i = 0; i < 5; i++) { s3Scores[i] = QualityFeatures.avgQuality(ProtoPredictor.expandFreq(pos.getSamples(2).getCounts(i).getQualityScoresForwardStrandList())); s3Scores[i + 5] = QualityFeatures.avgQuality(ProtoPredictor.expandFreq(pos.getSamples(2).getCounts(i).getQualityScoresReverseStrandList())); } s3CountsString = "\t" + Arrays.toString(s3Counts).replaceAll(" ", ""); s3ScoresString = "\t" + Arrays.toString(s3Scores).replaceAll(" ", ""); s3CountsSum = IntStream.of(s3Counts).sum(); } String features; String refId=(pos.hasReferenceId()?pos.getReferenceId():Integer.toString(pos.getReferenceIndex())); if (!longReport) { features = (pos.hasFrequencyOfMutation() ? pos.getFrequencyOfMutation() : "") + "\t" + refId + "\t" + pos.getPosition() + "\t" + pos.getReferenceBase() + "\t" + Arrays.toString(s1Scores).replaceAll(" ", "") + "\t" + Arrays.toString(s2Scores).replaceAll(" ", "") + s3ScoresString + "\t" + Integer.toString(IntStream.of(s1Counts).sum() + IntStream.of(s2Counts).sum() + s3CountsSum); } else { features = (pos.hasFrequencyOfMutation() ? pos.getFrequencyOfMutation() : "") + "\t" + refId + "\t" + pos.getPosition() + "\t" + pos.getReferenceBase() + "\t" + Arrays.toString(s1Scores).replaceAll(" ", "") + "\t" + Arrays.toString(s2Scores).replaceAll(" ", "") + s3ScoresString + "\t" + Integer.toString(IntStream.of(s1Counts).sum() + IntStream.of(s2Counts).sum() + s3CountsSum) + (pos.hasMutatedBase() ? pos.getMutatedBase() : "") + "\t" + Arrays.toString(s1Counts).replaceAll(" ", "") + "\t" + Arrays.toString(s2Counts).replaceAll(" ", "") + s3CountsString; } return features; } public void printPredictions(String prefix, String modelPath, String evaluationDataFilename, String resultsPath) throws Exception { modelLoader = new ModelLoader(modelPath); RecordReader reader = new RecordReader(evaluationDataFilename); featureMapper = modelLoader.loadFeatureMapper(reader.getProperties()); if (featureMapper.getClass().getCanonicalName().contains("Trio")) { //we have a trio mapper, need to output features for a third sample isTrio = true; System.out.println("setting output to trio mode"); } calculator = new BayesCalibrator(modelPath, prefix, false); MultiLayerNetwork model = modelLoader.loadMultiLayerNetwork(prefix); if (model == null) { System.err.println("Cannot load model with prefix: " + prefix); System.exit(1); } MultiLayerNetwork calibratingModel = modelLoader.loadMultiLayerNetwork(prefix + "Calibrated"); if (cmodel == null && calibratingModel != null) { cmodel = new CalibratingModel(model, featureMapper, calibratingModel); } File dir = new File(resultsPath); // attempt to create the directory here dir.mkdir(); //initialize results printer String resultFilename = resultsPath + prefix + "-" + type + ".tsv"; PrintWriter results = new PrintWriter(resultFilename, "UTF-8"); writeHeader(results, isTrio); System.out.println("Writing predictions to " + resultFilename); //may need to adjust batch size and write outputs piecewise if test sets are very large //BaseInformationIterator baseIter = new BaseInformationIterator(testsetPath, Integer.MAX_VALUE, new FeatureMapperV2(), new SimpleFeatureCalculator()); //DataSet ds = baseIter.next(); //set up logger ProgressLogger pgReadWrite = new ProgressLogger(LOG); pgReadWrite.itemsName = prefix; pgReadWrite.expectedUpdates = Math.min(arguments.scoreN,reader.getTotalRecords()); pgReadWrite.displayFreeMemory = true; pgReadWrite.start(); AreaUnderTheROCCurve aucLossCalculator = new AreaUnderTheROCCurve(50000); int index = 0; for (BaseInformationRecords.BaseInformation record : reader) { //don't bother trying to make predictions when a sample has 0 counts. model outputs nan apparently. if (SKIP0COUNTS) { boolean bothHaveCount = true; for (BaseInformationRecords.SampleInfo sample : record.getSamplesList()) { boolean hasCount = false; for (BaseInformationRecords.CountInfo count : sample.getCountsList()) { if (count.getGenotypeCountReverseStrand() > 0 || count.getGenotypeCountForwardStrand() > 0) { hasCount = true; break; } } if (!hasCount) { bothHaveCount = false; break; } } if (!bothHaveCount) { continue; } } writeRecordResult(model, calibratingModel, results, featureMapper, pgReadWrite, record, aucLossCalculator, calculator, isTrio); index++; if (index > scoreN) break; } System.out.println("AUC on " + prefix + "=" + aucLossCalculator.evaluateStatistic()); results.close(); pgReadWrite.stop(); calculator.save(); modelLoader.writeTestCount(reader.getTotalRecords()); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy