All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.iterator.BertIterator Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 * Copyright (c) 2019 Konduit K.K.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.iterator;

import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * BertIterator is a MultiDataSetIterator for training BERT (Transformer) models in the following way:
* (a) Unsupervised - Masked language model task (no sentence matching task is implemented thus far)
* (b) Supervised - For sequence classification (i.e., 1 label per sequence, typically used for fine tuning)
* The task can be specified using {@link Task}. *
* Example for unsupervised training:
*
 * {@code
 *          BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab);
 *          BertIterator b = BertIterator.builder()
 *              .tokenizer(t)
 *              .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
 *              .minibatchSize(2)
 *              .sentenceProvider()
 *              .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
 *              .vocabMap(t.getVocab())
 *              .task(BertIterator.Task.UNSUPERVISED)
 *              .masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
 *              .unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
 *              .maskToken("[MASK]")
 *              .build();
 * }
 * 
*
* Example for supervised (sequence classification - one label per sequence) training:
*
 * {@code
 *          BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab);
 *          BertIterator b = BertIterator.builder()
 *              .tokenizer(t)
 *              .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
 *              .minibatchSize(2)
 *              .sentenceProvider(new TestSentenceProvider())
 *              .featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
 *              .vocabMap(t.getVocab())
 *              .task(BertIterator.Task.SEQ_CLASSIFICATION)
 *              .build();
 * }
 * 
*
* Example to use an instantiated iterator for inference:
*
 * {@code
 *          BertIterator b;
 *          Pair featuresAndMask;
 *          INDArray[] features;
 *          INDArray[] featureMasks;
 *
 *          //With sentences
 *          List forInference;
 *          featuresAndMask = b.featurizeSentences(forInference);
 *
 *          //OR with sentence pairs
 *          List> forInferencePair};
 *          featuresAndMask = b.featurizeSentencePairs(forInference);
 *
 *          features = featuresAndMask.getFirst();
 *          featureMasks = featuresAndMask.getSecond();
 * }
 * 
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
*
* {@link LengthHandling} configuration:
* Determines how to handle variable-length sequence situations.
* FIXED_LENGTH: Always trim longer sequences to the specified length, and always pad shorter sequences to the specified length.
* ANY_LENGTH: Output length is determined by the length of the longest sequence in the minibatch. Shorter sequences within the * minibatch are zero padded and masked.
* CLIP_ONLY: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in * a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the * maximum (within the current minibatch) they will be zero padded and masked.
*

* {@link FeatureArrays} configuration:
* Determines what arrays should be included.
* INDICES_MASK: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).
* INDICES_MASK_SEGMENTID: Indices array, mask array and segment ID array (which is all 0s for single segment tasks). Returns * 2 feature arrays (indices, segment ID) and 1 feature mask array (plus labels)
*
* {@link UnsupervisedLabelFormat} configuration:
* Only relevant when the task is set to {@link Task#UNSUPERVISED}. Determine the format of the labels:
* RANK2_IDX: return int32 [minibatch, numTokens] array with entries being class numbers. Example use case: with sparse softmax loss functions.
* RANK3_NCL: return float32 [minibatch, numClasses, numTokens] array with 1-hot entries along dimension 1. Example use case: RnnOutputLayer, RnnLossLayer
* RANK3_LNC: return float32 [numTokens, minibatch, numClasses] array with 1-hot entries along dimension 2. This format is occasionally * used for some RNN layers in libraries such as TensorFlow, for example
*
*/ public class BertIterator implements MultiDataSetIterator { public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION} public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY} public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID} public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC} protected Task task; protected TokenizerFactory tokenizerFactory; protected int maxTokens = -1; protected int minibatchSize = 32; protected boolean padMinibatches = false; @Getter @Setter protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; protected LabeledPairSentenceProvider sentencePairProvider = null; protected LengthHandling lengthHandling; protected FeatureArrays featureArrays; protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects? protected BertSequenceMasker masker = null; protected UnsupervisedLabelFormat unsupervisedLabelFormat = null; protected String maskToken; protected String prependToken; protected String appendToken; protected List vocabKeysAsList; protected BertIterator(Builder b) { this.task = b.task; this.tokenizerFactory = b.tokenizerFactory; this.maxTokens = b.maxTokens; this.minibatchSize = b.minibatchSize; this.padMinibatches = b.padMinibatches; this.preProcessor = b.preProcessor; this.sentenceProvider = b.sentenceProvider; this.sentencePairProvider = b.sentencePairProvider; this.lengthHandling = b.lengthHandling; this.featureArrays = b.featureArrays; this.vocabMap = b.vocabMap; this.masker = b.masker; this.unsupervisedLabelFormat = b.unsupervisedLabelFormat; this.maskToken = b.maskToken; this.prependToken = b.prependToken; this.appendToken = b.appendToken; } @Override public boolean hasNext() { if (sentenceProvider != null) return sentenceProvider.hasNext(); return sentencePairProvider.hasNext(); } @Override public MultiDataSet next() { return next(minibatchSize); } @Override public void remove() { throw new UnsupportedOperationException("Not supported"); } @Override public MultiDataSet next(int num) { Preconditions.checkState(hasNext(), "No next element available"); List, String>> tokensAndLabelList; int mbSize = 0; int outLength; long[] segIdOnesFrom = null; if (sentenceProvider != null) { List> list = new ArrayList<>(num); while (sentenceProvider.hasNext() && mbSize++ < num) { list.add(sentenceProvider.nextSentence()); } SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list); tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList(); outLength = sentenceListProcessed.getMaxL(); } else if (sentencePairProvider != null) { List> listPairs = new ArrayList<>(num); while (sentencePairProvider.hasNext() && mbSize++ < num) { listPairs.add(sentencePairProvider.nextSentencePair()); } SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs); tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList(); outLength = sentencePairListProcessed.getMaxL(); segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom(); } else { //TODO - other types of iterators... throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented"); } Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); INDArray[] featureArray = featuresAndMaskArraysPair.getFirst(); INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond(); Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength); INDArray[] labelArray = labelsAndMaskArraysPair.getFirst(); INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond(); org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray); if (preProcessor != null) preProcessor.preProcess(mds); return mds; } /** * For use during inference. Will convert a given list of sentences to features and feature masks as appropriate. * * @param listOnlySentences * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array */ public Pair featurizeSentences(List listOnlySentences) { List> sentencesWithNullLabel = addDummyLabel(listOnlySentences); SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel); List, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList(); int outLength = sentenceListProcessed.getMaxL(); if (preProcessor != null) { Pair featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null); MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null); preProcessor.preProcess(dummyMDS); return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); } return convertMiniBatchFeatures(tokensAndLabelList, outLength, null); } /** * For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate. * * @param listOnlySentencePairs * @return Pair of INDArrays[], first element is feature arrays and the second is the masks array */ public Pair featurizeSentencePairs(List> listOnlySentencePairs) { Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null)."); List> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs); SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel); List, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList(); int outLength = sentencePairListProcessed.getMaxL(); long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom(); if (preProcessor != null) { Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null); preProcessor.preProcess(dummyMDS); return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays()); } return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom); } private Pair convertMiniBatchFeatures(List, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) { int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size(); int[][] outIdxs = new int[mbPadded][outLength]; int[][] outMask = new int[mbPadded][outLength]; int[][] outSegmentId = null; if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) outSegmentId = new int[mbPadded][outLength]; for (int i = 0; i < tokensAndLabelList.size(); i++) { Pair, String> p = tokensAndLabelList.get(i); List t = p.getFirst(); for (int j = 0; j < outLength && j < t.size(); j++) { Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j)); int idx = vocabMap.get(t.get(j)); outIdxs[i][j] = idx; outMask[i][j] = 1; if (segIdOnesFrom != null && j >= segIdOnesFrom[i]) outSegmentId[i][j] = 1; } } //Create actual arrays. Indices, mask, and optional segment ID INDArray outIdxsArr = Nd4j.createFromArray(outIdxs); INDArray outMaskArr = Nd4j.createFromArray(outMask); INDArray outSegmentIdArr; INDArray[] f; INDArray[] fm; if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) { outSegmentIdArr = Nd4j.createFromArray(outSegmentId); f = new INDArray[]{outIdxsArr, outSegmentIdArr}; fm = new INDArray[]{outMaskArr, null}; } else { f = new INDArray[]{outIdxsArr}; fm = new INDArray[]{outMaskArr}; } return new Pair<>(f, fm); } private SentenceListProcessed tokenizeMiniBatch(List> list) { //Get and tokenize the sentences for this minibatch SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size()); int longestSeq = -1; for (Pair p : list) { List tokens = tokenizeSentence(p.getFirst()); sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond())); longestSeq = Math.max(longestSeq, tokens.size()); } //Determine output array length... int outLength; switch (lengthHandling) { case FIXED_LENGTH: outLength = maxTokens; break; case ANY_LENGTH: outLength = longestSeq; break; case CLIP_ONLY: outLength = Math.min(maxTokens, longestSeq); break; default: throw new RuntimeException("Not implemented length handling mode: " + lengthHandling); } sentenceListProcessed.setMaxL(outLength); return sentenceListProcessed; } private SentencePairListProcessed tokenizePairsMiniBatch(List> listPairs) { SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size()); for (Triple t : listPairs) { List tokensL = tokenizeSentence(t.getFirst(), true); List tokensR = tokenizeSentence(t.getSecond(), true); List tokens = new ArrayList<>(maxTokens); int maxLength = maxTokens; if (prependToken != null) maxLength--; if (appendToken != null) maxLength -= 2; if (tokensL.size() + tokensR.size() > maxLength) { boolean shortOnL = tokensL.size() < tokensR.size(); int shortSize = Math.min(tokensL.size(), tokensR.size()); if (shortSize > maxLength / 2) { //both lists need to be sliced tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear(); } else { //slice longer list if (shortOnL) { //longer on R - slice R tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear(); } else { //longer on L - slice L tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear(); } } } if (prependToken != null) tokens.add(prependToken); tokens.addAll(tokensL); if (appendToken != null) tokens.add(appendToken); int segIdOnesFrom = tokens.size(); tokens.addAll(tokensR); if (appendToken != null) tokens.add(appendToken); sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird())); } sentencePairListProcessed.setMaxL(maxTokens); return sentencePairListProcessed; } private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) { INDArray[] l = new INDArray[1]; INDArray[] lm; int mbSize = tokenizedSentences.size(); int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size(); if (task == Task.SEQ_CLASSIFICATION) { //Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses] int numClasses; int[] classLabels = new int[mbPadded]; if (sentenceProvider != null) { numClasses = sentenceProvider.numLabelClasses(); List labels = sentenceProvider.allLabels(); for (int i = 0; i < mbSize; i++) { String lbl = tokenizedSentences.get(i).getRight(); classLabels[i] = labels.indexOf(lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); } } else if (sentencePairProvider != null) { numClasses = sentencePairProvider.numLabelClasses(); List labels = sentencePairProvider.allLabels(); for (int i = 0; i < mbSize; i++) { String lbl = tokenizedSentences.get(i).getRight(); classLabels[i] = labels.indexOf(lbl); Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl); } } else { throw new RuntimeException(); } l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses); for (int i = 0; i < mbSize; i++) { l[0].putScalar(i, classLabels[i], 1.0); } lm = null; if (padMinibatches && mbSize != mbPadded) { INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1); lm = new INDArray[]{a}; a.get(NDArrayIndex.interval(0, mbSize), NDArrayIndex.all()).assign(1); } } else if (task == Task.UNSUPERVISED) { //Unsupervised, masked language model task //Output is either 2d, or 3d depending on settings if (vocabKeysAsList == null) { String[] arr = new String[vocabMap.size()]; for (Map.Entry e : vocabMap.entrySet()) { arr[e.getValue()] = e.getKey(); } vocabKeysAsList = Arrays.asList(arr); } int vocabSize = vocabMap.size(); INDArray labelArr; INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength); if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) { labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) { labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength); } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) { labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize); } else { throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); } for (int i = 0; i < mbSize; i++) { List tokens = tokenizedSentences.get(i).getFirst(); Pair, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList); List maskedTokens = p.getFirst(); boolean[] predictionTarget = p.getSecond(); int seqLen = Math.min(predictionTarget.length, outLength); for (int j = 0; j < seqLen; j++) { if (predictionTarget[j]) { String oldToken = tokenizedSentences.get(i).getFirst().get(j); //This is target int targetTokenIdx = vocabMap.get(oldToken); if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) { labelArr.putScalar(i, j, targetTokenIdx); } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) { labelArr.putScalar(i, j, targetTokenIdx, 1.0); } else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) { labelArr.putScalar(j, i, targetTokenIdx, 1.0); } lMask.putScalar(i, j, 1.0); //Also update previously created feature label indexes: String newToken = maskedTokens.get(j); int newTokenIdx = vocabMap.get(newToken); //first element of features is outIdxsArr featureArray[0].putScalar(i, j, newTokenIdx); } } } l[0] = labelArr; lm = new INDArray[1]; lm[0] = lMask; } else { throw new IllegalStateException("Task not yet implemented: " + task); } return new Pair<>(l, lm); } private List tokenizeSentence(String sentence) { return tokenizeSentence(sentence, false); } private List tokenizeSentence(String sentence, boolean ignorePrependAppend) { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); if (prependToken != null && !ignorePrependAppend) tokens.add(prependToken); while (t.hasMoreTokens()) { String token = t.nextToken(); tokens.add(token); } if (appendToken != null && !ignorePrependAppend) tokens.add(appendToken); return tokens; } private List> addDummyLabel(List listOnlySentences) { List> list = new ArrayList<>(listOnlySentences.size()); for (String s : listOnlySentences) { list.add(new Pair(s, null)); } return list; } private List> addDummyLabelForPairs(List> listOnlySentencePairs) { List> list = new ArrayList<>(listOnlySentencePairs.size()); for (Pair p : listOnlySentencePairs) { list.add(new Triple(p.getFirst(), p.getSecond(), null)); } return list; } @Override public boolean resetSupported() { return true; } @Override public boolean asyncSupported() { return true; } @Override public void reset() { if (sentenceProvider != null) { sentenceProvider.reset(); } } public static Builder builder() { return new Builder(); } public static class Builder { protected Task task; protected TokenizerFactory tokenizerFactory; protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH; protected int maxTokens = -1; protected int minibatchSize = 32; protected boolean padMinibatches = false; protected MultiDataSetPreProcessor preProcessor; protected LabeledSentenceProvider sentenceProvider = null; protected LabeledPairSentenceProvider sentencePairProvider = null; protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID; protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects? protected BertSequenceMasker masker = new BertMaskedLMMasker(); protected UnsupervisedLabelFormat unsupervisedLabelFormat; protected String maskToken; protected String prependToken; protected String appendToken; /** * Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details. */ public Builder task(Task task) { this.task = task; return this; } /** * Specify the TokenizerFactory to use. * For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory} * is used */ public Builder tokenizer(TokenizerFactory tokenizerFactory) { this.tokenizerFactory = tokenizerFactory; return this; } /** * Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details. * * @param lengthHandling Length handling * @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH} * @return */ public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) { this.lengthHandling = lengthHandling; this.maxTokens = maxLength; return this; } /** * Minibatch size to use (number of examples to train on for each iteration) * See also: {@link #padMinibatches} * * @param minibatchSize Minibatch size */ public Builder minibatchSize(int minibatchSize) { this.minibatchSize = minibatchSize; return this; } /** * Default: false (disabled)
* If the dataset is not an exact multiple of the minibatch size, should we pad the smaller final minibatch?
* For example, if we have 100 examples total, and 32 minibatch size, the following number of examples will be returned * for subsequent calls of next() in the one epoch:
* padMinibatches = false (default): 32, 32, 32, 4.
* padMinibatches = true: 32, 32, 32, 32 (note: the last minibatch will have 4 real examples, and 28 masked out padding examples).
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an * exact number of examples in all minibatches to function. */ public Builder padMinibatches(boolean padMinibatches) { this.padMinibatches = padMinibatches; return this; } /** * Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null) */ public Builder preProcessor(MultiDataSetPreProcessor preProcessor) { this.preProcessor = preProcessor; return this; } /** * Specify the source of the data for classification. */ public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) { this.sentenceProvider = sentenceProvider; return this; } /** * Specify the source of the data for classification on sentence pairs. */ public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) { this.sentencePairProvider = sentencePairProvider; return this; } /** * Specify what arrays should be returned. See {@link BertIterator} for more details. */ public Builder featureArrays(FeatureArrays featureArrays) { this.featureArrays = featureArrays; return this; } /** * Provide the vocabulary as a map. Keys are the words in the vocabulary, and values are the indices of those * words. For indices, they should be in range 0 to vocabMap.size()-1 inclusive.
* If using {@link BertWordPieceTokenizerFactory}, * this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()} */ public Builder vocabMap(Map vocabMap) { this.vocabMap = vocabMap; return this; } /** * Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a * masked language model. This can be used to customize how the masking is performed.
* Default: {@link BertMaskedLMMasker} */ public Builder masker(BertSequenceMasker masker) { this.masker = masker; return this; } /** * Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a * masked language model. Used to specify the format that the labels should be returned in. * See {@link BertIterator} for more details. */ public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) { this.unsupervisedLabelFormat = labelFormat; return this; } /** * Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a * masked language model. This specifies the token (such as "[MASK]") that should be used when a value is masked out. * Note that this is passed to the {@link BertSequenceMasker} defined by {@link #masker(BertSequenceMasker)} hence * the exact behaviour will depend on what masker is used.
* Note that this must be in the vocabulary map set in {@link #vocabMap} */ public Builder maskToken(String maskToken) { this.maskToken = maskToken; return this; } /** * Prepend the specified token to the sequences, when doing supervised training.
* i.e., any token sequences will have this added at the start.
* Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.
* No token is prepended by default. * * @param prependToken The token to start each sequence with (null: no token will be prepended) */ public Builder prependToken(String prependToken) { this.prependToken = prependToken; return this; } /** * Append the specified token to the sequences, when doing training on sentence pairs.
* Generally "[SEP]" is used * No token in appended by default. * * @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended) * @return */ public Builder appendToken(String appendToken) { this.appendToken = appendToken; return this; } public BertIterator build() { Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed"); Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required"); Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map) to set"); Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method"); Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified"); if (sentencePairProvider != null) { Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider"); Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider"); Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider"); Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null"); } if (appendToken != null) { Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider."); } return new BertIterator(this); } } private static class SentencePairListProcessed { private int listLength = 0; @Getter private long[] segIdOnesFrom; private int cursor = 0; private SentenceListProcessed sentenceListProcessed; private SentencePairListProcessed(int listLength) { this.listLength = listLength; segIdOnesFrom = new long[listLength]; sentenceListProcessed = new SentenceListProcessed(listLength); } private void addProcessedToList(long segIdIdx, Pair, String> tokenizedSentencePairAndLabel) { segIdOnesFrom[cursor] = segIdIdx; sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel); cursor++; } private void setMaxL(int maxL) { sentenceListProcessed.setMaxL(maxL); } private int getMaxL() { return sentenceListProcessed.getMaxL(); } private List, String>> getTokensAndLabelList() { return sentenceListProcessed.getTokensAndLabelList(); } } private static class SentenceListProcessed { private int listLength; @Getter @Setter private int maxL; @Getter private List, String>> tokensAndLabelList; private SentenceListProcessed(int listLength) { this.listLength = listLength; tokensAndLabelList = new ArrayList<>(listLength); } private void addProcessedToList(Pair, String> tokenizedSentenceAndLabel) { tokensAndLabelList.add(tokenizedSentenceAndLabel); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy