All Downloads are FREE. Search and download functionalities are using the official Maven repository.

mulan.transformations.BinaryRelevanceTransformation Maven / Gradle / Ivy

Go to download

Mulan is an open-source Java library for learning from multi-label datasets.

The newest version!
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    BinaryRelevanceTransformation.java
 *    Copyright (C) 2009-2012 Aristotle University of Thessaloniki, Greece
 */
package mulan.transformations;

import java.io.Serializable;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;
import weka.filters.unsupervised.attribute.Remove;

/**
 * Class that implements the binary relevance transformation
 *
 * @author Grigorios Tsoumakas
 * @version 2012.05.30
 */
public class BinaryRelevanceTransformation implements Serializable {

    private MultiLabelInstances data;
    private Instances shell;
    private Remove remove;
    private Add add;

    /**
     * Constructor
     *
     * @param data a multi-label dataset
     */
    public BinaryRelevanceTransformation(MultiLabelInstances data) {
        try {
            this.data = data;
            remove = new Remove();
            int[] labelIndices = data.getLabelIndices();
            int[] indices = new int[labelIndices.length];
            System.arraycopy(labelIndices, 0, indices, 0, labelIndices.length);
            remove.setAttributeIndicesArray(indices);
            remove.setInvertSelection(false);
            remove.setInputFormat(data.getDataSet());
            shell = Filter.useFilter(data.getDataSet(), remove);
            add = new Add();
            add.setAttributeIndex("last");
            add.setNominalLabels("0,1");
            add.setAttributeName("BinaryRelevanceLabel");
            add.setInputFormat(shell);
            shell = Filter.useFilter(shell, add);
            shell.setClassIndex(shell.numAttributes() - 1);
        } catch (Exception ex) {
            Logger.getLogger(BinaryRelevanceTransformation.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    /**
     * Remove all label attributes except labelToKeep
     *
     * @param instance
     * @param labelToKeep
     * @return transformed Instance
     */
    public Instance transformInstance(Instance instance, int labelToKeep) {
        Instance transformedInstance;
        remove.input(instance);
        transformedInstance = remove.output();
        add.input(transformedInstance);
        transformedInstance = add.output();
        transformedInstance.setDataset(shell);

        int[] labelIndices = data.getLabelIndices();
        if (data.getDataSet().attribute(labelIndices[labelToKeep]).value(0).equals("1")) {
            transformedInstance.setValue(shell.numAttributes() - 1, 1 - instance.value(labelIndices[labelToKeep]));
        } else {
            transformedInstance.setValue(shell.numAttributes() - 1, instance.value(labelIndices[labelToKeep]));
        }
        return transformedInstance;
    }

    /**
     * Remove all label attributes except labelToKeep
     *
     * @param labelToKeep
     * @return transformed Instances object
     * @throws Exception
     */
    public Instances transformInstances(int labelToKeep) throws Exception {
        Instances shellCopy = new Instances(this.shell);
        boolean order10 = false;
        int[] labelIndices = data.getLabelIndices();
        if (data.getDataSet().attribute(labelIndices[labelToKeep]).value(0).equals("1")) {
            order10 = true;
        }
        for (int j = 0; j < shellCopy.numInstances(); j++) {
            if (order10) {
                shellCopy.instance(j).setValue(shellCopy.numAttributes() - 1, 1 - data.getDataSet().instance(j).value(labelIndices[labelToKeep]));
            } else {
                shellCopy.instance(j).setValue(shellCopy.numAttributes() - 1, data.getDataSet().instance(j).value(labelIndices[labelToKeep]));
            }
        }

        return shellCopy;
    }

    /**
     * Remove all label attributes except that at indexOfLabelToKeep
     *
     * @param train
     * @param labelIndices
     * @param indexToKeep
     * @return transformed Instances object
     * @throws Exception
     */
    public static Instances transformInstances(Instances train, int[] labelIndices, int indexToKeep) throws Exception {
        int numLabels = labelIndices.length;

        train.setClassIndex(indexToKeep);


        // Indices of attributes to remove
        int[] indicesToRemove = new int[numLabels - 1];
        int counter2 = 0;
        for (int counter1 = 0; counter1 < numLabels; counter1++) {
            if (labelIndices[counter1] != indexToKeep) {
                indicesToRemove[counter2] = labelIndices[counter1];
                counter2++;
            }
        }

        Remove remove = new Remove();
        remove.setAttributeIndicesArray(indicesToRemove);
        remove.setInputFormat(train);
        remove.setInvertSelection(true);
        Instances result = Filter.useFilter(train, remove);
        return result;
    }

    /**
     * Remove all label attributes except label at position indexToKeep
     *
     * @param instance
     * @param labelIndices
     * @param indexToKeep
     * @return transformed Instance
     */
    public static Instance transformInstance(Instance instance, int[] labelIndices, int indexToKeep) {
        double[] values = instance.toDoubleArray();
        double[] transformedValues = new double[values.length - labelIndices.length + 1];

        int counterTransformed = 0;
        boolean isLabel = false;

        for (int i = 0; i < values.length; i++) {
            for (int j = 0; j < labelIndices.length; j++) {
                if (i == labelIndices[j] && i != indexToKeep) {
                    isLabel = true;
                    break;
                }
            }

            if (!isLabel) {
                transformedValues[counterTransformed] = instance.value(i);
                counterTransformed++;
            }
            isLabel = false;
        }

        Instance transformedInstance = DataUtils.createInstance(instance, 1, transformedValues);
        return transformedInstance;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy