All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.arbiter.layers.GlobalPoolingLayerSpace Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta7
Show newest version
package org.deeplearning4j.arbiter.layers;

import lombok.AccessLevel;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;

/**
 * Layer space for a {@link GlobalPoolingLayer}
 *
 * @author Alex Black
 */
@Data
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization
public class GlobalPoolingLayerSpace extends LayerSpace {

    protected ParameterSpace poolingDimensions;
    protected ParameterSpace collapseDimensions;
    protected ParameterSpace poolingType;
    protected ParameterSpace pNorm;

    private int numParameters;

    private GlobalPoolingLayerSpace(Builder builder) {
        super(builder);
        this.poolingDimensions = builder.poolingDimensions;
        this.collapseDimensions = builder.collapseDimensions;
        this.poolingType = builder.poolingType;
        this.pNorm = builder.pNorm;

        this.numParameters = LeafUtils.countUniqueParameters(collectLeaves());
    }

    @Override
    public GlobalPoolingLayer getValue(double[] parameterValues) {
        GlobalPoolingLayer.Builder builder = new GlobalPoolingLayer.Builder();
        super.setLayerOptionsBuilder(builder, parameterValues);
        if (poolingDimensions != null)
            builder.poolingDimensions(poolingDimensions.getValue(parameterValues));
        if (collapseDimensions != null)
            builder.collapseDimensions(collapseDimensions.getValue(parameterValues));
        if (poolingType != null)
            builder.poolingType(poolingType.getValue(parameterValues));
        if (pNorm != null)
            builder.pnorm(pNorm.getValue(parameterValues));
        return builder.build();
    }

    @Override
    public int numParameters() {
        return numParameters;
    }

    @Override
    public boolean isLeaf() {
        return false;
    }

    @Override
    public void setIndices(int... indices) {
        throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space");
    }



    public static class Builder extends LayerSpace.Builder {

        protected ParameterSpace poolingDimensions;
        protected ParameterSpace collapseDimensions;
        protected ParameterSpace poolingType;
        protected ParameterSpace pNorm;

        public Builder poolingDimensions(int... poolingDimensions) {
            return poolingDimensions(new FixedValue<>(poolingDimensions));
        }

        public Builder poolingDimensions(ParameterSpace poolingDimensions) {
            this.poolingDimensions = poolingDimensions;
            return this;
        }

        public Builder collapseDimensions(boolean collapseDimensions) {
            return collapseDimensions(new FixedValue<>(collapseDimensions));
        }

        public Builder collapseDimensions(ParameterSpace collapseDimensions) {
            this.collapseDimensions = collapseDimensions;
            return this;
        }

        public Builder poolingType(PoolingType poolingType) {
            return poolingType(new FixedValue<>(poolingType));
        }

        public Builder poolingType(ParameterSpace poolingType) {
            this.poolingType = poolingType;
            return this;
        }

        public Builder pNorm(int pNorm) {
            return pNorm(new FixedValue<>(pNorm));
        }

        public Builder pNorm(ParameterSpace pNorm) {
            this.pNorm = pNorm;
            return this;
        }

        public GlobalPoolingLayerSpace build() {
            return new GlobalPoolingLayerSpace(this);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy