opennlp.tools.util.eval.CrossValidationPartitioner Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.util.eval;
import java.io.IOException;
import java.util.Collection;
import java.util.NoSuchElementException;
import opennlp.tools.util.CollectionObjectStream;
import opennlp.tools.util.ObjectStream;
/**
* Provides access to training and test partitions for n-fold cross validation.
*
* Cross validation is used to evaluate the performance of a classifier when only
* training data is available. The training set is split into n parts
* and the training / evaluation is performed n times on these parts.
* The training partition always consists of n -1 parts and one part is used for testing.
*
* To use the CrossValidationPartioner
a client iterates over the n
* TrainingSampleStream
s. Each TrainingSampleStream
represents
* one partition and is used first for training and afterwards for testing.
* The TestSampleStream
can be obtained from the TrainingSampleStream
* with the getTestSampleStream
method.
*/
public class CrossValidationPartitioner {
/**
* The TestSampleStream
iterates over all test elements.
*
* @param
*/
private static class TestSampleStream implements ObjectStream {
private ObjectStream sampleStream;
private final int numberOfPartitions;
private final int testIndex;
private int index;
private boolean isPoisened;
private TestSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) {
this.numberOfPartitions = numberOfPartitions;
this.sampleStream = sampleStream;
this.testIndex = testIndex;
}
public E read() throws IOException {
if (isPoisened) {
throw new IllegalStateException();
}
// skip training samples
while (index % numberOfPartitions != testIndex) {
sampleStream.read();
index++;
}
index++;
return sampleStream.read();
}
/**
* Throws UnsupportedOperationException
*/
public void reset() {
throw new UnsupportedOperationException();
}
public void close() throws IOException {
sampleStream.close();
isPoisened = true;
}
void poison() {
isPoisened = true;
}
}
/**
* The TrainingSampleStream
which iterates over
* all training elements.
*
* Note:
* After the TestSampleStream
was obtained
* the TrainingSampleStream
must not be used
* anymore, otherwise a {@link IllegalStateException}
* is thrown.
*
* The ObjectStream
s must not be used anymore after the
* CrossValidationPartitioner
was moved
* to one of next partitions. If they are called anyway
* a {@link IllegalStateException} is thrown.
*
* @param
*/
public static class TrainingSampleStream implements ObjectStream {
private ObjectStream sampleStream;
private final int numberOfPartitions;
private final int testIndex;
private int index;
private boolean isPoisened;
private TestSampleStream testSampleStream;
TrainingSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) {
this.numberOfPartitions = numberOfPartitions;
this.sampleStream = sampleStream;
this.testIndex = testIndex;
}
public E read() throws IOException {
if (testSampleStream != null || isPoisened) {
throw new IllegalStateException();
}
// If the test element is reached skip over it to not include it in
// the training data
if (index % numberOfPartitions == testIndex) {
sampleStream.read();
index++;
}
index++;
return sampleStream.read();
}
/**
* Resets the training sample. Use this if you need to collect things before
* training, for example, to collect induced abbreviations or create a POS
* Dictionary.
*
* @throws IOException
*/
public void reset() throws IOException {
if (testSampleStream != null || isPoisened) {
throw new IllegalStateException();
}
this.index = 0;
this.sampleStream.reset();
}
public void close() throws IOException {
sampleStream.close();
poison();
}
void poison() {
isPoisened = true;
if (testSampleStream != null)
testSampleStream.poison();
}
/**
* Retrieves the ObjectStream
over the test/evaluations
* elements and poisons this TrainingSampleStream
.
* From now on calls to the hasNext and next methods are forbidden
* and will raise anIllegalArgumentException
.
*
* @return the test sample stream
*/
public ObjectStream getTestSampleStream() throws IOException {
if (isPoisened) {
throw new IllegalStateException();
}
if (testSampleStream == null) {
sampleStream.reset();
testSampleStream = new TestSampleStream(sampleStream, numberOfPartitions, testIndex);
}
return testSampleStream;
}
}
/**
* An ObjectStream
over the whole set of data samples which
* are used for the cross validation.
*/
private ObjectStream sampleStream;
/**
* The number of parts the data is divided into.
*/
private final int numberOfPartitions;
/**
* The index of test part.
*/
private int testIndex;
/**
* The last handed out TrainingIterator
. The reference
* is needed to poison the instance to fail fast if it is used
* despite the fact that it is forbidden!.
*/
private TrainingSampleStream lastTrainingSampleStream;
/**
* Initializes the current instance.
*
* @param inElements
* @param numberOfPartitions
*/
public CrossValidationPartitioner(ObjectStream inElements, int numberOfPartitions) {
this.sampleStream = inElements;
this.numberOfPartitions = numberOfPartitions;
}
/**
* Initializes the current instance.
*
* @param elements
* @param numberOfPartitions
*/
public CrossValidationPartitioner(Collection elements, int numberOfPartitions) {
this(new CollectionObjectStream(elements), numberOfPartitions);
}
/**
* Checks if there are more partitions available.
*/
public boolean hasNext() {
return testIndex < numberOfPartitions;
}
/**
* Retrieves the next training and test partitions.
*/
public TrainingSampleStream next() throws IOException {
if (hasNext()) {
if (lastTrainingSampleStream != null)
lastTrainingSampleStream.poison();
sampleStream.reset();
TrainingSampleStream trainingSampleStream = new TrainingSampleStream(sampleStream,
numberOfPartitions, testIndex);
testIndex++;
lastTrainingSampleStream = trainingSampleStream;
return trainingSampleStream;
}
else {
throw new NoSuchElementException();
}
}
@Override
public String toString() {
return "At partition" + Integer.toString(testIndex + 1) +
" of " + Integer.toString(numberOfPartitions);
}
}