Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.iterator;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* BertIterator is a MultiDataSetIterator for training BERT (Transformer) models in the following way:
* (a) Unsupervised - Masked language model task (no sentence matching task is implemented thus far)
* (b) Supervised - For sequence classification (i.e., 1 label per sequence, typically used for fine tuning)
* The task can be specified using {@link Task}.
*
* Example for unsupervised training:
*
* This iterator supports numerous ways of configuring the behaviour with respect to the sequence lengths and data layout.
*
* {@link LengthHandling} configuration:
* Determines how to handle variable-length sequence situations.
* FIXED_LENGTH: Always trim longer sequences to the specified length, and always pad shorter sequences to the specified length.
* ANY_LENGTH: Output length is determined by the length of the longest sequence in the minibatch. Shorter sequences within the
* minibatch are zero padded and masked.
* CLIP_ONLY: For any sequences longer than the specified maximum, clip them. If the maximum sequence length in
* a minibatch is shorter than the specified maximum, no padding will occur. For sequences that are shorter than the
* maximum (within the current minibatch) they will be zero padded and masked.
*
* {@link FeatureArrays} configuration:
* Determines what arrays should be included.
* INDICES_MASK: Indices array and mask array only, no segment ID array. Returns 1 feature array, 1 feature mask array (plus labels).
* INDICES_MASK_SEGMENTID: Indices array, mask array and segment ID array (which is all 0s for single segment tasks). Returns
* 2 feature arrays (indices, segment ID) and 1 feature mask array (plus labels)
*
* {@link UnsupervisedLabelFormat} configuration:
* Only relevant when the task is set to {@link Task#UNSUPERVISED}. Determine the format of the labels:
* RANK2_IDX: return int32 [minibatch, numTokens] array with entries being class numbers. Example use case: with sparse softmax loss functions.
* RANK3_NCL: return float32 [minibatch, numClasses, numTokens] array with 1-hot entries along dimension 1. Example use case: RnnOutputLayer, RnnLossLayer
* RANK3_LNC: return float32 [numTokens, minibatch, numClasses] array with 1-hot entries along dimension 2. This format is occasionally
* used for some RNN layers in libraries such as TensorFlow, for example
*
*/
public class BertIterator implements MultiDataSetIterator {
public enum Task {UNSUPERVISED, SEQ_CLASSIFICATION}
public enum LengthHandling {FIXED_LENGTH, ANY_LENGTH, CLIP_ONLY}
public enum FeatureArrays {INDICES_MASK, INDICES_MASK_SEGMENTID}
public enum UnsupervisedLabelFormat {RANK2_IDX, RANK3_NCL, RANK3_LNC}
protected Task task;
protected TokenizerFactory tokenizerFactory;
protected int maxTokens = -1;
protected int minibatchSize = 32;
protected boolean padMinibatches = false;
@Getter @Setter
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected LengthHandling lengthHandling;
protected FeatureArrays featureArrays;
protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap or similar for fewer objects?
protected BertSequenceMasker masker = null;
protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
protected String maskToken;
protected String prependToken;
protected List vocabKeysAsList;
protected BertIterator(Builder b){
this.task = b.task;
this.tokenizerFactory = b.tokenizerFactory;
this.maxTokens = b.maxTokens;
this.minibatchSize = b.minibatchSize;
this.padMinibatches = b.padMinibatches;
this.preProcessor = b.preProcessor;
this.sentenceProvider = b.sentenceProvider;
this.lengthHandling = b.lengthHandling;
this.featureArrays = b.featureArrays;
this.vocabMap = b.vocabMap;
this.masker = b.masker;
this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
this.maskToken = b.maskToken;
this.prependToken = b.prependToken;
}
@Override
public boolean hasNext() {
return sentenceProvider.hasNext();
}
@Override
public MultiDataSet next() {
return next(minibatchSize);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Not supported");
}
@Override
public MultiDataSet next(int num) {
Preconditions.checkState(hasNext(), "No next element available");
List> list = new ArrayList<>(num);
int count = 0;
if(sentenceProvider != null){
while(sentenceProvider.hasNext() && count++ < num) {
list.add(sentenceProvider.nextSentence());
}
} else {
//TODO - other types of iterators...
throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
}
//Get and tokenize the sentences for this minibatch
List, String>> tokenizedSentences = new ArrayList<>(num);
int longestSeq = -1;
for(Pair p : list){
List tokens = tokenizeSentence(p.getFirst());
tokenizedSentences.add(new Pair<>(tokens, p.getSecond()));
longestSeq = Math.max(longestSeq, tokens.size());
}
//Determine output array length...
int outLength;
switch (lengthHandling){
case FIXED_LENGTH:
outLength = maxTokens;
break;
case ANY_LENGTH:
outLength = longestSeq;
break;
case CLIP_ONLY:
outLength = Math.min(maxTokens, longestSeq);
break;
default:
throw new RuntimeException("Not implemented length handling mode: " + lengthHandling);
}
int mb = tokenizedSentences.size();
int mbPadded = padMinibatches ? minibatchSize : mb;
int[][] outIdxs = new int[mbPadded][outLength];
int[][] outMask = new int[mbPadded][outLength];
for( int i=0; i,String> p = tokenizedSentences.get(i);
List t = p.getFirst();
for( int j=0; j labels = sentenceProvider.allLabels();
for(int i=0; i= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", lbl);
}
} else {
throw new RuntimeException();
}
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
for( int i=0; i e : vocabMap.entrySet()){
arr[e.getValue()] = e.getKey();
}
vocabKeysAsList = Arrays.asList(arr);
}
int vocabSize = vocabMap.size();
INDArray labelArr;
INDArray lMask = Nd4j.zeros(DataType.INT, mbPadded, outLength);
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
} else {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
}
for( int i=0; i tokens = tokenizedSentences.get(i).getFirst();
Pair,boolean[]> p = masker.maskSequence(tokens, maskToken, vocabKeysAsList);
List maskedTokens = p.getFirst();
boolean[] predictionTarget = p.getSecond();
int seqLen = Math.min(predictionTarget.length, outLength);
for(int j=0; j tokenizeSentence(String sentence) {
Tokenizer t = tokenizerFactory.create(sentence);
List tokens = new ArrayList<>();
if(prependToken != null)
tokens.add(prependToken);
while (t.hasMoreTokens()) {
String token = t.nextToken();
tokens.add(token);
}
return tokens;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return true;
}
@Override
public void reset() {
if(sentenceProvider != null){
sentenceProvider.reset();
}
}
public static Builder builder(){
return new Builder();
}
public static class Builder {
protected Task task;
protected TokenizerFactory tokenizerFactory;
protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH;
protected int maxTokens = -1;
protected int minibatchSize = 32;
protected boolean padMinibatches = false;
protected MultiDataSetPreProcessor preProcessor;
protected LabeledSentenceProvider sentenceProvider = null;
protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
protected Map vocabMap; //TODO maybe use Eclipse ObjectIntHashMap for fewer objects?
protected BertSequenceMasker masker = new BertMaskedLMMasker();
protected UnsupervisedLabelFormat unsupervisedLabelFormat;
protected String maskToken;
protected String prependToken;
/**
* Specify the {@link Task} the iterator should be set up for. See {@link BertIterator} for more details.
*/
public Builder task(Task task){
this.task = task;
return this;
}
/**
* Specify the TokenizerFactory to use.
* For BERT, typically {@link org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory}
* is used
*/
public Builder tokenizer(TokenizerFactory tokenizerFactory){
this.tokenizerFactory = tokenizerFactory;
return this;
}
/**
* Specifies how the sequence length of the output data should be handled. See {@link BertIterator} for more details.
* @param lengthHandling Length handling
* @param maxLength Not used if LengthHandling is set to {@link LengthHandling#ANY_LENGTH}
* @return
*/
public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength){
this.lengthHandling = lengthHandling;
this.maxTokens = maxLength;
return this;
}
/**
* Minibatch size to use (number of examples to train on for each iteration)
* See also: {@link #padMinibatches}
* @param minibatchSize Minibatch size
*/
public Builder minibatchSize(int minibatchSize){
this.minibatchSize = minibatchSize;
return this;
}
/**
* Default: false (disabled)
* If the dataset is not an exact multiple of the minibatch size, should we pad the smaller final minibatch?
* For example, if we have 100 examples total, and 32 minibatch size, the following number of examples will be returned
* for subsequent calls of next() in the one epoch:
* padMinibatches = false (default): 32, 32, 32, 4.
* padMinibatches = true: 32, 32, 32, 32 (note: the last minibatch will have 4 real examples, and 28 masked out padding examples).
* Both options should result in exactly the same model. However, some BERT implementations may require exactly an
* exact number of examples in all minibatches to function.
*/
public Builder padMinibatches(boolean padMinibatches){
this.padMinibatches = padMinibatches;
return this;
}
/**
* Set the preprocessor to be used on the MultiDataSets before returning them. Default: none (null)
*/
public Builder preProcessor(MultiDataSetPreProcessor preProcessor){
this.preProcessor = preProcessor;
return this;
}
/**
* Specify the source of the data for classification. Can also be used for unsupervised learning; in the unsupervised
* use case, the labels will be ignored.
*/
public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider){
this.sentenceProvider = sentenceProvider;
return this;
}
/**
* Specify what arrays should be returned. See {@link BertIterator} for more details.
*/
public Builder featureArrays(FeatureArrays featureArrays){
this.featureArrays = featureArrays;
return this;
}
/**
* Provide the vocabulary as a map. Keys are the words in the vocabulary, and values are the indices of those
* words. For indices, they should be in range 0 to vocabMap.size()-1 inclusive.
* If using {@link BertWordPieceTokenizerFactory},
* this can be obtained using {@link BertWordPieceTokenizerFactory#getVocab()}
*/
public Builder vocabMap(Map vocabMap){
this.vocabMap = vocabMap;
return this;
}
/**
* Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a
* masked language model. This can be used to customize how the masking is performed.
* Default: {@link BertMaskedLMMasker}
*/
public Builder masker(BertSequenceMasker masker){
this.masker = masker;
return this;
}
/**
* Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a
* masked language model. Used to specify the format that the labels should be returned in.
* See {@link BertIterator} for more details.
*/
public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat){
this.unsupervisedLabelFormat = labelFormat;
return this;
}
/**
* Used only for unsupervised training (i.e., when task is set to {@link Task#UNSUPERVISED} for learning a
* masked language model. This specifies the token (such as "[MASK]") that should be used when a value is masked out.
* Note that this is passed to the {@link BertSequenceMasker} defined by {@link #masker(BertSequenceMasker)} hence
* the exact behaviour will depend on what masker is used.
* Note that this must be in the vocabulary map set in {@link #vocabMap}
*/
public Builder maskToken(String maskToken){
this.maskToken = maskToken;
return this;
}
/**
* Prepend the specified token to the sequences, when doing supervised training.
* i.e., any token sequences will have this added at the start.
* Some BERT/Transformer models may need this - for example sequences starting with a "[CLS]" token.
* No token is prepended by default.
*
* @param prependToken The token to start each sequence with (null: no token will be prepended)
*/
public Builder prependToken(String prependToken){
this.prependToken = prependToken;
return this;
}
public BertIterator build(){
Preconditions.checkState(task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
Preconditions.checkState(tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
Preconditions.checkState(vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map) to set");
Preconditions.checkState(task != Task.UNSUPERVISED || masker != null, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || unsupervisedLabelFormat != null, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
Preconditions.checkState(task != Task.UNSUPERVISED || maskToken != null, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
return new BertIterator(this);
}
}
}