All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datacleaner.components.machinelearning.MLClassificationTransformer Maven / Gradle / Ivy

There is a newer version: 5.8.1
Show newest version
/**
 * DataCleaner (community edition)
 * Copyright (C) 2014 Free Software Foundation, Inc.
 *
 * This copyrighted material is made available to anyone wishing to use, modify,
 * copy, or redistribute it subject to the terms and conditions of the GNU
 * Lesser General Public License, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
 * for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this distribution; if not, write to:
 * Free Software Foundation, Inc.
 * 51 Franklin Street, Fifth Floor
 * Boston, MA  02110-1301  USA
 */
package org.datacleaner.components.machinelearning;

import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;

import javax.inject.Named;

import org.apache.commons.lang.SerializationUtils;
import org.datacleaner.api.Categorized;
import org.datacleaner.api.Configured;
import org.datacleaner.api.Description;
import org.datacleaner.api.FileProperty;
import org.datacleaner.api.FileProperty.FileAccessMode;
import org.datacleaner.api.Initialize;
import org.datacleaner.api.InputColumn;
import org.datacleaner.api.InputRow;
import org.datacleaner.api.OutputColumns;
import org.datacleaner.api.Transformer;
import org.datacleaner.api.Validate;
import org.datacleaner.components.machinelearning.api.MLClassification;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLClassifier;
import org.datacleaner.components.machinelearning.impl.MLClassificationRecordImpl;

import com.google.common.io.Files;

@Named("Apply classifier")
@Description("Applies a classifier to incoming records. Note that the classifier must first be trained using one of the analyzers found in the 'Machine Learning' category.")
@Categorized(MachineLearningCategory.class)
public class MLClassificationTransformer implements Transformer {

    public enum OutputFormat {
        WINNER_CLASS_AND_CONFIDENCE,
        CONFIDENCE_MATRIX
    }

    @Configured
    InputColumn[] features;

    @Configured
    @FileProperty(accessMode = FileAccessMode.OPEN, extension = ".model.ser")
    File modelFile = new File("classifier.model.ser");

    @Configured
    OutputFormat outputFormat = OutputFormat.WINNER_CLASS_AND_CONFIDENCE;

    private MLClassifier classifier;

    @Validate
    public void validate() throws IOException {
        if (!modelFile.exists()) {
            throw new IllegalArgumentException("Model file '" + modelFile + "' does not exist.");
        }
        classifier = (MLClassifier) SerializationUtils.deserialize(Files.toByteArray(modelFile));

        MLComponentUtils.validateClassifierMapping(classifier, features);
    }

    @Initialize
    public void init() {
        try {
            final byte[] bytes = Files.toByteArray(modelFile);
            classifier = (MLClassifier) SerializationUtils.deserialize(bytes);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    @Override
    public OutputColumns getOutputColumns() {
        if (classifier == null) {
            init();
        }

        if (outputFormat == OutputFormat.WINNER_CLASS_AND_CONFIDENCE) {
            String modelName = modelFile.getName();
            if (modelName.toLowerCase().endsWith(".model.ser")) {
                modelName = modelName.substring(0, modelName.length() - ".model.ser".length());
            }
            final String[] columnNames = new String[] { modelName + " class", modelName + " confidence" };
            final Class[] columnTypes =
                    new Class[] { classifier.getMetadata().getClassificationType(), Double.class };
            return new OutputColumns(columnNames, columnTypes);
        } else {
            final List classifications = classifier.getMetadata().getClassifications();
            final String[] columnNames = new String[classifications.size()];
            for (int i = 0; i < columnNames.length; i++) {
                columnNames[i] = classifications.toString() + " confidence";
            }
            return new OutputColumns(Double.class, columnNames);
        }
    }

    @Override
    public Object[] transform(InputRow inputRow) {
        final MLClassificationRecord record = MLClassificationRecordImpl.forEvaluation(inputRow, features);
        final MLClassification classification = classifier.classify(record);
        final int bestClassificationIndex = classification.getBestClassificationIndex();
        final double confidence = classification.getConfidence(bestClassificationIndex);
        final Object classificationValue = classifier.getMetadata().getClassification(bestClassificationIndex);
        return new Object[] { classificationValue, confidence };
    }
}