All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.functions.LinearRegression Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    LinearRegression.java
 *    Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.functions;

import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;

import no.uib.cipr.matrix.*;
import no.uib.cipr.matrix.Matrix;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.core.matrix.QRDecomposition;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 
 * Class for using linear regression for prediction.
 * Uses the Akaike criterion for model selection, and is able to deal with
 * weighted instances.
 * 

* * Valid options are: *

* *

 * -S <number of selection method>
 *  Set the attribute selection method to use. 1 = None, 2 = Greedy.
 *  (default 0 = M5' method)
 * 
* *
 * -C
 *  Do not try to eliminate colinear attributes.
 * 
* *
 * -R <double>
 *  Set ridge parameter (default 1.0e-8).
 * 
* *
 * -minimal
 *  Conserve memory, don't keep dataset header and means/stdevs.
 *  Model cannot be printed out if this option is enabled. (default: keep data)
 * 
* *
 * -additional-stats
 *  Output additional statistics.
 * 
* *
 * -output-debug-info
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * 
* *
 * -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
 * 
* *
 * -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
 * 
* * @author Eibe Frank ([email protected]) * @author Len Trigg ([email protected]) * @version $Revision: 14872 $ */ public class LinearRegression extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler { /** Attribute selection method: M5 method */ public static final int SELECTION_M5 = 0; /** Attribute selection method: No attribute selection */ public static final int SELECTION_NONE = 1; /** Attribute selection method: Greedy method */ public static final int SELECTION_GREEDY = 2; /** Attribute selection methods */ public static final Tag[] TAGS_SELECTION = { new Tag(SELECTION_NONE, "No attribute selection"), new Tag(SELECTION_M5, "M5 method"), new Tag(SELECTION_GREEDY, "Greedy method") }; /** for serialization */ static final long serialVersionUID = -3364580862046573747L; /** Array for storing coefficients of linear regression. */ protected double[] m_Coefficients; /** Which attributes are relevant? */ protected boolean[] m_SelectedAttributes; /** Variable for storing transformed training data. */ protected Instances m_TransformedData; /** The filter for removing missing values. */ protected ReplaceMissingValues m_MissingFilter; /** * The filter storing the transformation from nominal to binary attributes. */ protected NominalToBinary m_TransformFilter; /** The standard deviations of the class attribute */ protected double m_ClassStdDev; /** The mean of the class attribute */ protected double m_ClassMean; /** The index of the class attribute */ protected int m_ClassIndex; /** The attributes means */ protected double[] m_Means; /** The attribute standard deviations */ protected double[] m_StdDevs; /** * Whether to output additional statistics such as std. dev. of coefficients * and t-stats */ protected boolean m_outputAdditionalStats; /** The current attribute selection method */ protected int m_AttributeSelection; /** Try to eliminate correlated attributes? */ protected boolean m_EliminateColinearAttributes = true; /** Turn off all checks and conversions? */ protected boolean m_checksTurnedOff = false; /** Use QR decomposition */ protected boolean m_useQRDecomposition = false; /** The ridge parameter */ protected double m_Ridge = 1.0e-8; /** Conserve memory? */ protected boolean m_Minimal = false; /** Model already built? */ protected boolean m_ModelBuilt = false; /** True if the model is a zero R one */ protected boolean m_isZeroR; /** The degrees of freedom of the regression model */ private int m_df; /** The R-squared value of the regression model */ private double m_RSquared; /** The adjusted R-squared value of the regression model */ private double m_RSquaredAdj; /** The F-statistic of the regression model */ private double m_FStat; /** Array for storing the standard error of each coefficient */ private double[] m_StdErrorOfCoef; /** Array for storing the t-statistic of each coefficient */ private double[] m_TStats; public LinearRegression() { m_numDecimalPlaces = 4; } /** * Generates a linear regression function predictor. * * @param argv the options */ public static void main(String argv[]) { runClassifier(new LinearRegression(), argv); } /** * Returns a string describing this classifier * * @return a description of the classifier suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Class for using linear regression for prediction. Uses the Akaike " + "criterion for model selection, and is able to deal with weighted " + "instances."; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Builds a regression model for the given data. * * @param data the training data to be used for generating the linear * regression function * @throws Exception if the classifier could not be built successfully */ @Override public void buildClassifier(Instances data) throws Exception { m_ModelBuilt = false; m_isZeroR = false; if (data.numInstances() == 1) { m_Coefficients = new double[1]; m_Coefficients[0] = data.instance(0).classValue(); m_SelectedAttributes = new boolean[data.numAttributes()]; m_isZeroR = true; return; } if (!m_checksTurnedOff) { // can classifier handle the data? getCapabilities().testWithFail(data); if (m_outputAdditionalStats) { // check that the instances weights are all 1 // because the RegressionAnalysis class does // not handle weights boolean ok = true; for (int i = 0; i < data.numInstances(); i++) { if (data.instance(i).weight() != 1) { ok = false; break; } } if (!ok) { throw new Exception( "Can only compute additional statistics on unweighted data"); } } // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); m_TransformFilter = new NominalToBinary(); m_TransformFilter.setInputFormat(data); data = Filter.useFilter(data, m_TransformFilter); m_MissingFilter = new ReplaceMissingValues(); m_MissingFilter.setInputFormat(data); data = Filter.useFilter(data, m_MissingFilter); data.deleteWithMissingClass(); } else { m_TransformFilter = null; m_MissingFilter = null; } m_ClassIndex = data.classIndex(); m_TransformedData = data; // Turn all attributes on for a start m_Coefficients = null; // Compute means and standard deviations m_SelectedAttributes = new boolean[data.numAttributes()]; m_Means = new double[data.numAttributes()]; m_StdDevs = new double[data.numAttributes()]; for (int j = 0; j < data.numAttributes(); j++) { if (j != m_ClassIndex) { m_SelectedAttributes[j] = true; // Turn attributes on for a start m_Means[j] = data.meanOrMode(j); m_StdDevs[j] = Math.sqrt(data.variance(j)); if (m_StdDevs[j] == 0) { m_SelectedAttributes[j] = false; } } } m_ClassStdDev = Math.sqrt(data.variance(m_TransformedData.classIndex())); m_ClassMean = data.meanOrMode(m_TransformedData.classIndex()); // Perform the regression findBestModel(); if (m_outputAdditionalStats) { // find number of coefficients, degrees of freedom int k = 1; for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { if (m_SelectedAttributes[i]) { k++; } } } m_df = m_TransformedData.numInstances() - k; // calculate R^2 and F-stat double se = calculateSE(m_SelectedAttributes, m_Coefficients); m_RSquared = RegressionAnalysis.calculateRSquared(m_TransformedData, se); m_RSquaredAdj = RegressionAnalysis.calculateAdjRSquared(m_RSquared, m_TransformedData.numInstances(), k); m_FStat = RegressionAnalysis.calculateFStat(m_RSquared, m_TransformedData.numInstances(), k); // calculate std error of coefficients and t-stats m_StdErrorOfCoef = RegressionAnalysis.calculateStdErrorOfCoef(m_TransformedData, m_SelectedAttributes, se, m_TransformedData.numInstances(), k); m_TStats = RegressionAnalysis.calculateTStats(m_Coefficients, m_StdErrorOfCoef, k); } // Save memory if (m_Minimal) { m_TransformedData = null; m_Means = null; m_StdDevs = null; } else { m_TransformedData = new Instances(data, 0); } m_ModelBuilt = true; } /** * Classifies the given instance using the linear regression function. * * @param instance the test instance * @return the classification * @throws Exception if classification can't be done successfully */ @Override public double classifyInstance(Instance instance) throws Exception { // Transform the input instance Instance transformedInstance = instance; if (!m_checksTurnedOff && !m_isZeroR) { m_TransformFilter.input(transformedInstance); m_TransformFilter.batchFinished(); transformedInstance = m_TransformFilter.output(); m_MissingFilter.input(transformedInstance); m_MissingFilter.batchFinished(); transformedInstance = m_MissingFilter.output(); } // Calculate the dependent variable from the regression model return regressionPrediction(transformedInstance, m_SelectedAttributes, m_Coefficients); } /** * Outputs the linear regression model as a string. * * @return the model as string */ @Override public String toString() { if (!m_ModelBuilt) { return "Linear Regression: No model built yet."; } if (m_Minimal) { return "Linear Regression: Model built."; } try { StringBuilder text = new StringBuilder(); int column = 0; boolean first = true; text.append("\nLinear Regression Model\n\n"); text.append(m_TransformedData.classAttribute().name() + " =\n\n"); for (int i = 0; i < m_TransformedData.numAttributes(); i++) { if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) { if (!first) { text.append(" +\n"); } else { first = false; } text.append(Utils.doubleToString(m_Coefficients[column], 12, m_numDecimalPlaces) + " * "); text.append(m_TransformedData.attribute(i).name()); column++; } } text.append(" +\n" + Utils.doubleToString(m_Coefficients[column], 12, m_numDecimalPlaces)); if (m_outputAdditionalStats) { int maxAttLength = 0; for (int i = 0; i < m_TransformedData.numAttributes(); i++) { if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) { if (m_TransformedData.attribute(i).name().length() > maxAttLength) { maxAttLength = m_TransformedData.attribute(i).name().length(); } } } maxAttLength += 3; if (maxAttLength < "Variable".length() + 3) { maxAttLength = "Variable".length() + 3; } text.append("\n\nRegression Analysis:\n\n" + Utils.padRight("Variable", maxAttLength) + " Coefficient SE of Coef t-Stat"); column = 0; for (int i = 0; i < m_TransformedData.numAttributes(); i++) { if ((i != m_ClassIndex) && (m_SelectedAttributes[i])) { text.append("\n" + Utils.padRight(m_TransformedData.attribute(i).name(), maxAttLength)); text.append(Utils.doubleToString(m_Coefficients[column], 12, m_numDecimalPlaces)); text.append(" " + Utils.doubleToString(m_StdErrorOfCoef[column], 12, m_numDecimalPlaces)); text.append(" " + Utils.doubleToString(m_TStats[column], 12, m_numDecimalPlaces)); column++; } } text.append(Utils.padRight("\nconst", maxAttLength + 1) + Utils .doubleToString(m_Coefficients[column], 12, m_numDecimalPlaces)); text.append(" " + Utils.doubleToString(m_StdErrorOfCoef[column], 12, m_numDecimalPlaces)); text.append(" " + Utils.doubleToString(m_TStats[column], 12, m_numDecimalPlaces)); text.append("\n\nDegrees of freedom = " + Integer.toString(m_df)); text.append("\nR^2 value = " + Utils.doubleToString(m_RSquared, m_numDecimalPlaces)); text.append("\nAdjusted R^2 = " + Utils.doubleToString(m_RSquaredAdj, 5)); text.append("\nF-statistic = " + Utils.doubleToString(m_FStat, m_numDecimalPlaces)); } return text.toString(); } catch (Exception e) { return "Can't print Linear Regression!"; } } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ @Override public Enumeration




© 2015 - 2024 Weber Informatics LLC | Privacy Policy