All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.campagnelab.dl.somatic.learning.TrainSomaticModelOnGPU Maven / Gradle / Ivy

package org.campagnelab.dl.somatic.learning;

import it.unimi.dsi.logging.ProgressLogger;
import org.campagnelab.dl.framework.gpu.ParameterPrecision;
import org.campagnelab.dl.somatic.learning.iterators.NamedCachingDataSetIterator;
import org.campagnelab.dl.somatic.learning.iterators.NamedDataSetIterator;
import org.campagnelab.dl.framework.models.ModelSaver;
import org.campagnelab.dl.somatic.learning.performance.MeasurePerformance;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.dataset.api.iterator.CachingDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.cache.InMemoryDataSetCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

//import org.nd4j.jita.conf.CudaEnvironment;

/**
 * Train a neural network to predict mutations.
 * 

* Created by fac2003 on 5/21/16. * * @author Fabien Campagne */ public class TrainSomaticModelOnGPU extends SomaticTrainer { public static final int MIN_ITERATION_BETWEEN_BEST_MODEL = 1000; static private Logger LOG = LoggerFactory.getLogger(TrainSomaticModelOnGPU.class); private String validationDatasetFilename = null; @Override public void execute() { if (args().trainingSets.size() == 0) { System.out.println("Please add at least one training set to the args()."); return; } /* CudaEnvironment.getInstance().getConfiguration() .enableDebug(false) .allowMultiGPU(true) .setMaximumGridSize(512) .setMaximumBlockSize(512) .setMaximumDeviceCacheableLength(1024 * 1024 * 1024L) .setMaximumDeviceCache(8L * 1024 * 1024 * 1024L) .setMaximumHostCacheableLength(1024 * 1024 * 1024L) .setMaximumHostCache(8L * 1024 * 1024 * 1024L) // cross-device access is used for faster model averaging over pcie .allowCrossDeviceAccess(true);*/ if ("FP16".equals(args().precision)) { precision = ParameterPrecision.FP16; System.out.println("Parameter precision set to FP16."); } super.execute(); } public static void main(String[] args) throws IOException { TrainSomaticModelOnGPU tool = new TrainSomaticModelOnGPU(); SomaticTrainingArguments arguments = tool.createArguments(); tool.parseArguments(args, "TrainSomaticModelOnGPU", arguments); if ("FP16".equals(tool.args().precision)) { DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); } tool.execute(); tool.writeModelingConditions(arguments); System.err.println("Allow Multi-GPU"); } @Override protected NamedDataSetIterator decorateIterator(NamedDataSetIterator iterator) { return new NamedCachingDataSetIterator(new CachingDataSetIterator(iterator, new InMemoryDataSetCache()), ""); } @Override protected EarlyStoppingResult train(MultiLayerConfiguration conf, DataSetIterator async) throws IOException { validationDatasetFilename = args().validationSet; //check validation file for error if (!(new File(validationDatasetFilename).exists())) { throw new IOException("Validation file not found! " + validationDatasetFilename); } perf = new MeasurePerformance(args().numValidation, validationDatasetFilename, args().miniBatchSize, featureCalculator, labelMapper); ParallelWrapper wrapper = new ParallelWrapper.Builder(net) .prefetchBuffer(args().miniBatchSize) .workers(4) .averagingFrequency(1) .reportScoreAfterAveraging(false) .useLegacyAveraging(false) .build(); //Do training, and then generate and print samples from network int miniBatchNumber = 0; boolean init = true; ProgressLogger pgEpoch = new ProgressLogger(LOG); pgEpoch.displayLocalSpeed = true; pgEpoch.itemsName = "epoch"; pgEpoch.expectedUpdates = args().maxEpochs; pgEpoch.start(); bestScore = Double.MAX_VALUE; ModelSaver saver = new ModelSaver(directory); int numExamplesUsed = 0; Map scoreMap = new HashMap(); double bestAUC = 0; int notImproved = 0; int iter = 0; int epoch; assert async.resetSupported(): "Iterator must support reset."; for (epoch = 0; epoch < args().maxEpochs; epoch++) { wrapper.fit(async); pgEpoch.update(); double score = net.score(); scoreMap.put(epoch, score); bestScore = Math.min(score, bestScore); writeBestScoreFile(); async.reset(); saver.saveLatestModel(net, net.score()); writeProperties(this); writeBestScoreFile(); if (epoch % args().validateEvery == 0) { double auc = estimateTestSetPerf(epoch, iter); performanceLogger.log("epochs", numExamplesUsed, epoch, score, auc); if (auc > bestAUC) { saver.saveModel(net, "bestAUC", auc); bestAUC = auc; writeBestAUC(bestAUC); performanceLogger.log("bestAUC", numExamplesUsed, epoch, bestScore, bestAUC); notImproved = 0; } else { notImproved++; } if (notImproved > args().stopWhenEpochsWithoutImprovement) { // we have not improved after earlyStopCondition epoch, time to stop. break; } System.out.printf("epoch %d auc=%g%n", epoch, auc); } numExamplesUsed += async.totalExamples(); performanceLogger.write(); } pgEpoch.stop(); //TODO enable with 0.6.0+ // wrapper.shutdown(); return new EarlyStoppingResult(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, "not early stopping", scoreMap, performanceLogger.getBestEpoch("bestAUC"), bestScore, args().maxEpochs, net); } private void writeBestAUC(double bestAUC) { try { FileWriter scoreWriter = new FileWriter(directory + "/bestAUC"); scoreWriter.append(Double.toString(bestAUC)); scoreWriter.close(); } catch (IOException e) { } } MeasurePerformance perf; protected double estimateTestSetPerf(int epoch, int iter) throws IOException { if (validationDatasetFilename == null) return 0; double auc = perf.estimateAUC(net); System.out.printf("Epoch %d Iteration %d AUC=%f %n", epoch, iter, auc); return auc; } @Override public SomaticTrainingArguments createArguments() { return new SomaticTrainingArguments(); } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy