org.datacleaner.components.machinelearning.MLRegressionTransformer Maven / Gradle / Ivy
/**
* DataCleaner (community edition)
* Copyright (C) 2014 Free Software Foundation, Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.datacleaner.components.machinelearning;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import javax.inject.Named;
import org.apache.commons.lang.SerializationUtils;
import org.datacleaner.api.Categorized;
import org.datacleaner.api.Configured;
import org.datacleaner.api.Description;
import org.datacleaner.api.FileProperty;
import org.datacleaner.api.FileProperty.FileAccessMode;
import org.datacleaner.api.Initialize;
import org.datacleaner.api.InputColumn;
import org.datacleaner.api.InputRow;
import org.datacleaner.api.OutputColumns;
import org.datacleaner.api.Transformer;
import org.datacleaner.api.Validate;
import org.datacleaner.components.machinelearning.api.MLClassificationRecord;
import org.datacleaner.components.machinelearning.api.MLRegressor;
import org.datacleaner.components.machinelearning.impl.MLClassificationRecordImpl;
import com.google.common.io.Files;
@Named("Apply regression")
@Description("Applies a regression model to incoming records. Note that the regression model must first be trained using one of the analyzers found in the 'Machine Learning' category.")
@Categorized(MachineLearningCategory.class)
public class MLRegressionTransformer implements Transformer {
@Configured
InputColumn>[] features;
@Configured
@FileProperty(accessMode = FileAccessMode.OPEN, extension = ".model.ser")
File modelFile = new File("regression.model.ser");
private MLRegressor regressor;
@Validate
public void validate() throws IOException {
if (!modelFile.exists()) {
throw new IllegalArgumentException("Model file '" + modelFile + "' does not exist.");
}
regressor = (MLRegressor) SerializationUtils.deserialize(Files.toByteArray(modelFile));
MLComponentUtils.validateRegressorMapping(regressor, features);
}
@Initialize
public void init() {
try {
final byte[] bytes = Files.toByteArray(modelFile);
regressor = (MLRegressor) SerializationUtils.deserialize(bytes);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
@Override
public OutputColumns getOutputColumns() {
String modelName = modelFile.getName();
if (modelName.toLowerCase().endsWith(".model.ser")) {
modelName = modelName.substring(0, modelName.length() - ".model.ser".length());
}
final String[] columnNames = new String[] { modelName + " value" };
final Class>[] columnTypes = new Class[] { Double.class };
return new OutputColumns(columnNames, columnTypes);
}
@Override
public Object[] transform(InputRow inputRow) {
final MLClassificationRecord record = MLClassificationRecordImpl.forEvaluation(inputRow, features);
final double prediction = regressor.predict(record);
return new Object[] { prediction };
}
}