All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.vision.layer.SqueezeExcitation Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2024 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */
package smile.vision.layer;

import smile.deep.activation.ActivationFunction;
import smile.deep.activation.ReLU;
import smile.deep.activation.Sigmoid;
import smile.deep.layer.*;
import smile.deep.tensor.Tensor;

/**
 * Squeeze-and-Excitation block from "Squeeze-and-Excitation Networks".
 *
 * @author Haifeng Li
 */
public class SqueezeExcitation extends LayerBlock {
    private final AdaptiveAvgPool2dLayer avgpool;
    private final Conv2dLayer conv1, conv2;
    private final ActivationFunction delta, sigma;

    /**
     * Constructor.
     * @param inputChannels the number of channels in the input image.
     * @param squeezeChannels the number of squeeze channels.
     */
    public SqueezeExcitation(int inputChannels, int squeezeChannels) {
        this(inputChannels, squeezeChannels, new ReLU(true), new Sigmoid(true));
    }

    /**
     * Constructor.
     * @param inputChannels the number of channels in the input image.
     * @param squeezeChannels the number of squeeze channels.
     * @param delta the delta activation function.
     * @param sigma the sigma activation function.
     */
    public SqueezeExcitation(int inputChannels, int squeezeChannels, ActivationFunction delta, ActivationFunction sigma) {
        super("SqueezeExcitation");
        this.avgpool = new AdaptiveAvgPool2dLayer(1);
        this.conv1 = Layer.conv2d(inputChannels, squeezeChannels, 1);
        this.conv2 = Layer.conv2d(squeezeChannels, inputChannels, 1);
        this.delta = delta;
        this.sigma = sigma;
        add("avgpool", avgpool);
        add("fc1", conv1);
        add("fc2", conv2);
        add("activation", delta);
        add("scale_activation", sigma);
    }

    @Override
    public Tensor forward(Tensor input) {
        Tensor t1 = avgpool.forward(input);
        Tensor t2 = conv1.forward(t1);
        t1.close();

        t2 = delta.forward(t2);
        Tensor t3 = conv2.forward(t2);
        t2.close();

        t3 = sigma.forward(t3);
        Tensor output = t3.mul(input);
        t3.close();
        input.close();

        return output;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy