All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.broadinstitute.hellbender.tools.walkers.vqsr.CNNVariantTrain Maven / Gradle / Ivy

There is a newer version: 4.6.0.0
Show newest version
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import org.broadinstitute.barclay.argparser.*;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


/**
 * Train a Convolutional Neural Network (CNN) for filtering variants.
 * This tool expects requires training data generated by {@link CNNVariantWriteTensors}.
 *
 *
 * 

Inputs

*
    *
  • data-dir The training data created by {@link CNNVariantWriteTensors}.
  • *
  • The --tensor-type argument determines what types of tensors the model will expect. * Set it to "reference" for 1D tensors or "read_tensor" for 2D tensors.
  • *
* *

Outputs

*
    *
  • output-dir The model weights file and semantic configuration json are saved here. * This default to the current working directory.
  • *
  • model-name The name for your model.
  • *
* *

Usage example

* *

Train a 1D CNN on Reference Tensors

*
 * gatk CNNVariantTrain \
 *   -tensor-type reference \
 *   -input-tensor-dir my_tensor_folder \
 *   -model-name my_1d_model
 * 
* *

Train a 2D CNN on Read Tensors

*
 * gatk CNNVariantTrain \
 *   -input-tensor-dir my_tensor_folder \
 *   -tensor-type read-tensor \
 *   -model-name my_2d_model
 * 
* */ @CommandLineProgramProperties( summary = "Train a CNN model for filtering variants", oneLineSummary = "Train a CNN model for filtering variants", programGroup = VariantFilteringProgramGroup.class ) @DocumentedFeature @ExperimentalFeature public class CNNVariantTrain extends CommandLineProgram { @Argument(fullName = "input-tensor-dir", shortName = "input-tensor-dir", doc = "Directory of training tensors to create.") private String inputTensorDir; @Argument(fullName = "output-dir", shortName = "output-dir", doc = "Directory where models will be saved, defaults to current working directory.", optional = true) private String outputDir = "./"; @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Type of tensors to use as input reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true) private TensorType tensorType = TensorType.reference; @Argument(fullName = "model-name", shortName = "model-name", doc = "Name of the model to be trained.", optional = true) private String modelName = "variant_filter_model"; @Argument(fullName = "epochs", shortName = "epochs", doc = "Maximum number of training epochs.", optional = true, minValue = 0) private int epochs = 10; @Argument(fullName = "training-steps", shortName = "training-steps", doc = "Number of training steps per epoch.", optional = true, minValue = 0) private int trainingSteps = 10; @Argument(fullName = "validation-steps", shortName = "validation-steps", doc = "Number of validation steps per epoch.", optional = true, minValue = 0) private int validationSteps = 2; @Argument(fullName = "image-dir", shortName = "image-dir", doc = "Path where plots and figures are saved.", optional = true) private String imageDir; @Argument(fullName = "conv-width", shortName = "conv-width", doc = "Width of convolution kernels", optional = true) private int convWidth = 5; @Argument(fullName = "conv-height", shortName = "conv-height", doc = "Height of convolution kernels", optional = true) private int convHeight = 5; @Argument(fullName = "conv-dropout", shortName = "conv-dropout", doc = "Dropout rate in convolution layers", optional = true) private float convDropout = 0.0f; @Argument(fullName = "conv-batch-normalize", shortName = "conv-batch-normalize", doc = "Batch normalize convolution layers", optional = true) private boolean convBatchNormalize = false; @Argument(fullName = "conv-layers", shortName = "conv-layers", doc = "List of number of filters to use in each convolutional layer", optional = true) private List convLayers = new ArrayList(); @Argument(fullName = "padding", shortName = "padding", doc = "Padding for convolution layers, valid or same", optional = true) private String padding = "valid"; @Argument(fullName = "spatial-dropout", shortName = "spatial-dropout", doc = "Spatial dropout on convolution layers", optional = true) private boolean spatialDropout = false; @Argument(fullName = "fc-layers", shortName = "fc-layers", doc = "List of number of filters to use in each fully-connected layer", optional = true) private List fcLayers = new ArrayList(); @Argument(fullName = "fc-dropout", shortName = "fc-dropout", doc = "Dropout rate in fully-connected layers", optional = true) private float fcDropout = 0.0f; @Argument(fullName = "fc-batch-normalize", shortName = "fc-batch-normalize", doc = "Batch normalize fully-connected layers", optional = true) private boolean fcBatchNormalize = false; @Argument(fullName = "annotation-units", shortName = "annotation-units", doc = "Number of units connected to the annotation input layer", optional = true) private int annotationUnits = 16; @Argument(fullName = "annotation-shortcut", shortName = "annotation-shortcut", doc = "Shortcut connections on the annotation layers.", optional = true) private boolean annotationShortcut = false; @Advanced @Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true) private boolean channelsLast = true; @Advanced @Argument(fullName = "annotation-set", shortName = "annotation-set", doc = "Which set of annotations to use.", optional = true) private String annotationSet = "best_practices"; // Start the Python executor. This does not actually start the Python process, but fails if python can't be located final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(true); @Override protected void onStartup() { PythonScriptExecutor.checkPythonEnvironmentForPackage("vqsr_cnn"); } @Override protected Object doWork() { final Resource pythonScriptResource = new Resource("training.py", CNNVariantTrain.class); List arguments = new ArrayList<>(Arrays.asList( "--data_dir", inputTensorDir, "--output_dir", outputDir, "--tensor_name", tensorType.name(), "--annotation_set", annotationSet, "--conv_width", Integer.toString(convWidth), "--conv_height", Integer.toString(convHeight), "--conv_dropout", Float.toString(convDropout), "--padding", padding, "--fc_dropout", Float.toString(fcDropout), "--annotation_units", Integer.toString(annotationUnits), "--epochs", Integer.toString(epochs), "--training_steps", Integer.toString(trainingSteps), "--validation_steps", Integer.toString(validationSteps), "--gatk_version", this.getVersion(), "--id", modelName)); // Add boolean arguments if(channelsLast){ arguments.add("--channels_last"); } else { arguments.add("--channels_first"); } if(imageDir != null){ arguments.addAll(Arrays.asList("--image_dir", imageDir)); } if (convLayers.size() == 0 && fcLayers.size() == 0){ if (tensorType == TensorType.reference) { arguments.addAll(Arrays.asList("--mode", "train_default_1d_model")); } else if (tensorType == TensorType.read_tensor) { arguments.addAll(Arrays.asList("--mode", "train_default_2d_model")); } else { throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name()); } } else { // Command line specified custom architecture if(convBatchNormalize){ arguments.add("--conv_batch_normalize"); } if(fcBatchNormalize){ arguments.add("--fc_batch_normalize"); } if(spatialDropout){ arguments.add("--spatial_dropout"); } if(annotationShortcut){ arguments.add("--annotation_shortcut"); } // Add list arguments arguments.add("--conv_layers"); for(Integer cl : convLayers){ arguments.add(Integer.toString(cl)); } arguments.add("--fc_layers"); for(Integer fl : fcLayers){ arguments.add(Integer.toString(fl)); } if (tensorType == TensorType.reference) { arguments.addAll(Arrays.asList("--mode", "train_args_model_on_reference_and_annotations")); } else if (tensorType == TensorType.read_tensor) { arguments.addAll(Arrays.asList("--mode", "train_args_model_on_read_tensors_and_annotations")); } else { throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name()); } } logger.info("Args are:"+ Arrays.toString(arguments.toArray())); final boolean pythonReturnCode = pythonExecutor.executeScript( pythonScriptResource, null, arguments ); return pythonReturnCode; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy