Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.iterator;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class BertIterator implements MultiDataSetIterator {
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
protected Task task;
protected TokenizerFactory tokenizerFactory;
protected int maxTokens = -1;
protected int minibatchSize = 32;
protected boolean padMinibatches = false;
@Getter
@Setter
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected LabeledPairSentenceProvider sentencePairProvider = null;
protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays;
protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
protected BertSequenceMasker masker = null;
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken;
protected String prependToken;
protected String appendToken;
protected List vocabKeysAsList;
protected BertIterator(Builder b) {
this.task = b.task;
this.tokenizerFactory = b.tokenizerFactory;
this.maxTokens = b.maxTokens;
this.minibatchSize = b.minibatchSize;
this.padMinibatches = b.padMinibatches;
this.preProcessor = b.preProcessor;
this.sentenceProvider = b.sentenceProvider;
this.sentencePairProvider = b.sentencePairProvider;
this.lengthHandling = b.lengthHandling;
this.featureArrays = b.featureArrays;
this.vocabMap = b.vocabMap;
this.masker = b.masker;
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
this.maskToken = b.maskToken;
this.prependToken = b.prependToken;
this.appendToken = b.appendToken;
}
@Override
public boolean hasNext() {
if (sentenceProvider != null)
return sentenceProvider.hasNext();
return sentencePairProvider.hasNext();
}
@Override
public MultiDataSet next() {
return next(minibatchSize);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Not supported");
}
@Override
public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available");
List, String>> tokensAndLabelList;
int mbSize = 0;
int outLength;
long[] segIdOnesFrom = null;
if (sentenceProvider != null) {
List> list = new ArrayList<>(num);
while (sentenceProvider.hasNext() && mbSize++ < num) {
list.add(sentenceProvider.nextSentence());
}
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(list);
tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
outLength = sentenceListProcessed.getMaxL();
} else if (sentencePairProvider != null) {
List> listPairs = new ArrayList<>(num);
while (sentencePairProvider.hasNext() && mbSize++ < num) {
listPairs.add(sentencePairProvider.nextSentencePair());
}
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(listPairs);
tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
outLength = sentencePairListProcessed.getMaxL();
segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
} else {
//TODO - other types of iterators...
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
}
Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
INDArray[] featureArray = featuresAndMaskArraysPair.getFirst();
INDArray[] featureMaskArray = featuresAndMaskArraysPair.getSecond();
Pair labelsAndMaskArraysPair = convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
INDArray[] labelArray = labelsAndMaskArraysPair.getFirst();
INDArray[] labelMaskArray = labelsAndMaskArraysPair.getSecond();
org.nd4j.linalg.dataset.MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray);
if (preProcessor != null)
preProcessor.preProcess(mds);
return mds;
}
/**
* For use during inference. Will convert a given list of sentences to features and feature masks as appropriate.
*
* @param listOnlySentences
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
*/
public Pair featurizeSentences(List listOnlySentences) {
List> sentencesWithNullLabel = addDummyLabel(listOnlySentences);
SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(sentencesWithNullLabel);
List, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
int outLength = sentenceListProcessed.getMaxL();
if (preProcessor != null) {
Pair featureFeatureMasks = convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featureFeatureMasks.getFirst(), null, featureFeatureMasks.getSecond(), null);
preProcessor.preProcess(dummyMDS);
return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
return convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
}
/**
* For use during inference. Will convert a given pair of a list of sentences to features and feature masks as appropriate.
*
* @param listOnlySentencePairs
* @return Pair of INDArrays[], first element is feature arrays and the second is the masks array
*/
public Pair featurizeSentencePairs(List> listOnlySentencePairs) {
Preconditions.checkState(sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
List> sentencePairsWithNullLabel = addDummyLabelForPairs(listOnlySentencePairs);
SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(sentencePairsWithNullLabel);
List, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
int outLength = sentencePairListProcessed.getMaxL();
long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
if (preProcessor != null) {
Pair featuresAndMaskArraysPair = convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
MultiDataSet dummyMDS = new org.nd4j.linalg.dataset.MultiDataSet(featuresAndMaskArraysPair.getFirst(), null, featuresAndMaskArraysPair.getSecond(), null);
preProcessor.preProcess(dummyMDS);
return new Pair<>(dummyMDS.getFeatures(), dummyMDS.getFeaturesMaskArrays());
}
return convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
}
private Pair convertMiniBatchFeatures(List, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
int mbPadded = padMinibatches ? minibatchSize : tokensAndLabelList.size();
int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength];
int[][] outSegmentId = null;
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID)
outSegmentId = new int[mbPadded][outLength];
for (int i = 0; i < tokensAndLabelList.size(); i++) {
Pair, String> p = tokensAndLabelList.get(i);
List t = p.getFirst();
for (int j = 0; j < outLength && j < t.size(); j++) {
Preconditions.checkState(vocabMap.containsKey(t.get(j)), "Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
int idx = vocabMap.get(t.get(j));
outIdxs[i][j] = idx;
outMask[i][j] = 1;
if (segIdOnesFrom != null && j >= segIdOnesFrom[i])
outSegmentId[i][j] = 1;
}
}
//Create actual arrays. Indices, mask, and optional segment ID
INDArray outIdxsArr = Nd4j.createFromArray(outIdxs);
INDArray outMaskArr = Nd4j.createFromArray(outMask);
INDArray outSegmentIdArr;
INDArray[] f;
INDArray[] fm;
if (featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
outSegmentIdArr = Nd4j.createFromArray(outSegmentId);
f = new INDArray[]{outIdxsArr, outSegmentIdArr};
fm = new INDArray[]{outMaskArr, null};
} else {
f = new INDArray[]{outIdxsArr};
fm = new INDArray[]{outMaskArr};
}
return new Pair<>(f, fm);
}
private SentenceListProcessed tokenizeMiniBatch(List> list) {
//Get and tokenize the sentences for this minibatch
SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
int longestSeq = -1;
for (Pair p : list) {
List tokens = tokenizeSentence(p.getFirst());
sentenceListProcessed.addProcessedToList(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size());
}
//Determine output array length...
int outLength;
switch (lengthHandling) {
case FIXED_LENGTH:
outLength = maxTokens;
break;
case ANY_LENGTH:
outLength = longestSeq;
break;
case CLIP_ONLY:
outLength = Math.min(maxTokens, longestSeq);
break;
default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
}
sentenceListProcessed.setMaxL(outLength);
return sentenceListProcessed;
}
private SentencePairListProcessed tokenizePairsMiniBatch(List> listPairs) {
SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
for (Triple t : listPairs) {
List tokensL = tokenizeSentence(t.getFirst(), true);
List tokensR = tokenizeSentence(t.getSecond(), true);
List tokens = new ArrayList<>(maxTokens);
int maxLength = maxTokens;
if (prependToken != null)
maxLength--;
if (appendToken != null)
maxLength -= 2;
if (tokensL.size() + tokensR.size() > maxLength) {
boolean shortOnL = tokensL.size() < tokensR.size();
int shortSize = Math.min(tokensL.size(), tokensR.size());
if (shortSize > maxLength / 2) {
//both lists need to be sliced
tokensL.subList(maxLength / 2, tokensL.size()).clear(); //if maxsize/2 is odd pop extra on L side to match implementation in TF
tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
} else {
//slice longer list
if (shortOnL) {
//longer on R - slice R
tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
} else {
//longer on L - slice L
tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
}
}
}
if (prependToken != null)
tokens.add(prependToken);
tokens.addAll(tokensL);
if (appendToken != null)
tokens.add(appendToken);
int segIdOnesFrom = tokens.size();
tokens.addAll(tokensR);
if (appendToken != null)
tokens.add(appendToken);
sentencePairListProcessed.addProcessedToList(segIdOnesFrom, new Pair<>(tokens, t.getThird()));
}
sentencePairListProcessed.setMaxL(maxTokens);
return sentencePairListProcessed;
}
private Pair convertMiniBatchLabels(List, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
INDArray[] l = new INDArray[1];
INDArray[] lm;
int mbSize = tokenizedSentences.size();
int mbPadded = padMinibatches ? minibatchSize : tokenizedSentences.size();
if (task == Task.SEQ_CLASSIFICATION) {
//Sequence classification task: output is 2d, one-hot, shape [minibatch, numClasses]
int numClasses;
int[] classLabels = new int[mbPadded];
if (sentenceProvider != null) {
numClasses = sentenceProvider.numLabelClasses();
List labels = sentenceProvider.allLabels();
for (int i = 0; i < mbSize; i++) {
String lbl = tokenizedSentences.get(i).getRight();
classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
}
} else if (sentencePairProvider != null) {
numClasses = sentencePairProvider.numLabelClasses();
List labels = sentencePairProvider.allLabels();
for (int i = 0; i < mbSize; i++) {
String lbl = tokenizedSentences.get(i).getRight();
classLabels[i] = labels.indexOf(lbl);
Preconditions.checkState(classLabels[i] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
}
} else {
throw new RuntimeException();
}
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
for (int i = 0; i < mbSize; i++) {
l[0].putScalar(i, classLabels[i], 1.0);
}
lm = null;
if (padMinibatches && mbSize != mbPadded) {
INDArray a = Nd4j.zeros(DataType.FLOAT, mbPadded, 1);
lm = new INDArray[]{a};
a.get(NDArrayIndex.interval(0, mbSize), NDArrayIndex.all()).assign(1);
}
} else if (task == Task.UNSUPERVISED) {
//Unsupervised, masked language model task
//Output is either 2d, or 3d depending on settings
if (vocabKeysAsList == null) {
String[] arr = new String[vocabMap.size()];
for (Map.Entry e : vocabMap.entrySet()) {
arr[e.getValue()] = e.getKey();
}
vocabKeysAsList = Arrays.asList(arr);
}
int vocabSize = vocabMap.size();
INDArray labelArr;
INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength);
if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
} else {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
}
for (int i = 0; i < mbSize; i++) {
List tokens = tokenizedSentences.get(i).getFirst();
Pair, boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
List maskedTokens = p.getFirst();
boolean[] predictionTarget = p.getSecond();
int seqLen = Math.min(predictionTarget.length, outLength);
for (int j = 0; j < seqLen; j++) {
if (predictionTarget[j]) {
String oldToken = tokenizedSentences.get(i).getFirst().get(j); //This is target
int targetTokenIdx = vocabMap.get(oldToken);
if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
labelArr.putScalar(i, j, targetTokenIdx);
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
labelArr.putScalar(i, j, targetTokenIdx, 1.0);
} else if (unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
labelArr.putScalar(j, i, targetTokenIdx, 1.0);
}
lMask.putScalar(i, j, 1.0);
//Also update previously created feature label indexes:
String newToken = maskedTokens.get(j);
int newTokenIdx = vocabMap.get(newToken);
//first element of features is outIdxsArr
featureArray[0].putScalar(i, j, newTokenIdx);
}
}
}
l[0] = labelArr;
lm = new INDArray[1];
lm[0] = lMask;
} else {
throw new IllegalStateException("Task not yet implemented: " + task);
}
return new Pair<>(l, lm);
}
private List tokenizeSentence(String sentence) {
return tokenizeSentence(sentence, false);
}
private List tokenizeSentence(String sentence, boolean ignorePrependAppend) {
Tokenizer t = tokenizerFactory.create(sentence);
List tokens = new ArrayList<>();
if (prependToken != null && !ignorePrependAppend)
tokens.add(prependToken);
while (t.hasMoreTokens()) {
String token = t.nextToken();
tokens.add(token);
}
if (appendToken != null && !ignorePrependAppend)
tokens.add(appendToken);
return tokens;
}
private List> addDummyLabel(List listOnlySentences) {
List> list = new ArrayList<>(listOnlySentences.size());
for (String s : listOnlySentences) {
list.add(new Pair(s, null));
}
return list;
}
private List> addDummyLabelForPairs(List> listOnlySentencePairs) {
List> list = new ArrayList<>(listOnlySentencePairs.size());
for (Pair p : listOnlySentencePairs) {
list.add(new Triple(p.getFirst(), p.getSecond(), null));
}
return list;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
if (sentenceProvider != null) {
sentenceProvider.reset();
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected Task task;
protected TokenizerFactory tokenizerFactory;
protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH;
protected int maxTokens = -1;
protected int minibatchSize = 32;
protected boolean padMinibatches = false;
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected LabeledPairSentenceProvider sentencePairProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken;
protected String prependToken;
protected String appendToken;
/**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
*/
public Builder task(Task task) {
this.task = task;
return this;
}
/**
* Specify the TokenizerFactory to use.
* For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory}
* is used
*/
public Builder tokenizer(TokenizerFactory tokenizerFactory) {
this.tokenizerFactory = tokenizerFactory;
return this;
}
/**
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
*
* @param lengthHandling Length handling
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
* @return
*/
public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) {
this.lengthHandling = lengthHandling;
this.maxTokens = maxLength;
return this;
}
/**
* Minibatch size to use (number of examples to train on for each iteration)
* See also: {@link #padMinibatches}
*
* @param minibatchSize Minibatch size
*/
public Builder minibatchSize(int minibatchSize) {
this.minibatchSize = minibatchSize;
return this;
}
/**
* Default: false (disabled)
* If the dataset is not an exact multiple of the minibatch size, should we pad the smaller final minibatch?
* For example, if we have 100 examples total, and 32 minibatch size, the following number of examples will be returned
* for subsequent calls of next() in the one epoch:
* padMinibatches = false (default): 32, 32, 32, 4.
* padMinibatches = true: 32, 32, 32, 32 (note: the last minibatch will have 4 real examples, and 28 masked out padding examples).
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an
* exact number of examples in all minibatches to function.
*/
public Builder padMinibatches(boolean padMinibatches) {
this.padMinibatches = padMinibatches;
return this;
}
/**
* Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null)
*/
public Builder preProcessor(MultiDataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
return this;
}
/**
* Specify the source of the data for classification.
*/
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
this.sentenceProvider = sentenceProvider;
return this;
}
/**
* Specify the source of the data for classification on sentence pairs.
*/
public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
this.sentencePairProvider = sentencePairProvider;
return this;
}
/**
* Specify what arrays should be returned. See {@link BertIterator} for more details.
*/
public Builder featureArrays(FeatureArrays featureArrays) {
this.featureArrays = featureArrays;
return this;
}
/**
* Provide the vocabulary as a map. Keys are the words in the vocabulary, and values are the indices of those
* words. For indices, they should be in range 0 to vocabMap.size()-1 inclusive.
* If using {@link BertWordPieceTokenizerFactory},
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
*/
public Builder vocabMap(Map vocabMap) {
this.vocabMap = vocabMap;
return this;
}
/**
* Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a
* masked language model. This can be used to customize how the masking is performed.
* Default: {@link BertMaskedLMMasker}
*/
public Builder masker(BertSequenceMasker masker) {
this.masker = masker;
return this;
}
/**
* Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a
* masked language model. Used to specify the format that the labels should be returned in.
* See {@link BertIterator} for more details.
*/
public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) {
this.unsupervisedLabelFormat = labelFormat;
return this;
}
/**
* Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a
* masked language model. This specifies the token (such as "[MASK]") that should be used when a value is masked out.
* Note that this is passed to the {@link BertSequenceMasker} defined by {@link #masker(BertSequenceMasker)} hence
* the exact behaviour will depend on what masker is used.
* Note that this must be in the vocabulary map set in {@link #vocabMap}
*/
public Builder maskToken(String maskToken) {
this.maskToken = maskToken;
return this;
}
/**
* Prepend the specified token to the sequences, when doing supervised training.
* i.e., any token sequences will have this added at the start.
* Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.
* No token is prepended by default.
*
* @param prependToken The token to start each sequence with (null: no token will be prepended)
*/
public Builder prependToken(String prependToken) {
this.prependToken = prependToken;
return this;
}
/**
* Append the specified token to the sequences, when doing training on sentence pairs.
* Generally "[SEP]" is used
* No token in appended by default.
*
* @param appendToken Token at end of each sentence for pairs of sentences (null: no token will be appended)
* @return
*/
public Builder appendToken(String appendToken) {
this.appendToken = appendToken;
return this;
}
public BertIterator build() {
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map) to set");
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
if (sentencePairProvider != null) {
Preconditions.checkState(task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
Preconditions.checkState(featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
Preconditions.checkState(lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
Preconditions.checkState(sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
}
if (appendToken != null) {
Preconditions.checkState(sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
}
return new BertIterator(this);
}
}
private static class SentencePairListProcessed {
private int listLength = 0;
@Getter
private long[] segIdOnesFrom;
private int cursor = 0;
private SentenceListProcessed sentenceListProcessed;
private SentencePairListProcessed(int listLength) {
this.listLength = listLength;
segIdOnesFrom = new long[listLength];
sentenceListProcessed = new SentenceListProcessed(listLength);
}
private void addProcessedToList(long segIdIdx, Pair, String> tokenizedSentencePairAndLabel) {
segIdOnesFrom[cursor] = segIdIdx;
sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
cursor++;
}
private void setMaxL(int maxL) {
sentenceListProcessed.setMaxL(maxL);
}
private int getMaxL() {
return sentenceListProcessed.getMaxL();
}
private List, String>> getTokensAndLabelList() {
return sentenceListProcessed.getTokensAndLabelList();
}
}
private static class SentenceListProcessed {
private int listLength;
@Getter
@Setter
private int maxL;
@Getter
private List, String>> tokensAndLabelList;
private SentenceListProcessed(int listLength) {
this.listLength = listLength;
tokensAndLabelList = new ArrayList<>(listLength);
}
private void addProcessedToList(Pair, String> tokenizedSentenceAndLabel) {
tokensAndLabelList.add(tokenizedSentenceAndLabel);
}
}
}