![JAR search and dependency download from the Maven repository](/logo.png)
opennlp.tools.util.eval.CrossValidationPartitioner Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.util.eval;
import java.io.IOException;
import java.util.Collection;
import java.util.NoSuchElementException;
import opennlp.tools.util.CollectionObjectStream;
import opennlp.tools.util.ObjectStream;
/**
* Provides access to training and test partitions for n-fold cross validation.
*
* Cross validation is used to evaluate the performance of a classifier when only
* training data is available. The training set is split into n parts
* and the training / evaluation is performed {@code n} times on these parts.
* The training partition always consists of {@code n - 1} parts and one part is used for testing.
*
* To use the {@link CrossValidationPartitioner} a client iterates over the n
* {@link TrainingSampleStream stream}. Each {@link TrainingSampleStream} represents
* one partition and is used first for training and afterwards for testing.
* The {@link TestSampleStream} can be obtained via the
* {@link TrainingSampleStream#getTestSampleStream()} method.
*/
public class CrossValidationPartitioner {
/**
* The {@link TestSampleStream} iterates over all test elements.
*
* @param The generic type of samples.
*/
private static class TestSampleStream implements ObjectStream {
private final ObjectStream sampleStream;
private final int numberOfPartitions;
private final int testIndex;
private int index;
private boolean isPoisoned;
private TestSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) {
this.numberOfPartitions = numberOfPartitions;
this.sampleStream = sampleStream;
this.testIndex = testIndex;
}
@Override
public E read() throws IOException {
if (isPoisoned) {
throw new IllegalStateException();
}
// skip training samples
while (index % numberOfPartitions != testIndex) {
sampleStream.read();
index++;
}
index++;
return sampleStream.read();
}
/**
* @throws UnsupportedOperationException Thrown to signal no implementation is available.
*/
@Override
public void reset() {
throw new UnsupportedOperationException();
}
@Override
public void close() throws IOException {
sampleStream.close();
isPoisoned = true;
}
void poison() {
isPoisoned = true;
}
}
/**
* The {@link TrainingSampleStream} which iterates over
* all training elements.
*
* Note:
* After the {@link TestSampleStream} was obtained
* the {@link TrainingSampleStream} must not be used
* anymore, otherwise a {@link IllegalStateException}
* is thrown.
*
* The {@link ObjectStream streams} must not be used anymore after the
* {@link CrossValidationPartitioner} was moved to one of next partitions.
* If they are called anyway an {@link IllegalStateException} is thrown.
*
* @param The generic type of samples.
*/
public static class TrainingSampleStream implements ObjectStream {
private final ObjectStream sampleStream;
private final int numberOfPartitions;
private final int testIndex;
private int index;
private boolean isPoisoned;
private TestSampleStream testSampleStream;
TrainingSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) {
this.numberOfPartitions = numberOfPartitions;
this.sampleStream = sampleStream;
this.testIndex = testIndex;
}
@Override
public E read() throws IOException {
if (testSampleStream != null || isPoisoned) {
throw new IllegalStateException();
}
// If the test element is reached skip over it to not include it in
// the training data
if (index % numberOfPartitions == testIndex) {
sampleStream.read();
index++;
}
index++;
return sampleStream.read();
}
/**
* Resets the training sample. Use this if you need to collect things before
* training, for example, to collect induced abbreviations or create a POS
* Dictionary.
*
* @throws IOException Thrown if IO errors occurred.
* @throws IllegalStateException Thrown if a non-consistent state occurred.
*/
@Override
public void reset() throws IOException {
if (testSampleStream != null || isPoisoned) {
throw new IllegalStateException();
}
this.index = 0;
this.sampleStream.reset();
}
@Override
public void close() throws IOException {
sampleStream.close();
poison();
}
void poison() {
isPoisoned = true;
if (testSampleStream != null)
testSampleStream.poison();
}
/**
* Retrieves the {@link ObjectStream} over the test/evaluations
* elements and poisons this {@link TrainingSampleStream}.
* From now on calls to the hasNext and next methods are forbidden
* and will raise an {@link IllegalArgumentException}.
*
* @return The test sample {@link ObjectStream stream}.
*/
public ObjectStream getTestSampleStream() throws IOException {
if (isPoisoned) {
throw new IllegalStateException();
}
if (testSampleStream == null) {
sampleStream.reset();
testSampleStream = new TestSampleStream<>(sampleStream, numberOfPartitions, testIndex);
}
return testSampleStream;
}
}
/**
* An {@link ObjectStream stream} over the whole set of data samples which
* are used for the cross validation.
*/
private final ObjectStream sampleStream;
/**
* The number of parts the data is divided into.
*/
private final int numberOfPartitions;
/**
* The index of test part.
*/
private int testIndex;
/**
* The last handed out {@link TrainingSampleStream}. The reference
* is needed to poison the instance to fail fast if it is used
* despite the fact that it is forbidden!.
*/
private TrainingSampleStream lastTrainingSampleStream;
/**
* Initializes {@link CrossValidationPartitioner} instance.
*
* @param inElements The {@link ObjectStream} that provides the elements.
* @param numberOfPartitions The number of partitions. Must be greater than {@code 0}.
*/
public CrossValidationPartitioner(ObjectStream inElements, int numberOfPartitions) {
this.sampleStream = inElements;
this.numberOfPartitions = numberOfPartitions;
}
/**
* Initializes {@link CrossValidationPartitioner} instance.
*
* @param elements A {@link Collection} that provides the elements.
* @param numberOfPartitions The number of partitions. Must be greater than {@code 0}.
*/
public CrossValidationPartitioner(Collection elements, int numberOfPartitions) {
this(new CollectionObjectStream<>(elements), numberOfPartitions);
}
/**
* Checks if there are more partitions available.
*/
public boolean hasNext() {
return testIndex < numberOfPartitions;
}
/**
* Retrieves the next training and test partitions.
*/
public TrainingSampleStream next() throws IOException {
if (hasNext()) {
if (lastTrainingSampleStream != null)
lastTrainingSampleStream.poison();
sampleStream.reset();
TrainingSampleStream trainingSampleStream = new TrainingSampleStream<>(sampleStream,
numberOfPartitions, testIndex);
testIndex++;
lastTrainingSampleStream = trainingSampleStream;
return trainingSampleStream;
}
else {
throw new NoSuchElementException();
}
}
@Override
public String toString() {
return "At partition" + (testIndex + 1) +
" of " + numberOfPartitions;
}
}