All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.nn.core.Embedding Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Optional;

/**
 * An Embedding block map a collection of items to 1-Dimensional representative {@link NDArray}s.
 *
 * @param  the type of item that should be embedded and map to the array
 */
public abstract class Embedding extends AbstractBlock implements AbstractIndexedEmbedding {

    private static final byte VERSION = 6;

    protected int numEmbeddings;
    protected int embeddingSize;
    protected SparseFormat sparseFormat;

    protected AbstractIndexedEmbedding fallthroughEmbedding;

    protected Parameter embedding;

    @SuppressWarnings("this-escape")
    protected Embedding(BaseBuilder baseBuilder) {
        super(VERSION);
        embeddingSize = baseBuilder.embeddingSize;
        numEmbeddings = baseBuilder.numEmbeddings != 0 ? baseBuilder.numEmbeddings : 1;
        sparseFormat = baseBuilder.sparseFormat;
        embedding =
                addParameter(
                        Parameter.builder()
                                .setName("embedding")
                                .setType(Parameter.Type.WEIGHT)
                                .build());
        if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) {
            throw new IllegalArgumentException(
                    "You can not specify both a fallthrough and a defaultItem");
        } else if (baseBuilder.fallthrough != null) {
            fallthroughEmbedding = baseBuilder.fallthrough;
        } else if (baseBuilder.defaultItem != null) {
            fallthroughEmbedding = new DefaultItem(baseBuilder.defaultItem);
        } else if (baseBuilder.useDefault) {
            fallthroughEmbedding = new DefaultEmbedding();
        }
        inputShapes = new Shape[] {new Shape(-1)};
    }

    /**
     * Constructs a pretrained embedding.
     *
     * @param embedding the embedding array
     */
    protected Embedding(NDArray embedding) {
        this(embedding, SparseFormat.DENSE);
    }

    /**
     * Constructs a pretrained embedding.
     *
     * 

Because it is created with preTrained data, it is created as a frozen block. If you with * to update it, call {@link Block#freezeParameters(boolean)}. * * @param embedding the embedding array * @param format whether to compute row sparse gradient in the backward calculation */ @SuppressWarnings("this-escape") protected Embedding(NDArray embedding, SparseFormat format) { super(VERSION); numEmbeddings = Math.toIntExact(embedding.getShape().get(0)); embeddingSize = Math.toIntExact(embedding.getShape().get(1)); this.sparseFormat = format; this.embedding = addParameter( Parameter.builder() .setName("embedding") .setType(Parameter.Type.WEIGHT) .build()); this.embedding.setArray(embedding); inputShapes = new Shape[] {new Shape(-1)}; freezeParameters(true); } /** {@inheritDoc} */ @Override public void prepare(Shape[] inputShapes) { // numItems will be adjusted by embedding array or fallthroughEmbedding embedding.setShape(new Shape(numEmbeddings, embeddingSize)); } /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))}; } /** {@inheritDoc} */ @Override protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { NDArray input = inputs.head(); Device device = input.getDevice(); NDArray weightArr = parameterStore.getValue(embedding, device, training); return embedding(input, weightArr, sparseFormat); } /** {@inheritDoc} */ @Override public void saveParameters(DataOutputStream os) throws IOException { os.writeByte(VERSION); saveInputShapes(os); os.writeInt(sparseFormat.getValue()); embedding.save(os); } /** {@inheritDoc} */ @Override public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException { byte version = is.readByte(); // True to prepend an empty zero index to embedding table // For compatibility with versions that did not always have // the zero index reserved for the fallthrough embedding boolean addMissingZero = false; if (version >= 3) { readInputShapes(is); if (version == 3) { addMissingZero = !is.readBoolean(); } if (version == 6) { sparseFormat = SparseFormat.fromValue(is.readInt()); } else { sparseFormat = is.readBoolean() ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE; } if (version < 6) { // read the datatype from old version is.readUTF(); } if (version == 3 || version == 4) { int embedderSize = is.readInt(); for (int i = 1; i <= embedderSize; i++) { int encodedKeySize = is.readInt(); byte[] encodedKey = new byte[encodedKeySize]; if (is.read(encodedKey) != encodedKey.length) { throw new MalformedModelException("Model data is malformed"); } is.readInt(); } } } else if (version == 2) { readInputShapes(is); addMissingZero = true; } else if (version != 1) { throw new MalformedModelException("Unsupported encoding version: " + version); } embedding.load(manager, is); numEmbeddings = (int) embedding.getArray().getShape().get(0); embeddingSize = (int) embedding.getArray().getShape().get(1); if (addMissingZero) { numEmbeddings++; embedding.setArray( NDArrays.concat( new NDList( manager.zeros(new Shape(1, embeddingSize)), embedding.getArray()))); } } /** {@inheritDoc} */ @Override public NDArray embed(NDManager manager, T[] items) { return manager.create(Arrays.stream(items).mapToLong(this::embed).toArray()); } /** * A simple lookup table that looks up embeddings in a fixed dictionary and size. * * @param input NDArray containing indices into the embedding matrix * @param weight The embedding matrix with number of rows equal to the maximum possible index + * 1, and number of columns equal to the embedding size * @param sparse SparseFormat of the gradient * @return output NDArray */ public static NDList embedding(NDArray input, NDArray weight, SparseFormat sparse) { return input.getNDArrayInternal().embedding(input, weight, sparse); } /** * The Builder to construct a {@link Embedding} type of {@link Block}. * * @param the type of object to embed */ public abstract static class BaseBuilder> { protected Class embeddingType; protected int numEmbeddings; protected int embeddingSize; protected boolean useDefault = true; protected T defaultItem; protected AbstractIndexedEmbedding fallthrough; protected SparseFormat sparseFormat = SparseFormat.DENSE; protected BaseBuilder() {} /** * Returns the embedded type. * * @return the embedded type */ public Class getEmbeddingType() { return embeddingType; } /** * Creates a new {@link BaseBuilder} with the specified embedding type. * * @param embeddingType the embedding class * @return a new {@link BaseBuilder} class with the specified embedding type */ protected abstract B setType(Class embeddingType); /** * Sets the size of the embeddings. * * @param embeddingSize the size of the 1D embedding array * @return this Builder */ public B setEmbeddingSize(int embeddingSize) { this.embeddingSize = embeddingSize; return self(); } /** * Sets the size of the dictionary of embeddings. * * @param numEmbeddings the size of the dictionary of embeddings * @return this Builder */ public B optNumEmbeddings(int numEmbeddings) { this.numEmbeddings = numEmbeddings; return self(); } /** * Sets whether to use a default embedding for undefined items (default true). * * @param useDefault true to provide a default embedding and false to throw an {@link * IllegalArgumentException} when the item can not be found * @return this Builder */ public B optUseDefault(boolean useDefault) { this.useDefault = useDefault; return self(); } /** * Sets whether to use a default item's embedding for undefined items. * * @param defaultItem the item to use as a default. * @return this Builder */ public B optDefaultItem(T defaultItem) { this.defaultItem = defaultItem; return self(); } /** * Sets a custom handler for items not found in the embedding. * *

See the standard fallthrough handlers {@link #optUseDefault(boolean)} and {@link * #optDefaultItem(Object)}. * * @param fallthrough the embedding to handle default cases. * @return this Builder */ public B optFallthrough(AbstractIndexedEmbedding fallthrough) { this.fallthrough = fallthrough; return self(); } /** * Sets the optional parameter whether to compute row sparse gradient in the backward * calculation. If set to True, the grad’s storage type is row_sparse. * * @param sparseFormat whether to compute row sparse gradient in the backward calculation * @return this Builder */ public B optSparseFormat(SparseFormat sparseFormat) { this.sparseFormat = sparseFormat; return self(); } /** * Returns this {code Builder} object. * * @return this {@code BaseBuilder} */ protected abstract B self(); } protected class DefaultEmbedding implements AbstractIndexedEmbedding { /** {@inheritDoc} */ @Override public byte[] encode(T input) throws IOException { return Embedding.this.encode(input); } /** {@inheritDoc} */ @Override public T decode(byte[] byteArray) throws IOException { return Embedding.this.decode(byteArray); } /** {@inheritDoc} */ @Override public boolean hasItem(T item) { return true; } /** {@inheritDoc} */ @Override public NDArray embed(NDManager manager, T[] items) { int length = items.length; NDArray base = embedding.getArray().get(0); base.attach(manager); return base.repeat(new Shape(length, embeddingSize)); } /** {@inheritDoc} */ @Override public long embed(T item) { return 0; } /** {@inheritDoc} */ @Override public Optional unembed(long index) { return Optional.empty(); } } protected class DefaultItem implements AbstractIndexedEmbedding { private T defaultItem; public DefaultItem(T defaultItem) { this.defaultItem = defaultItem; } /** {@inheritDoc} */ @Override public byte[] encode(T input) throws IOException { return Embedding.this.encode(input); } /** {@inheritDoc} */ @Override public T decode(byte[] byteArray) throws IOException { return Embedding.this.decode(byteArray); } /** {@inheritDoc} */ @Override public boolean hasItem(T item) { return true; } /** {@inheritDoc} */ @Override @SuppressWarnings("unchecked") public NDArray embed(NDManager manager, T[] items) { Object[] defaults = new Object[items.length]; Arrays.fill(defaults, defaultItem); return Embedding.this.embed(manager, (T[]) defaults); } /** {@inheritDoc} */ @Override public long embed(T item) { return 0; } /** {@inheritDoc} */ @Override public Optional unembed(long index) { return Optional.of(defaultItem); } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy