
ai.djl.nn.core.Embedding Maven / Gradle / Ivy
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn.core;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.nn.convolutional.Conv2D.Builder;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* An Embedding block map a collection of items to 1-Dimensional representative {@link NDArray}s.
*
* @param the type of item that should be embedded and map to the array
*/
public class Embedding extends ParameterBlock {
private static final byte VERSION = 2;
private int embeddingSize;
private boolean useDefault;
private DataType dataType;
private Map embedder;
private int numItems;
private Parameter embedding;
Embedding(Builder builder) {
embeddingSize = builder.embeddingSize;
useDefault = builder.useDefault;
dataType = builder.dataType;
embedding = new Parameter("embedding", this, ParameterType.WEIGHT);
embedder = new ConcurrentHashMap<>(builder.items.size());
numItems = 0;
if (useDefault) {
numItems++;
}
for (T item : builder.items) {
embedder.put(item, numItems++);
}
inputShapes = new Shape[] {new Shape(-1)};
}
/**
* Constructs a pretrained embedding.
*
* @param embedding the embedding array
* @param items the items in the embedding (in matching order to the embedding array)
*/
public Embedding(NDArray embedding, List items) {
embeddingSize = Math.toIntExact(embedding.getShape().get(1));
useDefault = false;
dataType = embedding.getDataType();
this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT);
this.embedding.setArray(embedding);
numItems = items.size();
embedder = new ConcurrentHashMap<>(numItems);
for (int i = 0; i < items.size(); i++) {
embedder.put(items.get(i), i);
}
inputShapes = new Shape[] {new Shape(-1)};
}
/**
* Creates a builder to build an {@link Embedding}.
*
* @return a new builder
*/
public static Embedding.Builder> builder() {
return new Embedding.Builder<>();
}
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
return new Shape[] {inputShapes[0].addAll(new Shape(embeddingSize))};
}
/** {@inheritDoc} */
@Override
public List getDirectParameters() {
return Collections.singletonList(embedding);
}
/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
if ("embedding".equals(name)) {
return new Shape(numItems, embeddingSize);
}
throw new IllegalArgumentException("Invalid parameter name");
}
/** {@inheritDoc} */
@Override
public NDList forward(
ParameterStore parameterStore, NDList inputs, PairList params) {
NDList opInputs = opInputs(parameterStore, inputs);
NDArrayEx ex = opInputs.head().getNDArrayInternal();
NDList result = ex.embedding(opInputs, numItems, embeddingSize, dataType, params);
if (inputs.singletonOrThrow().getShape().dimension() == 0) {
result = new NDList(result.singletonOrThrow().reshape(embeddingSize));
}
return result;
}
/** {@inheritDoc} */
@Override
public void saveParameters(DataOutputStream os) throws IOException {
os.writeByte(VERSION);
saveInputShapes(os);
embedding.save(os);
}
/** {@inheritDoc} */
@Override
public void loadParameters(NDManager manager, DataInputStream is)
throws IOException, MalformedModelException {
byte version = is.readByte();
if (version == VERSION) {
readInputShapes(is);
} else if (version != 1) {
throw new MalformedModelException("Unsupported encoding version: " + version);
}
embedding.load(manager, is);
}
/**
* Returns whether an item is in the embedding.
*
* @param item the item to test
* @return true if the item is in the embedding
*/
public boolean hasItem(T item) {
return embedder.containsKey(item);
}
private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
NDArray items = inputs.singletonOrThrow();
Device device = items.getDevice();
NDList ret = new NDList(2);
if (items.getShape().dimension() == 0) {
ret.add(items.reshape(1));
} else {
ret.add(items);
}
ret.add(parameterStore.getValue(embedding, device));
return ret;
}
/**
* Embeds an array of items.
*
* @param manager the manager for the new embeddings
* @param items the items to embed
* @return the embedding {@link NDArray} of Shape(items.length)
*/
public NDArray embed(NDManager manager, T[] items) {
return manager.create(Arrays.stream(items).mapToInt(this::embedHelper).toArray());
}
/**
* Embeds an item.
*
* @param manager the manager for the new embedding
* @param item the item to embed
* @return the embedding {@link NDArray} of Shape()
*/
public NDArray embed(NDManager manager, T item) {
return manager.create(embedHelper(item));
}
private int embedHelper(T value) {
if (embedder.containsKey(value)) {
return embedder.get(value);
} else {
if (useDefault) {
return 0;
} else {
throw new IllegalArgumentException("The provided item was not found");
}
}
}
/**
* The Builder to construct a {@link Embedding} type of {@link Block}.
*
* @param the type of object to embed
*/
public static final class Builder {
private Class embeddingType;
private Collection items;
private int embeddingSize;
private boolean useDefault = true;
private DataType dataType = DataType.FLOAT32;
Builder() {}
private Builder(Class embeddingType, Builder> parent) {
this.embeddingType = embeddingType;
this.embeddingSize = parent.embeddingSize;
this.useDefault = parent.useDefault;
this.dataType = parent.dataType;
}
/**
* Returns the embedded type.
*
* @return the embedded type
*/
public Class getEmbeddingType() {
return embeddingType;
}
/**
* Creates a new {@link Builder} with the specified embedding type.
*
* @param embeddingType the embedding class
* @param the embedding type
* @return a new {@link Builder} class with the specified embedding type
*/
public Builder setType(Class embeddingType) {
return new Builder<>(embeddingType, this);
}
/**
* Sets the collection of items that should feature embeddings.
*
* @param items a collection containing all the items that embeddings should be created for
* @return this Builder
*/
public Builder setItems(Collection items) {
this.items = items;
return this;
}
/**
* Sets the size of the embeddings.
*
* @param embeddingSize the size of the 1D embedding array
* @return this Builder
*/
public Builder setEmbeddingSize(int embeddingSize) {
this.embeddingSize = embeddingSize;
return this;
}
/**
* Sets whether to use a default embedding for undefined items (default true).
*
* @param useDefault true to provide a default embedding and false to throw an {@link
* IllegalArgumentException} when the item can not be found
* @return this Builder
*/
public Builder optUseDefault(boolean useDefault) {
this.useDefault = useDefault;
return this;
}
/**
* Sets the data type of the embedding arrays (default is Float32).
*
* @param dataType the dataType to use for the embedding
* @return this Builder
*/
public Builder optDataType(DataType dataType) {
this.dataType = dataType;
return this;
}
/**
* Builds the {@link Embedding}.
*
* @return the constructed {@code Embedding}
* @throws IllegalArgumentException if all required parameters (items, embeddingSize) have
* not been set
*/
public Embedding build() {
if (items == null) {
throw new IllegalArgumentException("You must specify the items to embed");
}
if (embeddingSize == 0) {
throw new IllegalArgumentException("You must specify the embedding size");
}
return new Embedding<>(this);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy