All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.core.Embedding Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Optional;

/**
 * An Embedding block map a collection of items to 1-Dimensional representative {@link NDArray}s.
 *
 * @param  the type of item that should be embedded and map to the array
 */
public abstract class Embedding extends AbstractBlock implements AbstractIndexedEmbedding {

    private static final byte VERSION = 6;

    protected int numEmbeddings;
    protected int embeddingSize;
    protected SparseFormat sparseFormat;

    protected AbstractIndexedEmbedding fallthroughEmbedding;

    protected Parameter embedding;

    protected Embedding(BaseBuilder baseBuilder) {
        super(VERSION);
        embeddingSize = baseBuilder.embeddingSize;
        numEmbeddings = baseBuilder.numEmbeddings != 0 ? baseBuilder.numEmbeddings : 1;
        sparseFormat = baseBuilder.sparseFormat;
        embedding =
                addParameter(
                        Parameter.builder()
                                .setName("embedding")
                                .setType(Parameter.Type.WEIGHT)
                                .build());
        if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) {
            throw new IllegalArgumentException(
                    "You can not specify both a fallthrough and a defaultItem");
        } else if (baseBuilder.fallthrough != null) {
            fallthroughEmbedding = baseBuilder.fallthrough;
        } else if (baseBuilder.defaultItem != null) {
            fallthroughEmbedding = new DefaultItem(baseBuilder.defaultItem);
        } else if (baseBuilder.useDefault) {
            fallthroughEmbedding = new DefaultEmbedding();
        }
        inputShapes = new Shape[] {new Shape(-1)};
    }

    /**
     * Constructs a pretrained embedding.
     *
     * @param embedding the embedding array
     */
    public Embedding(NDArray embedding) {
        this(embedding, SparseFormat.DENSE);
    }

    /**
     * Constructs a pretrained embedding.
     *
     * @param embedding the embedding array
     * @param format whether to compute row sparse gradient in the backward calculation
     */
    public Embedding(NDArray embedding, SparseFormat format) {
        super(VERSION);
        numEmbeddings = Math.toIntExact(embedding.getShape().get(0));
        embeddingSize = Math.toIntExact(embedding.getShape().get(1));
        this.sparseFormat = format;
        this.embedding =
                addParameter(
                        Parameter.builder()
                                .setName("embedding")
                                .setType(Parameter.Type.WEIGHT)
                                .build());
        this.embedding.setArray(embedding);
        inputShapes = new Shape[] {new Shape(-1)};
    }

    /** {@inheritDoc} */
    @Override
    public void prepare(Shape[] inputShapes) {
        // numItems will be adjusted by embedding array or fallthroughEmbedding
        embedding.setShape(new Shape(numEmbeddings, embeddingSize));
    }

    /** {@inheritDoc} */
    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))};
    }

    /** {@inheritDoc} */
    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList params) {
        NDArray input = inputs.head();
        Device device = input.getDevice();
        NDArray weightArr = parameterStore.getValue(embedding, device, training);
        return embedding(input, weightArr, sparseFormat);
    }

    /** {@inheritDoc} */
    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(VERSION);
        saveInputShapes(os);
        os.writeInt(sparseFormat.getValue());
        embedding.save(os);
    }

    /** {@inheritDoc} */
    @Override
    public void loadParameters(NDManager manager, DataInputStream is)
            throws IOException, MalformedModelException {
        byte version = is.readByte();

        // True to prepend an empty zero index to embedding table
        // For compatibility with versions that did not always have
        // the zero index reserved for the fallthrough embedding
        boolean addMissingZero = false;

        if (version >= 3) {
            readInputShapes(is);
            if (version == 3) {
                addMissingZero = !is.readBoolean();
            }
            if (version == 6) {
                sparseFormat = SparseFormat.fromValue(is.readInt());
            } else {
                sparseFormat = is.readBoolean() ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE;
            }
            if (version < 6) {
                // read the datatype from old version
                is.readUTF();
            }
            if (version == 3 || version == 4) {
                int embedderSize = is.readInt();
                for (int i = 1; i <= embedderSize; i++) {
                    int encodedKeySize = is.readInt();
                    byte[] encodedKey = new byte[encodedKeySize];
                    if (is.read(encodedKey) != encodedKey.length) {
                        throw new MalformedModelException("Model data is malformed");
                    }
                    is.readInt();
                }
            }
        } else if (version == 2) {
            readInputShapes(is);
            addMissingZero = true;
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        embedding.load(manager, is);
        numEmbeddings = (int) embedding.getArray().getShape().get(0);
        embeddingSize = (int) embedding.getArray().getShape().get(1);
        if (addMissingZero) {
            numEmbeddings++;
            embedding.setArray(
                    NDArrays.concat(
                            new NDList(
                                    manager.zeros(new Shape(1, embeddingSize)),
                                    embedding.getArray())));
        }
    }

    /** {@inheritDoc} */
    @Override
    public NDArray embed(NDManager manager, T[] items) {
        return manager.create(Arrays.stream(items).mapToLong(this::embed).toArray());
    }

    /**
     * A simple lookup table that looks up embeddings in a fixed dictionary and size.
     *
     * @param input NDArray containing indices into the embedding matrix
     * @param weight The embedding matrix with number of rows equal to the maximum possible index +
     *     1, and number of columns equal to the embedding size
     * @param sparse SparseFormat of the gradient
     * @return output NDArray
     */
    public static NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) {
        return input.getNDArrayInternal().embedding(input, weight, sparse);
    }

    /**
     * The Builder to construct a {@link Embedding} type of {@link Block}.
     *
     * @param  the type of object to embed
     */
    public abstract static class BaseBuilder> {

        protected Class embeddingType;
        protected int numEmbeddings;
        protected int embeddingSize;
        protected boolean useDefault = true;
        protected T defaultItem;
        protected AbstractIndexedEmbedding fallthrough;
        protected SparseFormat sparseFormat = SparseFormat.DENSE;

        protected BaseBuilder() {}

        /**
         * Returns the embedded type.
         *
         * @return the embedded type
         */
        public Class getEmbeddingType() {
            return embeddingType;
        }

        /**
         * Creates a new {@link BaseBuilder} with the specified embedding type.
         *
         * @param embeddingType the embedding class
         * @return a new {@link BaseBuilder} class with the specified embedding type
         */
        protected abstract B setType(Class embeddingType);

        /**
         * Sets the size of the embeddings.
         *
         * @param embeddingSize the size of the 1D embedding array
         * @return this Builder
         */
        public B setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return self();
        }

        /**
         * Sets the size of the dictionary of embeddings.
         *
         * @param numEmbeddings the size of the dictionary of embeddings
         * @return this Builder
         */
        public B optNumEmbeddings(int numEmbeddings) {
            this.numEmbeddings = numEmbeddings;
            return self();
        }

        /**
         * Sets whether to use a default embedding for undefined items (default true).
         *
         * @param useDefault true to provide a default embedding and false to throw an {@link
         *     IllegalArgumentException} when the item can not be found
         * @return this Builder
         */
        public B optUseDefault(boolean useDefault) {
            this.useDefault = useDefault;
            return self();
        }

        /**
         * Sets whether to use a default item's embedding for undefined items.
         *
         * @param defaultItem the item to use as a default.
         * @return this Builder
         */
        public B optDefaultItem(T defaultItem) {
            this.defaultItem = defaultItem;
            return self();
        }

        /**
         * Sets a custom handler for items not found in the embedding.
         *
         * 

See the standard fallthrough handlers {@link #optUseDefault(boolean)} and {@link * #optDefaultItem(Object)}. * * @param fallthrough the embedding to handle default cases. * @return this Builder */ public B optFallthrough(AbstractIndexedEmbedding fallthrough) { this.fallthrough = fallthrough; return self(); } /** * Sets the optional parameter whether to compute row sparse gradient in the backward * calculation. If set to True, the grad’s storage type is row_sparse. * * @param sparseFormat whether to compute row sparse gradient in the backward calculation * @return this Builder */ public B optSparseFormat(SparseFormat sparseFormat) { this.sparseFormat = sparseFormat; return self(); } /** * Returns this {code Builder} object. * * @return this {@code BaseBuilder} */ protected abstract B self(); } protected class DefaultEmbedding implements AbstractIndexedEmbedding { /** {@inheritDoc} */ @Override public byte[] encode(T input) throws IOException { return Embedding.this.encode(input); } /** {@inheritDoc} */ @Override public T decode(byte[] byteArray) throws IOException { return Embedding.this.decode(byteArray); } /** {@inheritDoc} */ @Override public boolean hasItem(T item) { return true; } /** {@inheritDoc} */ @Override public NDArray embed(NDManager manager, T[] items) { int length = items.length; NDArray base = embedding.getArray().get(0); base.attach(manager); return base.repeat(new Shape(length, embeddingSize)); } /** {@inheritDoc} */ @Override public long embed(T item) { return 0; } /** {@inheritDoc} */ @Override public Optional unembed(long index) { return Optional.empty(); } } protected class DefaultItem implements AbstractIndexedEmbedding { private T defaultItem; public DefaultItem(T defaultItem) { this.defaultItem = defaultItem; } /** {@inheritDoc} */ @Override public byte[] encode(T input) throws IOException { return Embedding.this.encode(input); } /** {@inheritDoc} */ @Override public T decode(byte[] byteArray) throws IOException { return Embedding.this.decode(byteArray); } /** {@inheritDoc} */ @Override public boolean hasItem(T item) { return true; } /** {@inheritDoc} */ @Override @SuppressWarnings("unchecked") public NDArray embed(NDManager manager, T[] items) { Object[] defaults = new Object[items.length]; Arrays.fill(defaults, defaultItem); return Embedding.this.embed(manager, (T[]) defaults); } /** {@inheritDoc} */ @Override public long embed(T item) { return 0; } /** {@inheritDoc} */ @Override public Optional unembed(long index) { return Optional.of(defaultItem); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy