weka.classifiers.functions.LinearRegression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* LinearRegression.java
* Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.functions;
import java.util.Collections;
import java.util.Enumeration;
import no.uib.cipr.matrix.*;
import no.uib.cipr.matrix.Matrix;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
/**
* Class for using linear regression for prediction.
* Uses the Akaike criterion for model selection, and is able to deal with
* weighted instances.
*
*
* Valid options are:
*
*
*
* -S <number of selection method>
* Set the attribute selection method to use. 1 = None, 2 = Greedy.
* (default 0 = M5' method)
*
*
*
* -C
* Do not try to eliminate colinear attributes.
*
*
*
* -R <double>
* Set ridge parameter (default 1.0e-8).
*
*
*
* -minimal
* Conserve memory, don't keep dataset header and means/stdevs.
* Model cannot be printed out if this option is enabled. (default: keep data)
*
*
*
* -additional-stats
* Output additional statistics.
*
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
*
* @author Eibe Frank ([email protected])
* @author Len Trigg ([email protected])
* @version $Revision: 13013 $
*/
public class LinearRegression extends AbstractClassifier implements
OptionHandler, WeightedInstancesHandler {
/** Attribute selection method: M5 method */
public static final int SELECTION_M5 = 0;
/** Attribute selection method: No attribute selection */
public static final int SELECTION_NONE = 1;
/** Attribute selection method: Greedy method */
public static final int SELECTION_GREEDY = 2;
/** Attribute selection methods */
public static final Tag[] TAGS_SELECTION = {
new Tag(SELECTION_NONE, "No attribute selection"),
new Tag(SELECTION_M5, "M5 method"),
new Tag(SELECTION_GREEDY, "Greedy method") };
/** for serialization */
static final long serialVersionUID = -3364580862046573747L;
/** Array for storing coefficients of linear regression. */
protected double[] m_Coefficients;
/** Which attributes are relevant? */
protected boolean[] m_SelectedAttributes;
/** Variable for storing transformed training data. */
protected Instances m_TransformedData;
/** The filter for removing missing values. */
protected ReplaceMissingValues m_MissingFilter;
/**
* The filter storing the transformation from nominal to binary attributes.
*/
protected NominalToBinary m_TransformFilter;
/** The standard deviations of the class attribute */
protected double m_ClassStdDev;
/** The mean of the class attribute */
protected double m_ClassMean;
/** The index of the class attribute */
protected int m_ClassIndex;
/** The attributes means */
protected double[] m_Means;
/** The attribute standard deviations */
protected double[] m_StdDevs;
/**
* Whether to output additional statistics such as std. dev. of coefficients
* and t-stats
*/
protected boolean m_outputAdditionalStats;
/** The current attribute selection method */
protected int m_AttributeSelection;
/** Try to eliminate correlated attributes? */
protected boolean m_EliminateColinearAttributes = true;
/** Turn off all checks and conversions? */
protected boolean m_checksTurnedOff = false;
/** The ridge parameter */
protected double m_Ridge = 1.0e-8;
/** Conserve memory? */
protected boolean m_Minimal = false;
/** Model already built? */
protected boolean m_ModelBuilt = false;
/** True if the model is a zero R one */
protected boolean m_isZeroR;
/** The degrees of freedom of the regression model */
private int m_df;
/** The R-squared value of the regression model */
private double m_RSquared;
/** The adjusted R-squared value of the regression model */
private double m_RSquaredAdj;
/** The F-statistic of the regression model */
private double m_FStat;
/** Array for storing the standard error of each coefficient */
private double[] m_StdErrorOfCoef;
/** Array for storing the t-statistic of each coefficient */
private double[] m_TStats;
public LinearRegression() {
m_numDecimalPlaces = 4;
}
/**
* Generates a linear regression function predictor.
*
* @param argv the options
*/
public static void main(String argv[]) {
runClassifier(new LinearRegression(), argv);
}
/**
* Returns a string describing this classifier
*
* @return a description of the classifier suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return "Class for using linear regression for prediction. Uses the Akaike "
+ "criterion for model selection, and is able to deal with weighted "
+ "instances.";
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NUMERIC_CLASS);
result.enable(Capability.DATE_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Builds a regression model for the given data.
*
* @param data the training data to be used for generating the linear
* regression function
* @throws Exception if the classifier could not be built successfully
*/
@Override
public void buildClassifier(Instances data) throws Exception {
m_ModelBuilt = false;
m_isZeroR = false;
if (data.numInstances() == 1) {
m_Coefficients = new double[1];
m_Coefficients[0] = data.instance(0).classValue();
m_SelectedAttributes = new boolean[data.numAttributes()];
m_isZeroR = true;
return;
}
if (!m_checksTurnedOff) {
// can classifier handle the data?
getCapabilities().testWithFail(data);
if (m_outputAdditionalStats) {
// check that the instances weights are all 1
// because the RegressionAnalysis class does
// not handle weights
boolean ok = true;
for (int i = 0; i < data.numInstances(); i++) {
if (data.instance(i).weight() != 1) {
ok = false;
break;
}
}
if (!ok) {
throw new Exception(
"Can only compute additional statistics on unweighted data");
}
}
// remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
m_TransformFilter = new NominalToBinary();
m_TransformFilter.setInputFormat(data);
data = Filter.useFilter(data, m_TransformFilter);
m_MissingFilter = new ReplaceMissingValues();
m_MissingFilter.setInputFormat(data);
data = Filter.useFilter(data, m_MissingFilter);
data.deleteWithMissingClass();
} else {
m_TransformFilter = null;
m_MissingFilter = null;
}
m_ClassIndex = data.classIndex();
m_TransformedData = data;
// Turn all attributes on for a start
m_Coefficients = null;
// Compute means and standard deviations
m_SelectedAttributes = new boolean[data.numAttributes()];
m_Means = new double[data.numAttributes()];
m_StdDevs = new double[data.numAttributes()];
for (int j = 0; j < data.numAttributes(); j++) {
if (j != m_ClassIndex) {
m_SelectedAttributes[j] = true; // Turn attributes on for a start
m_Means[j] = data.meanOrMode(j);
m_StdDevs[j] = Math.sqrt(data.variance(j));
if (m_StdDevs[j] == 0) {
m_SelectedAttributes[j] = false;
}
}
}
m_ClassStdDev = Math.sqrt(data.variance(m_TransformedData.classIndex()));
m_ClassMean = data.meanOrMode(m_TransformedData.classIndex());
// Perform the regression
findBestModel();
if (m_outputAdditionalStats) {
// find number of coefficients, degrees of freedom
int k = 1;
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
if (m_SelectedAttributes[i]) {
k++;
}
}
}
m_df = m_TransformedData.numInstances() - k;
// calculate R^2 and F-stat
double se = calculateSE(m_SelectedAttributes, m_Coefficients);
m_RSquared = RegressionAnalysis.calculateRSquared(m_TransformedData, se);
m_RSquaredAdj =
RegressionAnalysis.calculateAdjRSquared(m_RSquared,
m_TransformedData.numInstances(), k);
m_FStat =
RegressionAnalysis.calculateFStat(m_RSquared,
m_TransformedData.numInstances(), k);
// calculate std error of coefficients and t-stats
m_StdErrorOfCoef =
RegressionAnalysis.calculateStdErrorOfCoef(m_TransformedData,
m_SelectedAttributes, se, m_TransformedData.numInstances(), k);
m_TStats =
RegressionAnalysis.calculateTStats(m_Coefficients, m_StdErrorOfCoef, k);
}
// Save memory
if (m_Minimal) {
m_TransformedData = null;
m_Means = null;
m_StdDevs = null;
} else {
m_TransformedData = new Instances(data, 0);
}
m_ModelBuilt = true;
}
/**
* Classifies the given instance using the linear regression function.
*
* @param instance the test instance
* @return the classification
* @throws Exception if classification can't be done successfully
*/
@Override
public double classifyInstance(Instance instance) throws Exception {
// Transform the input instance
Instance transformedInstance = instance;
if (!m_checksTurnedOff && !m_isZeroR) {
m_TransformFilter.input(transformedInstance);
m_TransformFilter.batchFinished();
transformedInstance = m_TransformFilter.output();
m_MissingFilter.input(transformedInstance);
m_MissingFilter.batchFinished();
transformedInstance = m_MissingFilter.output();
}
// Calculate the dependent variable from the regression model
return regressionPrediction(transformedInstance, m_SelectedAttributes,
m_Coefficients);
}
/**
* Outputs the linear regression model as a string.
*
* @return the model as string
*/
@Override
public String toString() {
if (!m_ModelBuilt) {
return "Linear Regression: No model built yet.";
}
if (m_Minimal) {
return "Linear Regression: Model built.";
}
try {
StringBuilder text = new StringBuilder();
int column = 0;
boolean first = true;
text.append("\nLinear Regression Model\n\n");
text.append(m_TransformedData.classAttribute().name() + " =\n\n");
for (int i = 0; i < m_TransformedData.numAttributes(); i++) {
if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) {
if (!first) {
text.append(" +\n");
} else {
first = false;
}
text.append(Utils.doubleToString(m_Coefficients[column], 12,
m_numDecimalPlaces) + " * ");
text.append(m_TransformedData.attribute(i).name());
column++;
}
}
text.append(" +\n"
+ Utils.doubleToString(m_Coefficients[column], 12, m_numDecimalPlaces));
if (m_outputAdditionalStats) {
int maxAttLength = 0;
for (int i = 0; i < m_TransformedData.numAttributes(); i++) {
if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) {
if (m_TransformedData.attribute(i).name().length() > maxAttLength) {
maxAttLength = m_TransformedData.attribute(i).name().length();
}
}
}
maxAttLength += 3;
if (maxAttLength < "Variable".length() + 3) {
maxAttLength = "Variable".length() + 3;
}
text.append("\n\nRegression Analysis:\n\n"
+ Utils.padRight("Variable", maxAttLength)
+ " Coefficient SE of Coef t-Stat");
column = 0;
for (int i = 0; i < m_TransformedData.numAttributes(); i++) {
if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) {
text.append("\n"
+ Utils.padRight(m_TransformedData.attribute(i).name(),
maxAttLength));
text.append(Utils.doubleToString(m_Coefficients[column], 12,
m_numDecimalPlaces));
text.append(" "
+ Utils.doubleToString(m_StdErrorOfCoef[column], 12,
m_numDecimalPlaces));
text.append(" "
+ Utils.doubleToString(m_TStats[column], 12, m_numDecimalPlaces));
column++;
}
}
text.append(Utils.padRight("\nconst", maxAttLength + 1)
+ Utils
.doubleToString(m_Coefficients[column], 12, m_numDecimalPlaces));
text.append(" "
+ Utils.doubleToString(m_StdErrorOfCoef[column], 12,
m_numDecimalPlaces));
text.append(" "
+ Utils.doubleToString(m_TStats[column], 12, m_numDecimalPlaces));
text.append("\n\nDegrees of freedom = " + Integer.toString(m_df));
text.append("\nR^2 value = "
+ Utils.doubleToString(m_RSquared, m_numDecimalPlaces));
text.append("\nAdjusted R^2 = "
+ Utils.doubleToString(m_RSquaredAdj, 5));
text.append("\nF-statistic = "
+ Utils.doubleToString(m_FStat, m_numDecimalPlaces));
}
return text.toString();
} catch (Exception e) {
return "Can't print Linear Regression!";
}
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration
© 2015 - 2024 Weber Informatics LLC | Privacy Policy