All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tika.eval.tools.TrainTestSplit Maven / Gradle / Ivy

There is a newer version: 3.0.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.eval.tools;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import org.apache.commons.compress.compressors.gzip.GzipCompressorOutputStream;

public class TrainTestSplit {

    private static String TRAINING = "train";
    private static String TESTING = "test";
    private static String DEVTEST = "devtest";

    private float trainingP = 0.7f;
    private float devTestP = 0.1f;
    private float testP = 0.2f;
    private Random random = new Random();

    public static void main(String[] args) throws Exception {
        Path leipzigDir = Paths.get(args[0]);
        Path outputDir = Paths.get(args[1]);
        TrainTestSplit splitter = new TrainTestSplit();
        splitter.execute(leipzigDir, outputDir);
    }

    private void execute(Path leipzigDir, Path outputDir) throws Exception {
        initOutDirs(outputDir);
        for (File f : leipzigDir.toFile().listFiles()) {
            if (f.isDirectory()) {
                continue;
            }
            processFile(f, outputDir);
        }
    }

    private void initOutDirs(Path outputDir) throws Exception {
        for (String which : new String[] { TRAINING, DEVTEST, TESTING}) {
            Path target = outputDir.resolve(which);
            if (! Files.isDirectory(target)) {
                Files.createDirectories(target);
            }
        }

    }

    private void processFile(File f, Path outputDir) throws Exception {
        Map writers = getWriters(outputDir, f);
        System.err.println("working on "+f);
        try (BufferedReader reader = Files.newBufferedReader(f.toPath(), StandardCharsets.UTF_8)) {
            String line = reader.readLine();
            while (line != null) {
                float r = random.nextFloat();
                if (r <= trainingP) {
                    writers.get(TRAINING).write(line + "\n");
                } else if (r < trainingP + devTestP) {
                    writers.get(DEVTEST).write(line + "\n");
                } else {
                    writers.get(TESTING).write(line+"\n");
                }
                line = reader.readLine();
            }
        }


        for (Writer w : writers.values()) {
            w.flush();
            w.close();
        }
    }

    private Map getWriters(Path outputDir, File f) throws IOException {
        Map writers = new HashMap<>();
        for (String which : new String[] { TRAINING, DEVTEST, TESTING}) {
            writers.put(which, getWriter(outputDir, which, f));
        }
        return writers;
    }

    private BufferedWriter getWriter(Path outputDir, String which, File f) throws IOException {
        OutputStream os = new GzipCompressorOutputStream(
                new BufferedOutputStream(
                Files.newOutputStream(outputDir.resolve(which).resolve(f.getName()+".gz"))));
        return new BufferedWriter(new OutputStreamWriter(os, StandardCharsets.UTF_8));
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy