chalk.tools.util.eval.CrossValidationPartitioner Maven / Gradle / Ivy
/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package chalk.tools.util.eval; import java.io.IOException; import java.util.Collection; import java.util.NoSuchElementException; import chalk.tools.util.CollectionObjectStream; import chalk.tools.util.ObjectStream; /** * Provides access to training and test partitions for n-fold cross validation. *
TrainingSampleStream represents * one partition and is used first for training and afterwards for testing. * The* Cross validation is used to evaluate the performance of a classifier when only * training data is available. The training set is split into n parts * and the training / evaluation is performed n times on these parts. * The training partition always consists of n -1 parts and one part is used for testing. *
* To use the
CrossValidationPartioner
a client iterates over the n *TrainingSampleStream
s. EachTestSampleStream
can be obtained from theTrainingSampleStream
* with thegetTestSampleStream
method. */ public class CrossValidationPartitioner{ /** * The TestSampleStream
iterates over all test elements. * * @param*/ private static class TestSampleStream implements ObjectStream { private ObjectStream sampleStream; private final int numberOfPartitions; private final int testIndex; private int index; private boolean isPoisened; private TestSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) { this.numberOfPartitions = numberOfPartitions; this.sampleStream = sampleStream; this.testIndex = testIndex; } public E read() throws IOException { if (isPoisened) { throw new IllegalStateException(); } // skip training samples while (index % numberOfPartitions != testIndex) { sampleStream.read(); index++; } index++; return sampleStream.read(); } /** * Throws UnsupportedOperationException
*/ public void reset() { throw new UnsupportedOperationException(); } public void close() throws IOException { sampleStream.close(); isPoisened = true; } void poison() { isPoisened = true; } } /** * TheTrainingSampleStream
which iterates over * all training elements. * * Note: * After theTestSampleStream
was obtained * theTrainingSampleStream
must not be used * anymore, otherwise a {@link IllegalStateException} * is thrown. * * TheObjectStream>
s must not be used anymore after the *CrossValidationPartitioner
was moved * to one of next partitions. If they are called anyway * a {@link IllegalStateException} is thrown. * * @param*/ public static class TrainingSampleStream implements ObjectStream { private ObjectStream sampleStream; private final int numberOfPartitions; private final int testIndex; private int index; private boolean isPoisened; private TestSampleStream testSampleStream; TrainingSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) { this.numberOfPartitions = numberOfPartitions; this.sampleStream = sampleStream; this.testIndex = testIndex; } public E read() throws IOException { if (testSampleStream != null || isPoisened) { throw new IllegalStateException(); } // If the test element is reached skip over it to not include it in // the training data if (index % numberOfPartitions == testIndex) { sampleStream.read(); index++; } index++; return sampleStream.read(); } /** * Resets the training sample. Use this if you need to collect things before * training, for example, to collect induced abbreviations or create a POS * Dictionary. * * @throws IOException */ public void reset() throws IOException { if (testSampleStream != null || isPoisened) { throw new IllegalStateException(); } this.index = 0; this.sampleStream.reset(); } public void close() throws IOException { sampleStream.close(); poison(); } void poison() { isPoisened = true; if (testSampleStream != null) testSampleStream.poison(); } /** * Retrieves the ObjectStream
over the test/evaluations * elements and poisons thisTrainingSampleStream
. * From now on calls to the hasNext and next methods are forbidden * and will raise anIllegalArgumentException
. * * @return the test sample stream */ public ObjectStreamgetTestSampleStream() throws IOException { if (isPoisened) { throw new IllegalStateException(); } if (testSampleStream == null) { sampleStream.reset(); testSampleStream = new TestSampleStream (sampleStream, numberOfPartitions, testIndex); } return testSampleStream; } } /** * An ObjectStream
over the whole set of data samples which * are used for the cross validation. */ private ObjectStreamsampleStream; /** * The number of parts the data is divided into. */ private final int numberOfPartitions; /** * The index of test part. */ private int testIndex; /** * The last handed out TrainingIterator
. The reference * is needed to poison the instance to fail fast if it is used * despite the fact that it is forbidden!. */ private TrainingSampleStreamlastTrainingSampleStream; /** * Initializes the current instance. * * @param inElements * @param numberOfPartitions */ public CrossValidationPartitioner(ObjectStream inElements, int numberOfPartitions) { this.sampleStream = inElements; this.numberOfPartitions = numberOfPartitions; } /** * Initializes the current instance. * * @param elements * @param numberOfPartitions */ public CrossValidationPartitioner(Collection elements, int numberOfPartitions) { this(new CollectionObjectStream (elements), numberOfPartitions); } /** * Checks if there are more partitions available. */ public boolean hasNext() { return testIndex < numberOfPartitions; } /** * Retrieves the next training and test partitions. */ public TrainingSampleStream next() throws IOException { if (hasNext()) { if (lastTrainingSampleStream != null) lastTrainingSampleStream.poison(); sampleStream.reset(); TrainingSampleStream trainingSampleStream = new TrainingSampleStream (sampleStream, numberOfPartitions, testIndex); testIndex++; lastTrainingSampleStream = trainingSampleStream; return trainingSampleStream; } else { throw new NoSuchElementException(); } } @Override public String toString() { return "At partition" + Integer.toString(testIndex + 1) + " of " + Integer.toString(numberOfPartitions); } }
© 2015 - 2024 Weber Informatics LLC | Privacy Policy