All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.campagnelab.dl.somatic.learning.SomaticTrainer Maven / Gradle / Ivy

package org.campagnelab.dl.somatic.learning;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.ParameterException;
import it.unimi.dsi.fastutil.floats.FloatArraySet;
import it.unimi.dsi.fastutil.floats.FloatSet;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.campagnelab.dl.framework.gpu.ParameterPrecision;
import org.campagnelab.dl.framework.mappers.ConfigurableFeatureMapper;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mappers.LabelMapper;
import org.campagnelab.dl.framework.tools.TrainingArguments;
import org.campagnelab.dl.somatic.mappers.SimpleFeatureCalculator;
import org.campagnelab.dl.framework.models.ModelLoader;
import org.campagnelab.dl.framework.architecture.nets.NeuralNetAssembler;
import org.campagnelab.dl.somatic.learning.iterators.BaseInformationConcatIterator;
import org.campagnelab.dl.somatic.learning.iterators.BaseInformationIterator;
import org.campagnelab.dl.somatic.learning.iterators.FirstNIterator;
import org.campagnelab.dl.somatic.learning.iterators.NamedDataSetIterator;
import org.campagnelab.dl.framework.models.ModelPropertiesHelper;
import org.campagnelab.dl.framework.performance.PerformanceLogger;
import org.campagnelab.dl.framework.tools.arguments.ConditionRecordingTool;
import org.campagnelab.goby.baseinfo.SequenceBaseInformationReader;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Date;
import java.util.List;
import java.util.Properties;

/**
 * Abstract class to facilitate variations of training protocols.
 * Created by fac2003 on 7/12/16.
 */
public abstract class SomaticTrainer extends ConditionRecordingTool {
    static private Logger LOG = LoggerFactory.getLogger(TrainSomaticModel.class);
    protected ParameterPrecision precision = ParameterPrecision.FP32;

    protected static TrainingArguments parseArguments(String[] args, String commandName) {
        SomaticTrainingArguments arguments = new SomaticTrainingArguments();
        JCommander commander = new JCommander(arguments);
        commander.setProgramName(commandName);
        try {
            commander.parse(args);
        } catch (ParameterException e) {

            commander.usage();
            throw e;

        }
        return arguments;
    }

    @Override
    public void execute() {
        if ("FP16".equals(args().precision)) {
            precision = ParameterPrecision.FP16;
            System.out.println("Parameter precision set to FP16.");
        }
        if ("FP16".equals(args().precision)) {
            DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
        }
        if (args().getTrainingSets().length == 0) {
            System.err.println("You must provide training datasets.");
        }
        FeatureMapper featureMapper = null;
        try {
            featureMapper = configureFeatureMapper(args().featureMapperClassname, args().isTrio, args().getTrainingSets());
            execute(featureMapper, args().getTrainingSets(), args().miniBatchSize);
        } catch (IOException e) {
            System.err.println("An exception occured. Details may be provided below");
            e.printStackTrace();
        }

    }

    protected double dropoutRate = 0.5;
    protected LabelMapper labelMapper = new SimpleFeatureCalculator();
    protected FeatureMapper featureCalculator;
    protected String directory;
    protected long time;
    protected int numHiddenNodes;
    protected String attempt;
    protected double bestScore;
    protected int numTrainingFiles;
    protected MultiLayerNetwork net;
    protected LossFunctions.LossFunction lossFunction;
    protected PerformanceLogger performanceLogger;


    public void execute(FeatureMapper featureCalculator, String trainingDataset[], int miniBatchSize) throws IOException {
        if (args().previousModelPath != null) {
            System.out.println(String.format("Resuming training with %s model parameters from %s %n", args().previousModelName, args().previousModelPath));
        }
        this.featureCalculator = featureCalculator;
        this.numTrainingFiles = trainingDataset.length;

        String path = "";

        time = new Date().getTime();
        System.out.println("time: " + time);
        System.out.println("epochs: " + args().maxEpochs);
        System.out.println(featureCalculator.getClass().getTypeName());
        directory = "models/" + Long.toString(time);
        attempt = "batch=" + miniBatchSize + "-learningRate=" + args().learningRate + "-time=" + time;
        int generateSamplesEveryNMinibatches = 10;
        FileUtils.forceMkdir(new File(directory));

        // Assemble the training iterator:
        labelMapper = new SimpleFeatureCalculator();
        List trainIterList = new ObjectArrayList<>(trainingDataset.length);
        for (int i = 0; i < trainingDataset.length; i++) {
            trainIterList.add(new BaseInformationIterator(trainingDataset[i], miniBatchSize,
                    featureCalculator, labelMapper));
        }
        NamedDataSetIterator async = new BaseInformationConcatIterator(trainIterList, miniBatchSize, featureCalculator, labelMapper);
        if (args().numTraining != Integer.MAX_VALUE) {
            async = new FirstNIterator(async, args().numTraining);
        }
        async = decorateIterator(async);
        System.out.println("Estimating scaling parameters:");
        //Load the training data:
        int numInputs = async.inputColumns();
        int numOutputs = async.totalOutcomes();
        numHiddenNodes = numInputs * 5;
        NeuralNetAssembler assembler = getNeuralNetAssembler();
        assembler.setSeed(args().seed);
        assembler.setLearningRate(args().learningRate);
        assembler.setDropoutRate(args().dropoutRate);
        assembler.setNumHiddenNodes(numHiddenNodes);
        assembler.setNumInputs(numInputs);
        assembler.setNumOutputs(numOutputs);

        lossFunction = LossFunctions.LossFunction.MCXENT;

        assembler.setLossFunction(lossFunction);
        assembler.setRegularizationRate(args().regularizationRate);

        //   assembler.setDropoutRate(dropoutRate);


        //changed from XAVIER in iteration 14
        assembler.setWeightInitialization(WeightInit.RELU);
        assembler.setLearningRatePolicy(LearningRatePolicy.Score);
        MultiLayerConfiguration conf = assembler.createNetwork();
        net = new MultiLayerNetwork(conf);
        net.init();
        if (args().previousModelPath != null) {
            // Load the parameters of a previously trained model and set them on the new model to continue
            // training where we left it off. Note that models must have the same architecture or setting
            // parameters will fail.
            ModelLoader loader = new ModelLoader(args().previousModelPath);
            Model savedModel = loader.loadModel(args().previousModelName);
            MultiLayerNetwork savedNet = savedModel instanceof MultiLayerNetwork ?
                    (MultiLayerNetwork) savedModel : null;
            if (savedNet == null || savedNet.getUpdater() == null || savedNet.params() == null) {
                System.err.println("Unable to load model or updater from " + args().previousModelPath);
            } else {
                net.setUpdater(savedNet.getUpdater());
                net.setParams(savedNet.params());

            }
        }

        //Print the  number of parameters in the network (and for each layer)
        Layer[] layers = net.getLayers();
        int totalNumParams = 0;
        for (int i = 0; i < layers.length; i++) {
            int nParams = layers[i].numParams();
            System.out.println("Number of parameters in layer " + i + ": " + nParams);
            totalNumParams += nParams;
        }
        System.out.println("Total number of network parameters: " + totalNumParams);

        writeProperties(this);
        performanceLogger = new PerformanceLogger(directory);
        EarlyStoppingResult result = train(conf, async);


        //Print out the results:
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + performanceLogger.getBestScore());
        System.out.println("AUC at best epoch: " + performanceLogger.getBestAUC());

        writeProperties(this);
        writeBestScoreFile();
        System.out.println("Model completed, saved at time: " + attempt);
        performanceLogger.write();
        resultValues().put("AUC", performanceLogger.getBestAUC());
        resultValues().put("score", performanceLogger.getBestScore());
        resultValues().put("bestModelEpoch", performanceLogger.getBestEpoch("bestAUC"));
        resultValues().put("model-time", time);
    }

    private NeuralNetAssembler getNeuralNetAssembler() {
        try {
            return (NeuralNetAssembler) Class.forName(args().architectureClassname).newInstance();
        } catch (Exception e) {
            System.err.println("Unable to instantiate net architecture " + args().architectureClassname);
            System.exit(1);
        }
        return null;
    }


    protected NamedDataSetIterator decorateIterator(NamedDataSetIterator iterator) {
        return iterator;
    }

    protected void writeBestScoreFile() throws IOException {

        FileWriter scoreWriter = new FileWriter(directory + "/bestScore");
        scoreWriter.append(Double.toString(performanceLogger.getBestScore()));
        scoreWriter.close();
    }

    protected void writeProperties(SomaticTrainer trainer) throws IOException {
        ModelPropertiesHelper mpHelper = new ModelPropertiesHelper();
        appendProperties(mpHelper);
        mpHelper.addProperties(getReaderProperties(trainer.args().trainingSets.get(0)));
        mpHelper.writeProperties(directory);
    }

    protected abstract EarlyStoppingResult train(MultiLayerConfiguration conf, DataSetIterator async)
            throws IOException;

    protected static void saveModel(LocalFileModelSaver saver, String directory, String prefix, MultiLayerNetwork net) throws IOException {
        FilenameUtils.concat(directory, prefix + "ModelConf.json");
        String confOut = FilenameUtils.concat(directory, prefix + "ModelConf.json");
        String paramOut = FilenameUtils.concat(directory, prefix + "ModelParams.bin");
        String updaterOut = FilenameUtils.concat(directory, prefix + "ModelUpdater.bin");
        save(net, confOut, paramOut, updaterOut);
    }

    protected static void save(MultiLayerNetwork net, String confOut, String paramOut, String updaterOut) throws IOException {
        String confJSON = net.getLayerWiseConfigurations().toJson();
        INDArray params = net.params();
        Updater updater = net.getUpdater();

        FileUtils.writeStringToFile(new File(confOut), confJSON, "UTF-8");
        try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(Paths.get(paramOut))))) {
            Nd4j.write(params, dos);
        }

        try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(new File(updaterOut))))) {
            oos.writeObject(updater);
        }
    }

    protected static int numLabels(INDArray labels) {
        FloatSet set = new FloatArraySet();
        for (int i = 0; i < labels.size(0); i++) {
            set.add(labels.getFloat(i));
        }
        return set.size();
    }


    public void appendProperties(ModelPropertiesHelper helper) {
        helper.setFeatureCalculator(featureCalculator);
        helper.setLearningRate(args().learningRate);
        helper.setDropoutRate(args().dropoutRate);
        helper.setNumHiddenNodes(numHiddenNodes);
        helper.setMiniBatchSize(args().miniBatchSize);
        // mpHelper.setBestScore(bestScore);
        helper.setNumEpochs(args().maxEpochs);
        helper.setNumTrainingSets(numTrainingFiles);
        helper.setTime(time);
        helper.setSeed(args().seed);
        helper.setLossFunction(lossFunction.name());
        helper.setEarlyStopCriterion(args().stopWhenEpochsWithoutImprovement);
        helper.setRegularization(args().regularizationRate);
        helper.setPrecision(precision);
    }


    public static FeatureMapper configureFeatureMapper(String featureMapperClassname, boolean isTrio, String[] trainingSets) throws IOException {


        try {
            Class clazz = Class.forName(featureMapperClassname + (isTrio ? "Trio" : ""));
            final FeatureMapper featureMapper = (FeatureMapper) clazz.newInstance();
            if (featureMapper instanceof ConfigurableFeatureMapper) {
                ConfigurableFeatureMapper cmapper = (ConfigurableFeatureMapper) featureMapper;
                if (trainingSets.length > 1) {
                    LOG.warn("sbip properties are only read from the first training set. Concat the files before training if you need to use properties across all inputs.");
                }
                final Properties properties = getReaderProperties(trainingSets[0]);
                cmapper.configure(properties);
            }
            return featureMapper;
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }

    private static Properties getReaderProperties(String trainingSet) throws IOException {
        SequenceBaseInformationReader reader = new SequenceBaseInformationReader(trainingSet);
        final Properties properties = reader.getProperties();
        reader.close();
        return properties;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy