org.deeplearning4j.zoo.model.VGG19 Maven / Gradle / Ivy
package org.deeplearning4j.zoo.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.zoo.ModelMetaData;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.ZooType;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
/**
* VGG-19, from Very Deep Convolutional Networks for Large-Scale Image Recognition
* https://arxiv.org/abs/1409.1556)
*
* ImageNet weights for this model are available and have been converted from https://github.com/fchollet/keras/tree/1.1.2/keras/applications.
*
* @author Justin Long (crockpotveggies)
*/
@AllArgsConstructor
@Builder
public class VGG19 extends ZooModel {
@Builder.Default private long seed = 1234;
@Builder.Default private int[] inputShape = new int[] {3, 224, 224};
@Builder.Default private int numClasses = 0;
@Builder.Default private IUpdater updater = new Nesterovs();
@Builder.Default private CacheMode cacheMode = CacheMode.NONE;
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
@Builder.Default private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.NO_WORKSPACE;
private VGG19() {}
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return "http://blob.deeplearning4j.org/models/vgg19_dl4j_inference.zip";
else
return null;
}
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 2782932419L;
else
return 0L;
}
@Override
public Class extends Model> modelType() {
return ComputationGraph.class;
}
public ComputationGraphConfiguration conf() {
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(updater)
.activation(Activation.RELU)
.cacheMode(cacheMode)
.trainingWorkspaceMode(workspaceMode)
.inferenceWorkspaceMode(workspaceMode)
.graphBuilder()
.addInputs("in")
// block 1
.layer(0, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nIn(inputShape[0]).nOut(64)
.cudnnAlgoMode(cudnnAlgoMode).build(), "in")
.layer(1, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(64).cudnnAlgoMode(cudnnAlgoMode).build(), "0")
.layer(2, new SubsamplingLayer.Builder()
.poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "1")
// block 2
.layer(3, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(128).cudnnAlgoMode(cudnnAlgoMode).build(), "2")
.layer(4, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(128).cudnnAlgoMode(cudnnAlgoMode).build(), "3")
.layer(5, new SubsamplingLayer.Builder()
.poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "4")
// block 3
.layer(6, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "5")
.layer(7, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "6")
.layer(8, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "7")
.layer(9, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "8")
.layer(10, new SubsamplingLayer.Builder()
.poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "9")
// block 4
.layer(11, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "10")
.layer(12, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "11")
.layer(13, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "12")
.layer(14, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "13")
.layer(15, new SubsamplingLayer.Builder()
.poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "14")
// block 5
.layer(16, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "15")
.layer(17, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "16")
.layer(18, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "17")
.layer(19, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1)
.padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "18")
.layer(20, new SubsamplingLayer.Builder()
.poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build(), "19")
.layer(21, new DenseLayer.Builder().nOut(4096).build(), "20")
.layer(22, new OutputLayer.Builder(
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).name("output")
.nOut(numClasses).activation(Activation.SOFTMAX) // radial basis function required
.build(), "21")
.setOutputs("22")
.backprop(true).pretrain(false)
.setInputTypes(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0]))
.build();
return conf;
}
@Override
public ComputationGraph init() {
ComputationGraph network = new ComputationGraph(conf());
network.init();
return network;
}
@Override
public ModelMetaData metaData() {
return new ModelMetaData(new int[][] {inputShape}, 1, ZooType.CNN);
}
@Override
public void setInputShape(int[][] inputShape) {
this.inputShape = inputShape[0];
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy