All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.vision.EfficientNet Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2024 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */
package smile.vision;

import java.awt.Image;
import java.util.function.IntFunction;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import smile.deep.activation.SiLU;
import smile.deep.layer.*;
import smile.deep.tensor.Tensor;
import smile.vision.layer.*;
import smile.vision.transform.Transform;

/**
 * EfficientNet is an image classification model family. It was first
 * described in EfficientNet: Rethinking Model Scaling for Convolutional
 * Neural Networks.
 *
 * @author Haifeng Li
 */
public class EfficientNet extends LayerBlock {
    private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(EfficientNet.class);
    private final AdaptiveAvgPool2dLayer avgpool;
    private final SequentialBlock features;
    private final SequentialBlock classifier;

    /**
     * Constructor.
     * @param invertedResidualSetting the network structure.
     * @param dropout the dropout probability.
     * @param stochasticDepthProb the stochastic depth probability.
     * @param numClasses the number of classes.
     * @param lastChannel the number of channels on the penultimate layer.
     * @param normLayer the functor to create the normalization layer.
     */
    public EfficientNet(MBConvConfig[] invertedResidualSetting, double dropout, double stochasticDepthProb,
                        int numClasses, int lastChannel, IntFunction normLayer) {
        super("EfficientNet");

        if (normLayer == null) {
            normLayer = BatchNorm2dLayer::new;
        }

        Layer[] layers = new Layer[invertedResidualSetting.length + 2];
        // building first layer
        int firstconvOutputChannels = invertedResidualSetting[0].inputChannels();
        layers[0] = new Conv2dNormActivation(new Conv2dNormActivation.Options(
                3, firstconvOutputChannels, 3, 2, normLayer, new SiLU(true)));

        // building inverted residual blocks
        int totalStageBlocks = 0;
        for (var config : invertedResidualSetting) {
            totalStageBlocks += config.numLayers();
        }

        int stageBlockId = 0;
        for (int i = 1; i <= invertedResidualSetting.length; i++) {
            var config = invertedResidualSetting[i - 1];
            logger.debug("Layer {}: {}", i, config);
            Layer[] stage = new Layer[config.numLayers()];
            for (int j = 0; j < stage.length; j++) {
                // overwrite info if not the first conv in the stage
                if (j > 0) {
                    config = new MBConvConfig(
                            config.expandRatio(),
                            config.kernel(),
                            1, // stride
                            config.outputChannels(),
                            config.outputChannels(),
                            config.numLayers(),
                            config.block()
                    );
                }
                // adjust stochastic depth probability based on the depth of the stage block
                double sdprob = stochasticDepthProb * stageBlockId / totalStageBlocks;
                logger.debug("Stage {} BlockId {}: {} sdprob = {}", j, stageBlockId, config, sdprob);
                stage[j] = config.block().equals("MBConv") ? new MBConv(config, sdprob, normLayer) : new FusedMBConv(config, sdprob, normLayer);
                stageBlockId++;
            }
            layers[i] = new SequentialBlock(stage);
        }

        // building last several layers
        int lastConvInputChannels = invertedResidualSetting[invertedResidualSetting.length-1].outputChannels();
        if (lastChannel <= 0) {
            lastChannel = 4 * lastConvInputChannels;
        }
        layers[invertedResidualSetting.length + 1] = new Conv2dNormActivation(new Conv2dNormActivation.Options(
                lastConvInputChannels, lastChannel, 1, normLayer, new SiLU(true)));

        features = new SequentialBlock(layers);
        avgpool = new AdaptiveAvgPool2dLayer(1);
        classifier = new SequentialBlock(
                new DropoutLayer(dropout, true),
                new LinearLayer(lastChannel, numClasses));
        add("features", features);
        add("avgpool", avgpool);
        add("classifier", classifier);

        // Initialization should run in torch.no_grad() mode and
        // will not be taken into account by autograd.
        try (var guard = Tensor.noGradGuard()) {
            try (var modules = module.modules()) {
                for (int i = 0; i < modules.size(); i++) {
                    var module = modules.get(i);
                    var name = module.name().getString();
                    switch (name) {
                        case "torch::nn::Conv2dImpl":
                            var conv2d = module.asConv2d();
                            torch.kaiming_normal_(conv2d.weight(), 0.0, new FanModeType(new kFanOut()), new Nonlinearity(new kLeakyReLU()));
                            var bias = conv2d.bias();
                            if (bias.defined()) {
                                bias.zero_();
                            }
                            break;
                        case "torch::nn::BatchNorm2dImpl":
                            var batchNorm2d = module.asBatchNorm2d();
                            torch.ones_(batchNorm2d.weight());
                            batchNorm2d.bias().zero_();
                            break;
                        case "torch::nn::GroupNormImpl":
                            var groupNorm = module.asGroupNorm();
                            torch.ones_(groupNorm.weight());
                            groupNorm.bias().zero_();
                            break;
                        case "torch::nn::LinearImpl":
                            var linear = module.asLinear();
                            double range = 1.0 / Math.sqrt(linear.options().out_features().get());
                            torch.uniform_(linear.weight(), -range, range);
                            torch.zeros_(linear.bias());
                            break;
                    }
                }
            }
        }
    }

    @Override
    public Tensor forward(Tensor input) {
        Tensor t1 = features.forward(input);
        Tensor t2 = avgpool.forward(t1);
        t1.close();
        Tensor t3 = t2.flatten(1);
        Tensor output = classifier.forward(t3);
        t2.close();
        t3.close();
        return output;
    }

    /**
     * Returns the feature layer block.
     * @return the feature layer block.
     */
    public SequentialBlock features() {
        return features;
    }

    /**
     * EfficientNet-V2_S (baseline) model.
     * @return the model.
     */
    public static VisionModel V2S() {
        return V2S("model/EfficientNet/efficientnet_v2_s.pt");
    }

    /**
     * EfficientNet-V2_S (baseline) model.
     * @param path the pre-trained model file path.
     * @return the model.
     */
    public static VisionModel V2S(String path) {
        MBConvConfig[] config = {
                MBConvConfig.FusedMBConv(1,3,1,24,24,2),
                MBConvConfig.FusedMBConv(4,3,2,24,48,4),
                MBConvConfig.FusedMBConv(4,3,2,48,64,4),
                MBConvConfig.MBConv(4,3,2,64,128,6),
                MBConvConfig.MBConv(6,3,1,128,160,9),
                MBConvConfig.MBConv(6,3,2,160,256,15)
        };
        Transform transform = Transform.classification(384, 384);

        var net = new EfficientNet(config, 0.2, 0.2, 1000, 1280,
                channels -> new BatchNorm2dLayer(channels, 0.001, 0.01, true));
        var model = new VisionModel(net, transform);
        model.load(path);
        return model;
    }

    /**
     * EfficientNet-V2_M (larger) model.
     * @return the model.
     */
    public static VisionModel V2M() {
        return V2M("model/EfficientNet/efficientnet_v2_m.pt");
    }

    /**
     * EfficientNet-V2_M (larger) model.
     * @param path the pre-trained model file path.
     * @return the model.
     */
    public static VisionModel V2M(String path) {
        MBConvConfig[] config = {
                MBConvConfig.FusedMBConv(1, 3, 1, 24, 24, 3),
                MBConvConfig.FusedMBConv(4, 3, 2, 24, 48, 5),
                MBConvConfig.FusedMBConv(4, 3, 2, 48, 80, 5),
                MBConvConfig.MBConv(4, 3, 2, 80, 160, 7),
                MBConvConfig.MBConv(6, 3, 1, 160, 176, 14),
                MBConvConfig.MBConv(6, 3, 2, 176, 304, 18),
                MBConvConfig.MBConv(6, 3, 1, 304, 512, 5)
        };
        Transform transform = Transform.classification(480, 480);

        var net = new EfficientNet(config, 0.3, 0.2, 1000, 1280,
                channels -> new BatchNorm2dLayer(channels, 0.001, 0.01, true));
        var model = new VisionModel(net, transform);
        model.load(path);
        return model;
    }

    /**
     * EfficientNet-V2_L (largest) model.
     * @return the model.
     */
    public static VisionModel V2L() {
        return V2L("model/EfficientNet/efficientnet_v2_l.pt");
    }

    /**
     * EfficientNet-V2_L (largest) model.
     * @param path the pre-trained model file path.
     * @return the model.
     */
    public static VisionModel V2L(String path) {
        MBConvConfig[] config = {
                MBConvConfig.FusedMBConv(1, 3, 1, 32, 32, 4),
                MBConvConfig.FusedMBConv(4, 3, 2, 32, 64, 7),
                MBConvConfig.FusedMBConv(4, 3, 2, 64, 96, 7),
                MBConvConfig.MBConv(4, 3, 2, 96, 192, 10),
                MBConvConfig.MBConv(6, 3, 1, 192, 224, 19),
                MBConvConfig.MBConv(6, 3, 2, 224, 384, 25),
                MBConvConfig.MBConv(6, 3, 1, 384, 640, 7)
        };
        Transform transform = Transform.classification(480, 480,
                new float[]{0.5f, 0.5f, 0.5f}, new float[]{0.5f, 0.5f, 0.5f}, Image.SCALE_SMOOTH);

        var net = new EfficientNet(config, 0.4, 0.2, 1000, 1280,
                channels -> new BatchNorm2dLayer(channels, 0.001, 0.01, true));
        var model = new VisionModel(net, transform);
        model.load(path);
        return model;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy