All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.nlp.generate.ContrastiveBatchTensorList Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class ContrastiveBatchTensorList extends BatchTensorList {
    // [batch, seq_past, hiddenDim]
    // The embed vector of the past seq. seq-dim-size = |past_seq|. Will grow.
    private NDArray pastHiddenStates;

    // [batch, vacabSize]. Only the last logits, used to recall candidate token.
    private NDArray logits;

    ContrastiveBatchTensorList(NDList list, long[] seqDimOrder) {
        super(list.get(0), list.get(1), list.subNDList(4), seqDimOrder);
        pastHiddenStates = list.get(2);
        logits = list.get(3);
    }

    ContrastiveBatchTensorList(
            NDArray pastOutputIds,
            NDArray pastAttentionMask,
            NDArray pastHiddenStates,
            NDArray logits,
            NDList pastKeyValues,
            long[] seqDimOrder) {
        super(pastOutputIds, pastAttentionMask, pastKeyValues, seqDimOrder);
        this.pastHiddenStates = pastHiddenStates;
        this.logits = logits;
    }

    public ContrastiveBatchTensorList() {}

    /** {@inheritDoc} */
    @Override
    public ContrastiveBatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
        return new ContrastiveBatchTensorList(inputList, seqDimOrder);
    }

    /** {@inheritDoc} */
    @Override
    public NDList getList() {
        // The pastOutputIds has to be the first in the output list
        return new NDList(
                        getPastOutputIds(),
                        getPastAttentionMask(),
                        getPastHiddenStates(),
                        getLogits())
                .addAll(getPastKeyValues());
    }

    /**
     * Returns the value of the pastHiddenStates.
     *
     * @return the value of pastHiddenStates
     */
    public NDArray getPastHiddenStates() {
        return pastHiddenStates;
    }

    public void setPastHiddenStates(NDArray pastHiddenStates) {
        this.pastHiddenStates = pastHiddenStates;
    }

    /**
     * Returns the value of the logits.
     *
     * @return the value of logits
     */
    public NDArray getLogits() {
        return logits;
    }

    public void setLogits(NDArray logits) {
        this.logits = logits;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy