All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.nlp.generate.TextGenerator Maven / Gradle / Ivy

/*
 * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;

import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * {@code TextGenerator} is an LMSearch (language model search) which contains multiple
 * autoregressive search methods.
 *
 * 

It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive * inference loop. */ public class TextGenerator { private String searchName; private SearchConfig config; private Predictor predictor; private NDArray positionOffset; private long[] endPosition; /** * Constructs a new {@code TextGenerator} instance. * * @param predictor the language model * @param searchName the autoregressive search name * @param searchConfig the autoregressive search configuration */ public TextGenerator( Predictor predictor, String searchName, SearchConfig searchConfig) { this.predictor = predictor; this.searchName = searchName; this.config = searchConfig; } /** * Executes greedy search. * * @param inputIds the input token ids. * @return the output token ids stored as NDArray and the endPosition of each sentence * @throws TranslateException if forward fails */ @SuppressWarnings("try") public NDArray greedySearch(NDArray inputIds) throws TranslateException { // Initialize the end position of each sentence endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))]; Arrays.fill(endPosition, config.getMaxSeqLength()); NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); NDManager manager = inputIds.getManager(); GreedyBatchTensorList searchState = new GreedyBatchTensorList(inputIds, null, null, attentionMask); while (true) { try (NDScope ignore = new NDScope()) { NDArray pastOutputIds = searchState.getPastOutputIds(); NDArray nextInputIds = searchState.getNextInputIds(); NDArray pastAttentionMask = searchState.getPastAttentionMask(); NDList pastKeyValues = searchState.getPastKeyValues(); long pastSeqLength = pastOutputIds == null ? 0 : pastOutputIds.getShape().getLastDimension(); NDList modelInput = prepareInput(nextInputIds, pastAttentionMask, pastSeqLength, 1); if (pastKeyValues != null) { modelInput.addAll(pastKeyValues); } CausalLMOutput modelOutput = predictor.predict(modelInput); NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits()); // Update searchState if (pastOutputIds == null) { pastOutputIds = nextInputIds; searchState.setPastOutputIds(pastOutputIds); } else { pastOutputIds = pastOutputIds.concat(nextInputIds, 1); searchState.setPastOutputIds(pastOutputIds); } nextInputIds = outputIds; searchState.setNextInputIds(nextInputIds); pastKeyValues = modelOutput.getPastKeyValuesList(); searchState.setPastKeyValues(pastKeyValues); pastAttentionMask = pastAttentionMask.concat( manager.ones( new Shape(inputIds.getShape().get(0), 1), DataType.INT64), 1); searchState.setPastAttentionMask(pastAttentionMask); // memory management NDScope.unregister(nextInputIds, pastAttentionMask, pastOutputIds); NDScope.unregister(pastKeyValues); } // Termination Criteria long[] outputIdsArray = searchState.getNextInputIds().toLongArray(); for (int i = 0; i < endPosition.length; ++i) { for (long tokenId : outputIdsArray) { if (tokenId == config.getEosTokenId()) { endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; break; } } } if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) { break; } } return searchState.getPastOutputIds().concat(searchState.getNextInputIds(), 1); } /** * Generates text using beam search. * * @param inputIds input tokens ids * @return the output token ids stored as NDArray and the endPosition of each sentence * @throws TranslateException if failed run forward * @see Beam Search */ @SuppressWarnings("try") public NDArray beamSearch(NDArray inputIds) throws TranslateException { // Initialize the end position of each sentence endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))]; Arrays.fill(endPosition, config.getMaxSeqLength()); NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); NDManager manager = inputIds.getManager(); long numBeam = config.getBeam(); long numBatch = inputIds.getShape().get(0); BeamBatchTensorList searchState = new BeamBatchTensorList(); long numHeads = 0; long kvDim = 0; while (true) { if (searchState.getPastAttentionMask() == null) { // Initial beams NDList modelInput = prepareInput(inputIds, attentionMask, 0, 1); CausalLMOutput modelOutput = predictor.predict(modelInput); // [batch, probDim] NDArray allProbs = modelOutput.getLogits().get(":, -1, :").softmax(1); // [batch, beam] NDList topK = allProbs.topK(Math.toIntExact(numBeam), -1, true, false); NDArray outputIds = topK.get(1).expandDims(2); NDArray lastProbs = topK.get(0).normalize(1, 1); assert outputIds.getShape().getShape().length == 3 : "Wrong shape"; assert lastProbs.getShape().getShape().length == 2 : "Wrong Shape"; // [batch, beam, seq + 1] attentionMask = attentionMask .concat(manager.ones(new Shape(numBatch, 1), DataType.INT64), -1) .expandDims(1) .repeat(1, numBeam); // [batch, beam, heads, seq_past, kvFeature] Function fn = ndarray -> ndarray.expandDims(1).repeat(1, numBeam); NDList pastKeyValues = new NDList( modelOutput.getPastKeyValuesList().stream() .map(fn) .collect(Collectors.toList())); // [batch, beam, seq_past] NDArray pastOutputIds = inputIds.expandDims(1).repeat(1, numBeam); searchState = new BeamBatchTensorList( outputIds, pastOutputIds, pastKeyValues, attentionMask, lastProbs); numHeads = pastKeyValues.get(0).getShape().get(2); kvDim = pastKeyValues.get(0).getShape().getLastDimension(); } try (NDScope ignore = new NDScope()) { long pastSeqLength = searchState.getPastOutputIds().getShape().getLastDimension(); NDList modelInput = prepareInput( searchState.getNextInputIds().reshape(numBatch * numBeam, 1), searchState.getPastAttentionMask().reshape(numBatch * numBeam, -1), pastSeqLength, config.getBeam()); final long finalNumHeads = numHeads; final long finalKvDim = kvDim; Function fn = ndarray -> ndarray.reshape( numBatch * numBeam, finalNumHeads, pastSeqLength, finalKvDim); NDList pastKeyValues = new NDList( searchState.getPastKeyValues().stream() .map(fn) .collect(Collectors.toList())); modelInput.addAll(pastKeyValues); CausalLMOutput modelOutput = predictor.predict(modelInput); NDList generatedOutput = StepGeneration.beamStepGeneration( searchState.getLastProbs(), modelOutput.getLogits(), numBatch, numBeam); // Update searchState searchState = updateSearchState(searchState, modelOutput, generatedOutput, manager); // Memory management NDScope.unregister( searchState.getNextInputIds(), searchState.getPastOutputIds(), searchState.getPastAttentionMask(), searchState.getLastProbs()); NDScope.unregister(searchState.getPastKeyValues()); } // Termination Criteria long[] outputIdsArray = searchState.getNextInputIds().toLongArray(); for (int i = 0; i < endPosition.length; ++i) { for (long tokenId : outputIdsArray) { if (tokenId == config.getEosTokenId()) { endPosition[i] = searchState.getPastOutputIds().getShape().get(1) + 1; break; } } } if (searchState.getPastOutputIds().getShape().getLastDimension() + 1 >= config.getMaxSeqLength()) { break; } } return searchState .getPastOutputIds() .concat(searchState.getNextInputIds(), -1) .reshape(numBatch * numBeam, -1); } /** * Generates text using contrastive search. * * @param inputIds input token ids * @return the output token ids stored as NDArray * @throws TranslateException if forward failed * @see Contrastive Search */ @SuppressWarnings("try") public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException { // inputIds: [batchSize, seqLength: t_init] // attentionMask: [batchSize, pastSeq]. seq-dim-size = |past_seq| + |inputIds|. // Initialize the end position of each sentence endPosition = new long[Math.toIntExact(inputIds.getShape().get(0))]; Arrays.fill(endPosition, config.getMaxSeqLength()); NDManager manager = inputIds.getManager(); NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); ContrastiveBatchTensorList searchState = new ContrastiveBatchTensorList(); while (true) { if (searchState.getPastKeyValues() == null) { NDList modelInput = prepareInput(inputIds, attentionMask, 0, 1); CausalLMOutput output = predictor.predict(modelInput); NDArray lastLogits = output.getLogits().get(":, -1, :"); searchState = new ContrastiveBatchTensorList( inputIds, attentionMask, output.getHiddenState(), lastLogits, output.getPastKeyValuesList(), new long[] {}); } /* Contrastive search loop main part */ // (1) candidate tokens recall; // (2) candidate re-rank by degeneration penalty try (NDScope ignore = new NDScope()) { NDArray topKIds = searchState .getLogits() .topK(config.getK(), -1, true, false) .get(1); // [batch, topK] // Generate model inputs and put candidates together into batch // [batch, topK] -> [batch * [topK]] -> [[batch * [topK]], seqLength=1] NDArray candidateInputIds = topKIds.flatten().reshape(-1, 1); assert candidateInputIds.getDataType() == DataType.INT64 : "inputIds datatype should be int64"; assert candidateInputIds.getShape().getShape().length == 2 : "shape not right"; // [batch, heads, seq_past, feature] -> [batch * topK, head, seq_past, feature] NDList kCopyPastKeyValues = new NDList( searchState.getPastKeyValues().stream() .map(ndarray -> ndarray.repeat(0, config.getK())) .collect(Collectors.toList())); assert kCopyPastKeyValues.get(0).getDataType() == DataType.FLOAT32 : "inputIds datatype should be Float32"; // [batch, seq_past] -> [batch * topK, seq_past] -> [batch * topK, seq_past + 1] long numBatch = topKIds.getShape().get(0); NDArray kCopyPastAttentionMask = searchState.getPastAttentionMask().repeat(0, config.getK()); kCopyPastAttentionMask = kCopyPastAttentionMask.concat( manager.ones( new Shape(numBatch * config.getK(), 1), DataType.INT64), 1); assert kCopyPastKeyValues.get(0).getShape().get(2) + 1 == kCopyPastAttentionMask.getShape().getLastDimension() : "attentionMask_seq = past_seq + new_input_seq"; // Forward with candidates in batch input NDList candidateModelInput = prepareInput( candidateInputIds, kCopyPastAttentionMask, searchState.getPastOutputIds().getShape().getLastDimension(), config.getK()); candidateModelInput.addAll(kCopyPastKeyValues); CausalLMOutput candidateOutput = predictor.predict(candidateModelInput); NDList generatedOutput = StepGeneration.constrastiveStepGeneration( topKIds, searchState.getLogits(), searchState.getPastHiddenStates(), candidateOutput.getHiddenState(), positionOffset, config.getAlpha()); // Update searchState for next loop searchState = updateSearchState(searchState, candidateOutput, generatedOutput, manager); // Memory NDScope.unregister( searchState.getPastOutputIds(), searchState.getPastAttentionMask(), searchState.getLogits(), searchState.getPastHiddenStates()); NDScope.unregister(searchState.getPastKeyValues()); } // Termination Criteria long[] outputIdsArray = searchState.getPastOutputIds().toLongArray(); for (int i = 0; i < endPosition.length; ++i) { for (long tokenId : outputIdsArray) { if (tokenId == config.getEosTokenId()) { endPosition[i] = searchState.getPastOutputIds().getShape().get(1); break; } } } if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) { break; } } return searchState.getPastOutputIds(); } private static BeamBatchTensorList updateSearchState( BeamBatchTensorList searchState, CausalLMOutput modelOutput, NDList generatedOutput, NDManager manager) { NDList pastKeyValues = searchState.getPastKeyValues(); long numHeads = pastKeyValues.get(0).getShape().get(2); long kvDim = pastKeyValues.get(0).getShape().getLastDimension(); long numBatch = searchState.getPastOutputIds().getShape().get(0); long numBeam = searchState.getPastOutputIds().getShape().get(1); long pastSeqLength = searchState.getPastOutputIds().getShape().getLastDimension(); NDArray nextInputIds = generatedOutput.get(0); assert nextInputIds.getShape().getShape().length == 3 : "Wrong Shape"; NDArray newProbs = generatedOutput.get(1); // [batch, beamNew] NDArray sourceBeamSelected = generatedOutput.get(2); // Act on [batch, beam, ...] dimension and the output will be [batch, beam, ...] NDIndex sourceBeamIndex = new NDIndex( "{}, {}, ...", manager.arange(0, numBatch, 1, DataType.INT64) .expandDims(1) .repeat(1, numBeam), sourceBeamSelected); // A simple concatenation is not enough. During the beam selection process, some source // beams are selected several times while some source beams are not selected even once. // The pastOutput should be reselected to have the right correspondence to the // newInputIds. NDArray pastOutputIds = searchState .getPastOutputIds() .concat(searchState.getNextInputIds(), -1) .get(sourceBeamIndex); Function fn = ndarray -> ndarray.reshape(numBatch, numBeam, numHeads, pastSeqLength + 1, kvDim) .get(sourceBeamIndex); pastKeyValues = new NDList( modelOutput.getPastKeyValuesList().stream() .map(fn) .collect(Collectors.toList())); NDArray pastAttentionMask = searchState .getPastAttentionMask() .concat(manager.ones(new Shape(numBatch, numBeam, 1), DataType.INT64), -1) .get(sourceBeamIndex); return new BeamBatchTensorList( nextInputIds, pastOutputIds, pastKeyValues, pastAttentionMask, newProbs); } private static ContrastiveBatchTensorList updateSearchState( ContrastiveBatchTensorList searchState, CausalLMOutput candidateOutput, NDList generatedOutput, NDManager manager) { // Update searchState for next iteration assert candidateOutput.getLogits().getShape().get(1) == 1 : "dimension check: here, outputLogits corresponds to inputSeq == 1"; long numBatch = searchState.getLogits().getShape().get(0); long logitsDim = searchState.getLogits().getShape().get(1); long pastSeqLengthPriorUpdate = searchState.getPastOutputIds().getShape().get(1); long numHeads = searchState.getPastKeyValues().get(0).getShape().get(1); long kvDim = searchState.getPastKeyValues().get(0).getShape().get(3); long hiddenDim = searchState.getPastHiddenStates().getShape().get(2); long k = candidateOutput.getLogits().getShape().get(0) / numBatch; // [batch, 1] NDArray select = generatedOutput.get(1); NDIndex selectIndex = new NDIndex( "{}, {}, ...", manager.arange(0, numBatch, 1, DataType.INT64), select.flatten()); // Take from candidateOutput // [batch, k, inputSeq=1, logitsDim] --select--> [batch, logitDim] NDArray nextLogits = candidateOutput.getLogits().reshape(numBatch, k, logitsDim).get(selectIndex); // Take from candidateOutput // [batch * k, heads, seq_past, feature] --select--> [batch, heads, seq_past, feature] Function fn = ndarray -> ndarray.reshape(numBatch, k, numHeads, pastSeqLengthPriorUpdate + 1, kvDim) .get(selectIndex); NDList nextPastKeyValue = new NDList( candidateOutput.getPastKeyValuesList().stream() .map(fn) .collect(Collectors.toList())); // To be concatenated into searchState.pastHiddenStates // [batch * k, inputSeq=1, hiddenDim] NDArray newHiddenState = candidateOutput.getHiddenState(); assert newHiddenState.getManager() == manager : "possible leaky memory"; NDArray nextPastHiddenStates = searchState .getPastHiddenStates() .concat( newHiddenState.reshape(numBatch, k, 1, hiddenDim).get(selectIndex), 1); // To be concatenated into searchState.outputIds // [batch, seq_past] NDArray outputIds = generatedOutput.get(0); NDArray nextOutputIds = searchState.getPastOutputIds().concat(outputIds, 1); // [batch, seq_past] NDArray nextPastAttentionMask = searchState .getPastAttentionMask() .concat(manager.ones(new Shape(numBatch, 1), DataType.INT64), 1); return new ContrastiveBatchTensorList( nextOutputIds, nextPastAttentionMask, nextPastHiddenStates, nextLogits, nextPastKeyValue, new long[] {}); } private NDArray prepareAttentionMaskOffset(NDArray inputIds, SearchConfig config) { // prepare attentionMask and positionOffset // Used to initialize the search boolean suffixPadding = config.isSuffixPadding(); NDManager manager = inputIds.getManager(); int numBatch = Math.toIntExact(inputIds.getShape().get(0)); int initSeqSize = Math.toIntExact(inputIds.getShape().get(1)); NDArray attentionMask = manager.ones(new Shape(1, inputIds.getShape().getLastDimension()), DataType.INT64) .reshape(1, -1) .repeat(0, numBatch); // Linear search from left to find the first position that's not padTokenId. long[][] offset = new long[numBatch][1]; for (int i = 0; i < numBatch; i++) { long[] aSequence = inputIds.get("{},:", i).toLongArray(); int idx = 0; while (idx < initSeqSize) { if (suffixPadding && aSequence[idx] == config.getPadTokenId() || !suffixPadding && aSequence[idx] != config.getPadTokenId()) { break; } idx++; } attentionMask.set( new NDIndex( "{},{}:{}", i, suffixPadding ? idx : 0, suffixPadding ? initSeqSize : idx), 0); if (!suffixPadding) { offset[i][0] = idx; } } positionOffset = manager.create(offset); return attentionMask; } private NDList prepareInput( NDArray inputIds, NDArray attentionMask, long pastSeqLength, int repeat) { // Pack the model input NDArray positionIds = inputIds.getManager() .arange( pastSeqLength, pastSeqLength + inputIds.getShape().getLastDimension(), 1, DataType.INT64) .expandDims(0) .repeat(0, inputIds.getShape().get(0)); NDArray positionIdsShifted = positionIds.subi(positionOffset.repeat(0, repeat)); positionIds = positionIdsShifted.maximum(positionIdsShifted.zerosLike()); return new NDList(inputIds, positionIds, attentionMask); } /** * Generate function call to generate text. * * @param inputIds the input token ids * @return generated token ids * @throws TranslateException if prediction fails */ public NDArray generate(NDArray inputIds) throws TranslateException { switch (searchName) { case "greedy": return greedySearch(inputIds); case "beam": return beamSearch(inputIds); case "contrastive": return contrastiveSearch(inputIds); default: throw new IllegalArgumentException( "searchName not correctly specified. Please choose among: {greedy, beam," + " contrastive}"); } } /** * Returns the value of the positionOffset. * * @return the value of positionOffset */ public NDArray getPositionOffset() { return positionOffset; } /** * Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength. * * @return the end position of each sentence */ public long[] getEndPosition() { return endPosition; } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy