ai.djl.modality.nlp.generate.ContrastiveSeqBatchScheduler Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* {@code ContrastiveSeqBatchScheduler} is a class which implements the contrastive search algorithm
* used in SeqBatchScheduler.
*/
public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler {
/**
* Constructs a new {@code ContrastiveSeqBatchScheduler} instance.
*
* @param lmBlock the predictor containing language model
* @param config the autoregressive search configuration
*/
public ContrastiveSeqBatchScheduler(
Predictor lmBlock, SearchConfig config) {
super(lmBlock, config);
}
/** {@inheritDoc} */
@Override
public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException {
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
manager = inputIds.getManager();
NDArray initOffSets = computeOffSets(inputIds, config);
NDArray attentionMask = computeAttentionMask(inputIds, config);
NDArray positionIds = computePositionIds(inputIds, initOffSets, 0, 1);
CausalLMOutput output =
predictor.predict(new NDList(inputIds, positionIds, attentionMask));
NDArray lastLogits = output.getLogits().get(":, -1, :");
// Used to mark the sequence dimension's ordinal number for each tensor in the
// serialized
// batchTensorList
long[] seqDimOrder = new long[28];
Arrays.fill(seqDimOrder, 0, 3, 1);
seqDimOrder[3] = -1; // -1 means no sequence dimension
Arrays.fill(seqDimOrder, 4, seqDimOrder.length, 2);
BatchTensorList batchTensorList =
new ContrastiveBatchTensorList(
inputIds,
attentionMask,
output.getHiddenState(),
lastLogits,
output.getPastKeyValuesList(),
seqDimOrder);
SeqBatcher ret = new SeqBatcher(batchTensorList, batchUids, initOffSets, manager);
// memory management
NDScope.unregister(output.getPastKeyValuesList());
NDScope.unregister(output.getHiddenState(), attentionMask, lastLogits);
NDScope.unregister(ret.offSets, ret.batchUid);
return ret;
}
}
/** {@inheritDoc} */
@Override
public NDArray inferenceCall() throws TranslateException {
NDArray outputIds;
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
/* Prepare input for one inference call */
NDArray logits = ((ContrastiveBatchTensorList) seqBatcher.getData()).getLogits();
NDArray topKIds = logits.topK(config.getK(), -1, true, false).get(1); // [batch, topK]
ContrastiveBatchTensorList searchState = (ContrastiveBatchTensorList) seqBatcher.data;
// Embed the topk dimension into batch dimension for an inference all
// [batch, topK] -> [batch * [topK]] -> [[batch * [topK]], seqLength=1]
NDArray candidateInputIds = topKIds.flatten().reshape(-1, 1);
assert candidateInputIds.getDataType() == DataType.INT64
: "inputIds datatype should be int64";
assert candidateInputIds.getShape().getShape().length == 2 : "shape not right";
// [batch, heads, seq_past, feature] -> [batch * topK, head, seq_past, feature]
NDList kCopyPastKeyValues =
new NDList(
searchState.getPastKeyValues().stream()
.map(ndarray -> ndarray.repeat(0, config.getK()))
.collect(Collectors.toList()));
assert kCopyPastKeyValues.get(0).getDataType() == DataType.FLOAT32
: "inputIds datatype should be Float32";
// [batch, seq_past] -> [batch * topK, seq_past] -> [batch * topK, seq_past + 1]
long numBatch = topKIds.getShape().get(0);
NDArray kCopyPastAttentionMask =
searchState.getPastAttentionMask().repeat(0, config.getK());
kCopyPastAttentionMask =
kCopyPastAttentionMask.concat(
manager.ones(new Shape(numBatch * config.getK(), 1), DataType.INT64),
1);
assert kCopyPastKeyValues.get(0).getShape().get(2) + 1
== kCopyPastAttentionMask.getShape().getLastDimension()
: "attentionMask_seq = past_seq + new_input_seq";
// Forward with candidates in batch input
NDArray candidatePositionIds =
computePositionIds(
candidateInputIds,
seqBatcher.offSets,
searchState.getPastOutputIds().getShape().getLastDimension(),
config.getK());
NDList modelInputs =
new NDList(candidateInputIds, candidatePositionIds, kCopyPastAttentionMask);
modelInputs.addAll(kCopyPastKeyValues);
CausalLMOutput candidateOutput = predictor.predict(modelInputs);
NDList generatedOutput =
StepGeneration.constrastiveStepGeneration(
topKIds,
logits,
searchState.getPastHiddenStates(),
candidateOutput.getHiddenState(),
seqBatcher.offSets,
config.getAlpha());
/* Update searchState for next loop */
long logitsDim = logits.getShape().get(1);
long numHeads = searchState.getPastKeyValues().get(0).getShape().get(1);
long kvDim = searchState.getPastKeyValues().get(0).getShape().get(3);
long currentSeqLength = searchState.getPastOutputIds().getShape().get(1);
long hiddenDim = searchState.getPastHiddenStates().getShape().get(2);
// [batch, 1]
NDArray select = generatedOutput.get(1);
NDIndex selectIndex =
new NDIndex(
"{}, {}, ...",
manager.arange(0, numBatch, 1, DataType.INT64),
select.flatten());
// Take from candidateOutput
// [batch, k, inputSeq=1, logitsDim] --select--> [batch, logitDim]
NDArray nextLogits =
candidateOutput
.getLogits()
.reshape(numBatch, config.getK(), logitsDim)
.get(selectIndex);
// Take from candidateOutput
// [batch * k, heads, seq_past, feature] --select--> [batch, heads, seq_past, feature]
Function fn =
ndarray ->
ndarray.reshape(
numBatch,
config.getK(),
numHeads,
currentSeqLength + 1,
kvDim)
.get(selectIndex);
NDList nextPastKeyValue =
new NDList(
candidateOutput.getPastKeyValuesList().stream()
.map(fn)
.collect(Collectors.toList()));
// To be concatenated into searchState.pastHiddenStates
// [batch * k, inputSeq=1, hiddenDim]
NDArray newHiddenState = candidateOutput.getHiddenState();
assert newHiddenState.getManager() == manager : "possible leaky memory";
NDArray nextPastHiddenStates =
searchState
.getPastHiddenStates()
.concat(
newHiddenState
.reshape(numBatch, config.getK(), 1, hiddenDim)
.get(selectIndex),
1);
// To be concatenated into searchState.outputIds
// [batch, seq_past]
outputIds = generatedOutput.get(0);
NDArray nextOutputIds = searchState.getPastOutputIds().concat(outputIds, 1);
// [batch, seq_past]
NDArray nextPastAttentionMask =
searchState
.getPastAttentionMask()
.concat(manager.ones(new Shape(numBatch, 1), DataType.INT64), 1);
seqBatcher.seqLength++;
seqBatcher.data =
new ContrastiveBatchTensorList(
nextOutputIds,
nextPastAttentionMask,
nextPastHiddenStates,
nextLogits,
nextPastKeyValue,
searchState.getSeqDimOrder());
/* Exit criteria */
seqBatcher.exitCriteria(outputIds, config.getMaxSeqLength(), config.getEosTokenId());
// Memory management
NDScope.unregister(nextOutputIds);
NDScope.unregister(nextPastAttentionMask);
NDScope.unregister(nextPastHiddenStates);
NDScope.unregister(nextLogits);
NDScope.unregister(nextPastKeyValue);
NDScope.unregister(outputIds);
}
return outputIds;
}
}