![JAR search and dependency download from the Maven repository](/logo.png)
org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.iterator.provider;
import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledPairSentenceProvider;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Triple;
import org.nd4j.common.util.MathUtils;
import java.util.*;
public class CollectionLabeledPairSentenceProvider implements LabeledPairSentenceProvider {
private final List sentenceL;
private final List sentenceR;
private final List labels;
private final Random rng;
private final int[] order;
private final List allLabels;
private int cursor = 0;
/**
* Lists containing sentences to iterate over with a third for labels
* Sentences in the same position in the first two lists are considered a pair
* @param sentenceL
* @param sentenceR
* @param labelsForSentences
*/
public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, @NonNull List sentenceR,
@NonNull List labelsForSentences) {
this(sentenceL, sentenceR, labelsForSentences, new Random());
}
/**
* Lists containing sentences to iterate over with a third for labels
* Sentences in the same position in the first two lists are considered a pair
* @param sentenceL
* @param sentenceR
* @param labelsForSentences
* @param rng If null, list order is not shuffled
*/
public CollectionLabeledPairSentenceProvider(@NonNull List sentenceL, List sentenceR, @NonNull List labelsForSentences,
Random rng) {
if (sentenceR.size() != sentenceL.size()) {
throw new IllegalArgumentException("Sentence lists must be same size (first list size: "
+ sentenceL.size() + ", second list size: " + sentenceR.size() + ")");
}
if (sentenceR.size() != labelsForSentences.size()) {
throw new IllegalArgumentException("Sentence pairs and labels must be same size (sentence pair size: "
+ sentenceR.size() + ", labels size: " + labelsForSentences.size() + ")");
}
this.sentenceL = sentenceL;
this.sentenceR = sentenceR;
this.labels = labelsForSentences;
this.rng = rng;
if (rng == null) {
order = null;
} else {
order = new int[sentenceR.size()];
for (int i = 0; i < sentenceR.size(); i++) {
order[i] = i;
}
MathUtils.shuffleArray(order, rng);
}
//Collect set of unique labels for all sentences
Set uniqueLabels = new HashSet<>(labelsForSentences);
allLabels = new ArrayList<>(uniqueLabels);
Collections.sort(allLabels);
}
@Override
public boolean hasNext() {
return cursor < sentenceR.size();
}
@Override
public Triple nextSentencePair() {
Preconditions.checkState(hasNext(),"No next element available");
int idx;
if (rng == null) {
idx = cursor++;
} else {
idx = order[cursor++];
}
return new Triple<>(sentenceL.get(idx), sentenceR.get(idx), labels.get(idx));
}
@Override
public void reset() {
cursor = 0;
if (rng != null) {
MathUtils.shuffleArray(order, rng);
}
}
@Override
public int totalNumSentences() {
return sentenceR.size();
}
@Override
public List allLabels() {
return allLabels;
}
@Override
public int numLabelClasses() {
return allLabels.size();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy