All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.core.SparseMax Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import java.util.stream.IntStream;

/**
 * {@code SparseMax} contains a generic implementation of sparsemax function the definition of
 * SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. {@code SparseMax} is a simpler
 * implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do
 * softmax on those max-K data, and we set all the other value as 0.
 */
public class SparseMax extends AbstractBlock {
    private static final Byte VERSION = 1;

    private int axis;
    private int topK;

    /** Creates a sparseMax activation function for the last axis and 3 elements. */
    public SparseMax() {
        this(-1, 3);
    }

    /**
     * Creates a sparseMax activation function along a given axis for 3 elements.
     *
     * @param axis the axis to do sparseMax for
     */
    public SparseMax(int axis) {
        this(axis, 3);
    }

    /**
     * Creates a sparseMax activation function along a given axis and number of elements.
     *
     * @param axis the axis to do sparseMax for
     * @param topK hyperParameter K
     */
    public SparseMax(int axis, int topK) {
        super(VERSION);
        this.axis = axis;
        this.topK = topK;
    }

    /** {@inheritDoc} */
    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        // the shape of input and output are the same
        return new Shape[] {inputShapes[0]};
    }

    /** {@inheritDoc} */
    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList params) {
        /*
        A simple implementation of sparseMax, where we only calculate softMax with largest K data
         */
        NDArray input = inputs.singletonOrThrow();
        if (axis != -1) {
            input = input.swapAxes(axis, -1);
        }

        // level should be: the max i-th is index j in input
        NDArray level = input.argSort(-1, false).toType(DataType.INT64, false);
        int lastDimSize = (int) input.size(input.getShape().dimension() - 1);

        // maskTopK should be: the topK in input is 1 and other is zero
        NDArray maskTopK =
                NDArrays.add(
                        IntStream.range(0, topK)
                                .mapToObj(j -> level.get("..., {}", j).oneHot(lastDimSize))
                                .toArray(NDArray[]::new));

        NDArray expSum =
                input.exp().mul(maskTopK).sum(new int[] {-1}, true).broadcast(input.getShape());
        NDArray output = input.exp().mul(maskTopK).div(expSum);

        if (axis != -1) {
            output = output.swapAxes(axis, -1);
        }
        return new NDList(output);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy