All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.iterator.CnnSentenceDataSetIterator Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.iterator;

import lombok.AllArgsConstructor;
import lombok.NonNull;
import org.deeplearning4j.iterator.provider.LabelAwareConverter;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;

import java.util.*;

/**
 * A DataSetIterator that provides data for training a CNN sentence classification models (though can of course
 * be used for general documents, not just sentences. The iterator handles conversion of sentences to training data for
 * CNNs, where each word is encoded using the word vector from the specified WordVectors (i.e., word2vec etc) model.
* Labels are encoded using a one-hot representation and are 2d - i.e., are intended to be used with a model that * utilizes global pooling.
*

* Specifically:
* - Features have shape [minibatchSize, 1, maxSentenceLength, wordVectorSize] OR [minibatchSize, 1, wordVectorSize, maxSentenceLength] * depending on the configuration (for sentencesAlongHeight = true/false respectively)
* - Labels are a 2d array with shape [minibatchSize, numLabels].
* * Sentences and labels are provided by a {@link LabeledSentenceProvider} - different implementations of this provide different * ways of loading sentences/documents with labels - for example, from files, etc. *

* Note: With regard to labels to class index assignment, they are sorted alphabetically. To get the assigment/mapping, * use {@link #getLabels()} or {@link #getLabelClassMap()} * * @author Alex Black */ @AllArgsConstructor public class CnnSentenceDataSetIterator implements DataSetIterator { public enum UnknownWordHandling { RemoveWord, UseUnknownVector } /** * Format of features:
* CNN1D: For use with 1d convolution layers: Shape [minibatch, vectorSize, sentenceLength]
* CNN2D: For use with 2d convolution layers: Shape [minibatch, 1, vectorSize, sentenceLength] or [minibatch, 1, sentenceLength, vectorSize], * depending on the setting for 'sentencesAlongHeight' configuration. */ public enum Format { RNN, CNN1D, CNN2D } private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL"; private Format format; private LabeledSentenceProvider sentenceProvider; private WordVectors wordVectors; private TokenizerFactory tokenizerFactory; private UnknownWordHandling unknownWordHandling; private boolean useNormalizedWordVectors; private int minibatchSize; private int maxSentenceLength; private boolean sentencesAlongHeight; private DataSetPreProcessor dataSetPreProcessor; private int wordVectorSize; private int numClasses; private Map labelClassMap; private INDArray unknown; private int cursor = 0; private Pair, String> preLoadedTokens; protected CnnSentenceDataSetIterator(Builder builder) { this.format = builder.format; this.sentenceProvider = builder.sentenceProvider; this.wordVectors = builder.wordVectors; this.tokenizerFactory = builder.tokenizerFactory; this.unknownWordHandling = builder.unknownWordHandling; this.useNormalizedWordVectors = builder.useNormalizedWordVectors; this.minibatchSize = builder.minibatchSize; this.maxSentenceLength = builder.maxSentenceLength; this.sentencesAlongHeight = builder.sentencesAlongHeight; this.dataSetPreProcessor = builder.dataSetPreProcessor; this.numClasses = this.sentenceProvider.numLabelClasses(); this.labelClassMap = new HashMap<>(); int count = 0; //First: sort the labels to ensure the same label assignment order (say train vs. test) List sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels()); Collections.sort(sortedLabels); this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length; for (String s : sortedLabels) { this.labelClassMap.put(s, count++); } if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) { if (useNormalizedWordVectors) { unknown = wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK()); } else { unknown = wordVectors.getWordVectorMatrix(wordVectors.getUNK()); } if(unknown == null){ unknown = wordVectors.getWordVectorMatrix(wordVectors.vocab().wordAtIndex(0)).like(); } } } /** * Generally used post training time to load a single sentence for predictions */ public INDArray loadSingleSentence(String sentence) { List tokens = tokenizeSentence(sentence); if(tokens.isEmpty()) throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" + sentence + "\""); if(format == Format.CNN1D || format == Format.RNN){ int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())}; INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f')); INDArrayIndex[] indices = new INDArrayIndex[3]; indices[0] = NDArrayIndex.point(0); for (int i = 0; i < featuresShape[2]; i++) { INDArray vector = getVector(tokens.get(i)); indices[1] = NDArrayIndex.all(); indices[2] = NDArrayIndex.point(i); features.put(indices, vector); } return features; } else { int[] featuresShape = new int[] {1, 1, 0, 0}; if (sentencesAlongHeight) { featuresShape[2] = Math.min(maxSentenceLength, tokens.size()); featuresShape[3] = wordVectorSize; } else { featuresShape[2] = wordVectorSize; featuresShape[3] = Math.min(maxSentenceLength, tokens.size()); } INDArray features = Nd4j.create(featuresShape); int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]); INDArrayIndex[] indices = new INDArrayIndex[4]; indices[0] = NDArrayIndex.point(0); indices[1] = NDArrayIndex.point(0); for (int i = 0; i < length; i++) { INDArray vector = getVector(tokens.get(i)); if (sentencesAlongHeight) { indices[2] = NDArrayIndex.point(i); indices[3] = NDArrayIndex.all(); } else { indices[2] = NDArrayIndex.all(); indices[3] = NDArrayIndex.point(i); } features.put(indices, vector); } return features; } } private INDArray getVector(String word) { INDArray vector; if (unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL) { //Yes, this *should* be using == for the sentinel String here vector = unknown; } else { if (useNormalizedWordVectors) { vector = wordVectors.getWordVectorMatrixNormalized(word); } else { vector = wordVectors.getWordVectorMatrix(word); } } return vector; } private List tokenizeSentence(String sentence) { Tokenizer t = tokenizerFactory.create(sentence); List tokens = new ArrayList<>(); while (t.hasMoreTokens()) { String token = t.nextToken(); if (!wordVectors.outOfVocabularySupported() && !wordVectors.hasWord(token)) { switch (unknownWordHandling) { case RemoveWord: continue; case UseUnknownVector: token = UNKNOWN_WORD_SENTINEL; } } tokens.add(token); } return tokens; } public Map getLabelClassMap() { return new HashMap<>(labelClassMap); } @Override public List getLabels() { //We don't want to just return the list from the LabelledSentenceProvider, as we sorted them earlier to do the // String -> Integer mapping String[] str = new String[labelClassMap.size()]; for (Map.Entry e : labelClassMap.entrySet()) { str[e.getValue()] = e.getKey(); } return Arrays.asList(str); } @Override public boolean hasNext() { if (sentenceProvider == null) { throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider"); } while (preLoadedTokens == null && sentenceProvider.hasNext()) { //Pre-load tokens. Because we filter out empty strings, or sentences with no valid words //we need to pre-load some tokens. Otherwise, sentenceProvider could have 1 (invalid) sentence //next, hasNext() would return true, but next(int) wouldn't be able to return anything preLoadTokens(); } return preLoadedTokens != null; } private void preLoadTokens() { if (preLoadedTokens != null) { return; } Pair p = sentenceProvider.nextSentence(); List tokens = tokenizeSentence(p.getFirst()); if (!tokens.isEmpty()) { preLoadedTokens = new Pair<>(tokens, p.getSecond()); } } @Override public DataSet next() { return next(minibatchSize); } @Override public DataSet next(int num) { if (sentenceProvider == null) { throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider"); } if (!hasNext()) { throw new NoSuchElementException("No next element"); } List, String>> tokenizedSentences = new ArrayList<>(num); int maxLength = -1; int minLength = Integer.MAX_VALUE; //Track to we know if we can skip mask creation for "all same length" case if (preLoadedTokens != null) { tokenizedSentences.add(preLoadedTokens); maxLength = Math.max(maxLength, preLoadedTokens.getFirst().size()); minLength = Math.min(minLength, preLoadedTokens.getFirst().size()); preLoadedTokens = null; } for (int i = tokenizedSentences.size(); i < num && sentenceProvider.hasNext(); i++) { Pair p = sentenceProvider.nextSentence(); List tokens = tokenizeSentence(p.getFirst()); if (!tokens.isEmpty()) { //Handle edge case: no tokens from sentence maxLength = Math.max(maxLength, tokens.size()); minLength = Math.min(minLength, tokens.size()); tokenizedSentences.add(new Pair<>(tokens, p.getSecond())); } else { //Skip the current iterator i--; } } if (maxSentenceLength > 0 && maxLength > maxSentenceLength) { maxLength = maxSentenceLength; } int currMinibatchSize = tokenizedSentences.size(); INDArray labels = Nd4j.create(currMinibatchSize, numClasses); for (int i = 0; i < tokenizedSentences.size(); i++) { String labelStr = tokenizedSentences.get(i).getSecond(); if (!labelClassMap.containsKey(labelStr)) { throw new IllegalStateException("Got label \"" + labelStr + "\" that is not present in list of LabeledSentenceProvider labels"); } int labelIdx = labelClassMap.get(labelStr); labels.putScalar(i, labelIdx, 1.0); } INDArray features; INDArray featuresMask = null; if(format == Format.CNN1D || format == Format.RNN){ int[] featuresShape = new int[]{currMinibatchSize, wordVectorSize, maxLength}; features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f')); INDArrayIndex[] idxs = new INDArrayIndex[3]; idxs[1] = NDArrayIndex.all(); for (int i = 0; i < currMinibatchSize; i++) { idxs[0] = NDArrayIndex.point(i); List currSentence = tokenizedSentences.get(i).getFirst(); for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) { idxs[2] = NDArrayIndex.point(j); INDArray vector = getVector(currSentence.get(j)); features.put(idxs, vector); } } if (minLength != maxLength) { featuresMask = Nd4j.create(currMinibatchSize, maxLength); for (int i = 0; i < currMinibatchSize; i++) { int sentenceLength = tokenizedSentences.get(i).getFirst().size(); if (sentenceLength >= maxLength) { featuresMask.getRow(i).assign(1.0); } else { featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.interval(0, sentenceLength)).assign(1.0); } } } } else { int[] featuresShape = new int[4]; featuresShape[0] = currMinibatchSize; featuresShape[1] = 1; if (sentencesAlongHeight) { featuresShape[2] = maxLength; featuresShape[3] = wordVectorSize; } else { featuresShape[2] = wordVectorSize; featuresShape[3] = maxLength; } features = Nd4j.create(featuresShape); INDArrayIndex[] indices = new INDArrayIndex[4]; indices[1] = NDArrayIndex.point(0); for (int i = 0; i < currMinibatchSize; i++) { indices[0] = NDArrayIndex.point(i); List currSentence = tokenizedSentences.get(i).getFirst(); for (int j = 0; j < currSentence.size() && j < maxSentenceLength; j++) { INDArray vector = getVector(currSentence.get(j)); if (sentencesAlongHeight) { indices[2] = NDArrayIndex.point(j); indices[3] = NDArrayIndex.all(); } else { indices[2] = NDArrayIndex.all(); indices[3] = NDArrayIndex.point(j); } features.put(indices, vector); } } if (minLength != maxLength) { if(sentencesAlongHeight){ featuresMask = Nd4j.create(currMinibatchSize, 1, maxLength, 1); for (int i = 0; i < currMinibatchSize; i++) { int sentenceLength = tokenizedSentences.get(i).getFirst().size(); if (sentenceLength >= maxLength) { featuresMask.slice(i).assign(1.0); } else { featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength), NDArrayIndex.point(0)).assign(1.0); } } } else { featuresMask = Nd4j.create(currMinibatchSize, 1, 1, maxLength); for (int i = 0; i < currMinibatchSize; i++) { int sentenceLength = tokenizedSentences.get(i).getFirst().size(); if (sentenceLength >= maxLength) { featuresMask.slice(i).assign(1.0); } else { featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength)).assign(1.0); } } } } } DataSet ds = new DataSet(features, labels, featuresMask, null); if (dataSetPreProcessor != null) { dataSetPreProcessor.preProcess(ds); } cursor += ds.numExamples(); return ds; } @Override public int inputColumns() { return wordVectorSize; } @Override public int totalOutcomes() { return numClasses; } @Override public boolean resetSupported() { return true; } @Override public boolean asyncSupported() { return true; } @Override public void reset() { cursor = 0; sentenceProvider.reset(); } @Override public int batch() { return minibatchSize; } @Override public void setPreProcessor(DataSetPreProcessor preProcessor) { this.dataSetPreProcessor = preProcessor; } @Override public DataSetPreProcessor getPreProcessor() { return dataSetPreProcessor; } @Override public void remove() { throw new UnsupportedOperationException("Not supported"); } public static class Builder { private Format format; private LabeledSentenceProvider sentenceProvider = null; private WordVectors wordVectors; private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord; private boolean useNormalizedWordVectors = true; private int maxSentenceLength = -1; private int minibatchSize = 32; private boolean sentencesAlongHeight = true; private DataSetPreProcessor dataSetPreProcessor; /** * @deprecated Due to old default, that will be changed in the future. Use {@link #Builder(Format)} to specify * the {@link Format} of the activations */ @Deprecated public Builder(){ //Default for backward compatibility this(Format.CNN2D); } /** * @param format The format to use for the features - i.e., for 1D or 2D CNNs */ public Builder(@NonNull Format format){ this.format = format; } /** * Specify how the (labelled) sentences / documents should be provided */ public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) { this.sentenceProvider = labeledSentenceProvider; return this; } /** * Specify how the (labelled) sentences / documents should be provided */ public Builder sentenceProvider(LabelAwareIterator iterator, @NonNull List labels) { LabelAwareConverter converter = new LabelAwareConverter(iterator, labels); return sentenceProvider(converter); } /** * Specify how the (labelled) sentences / documents should be provided */ public Builder sentenceProvider(LabelAwareDocumentIterator iterator, @NonNull List labels) { DocumentIteratorConverter converter = new DocumentIteratorConverter(iterator); return sentenceProvider(converter, labels); } /** * Specify how the (labelled) sentences / documents should be provided */ public Builder sentenceProvider(LabelAwareSentenceIterator iterator, @NonNull List labels) { SentenceIteratorConverter converter = new SentenceIteratorConverter(iterator); return sentenceProvider(converter, labels); } /** * Provide the WordVectors instance that should be used for training */ public Builder wordVectors(WordVectors wordVectors) { this.wordVectors = wordVectors; return this; } /** * The {@link TokenizerFactory} that should be used. Defaults to {@link DefaultTokenizerFactory} */ public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) { this.tokenizerFactory = tokenizerFactory; return this; } /** * Specify how unknown words (those that don't have a word vector in the provided WordVectors instance) should be * handled. Default: remove/ignore unknown words. */ public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) { this.unknownWordHandling = unknownWordHandling; return this; } /** * Minibatch size to use for the DataSetIterator */ public Builder minibatchSize(int minibatchSize) { this.minibatchSize = minibatchSize; return this; } /** * Whether normalized word vectors should be used. Default: true */ public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) { this.useNormalizedWordVectors = useNormalizedWordVectors; return this; } /** * Maximum sentence/document length. If sentences exceed this, they will be truncated to this length by * taking the first 'maxSentenceLength' known words. */ public Builder maxSentenceLength(int maxSentenceLength) { this.maxSentenceLength = maxSentenceLength; return this; } /** * If true (default): output features data with shape [minibatchSize, 1, maxSentenceLength, wordVectorSize]
* If false: output features with shape [minibatchSize, 1, wordVectorSize, maxSentenceLength] */ public Builder sentencesAlongHeight(boolean sentencesAlongHeight) { this.sentencesAlongHeight = sentencesAlongHeight; return this; } /** * Optional DataSetPreProcessor */ public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) { this.dataSetPreProcessor = dataSetPreProcessor; return this; } public CnnSentenceDataSetIterator build() { if (wordVectors == null) { throw new IllegalStateException( "Cannot build CnnSentenceDataSetIterator without a WordVectors instance"); } return new CnnSentenceDataSetIterator(this); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy