All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.zoo.model.UNet Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.zoo.model;

import lombok.AllArgsConstructor;
import lombok.Builder;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.ModelMetaData;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.ZooType;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
 * U-Net
 *
 * An implementation of U-Net, a deep learning network for image segmentation in Deeplearning4j.
 * The u-net is convolutional network architecture for fast and precise segmentation of images.
 * Up to now it has outperformed the prior best method (a sliding-window convolutional network) on the ISBI challenge for
 * segmentation of neuronal structures in electron microscopic stacks.
 *
 * 

Paper: https://arxiv.org/abs/1505.04597

*

Weights are available for image segmentation trained on a synthetic dataset

* * @author Justin Long (crockpotveggies) * */ @AllArgsConstructor @Builder public class UNet extends ZooModel { @Builder.Default private long seed = 1234; @Builder.Default private int[] inputShape = new int[] {3, 512, 512}; @Builder.Default private int numClasses = 0; @Builder.Default private WeightInit weightInit = WeightInit.RELU; @Builder.Default private IUpdater updater = new AdaDelta(); @Builder.Default private CacheMode cacheMode = CacheMode.NONE; @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED; @Builder.Default private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; private UNet() {} @Override public String pretrainedUrl(PretrainedType pretrainedType) { if (pretrainedType == PretrainedType.SEGMENT) return DL4JResources.getURLString("models/unet_dl4j_segment_inference.v1.zip"); else return null; } @Override public long pretrainedChecksum(PretrainedType pretrainedType) { if (pretrainedType == PretrainedType.SEGMENT) return 712347958L; else return 0L; } @Override public Class modelType() { return ComputationGraph.class; } @Override public ComputationGraph init() { ComputationGraphConfiguration.GraphBuilder graph = graphBuilder(); graph.addInputs("input").setInputTypes(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])); ComputationGraphConfiguration conf = graph.build(); ComputationGraph model = new ComputationGraph(conf); model.init(); return model; } public ComputationGraphConfiguration.GraphBuilder graphBuilder() { ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .weightInit(weightInit) .l2(5e-5) .miniBatch(true) .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .graphBuilder(); graph .addLayer("conv1-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "input") .addLayer("conv1-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv1-1") .addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2) .build(), "conv1-2") .addLayer("conv2-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "pool1") .addLayer("conv2-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv2-1") .addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2) .build(), "conv2-2") .addLayer("conv3-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "pool2") .addLayer("conv3-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv3-1") .addLayer("pool3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2) .build(), "conv3-2") .addLayer("conv4-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "pool3") .addLayer("conv4-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv4-1") .addLayer("drop4", new DropoutLayer.Builder(0.5).build(), "conv4-2") .addLayer("pool4", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2) .build(), "drop4") .addLayer("conv5-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1024) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "pool4") .addLayer("conv5-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1024) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv5-1") .addLayer("drop5", new DropoutLayer.Builder(0.5).build(), "conv5-2") // up6 .addLayer("up6-1", new Upsampling2D.Builder(2).build(), "drop5") .addLayer("up6-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(512) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "up6-1") .addVertex("merge6", new MergeVertex(), "drop4", "up6-2") .addLayer("conv6-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "merge6") .addLayer("conv6-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv6-1") // up7 .addLayer("up7-1", new Upsampling2D.Builder(2).build(), "conv6-2") .addLayer("up7-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(256) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "up7-1") .addVertex("merge7", new MergeVertex(), "conv3-2", "up7-2") .addLayer("conv7-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "merge7") .addLayer("conv7-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv7-1") // up8 .addLayer("up8-1", new Upsampling2D.Builder(2).build(), "conv7-2") .addLayer("up8-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(128) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "up8-1") .addVertex("merge8", new MergeVertex(), "conv2-2", "up8-2") .addLayer("conv8-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "merge8") .addLayer("conv8-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv8-1") // up9 .addLayer("up9-1", new Upsampling2D.Builder(2).build(), "conv8-2") .addLayer("up9-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(64) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "up9-1") .addVertex("merge9", new MergeVertex(), "conv1-2", "up9-2") .addLayer("conv9-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "merge9") .addLayer("conv9-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv9-1") .addLayer("conv9-3", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(2) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.RELU).build(), "conv9-2") .addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1) .convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) .activation(Activation.IDENTITY).build(), "conv9-3") .addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT) .activation(Activation.SIGMOID).build(), "conv10") .setOutputs("output"); return graph; } @Override public ModelMetaData metaData() { return new ModelMetaData(new int[][] {inputShape}, 1, ZooType.CNN); } @Override public void setInputShape(int[][] inputShape) { this.inputShape = inputShape[0]; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy