Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.broadinstitute.hellbender.tools.walkers.vqsr.CNNVariantTrain Maven / Gradle / Ivy
package org.broadinstitute.hellbender.tools.walkers.vqsr;
import org.broadinstitute.barclay.argparser.*;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Train a Convolutional Neural Network (CNN) for filtering variants.
* This tool expects requires training data generated by {@link CNNVariantWriteTensors}.
*
*
* Inputs
*
* data-dir The training data created by {@link CNNVariantWriteTensors}.
* The --tensor-type argument determines what types of tensors the model will expect.
* Set it to "reference" for 1D tensors or "read_tensor" for 2D tensors.
*
*
* Outputs
*
* output-dir The model weights file and semantic configuration json are saved here.
* This default to the current working directory.
* model-name The name for your model.
*
*
* Usage example
*
* Train a 1D CNN on Reference Tensors
*
* gatk CNNVariantTrain \
* -tensor-type reference \
* -input-tensor-dir my_tensor_folder \
* -model-name my_1d_model
*
*
* Train a 2D CNN on Read Tensors
*
* gatk CNNVariantTrain \
* -input-tensor-dir my_tensor_folder \
* -tensor-type read-tensor \
* -model-name my_2d_model
*
*
*/
@CommandLineProgramProperties(
summary = "Train a CNN model for filtering variants",
oneLineSummary = "Train a CNN model for filtering variants",
programGroup = VariantFilteringProgramGroup.class
)
@DocumentedFeature
@ExperimentalFeature
public class CNNVariantTrain extends CommandLineProgram {
@Argument(fullName = "input-tensor-dir", shortName = "input-tensor-dir", doc = "Directory of training tensors to create.")
private String inputTensorDir;
@Argument(fullName = "output-dir", shortName = "output-dir", doc = "Directory where models will be saved, defaults to current working directory.", optional = true)
private String outputDir = "./";
@Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Type of tensors to use as input reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
private TensorType tensorType = TensorType.reference;
@Argument(fullName = "model-name", shortName = "model-name", doc = "Name of the model to be trained.", optional = true)
private String modelName = "variant_filter_model";
@Argument(fullName = "epochs", shortName = "epochs", doc = "Maximum number of training epochs.", optional = true, minValue = 0)
private int epochs = 10;
@Argument(fullName = "training-steps", shortName = "training-steps", doc = "Number of training steps per epoch.", optional = true, minValue = 0)
private int trainingSteps = 10;
@Argument(fullName = "validation-steps", shortName = "validation-steps", doc = "Number of validation steps per epoch.", optional = true, minValue = 0)
private int validationSteps = 2;
@Argument(fullName = "image-dir", shortName = "image-dir", doc = "Path where plots and figures are saved.", optional = true)
private String imageDir;
@Argument(fullName = "conv-width", shortName = "conv-width", doc = "Width of convolution kernels", optional = true)
private int convWidth = 5;
@Argument(fullName = "conv-height", shortName = "conv-height", doc = "Height of convolution kernels", optional = true)
private int convHeight = 5;
@Argument(fullName = "conv-dropout", shortName = "conv-dropout", doc = "Dropout rate in convolution layers", optional = true)
private float convDropout = 0.0f;
@Argument(fullName = "conv-batch-normalize", shortName = "conv-batch-normalize", doc = "Batch normalize convolution layers", optional = true)
private boolean convBatchNormalize = false;
@Argument(fullName = "conv-layers", shortName = "conv-layers", doc = "List of number of filters to use in each convolutional layer", optional = true)
private List convLayers = new ArrayList();
@Argument(fullName = "padding", shortName = "padding", doc = "Padding for convolution layers, valid or same", optional = true)
private String padding = "valid";
@Argument(fullName = "spatial-dropout", shortName = "spatial-dropout", doc = "Spatial dropout on convolution layers", optional = true)
private boolean spatialDropout = false;
@Argument(fullName = "fc-layers", shortName = "fc-layers", doc = "List of number of filters to use in each fully-connected layer", optional = true)
private List fcLayers = new ArrayList();
@Argument(fullName = "fc-dropout", shortName = "fc-dropout", doc = "Dropout rate in fully-connected layers", optional = true)
private float fcDropout = 0.0f;
@Argument(fullName = "fc-batch-normalize", shortName = "fc-batch-normalize", doc = "Batch normalize fully-connected layers", optional = true)
private boolean fcBatchNormalize = false;
@Argument(fullName = "annotation-units", shortName = "annotation-units", doc = "Number of units connected to the annotation input layer", optional = true)
private int annotationUnits = 16;
@Argument(fullName = "annotation-shortcut", shortName = "annotation-shortcut", doc = "Shortcut connections on the annotation layers.", optional = true)
private boolean annotationShortcut = false;
@Advanced
@Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true)
private boolean channelsLast = true;
@Advanced
@Argument(fullName = "annotation-set", shortName = "annotation-set", doc = "Which set of annotations to use.", optional = true)
private String annotationSet = "best_practices";
// Start the Python executor. This does not actually start the Python process, but fails if python can't be located
final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(true);
@Override
protected void onStartup() {
PythonScriptExecutor.checkPythonEnvironmentForPackage("vqsr_cnn");
}
@Override
protected Object doWork() {
final Resource pythonScriptResource = new Resource("training.py", CNNVariantTrain.class);
List arguments = new ArrayList<>(Arrays.asList(
"--data_dir", inputTensorDir,
"--output_dir", outputDir,
"--tensor_name", tensorType.name(),
"--annotation_set", annotationSet,
"--conv_width", Integer.toString(convWidth),
"--conv_height", Integer.toString(convHeight),
"--conv_dropout", Float.toString(convDropout),
"--padding", padding,
"--fc_dropout", Float.toString(fcDropout),
"--annotation_units", Integer.toString(annotationUnits),
"--epochs", Integer.toString(epochs),
"--training_steps", Integer.toString(trainingSteps),
"--validation_steps", Integer.toString(validationSteps),
"--gatk_version", this.getVersion(),
"--id", modelName));
// Add boolean arguments
if(channelsLast){
arguments.add("--channels_last");
} else {
arguments.add("--channels_first");
}
if(imageDir != null){
arguments.addAll(Arrays.asList("--image_dir", imageDir));
}
if (convLayers.size() == 0 && fcLayers.size() == 0){
if (tensorType == TensorType.reference) {
arguments.addAll(Arrays.asList("--mode", "train_default_1d_model"));
} else if (tensorType == TensorType.read_tensor) {
arguments.addAll(Arrays.asList("--mode", "train_default_2d_model"));
} else {
throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
}
} else { // Command line specified custom architecture
if(convBatchNormalize){
arguments.add("--conv_batch_normalize");
}
if(fcBatchNormalize){
arguments.add("--fc_batch_normalize");
}
if(spatialDropout){
arguments.add("--spatial_dropout");
}
if(annotationShortcut){
arguments.add("--annotation_shortcut");
}
// Add list arguments
arguments.add("--conv_layers");
for(Integer cl : convLayers){
arguments.add(Integer.toString(cl));
}
arguments.add("--fc_layers");
for(Integer fl : fcLayers){
arguments.add(Integer.toString(fl));
}
if (tensorType == TensorType.reference) {
arguments.addAll(Arrays.asList("--mode", "train_args_model_on_reference_and_annotations"));
} else if (tensorType == TensorType.read_tensor) {
arguments.addAll(Arrays.asList("--mode", "train_args_model_on_read_tensors_and_annotations"));
} else {
throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
}
}
logger.info("Args are:"+ Arrays.toString(arguments.toArray()));
final boolean pythonReturnCode = pythonExecutor.executeScript(
pythonScriptResource,
null,
arguments
);
return pythonReturnCode;
}
}