All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.diffkt.model.EmbeddingBag.kt Maven / Gradle / Ivy

/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

package org.diffkt.model

import org.diffkt.*
import org.diffkt.model.Initializer.gaussian
import java.lang.IllegalArgumentException
import kotlin.random.Random

/**
 * A trainable embedding table with size vocabSize x embeddingSize
 *
 * @param numEmbeddings the size of the vocabulary/number of embedding vectors
 * @param embeddingSize the size of each embedding vector
 * @param reduction the reduction mode used to reduce each bag
 *
 * This layer is equivalent to Embedding followed by a reduction (e.g. sum)
 * across each bag of embeddings along the 0th axis.
 */
class EmbeddingBag(
        val trainableWeights: TrainableTensor,
        val reduction: Reduction,
) : TrainableLayer {

    override val trainables = listOf(trainableWeights)

    constructor(numEmbeddings: Int,
                embeddingSize: Int,
                reduction: Reduction,
                random: Random,
                initializer: (Shape, Random)->FloatTensor = Initializer.gaussian(),
    ) : this(TrainableTensor(initializer(Shape(numEmbeddings, embeddingSize), random)), reduction)

    /**
     * @param indices the indices into the embedding table
     * @param bagOffsets linear offsets into indices indicating the start of each bag.
     * @return FloatTensor of shape (numBags, embeddingSize)
     */
    operator fun invoke(indices: IntTensor, bagOffsets: IntTensor): DTensor {
        // TODO: Make this more efficient by doing the reduction in-place
        //   https://github.com/facebookincubator/diffkt/issues/276
        // Get the embeddings
        val embeddings = embedding(trainableWeights.tensor, indices.flatten())
        // Reduce the embeddings into bags
        val flatOffsets = bagOffsets.flatten()
        val lastOffset = indices.size
        val reducedBags = mutableListOf()
        for ((i, startOffset) in flatOffsets.dataIterator.withIndex()) {
            val endOffset = if (i + 1 == flatOffsets.size) lastOffset else flatOffsets.at(i + 1)
            reducedBags += reduction.reduce(embeddings.slice(startOffset, endOffset))
        }
        return concat(reducedBags)
    }

    override fun invoke(vararg inputs: DTensor): DTensor {
        throw IllegalArgumentException("Embedding must be called with an IntTensor")
    }

    override fun withTrainables(trainables: List>): EmbeddingBag {
        require(trainables.size == 1)
        return EmbeddingBag(trainables[0] as TrainableTensor, reduction)
    }

    override fun hashCode(): Int = combineHash("EmbeddingBag", trainableWeights, reduction)
    override fun equals(other: Any?): Boolean = other is EmbeddingBag &&
            other.trainableWeights == trainableWeights &&
            other.reduction == reduction

    companion object {
        sealed class Reduction {
            abstract fun reduce(x: DTensor): DTensor

            object Sum : Reduction() {
                override fun reduce(x: DTensor) = x.sum(0, keepDims = true)
            }
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy