All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.uci.jforestsx.tools.FoldGenerator Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package edu.uci.jforestsx.tools;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @author Yasser Ganjisaffar 
 */

public class FoldGenerator {

	public void generate(String inputFilename, int folds, String foldsFolder, boolean addTestSet) throws Exception {

		PrintStream[] trainingFiles = new PrintStream[folds];
		PrintStream[] validationFiles = new PrintStream[folds];
		PrintStream[] testFiles;
		if (addTestSet) {
			testFiles = new PrintStream[folds];
		} else {
			testFiles = null;
		}

		for (int f = 0; f < folds; f++) {
			String curFolder = foldsFolder + "/Fold" + (f + 1) + "/";
			new File(curFolder).mkdirs();
			trainingFiles[f] = new PrintStream(curFolder + "train.txt");
			validationFiles[f] = new PrintStream(curFolder + "valid.txt");
			if (addTestSet) {
				testFiles[f] = new PrintStream(curFolder + "test.txt");
			}
		}

		BufferedReader reader = new BufferedReader(new FileReader(new File(inputFilename)));
		String line;
		List instanceIds = new ArrayList();
		int curId = 0;
		while ((line = reader.readLine()) != null) {
			instanceIds.add(curId);
			curId++;
			if (curId % 10000 == 0) {
				System.out.println("Loaded " + curId + " lines.");
			}			
		}
		reader.close();
		Collections.shuffle(instanceIds);
		List> partitions = new ArrayList>();
		for (int f = 0; f < folds; f++) {
			partitions.add(new ArrayList());
		}

		Map instanceId2partition = new HashMap();
		for (int i = 0; i < instanceIds.size(); i++) {
			int qid = instanceIds.get(i);
			int partition = i % folds;
			partitions.get(partition).add(qid);
			instanceId2partition.put(qid, partition);
		}

		curId = 0;
		reader = new BufferedReader(new FileReader(new File(inputFilename)));
		while ((line = reader.readLine()) != null) {
			int partition = instanceId2partition.get(curId);

			for (int f = 0; f < folds; f++) {
				if (addTestSet) {
					if (f % folds == partition) {
						validationFiles[f].println(line);
					} else if ((f + 1) % folds == partition) {
						testFiles[f].println(line);
					} else {
						trainingFiles[f].println(line);
					}
				} else {
					if (f % folds == partition) {
						validationFiles[f].println(line);
					} else {
						trainingFiles[f].println(line);
					}
				}
			}
			curId++;
			if (curId % 10000 == 0) {
				System.out.println("Dumped " + curId + " lines.");
			}
		}
		reader.close();

		for (int f = 0; f < folds; f++) {
			trainingFiles[f].close();
			validationFiles[f].close();
		}
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy