All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.zoo.model.VGG19 Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.zoo.model;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.zoo.ModelMetaData;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.ZooType;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
 * VGG-19, from Very Deep Convolutional Networks for Large-Scale Image Recognition
 * https://arxiv.org/abs/1409.1556)
 *
 * 

ImageNet weights for this model are available and have been converted from https://github.com/fchollet/keras/tree/1.1.2/keras/applications.

* * @author Justin Long (crockpotveggies) */ @AllArgsConstructor @Builder public class VGG19 extends ZooModel { @Builder.Default private long seed = 1234; @Builder.Default private int[] inputShape = new int[] {3, 224, 224}; @Builder.Default private int numClasses = 0; @Builder.Default private IUpdater updater = new Nesterovs(); @Builder.Default private CacheMode cacheMode = CacheMode.NONE; @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED; @Builder.Default private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.NO_WORKSPACE; private VGG19() {} @Override public String pretrainedUrl(PretrainedType pretrainedType) { if (pretrainedType == PretrainedType.IMAGENET) return "http://blob.deeplearning4j.org/models/vgg19_dl4j_inference.zip"; else return null; } @Override public long pretrainedChecksum(PretrainedType pretrainedType) { if (pretrainedType == PretrainedType.IMAGENET) return 2782932419L; else return 0L; } @Override public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .activation(Activation.RELU) .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .graphBuilder() .addInputs("in") // block 1 .layer(0, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nIn(inputShape[0]).nOut(64) .cudnnAlgoMode(cudnnAlgoMode).build(), "in") .layer(1, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(64).cudnnAlgoMode(cudnnAlgoMode).build(), "0") .layer(2, new SubsamplingLayer.Builder() .poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build(), "1") // block 2 .layer(3, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(128).cudnnAlgoMode(cudnnAlgoMode).build(), "2") .layer(4, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(128).cudnnAlgoMode(cudnnAlgoMode).build(), "3") .layer(5, new SubsamplingLayer.Builder() .poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build(), "4") // block 3 .layer(6, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "5") .layer(7, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "6") .layer(8, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "7") .layer(9, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(256).cudnnAlgoMode(cudnnAlgoMode).build(), "8") .layer(10, new SubsamplingLayer.Builder() .poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build(), "9") // block 4 .layer(11, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "10") .layer(12, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "11") .layer(13, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "12") .layer(14, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "13") .layer(15, new SubsamplingLayer.Builder() .poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build(), "14") // block 5 .layer(16, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "15") .layer(17, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "16") .layer(18, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "17") .layer(19, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(1, 1) .padding(1, 1).nOut(512).cudnnAlgoMode(cudnnAlgoMode).build(), "18") .layer(20, new SubsamplingLayer.Builder() .poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2) .stride(2, 2).build(), "19") .layer(21, new DenseLayer.Builder().nOut(4096).build(), "20") .layer(22, new OutputLayer.Builder( LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).name("output") .nOut(numClasses).activation(Activation.SOFTMAX) // radial basis function required .build(), "21") .setOutputs("22") .backprop(true).pretrain(false) .setInputTypes(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0])) .build(); return conf; } @Override public ComputationGraph init() { ComputationGraph network = new ComputationGraph(conf()); network.init(); return network; } @Override public ModelMetaData metaData() { return new ModelMetaData(new int[][] {inputShape}, 1, ZooType.CNN); } @Override public void setInputShape(int[][] inputShape) { this.inputShape = inputShape[0]; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy