All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.dkpro.tc.ml.xgboost.XgboostTestTask Maven / Gradle / Ivy

There is a newer version: 1.1.0
Show newest version
/*******************************************************************************
 * Copyright 2018
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universität Darmstadt
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package org.dkpro.tc.ml.xgboost;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.io.libsvm.LibsvmDataFormatTestTask;
import org.dkpro.tc.ml.xgboost.core.XgboostPredictor;
import org.dkpro.tc.ml.xgboost.core.XgboostTrainer;

import de.tudarmstadt.ukp.dkpro.core.api.resources.PlatformDetector;

public class XgboostTestTask
    extends LibsvmDataFormatTestTask
    implements Constants
{

    public static List getClassificationParameters(TaskContext aContext,
            List classificationArguments, String learningMode)
        throws IOException
    {
        List parameters = new ArrayList<>();
        if (classificationArguments != null) {
            for (int i = 1; i < classificationArguments.size(); i++) {
                String a = (String) classificationArguments.get(i);
                parameters.add(a);
            }
        }

        if (!learningMode.equals(LM_REGRESSION)) {
            File folder = aContext.getFolder(OUTCOMES_INPUT_KEY, AccessMode.READONLY);
            File file = new File(folder, FILENAME_OUTCOMES);
            List outcomes = FileUtils.readLines(file, "utf-8");
            parameters.add("num_class=" + outcomes.size() + "\n");
        }

        return parameters;
    }

    @Override
    protected Object trainModel(TaskContext aContext) throws Exception
    {

        catchWindows32BitUsers();

        File fileTrain = getTrainFile(aContext);
        File model = new File(aContext.getFolder("", AccessMode.READWRITE),
                Constants.MODEL_CLASSIFIER);

        List parameters = getClassificationParameters(aContext, classificationArguments,
                learningMode);
        
        XgboostTrainer trainer = new XgboostTrainer();
        trainer.train(fileTrain, model, parameters);

        return model;
    }

    private void catchWindows32BitUsers()
    {
        PlatformDetector pd = new PlatformDetector();
        if (pd.getOs().equals(PlatformDetector.OS_WINDOWS)
                && pd.getArch().equals(PlatformDetector.ARCH_X86_32)) {
            throw new UnsupportedOperationException(
                    "Xgboost is not available for 32bit Windows operating systems. Please use a 64bit version.");
        }
    }

    @Override
    protected void runPrediction(TaskContext aContext, Object model) throws Exception
    {
        File testFile = getTestFile(aContext);
        
        XgboostPredictor predictor = new XgboostPredictor();
        List prediction = predictor.predict(testFile, (File) model);
        
        mergePredictionWithGold(aContext, prediction);
    }

    private void mergePredictionWithGold(TaskContext aContext, List prediction) throws Exception
    {

        File fileTest = getTestFile(aContext);
        BufferedWriter bw = null;
        try {
            bw = new BufferedWriter(
                    new OutputStreamWriter(new FileOutputStream(aContext.getFile(FILENAME_PREDICTIONS, AccessMode.READWRITE)), "utf-8"));

            List gold = readGoldValues(fileTest);

            checkNoDataCondition(gold, fileTest);

            bw.write("#PREDICTION;GOLD" + "\n");
            for (int i = 0; i < gold.size(); i++) {
                String p = prediction.get(i);
                String g = gold.get(i);
                bw.write(p + ";" + g);
                bw.write("\n");
            }
        }
        finally {
            IOUtils.closeQuietly(bw);
        }
    }

    private void checkNoDataCondition(List l, File source)
    {
        if (l.isEmpty()) {
            throw new IllegalStateException(
                    "The file [" + source.getAbsolutePath() + "] contains no prediction results");
        }
    }

    private List readGoldValues(File f) throws Exception
    {
        List goldValues = new ArrayList<>();
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new InputStreamReader(new FileInputStream(f), "utf-8"));

            String line = null;
            while ((line = reader.readLine()) != null) {
                if (line.isEmpty()) {
                    continue;
                }
                String[] split = line.split("\t");
                goldValues.add(split[0]);
            }

        }
        finally {
            IOUtils.closeQuietly(reader);
        }

        return goldValues;
    }

}