All Downloads are FREE. Search and download functionalities are using the official Maven repository.

chalk.tools.util.eval.CrossValidationPartitioner Maven / Gradle / Ivy

There is a newer version: 1.3.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package chalk.tools.util.eval;

import java.io.IOException;
import java.util.Collection;
import java.util.NoSuchElementException;

import chalk.tools.util.CollectionObjectStream;
import chalk.tools.util.ObjectStream;


/**
 * Provides access to training and test partitions for n-fold cross validation.
 * 

* Cross validation is used to evaluate the performance of a classifier when only * training data is available. The training set is split into n parts * and the training / evaluation is performed n times on these parts. * The training partition always consists of n -1 parts and one part is used for testing. *

* To use the CrossValidationPartioner a client iterates over the n * TrainingSampleStreams. Each TrainingSampleStream represents * one partition and is used first for training and afterwards for testing. * The TestSampleStream can be obtained from the TrainingSampleStream * with the getTestSampleStream method. */ public class CrossValidationPartitioner { /** * The TestSampleStream iterates over all test elements. * * @param */ private static class TestSampleStream implements ObjectStream { private ObjectStream sampleStream; private final int numberOfPartitions; private final int testIndex; private int index; private boolean isPoisened; private TestSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) { this.numberOfPartitions = numberOfPartitions; this.sampleStream = sampleStream; this.testIndex = testIndex; } public E read() throws IOException { if (isPoisened) { throw new IllegalStateException(); } // skip training samples while (index % numberOfPartitions != testIndex) { sampleStream.read(); index++; } index++; return sampleStream.read(); } /** * Throws UnsupportedOperationException */ public void reset() { throw new UnsupportedOperationException(); } public void close() throws IOException { sampleStream.close(); isPoisened = true; } void poison() { isPoisened = true; } } /** * The TrainingSampleStream which iterates over * all training elements. * * Note: * After the TestSampleStream was obtained * the TrainingSampleStream must not be used * anymore, otherwise a {@link IllegalStateException} * is thrown. * * The ObjectStream>s must not be used anymore after the * CrossValidationPartitioner was moved * to one of next partitions. If they are called anyway * a {@link IllegalStateException} is thrown. * * @param */ public static class TrainingSampleStream implements ObjectStream { private ObjectStream sampleStream; private final int numberOfPartitions; private final int testIndex; private int index; private boolean isPoisened; private TestSampleStream testSampleStream; TrainingSampleStream(ObjectStream sampleStream, int numberOfPartitions, int testIndex) { this.numberOfPartitions = numberOfPartitions; this.sampleStream = sampleStream; this.testIndex = testIndex; } public E read() throws IOException { if (testSampleStream != null || isPoisened) { throw new IllegalStateException(); } // If the test element is reached skip over it to not include it in // the training data if (index % numberOfPartitions == testIndex) { sampleStream.read(); index++; } index++; return sampleStream.read(); } /** * Resets the training sample. Use this if you need to collect things before * training, for example, to collect induced abbreviations or create a POS * Dictionary. * * @throws IOException */ public void reset() throws IOException { if (testSampleStream != null || isPoisened) { throw new IllegalStateException(); } this.index = 0; this.sampleStream.reset(); } public void close() throws IOException { sampleStream.close(); poison(); } void poison() { isPoisened = true; if (testSampleStream != null) testSampleStream.poison(); } /** * Retrieves the ObjectStream over the test/evaluations * elements and poisons this TrainingSampleStream. * From now on calls to the hasNext and next methods are forbidden * and will raise anIllegalArgumentException. * * @return the test sample stream */ public ObjectStream getTestSampleStream() throws IOException { if (isPoisened) { throw new IllegalStateException(); } if (testSampleStream == null) { sampleStream.reset(); testSampleStream = new TestSampleStream(sampleStream, numberOfPartitions, testIndex); } return testSampleStream; } } /** * An ObjectStream over the whole set of data samples which * are used for the cross validation. */ private ObjectStream sampleStream; /** * The number of parts the data is divided into. */ private final int numberOfPartitions; /** * The index of test part. */ private int testIndex; /** * The last handed out TrainingIterator. The reference * is needed to poison the instance to fail fast if it is used * despite the fact that it is forbidden!. */ private TrainingSampleStream lastTrainingSampleStream; /** * Initializes the current instance. * * @param inElements * @param numberOfPartitions */ public CrossValidationPartitioner(ObjectStream inElements, int numberOfPartitions) { this.sampleStream = inElements; this.numberOfPartitions = numberOfPartitions; } /** * Initializes the current instance. * * @param elements * @param numberOfPartitions */ public CrossValidationPartitioner(Collection elements, int numberOfPartitions) { this(new CollectionObjectStream(elements), numberOfPartitions); } /** * Checks if there are more partitions available. */ public boolean hasNext() { return testIndex < numberOfPartitions; } /** * Retrieves the next training and test partitions. */ public TrainingSampleStream next() throws IOException { if (hasNext()) { if (lastTrainingSampleStream != null) lastTrainingSampleStream.poison(); sampleStream.reset(); TrainingSampleStream trainingSampleStream = new TrainingSampleStream(sampleStream, numberOfPartitions, testIndex); testIndex++; lastTrainingSampleStream = trainingSampleStream; return trainingSampleStream; } else { throw new NoSuchElementException(); } } @Override public String toString() { return "At partition" + Integer.toString(testIndex + 1) + " of " + Integer.toString(numberOfPartitions); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy