All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.dkpro.tc.ml.weka.task.WekaOutcomeHarmonizer Maven / Gradle / Ivy

/**
 * Copyright 2018
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universität Darmstadt
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see http://www.gnu.org/licenses/.
 */
package org.dkpro.tc.ml.weka.task;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;

import org.apache.commons.lang.StringUtils;
import org.dkpro.tc.core.Constants;

import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;
import weka.filters.unsupervised.attribute.Remove;

public class WekaOutcomeHarmonizer
{
    private Instances trainData;
    private Instances testData;
    private boolean multiLabel;
    private boolean isRegression;

    /**
     * Suffix for class label names in the test data that have been adapted to match the training
     * data
     *
     */
    public static final String COMPATIBLE_OUTCOME_CLASS = "_Comp";

    public WekaOutcomeHarmonizer(Instances train, Instances test, String learningMode)
    {
        this.trainData = train;
        this.testData = test;
        this.multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
        this.isRegression = learningMode.equals(Constants.LM_REGRESSION);
    }

    public Instances harmonize() throws Exception
    {
        if(isRegression) {
            return testData;
        }
    

        // ================ SINGLE LABEL BRANCH ======================
        if (!multiLabel) {
            return harmonizeSingleLabel();
        }
        else {
            return harmonizeMultiLabel();
        }
    }

    private Instances harmonizeMultiLabel() throws Exception
    {
        Instances compTestData = null;
        
        int numTrainLabels = trainData.classIndex();
        int numTestLabels = testData.classIndex();

        List trainLabels = getLabels(trainData);

        // add new outcome class attributes to test data

        Add filter = new Add();
        for (int i = 0; i < numTrainLabels; i++) {
            // numTestLabels +i (because index starts from 0)
            filter.setAttributeIndex(Integer.toString(numTestLabels + i + 1));
            filter.setNominalLabels("0,1");
            filter.setAttributeName(trainData.attribute(i).name() + COMPATIBLE_OUTCOME_CLASS);
            filter.setInputFormat(testData);
            testData = Filter.useFilter(testData, filter);
        }

        // fill NEW test data with values from old test data plus the new class attributes
        compTestData = new Instances(testData, testData.numInstances());
        for (int i = 0; i < testData.numInstances(); i++) {
            weka.core.Instance instance = testData.instance(i);
            // fullfill with 0.
            for (int j = 0; j < numTrainLabels; j++) {
                instance.setValue(j + numTestLabels, 0.);
            }
            // fill the real values:
            for (int j = 0; j < numTestLabels; j++) {
                // part of train data: forget labels which are not part of the train data
                if (trainLabels.indexOf(instance.attribute(j).name()) != -1) {
                    // class label found in test data
                    int index = trainLabels.indexOf(instance.attribute(j).name());
                    instance.setValue(index + numTestLabels, instance.value(j));
                }
            }
            compTestData.add(instance);
        }

        // remove old class attributes
        for (int i = 0; i < numTestLabels; i++) {
            Remove remove = new Remove();
            remove.setAttributeIndices("1");
            remove.setInvertSelection(false);
            remove.setInputFormat(compTestData);
            compTestData = Filter.useFilter(compTestData, remove);
        }

        // adapt header and set new class label
        String relationTag = compTestData.relationName();
        compTestData.setRelationName(relationTag.substring(0, relationTag.indexOf("-C") + 2)
                + " " + numTrainLabels + " ");
        compTestData.setClassIndex(numTrainLabels);
        
        return compTestData;
    }

    @SuppressWarnings({ "rawtypes", "unchecked" })
    private Instances harmonizeSingleLabel() throws Exception
    {
        Instances compTestData = null;

        // retrieve class labels
        Enumeration trainOutcomeValues = trainData.classAttribute().enumerateValues();
        Enumeration testOutcomeValues = testData.classAttribute().enumerateValues();
        List trainLabels = Collections.list(trainOutcomeValues);
        List testLabels = Collections.list(testOutcomeValues);

        // add new outcome class attribute to test data
        Add addFilter = new Add();
        addFilter.setNominalLabels(StringUtils.join(trainLabels, ','));
        addFilter.setAttributeName(Constants.CLASS_ATTRIBUTE_NAME + COMPATIBLE_OUTCOME_CLASS);
        addFilter.setInputFormat(testData);
        testData = Filter.useFilter(testData, addFilter);

        // fill NEW test data with values from old test data plus the new class attribute
        compTestData = new Instances(testData, testData.numInstances());
        for (int i = 0; i < testData.numInstances(); i++) {
            weka.core.Instance instance = testData.instance(i);
            String label = (String) testLabels
                    .get((int) instance.value(testData.classAttribute()));
            if (trainLabels.indexOf(label) != -1) {
                instance.setValue(
                        testData.attribute(
                                Constants.CLASS_ATTRIBUTE_NAME + COMPATIBLE_OUTCOME_CLASS),
                        label);
            }
            else {
                instance.setMissing(testData.classIndex());
            }
            compTestData.add(instance);
        }

        // remove old class attribute
        Remove remove = new Remove();
        remove.setAttributeIndices(Integer
                .toString(compTestData.attribute(Constants.CLASS_ATTRIBUTE_NAME).index() + 1));
        remove.setInvertSelection(false);
        remove.setInputFormat(compTestData);
        compTestData = Filter.useFilter(compTestData, remove);

        // set new class attribute
        compTestData.setClass(compTestData
                .attribute(Constants.CLASS_ATTRIBUTE_NAME + COMPATIBLE_OUTCOME_CLASS));
        return compTestData;
    }

    private static ArrayList getLabels(Instances data)
    {
        int numLabels = data.classIndex();
        ArrayList list = new ArrayList();
        for (int i = 0; i < numLabels; i++) {
            list.add(data.attribute(i).name());
        }
        return list;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy