Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
weka.classifiers.functions.SimpleLinearRegression Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* SimpleLinearRegression.java
* Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.functions;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* Learns a simple linear regression model. Picks the
* attribute that results in the lowest squared error. Can only deal with
* numeric attributes.
*
*
* Valid options are:
*
*
*
* -additional-stats
* Output additional statistics.
*
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
*
* @author Eibe Frank ([email protected] )
* @version $Revision: 11130 $
*/
public class SimpleLinearRegression extends AbstractClassifier implements
WeightedInstancesHandler {
/** for serialization */
static final long serialVersionUID = 1679336022895414137L;
/** The chosen attribute */
private Attribute m_attribute;
/** The index of the chosen attribute */
private int m_attributeIndex;
/** The slope */
private double m_slope;
/** The intercept */
private double m_intercept;
/** The class mean for missing values */
private double m_classMeanForMissing;
/**
* Whether to output additional statistics such as std. dev. of coefficients
* and t-stats
*/
protected boolean m_outputAdditionalStats;
/** Degrees of freedom, used in statistical calculations */
private int m_df;
/** standard error of the slope */
private double m_seSlope = Double.NaN;
/** standard error of the intercept */
private double m_seIntercept = Double.NaN;
/** t-statistic of the slope */
private double m_tstatSlope = Double.NaN;
/** t-statistic of the intercept */
private double m_tstatIntercept = Double.NaN;
/** R^2 value for the regression */
private double m_rsquared = Double.NaN;
/** Adjusted R^2 value for the regression */
private double m_rsquaredAdj = Double.NaN;
/** F-statistic for the regression */
private double m_fstat = Double.NaN;
/** If true, suppress error message if no useful attribute was found */
private boolean m_suppressErrorMessage = false;
/**
* Returns a string describing this classifier
*
* @return a description of the classifier suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return "Learns a simple linear regression model. "
+ "Picks the attribute that results in the lowest squared error. "
+ "Can only deal with numeric attributes.";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration listOptions() {
Vector newVector = new Vector ();
newVector.addElement(new Option("\tOutput additional statistics.",
"additional-stats", 0, "-additional-stats"));
newVector.addAll(Collections.list(super.listOptions()));
return newVector.elements();
}
/**
* Parses a given list of options.
*
*
* Valid options are:
*
*
*
* -additional-stats
* Output additional statistics.
*
*
*
* -output-debug-info
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
*
* -do-not-check-capabilities
* If set, classifier capabilities are not checked before classifier is built
* (use with caution).
*
*
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
@Override
public void setOptions(String[] options) throws Exception {
setOutputAdditionalStats(Utils.getFlag("additional-stats", options));
super.setOptions(options);
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
@Override
public String[] getOptions() {
Vector result = new Vector();
if (getOutputAdditionalStats()) {
result.add("-additional-stats");
}
Collections.addAll(result, super.getOptions());
return result.toArray(new String[result.size()]);
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String outputAdditionalStatsTipText() {
return "Output additional statistics (such as "
+ "std deviation of coefficients and t-statistics)";
}
/**
* Set whether to output additional statistics (such as std. deviation of
* coefficients and t-statistics
*
* @param additional true if additional stats are to be output
*/
public void setOutputAdditionalStats(boolean additional) {
m_outputAdditionalStats = additional;
}
/**
* Get whether to output additional statistics (such as std. deviation of
* coefficients and t-statistics
*
* @return true if additional stats are to be output
*/
public boolean getOutputAdditionalStats() {
return m_outputAdditionalStats;
}
/**
* Generate a prediction for the supplied instance.
*
* @param inst the instance to predict.
* @return the prediction
* @throws Exception if an error occurs
*/
@Override
public double classifyInstance(Instance inst) throws Exception {
if (m_attribute == null) {
return m_intercept;
} else {
if (inst.isMissing(m_attributeIndex)) {
return m_classMeanForMissing;
}
return m_intercept + m_slope * inst.value(m_attributeIndex);
}
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NUMERIC_CLASS);
result.enable(Capability.DATE_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Builds a simple linear regression model given the supplied training data.
*
* @param insts the training data.
* @throws Exception if an error occurs
*/
@Override
public void buildClassifier(Instances insts) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(insts);
if (m_outputAdditionalStats) {
// check that the instances weights are all 1
// because the RegressionAnalysis class does
// not handle weights
boolean ok = true;
for (int i = 0; i < insts.numInstances(); i++) {
if (insts.instance(i).weight() != 1) {
ok = false;
break;
}
}
if (!ok) {
throw new Exception(
"Can only compute additional statistics on unweighted data");
}
}
// Compute sums and counts
double[] sum = new double[insts.numAttributes()];
double[] count = new double[insts.numAttributes()];
double[] classSumForMissing = new double[insts.numAttributes()];
double[] classSumSquaredForMissing = new double[insts.numAttributes()];
double classCount = 0;
double classSum = 0;
for (int j = 0; j < insts.numInstances(); j++) {
Instance inst = insts.instance(j);
if (!inst.classIsMissing()) {
for (int i = 0; i < insts.numAttributes(); i++) {
if (!inst.isMissing(i)) {
sum[i] += inst.weight() * inst.value(i);
count[i] += inst.weight();
} else {
classSumForMissing[i] += inst.classValue() * inst.weight();
classSumSquaredForMissing[i] +=
inst.classValue() * inst.classValue() * inst.weight();
}
}
classCount += inst.weight();
classSum += inst.weight() * inst.classValue();
}
}
// Compute means
double[] mean = new double[insts.numAttributes()];
double[] classMeanForMissing = new double[insts.numAttributes()];
double[] classMeanForKnown = new double[insts.numAttributes()];
for (int i = 0; i < insts.numAttributes(); i++) {
if (i != insts.classIndex()) {
if (count[i] > 0) {
mean[i] = sum[i] / count[i];
}
if (classCount - count[i] > 0) {
classMeanForMissing[i] =
classSumForMissing[i] / (classCount - count[i]);
}
if (count[i] > 0) {
classMeanForKnown[i] = (classSum - classSumForMissing[i]) / count[i];
}
}
}
sum = null;
count = null;
double[] slopes = new double[insts.numAttributes()];
double[] sumWeightedDiffsSquared = new double[insts.numAttributes()];
double[] sumWeightedClassDiffsSquared = new double[insts.numAttributes()];
// For all instances
for (int j = 0; j < insts.numInstances(); j++) {
Instance inst = insts.instance(j);
// Only need to do something if the class isn't missing
if (!inst.classIsMissing()) {
// For all attributes
for (int i = 0; i < insts.numAttributes(); i++) {
if (!inst.isMissing(i) && (i != insts.classIndex())) {
double yDiff = inst.classValue() - classMeanForKnown[i];
double weightedYDiff = inst.weight() * yDiff;
double diff = inst.value(i) - mean[i];
double weightedDiff = inst.weight() * diff;
slopes[i] += weightedYDiff * diff;
sumWeightedDiffsSquared[i] += weightedDiff * diff;
sumWeightedClassDiffsSquared[i] += weightedYDiff * yDiff;
}
}
}
}
// Pick the best attribute
double minSSE = Double.MAX_VALUE;
m_attribute = null;
int chosen = -1;
double chosenSlope = Double.NaN;
double chosenIntercept = Double.NaN;
double chosenMeanForMissing = Double.NaN;
for (int i = 0; i < insts.numAttributes(); i++) {
// Do we have missing values for this attribute?
double sseForMissing = classSumSquaredForMissing[i] -
(classSumForMissing[i] * classMeanForMissing[i]);
// Should we skip this attribute?
if ((i == insts.classIndex()) || (sumWeightedDiffsSquared[i] == 0)) {
continue;
}
// Compute final slope and intercept
double numerator = slopes[i];
slopes[i] /= sumWeightedDiffsSquared[i];
double intercept = classMeanForKnown[i] - slopes[i] * mean[i];
// Compute sum of squared errors
double sse = sumWeightedClassDiffsSquared[i] - slopes[i] * numerator;
// Add component due to missing value prediction
sse += sseForMissing;
// Check whether this is the best attribute
if (sse < minSSE) {
minSSE = sse;
chosen = i;
chosenSlope = slopes[i];
chosenIntercept = intercept;
chosenMeanForMissing = classMeanForMissing[i];
}
}
// Set parameters
if (chosen == -1) {
if (!m_suppressErrorMessage) {
System.err.println("----- no useful attribute found");
}
m_attribute = null;
m_attributeIndex = 0;
m_slope = 0;
m_intercept = classSum / classCount;
m_classMeanForMissing = 0;
} else {
m_attribute = insts.attribute(chosen);
m_attributeIndex = chosen;
m_slope = chosenSlope;
m_intercept = chosenIntercept;
m_classMeanForMissing = chosenMeanForMissing;
if (m_outputAdditionalStats) {
// Reduce data so that stats are correct
Instances newInsts = new Instances(insts, insts.numInstances());
for (int i = 0; i < insts.numInstances(); i++) {
Instance inst = insts.instance(i);
if (!inst.classIsMissing() && !inst.isMissing(m_attributeIndex)) {
newInsts.add(inst);
}
}
insts = newInsts;
// do regression analysis
m_df = insts.numInstances() - 2;
double[] stdErrors = RegressionAnalysis.calculateStdErrorOfCoef(insts,
m_attribute, m_slope, m_intercept, m_df);
m_seSlope = stdErrors[0];
m_seIntercept = stdErrors[1];
double[] coef = new double[2];
coef[0] = m_slope;
coef[1] = m_intercept;
double[] tStats = RegressionAnalysis
.calculateTStats(coef, stdErrors, 2);
m_tstatSlope = tStats[0];
m_tstatIntercept = tStats[1];
double ssr = RegressionAnalysis.calculateSSR(insts, m_attribute,
m_slope, m_intercept);
m_rsquared = RegressionAnalysis.calculateRSquared(insts, ssr);
m_rsquaredAdj = RegressionAnalysis.calculateAdjRSquared(m_rsquared,
insts.numInstances(), 2);
m_fstat = RegressionAnalysis.calculateFStat(m_rsquared,
insts.numInstances(), 2);
}
}
}
/**
* Returns true if a usable attribute was found.
*
* @return true if a usable attribute was found.
*/
public boolean foundUsefulAttribute() {
return (m_attribute != null);
}
/**
* Returns the index of the attribute used in the regression.
*
* @return the index of the attribute.
*/
public int getAttributeIndex() {
return m_attributeIndex;
}
/**
* Returns the slope of the function.
*
* @return the slope.
*/
public double getSlope() {
return m_slope;
}
/**
* Returns the intercept of the function.
*
* @return the intercept.
*/
public double getIntercept() {
return m_intercept;
}
/**
* Turn off the error message that is reported when no useful attribute is
* found.
*
* @param s if set to true turns off the error message
*/
public void setSuppressErrorMessage(boolean s) {
m_suppressErrorMessage = s;
}
/**
* Returns a description of this classifier as a string
*
* @return a description of the classifier.
*/
@Override
public String toString() {
StringBuffer text = new StringBuffer();
if (m_attribute == null) {
text.append("Predicting constant " + m_intercept);
} else {
text.append("Linear regression on " + m_attribute.name() + "\n\n");
text
.append(Utils.doubleToString(m_slope, 2) + " * " + m_attribute.name());
if (m_intercept > 0) {
text.append(" + " + Utils.doubleToString(m_intercept, 2));
} else {
text.append(" - " + Utils.doubleToString((-m_intercept), 2));
}
text.append("\n\nPredicting "
+ Utils.doubleToString(m_classMeanForMissing, 2) +
" if attribute value is missing.");
if (m_outputAdditionalStats) {
// put regression analysis here
int attNameLength = m_attribute.name().length() + 3;
if (attNameLength < "Variable".length() + 3) {
attNameLength = "Variable".length() + 3;
}
text.append("\n\nRegression Analysis:\n\n"
+ Utils.padRight("Variable", attNameLength)
+ " Coefficient SE of Coef t-Stat");
text.append("\n" + Utils.padRight(m_attribute.name(), attNameLength));
text.append(Utils.doubleToString(m_slope, 12, 4));
text.append(" " + Utils.doubleToString(m_seSlope, 12, 5));
text.append(" " + Utils.doubleToString(m_tstatSlope, 12, 5));
text.append(Utils.padRight("\nconst", attNameLength + 1)
+ Utils.doubleToString(m_intercept, 12, 4));
text.append(" " + Utils.doubleToString(m_seIntercept, 12, 5));
text.append(" " + Utils.doubleToString(m_tstatIntercept, 12, 5));
text.append("\n\nDegrees of freedom = " + Integer.toString(m_df));
text.append("\nR^2 value = " + Utils.doubleToString(m_rsquared, 5));
text.append("\nAdjusted R^2 = "
+ Utils.doubleToString(m_rsquaredAdj, 5));
text.append("\nF-statistic = " + Utils.doubleToString(m_fstat, 5));
}
}
text.append("\n");
return text.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 11130 $");
}
/**
* Main method for testing this class
*
* @param argv options
*/
public static void main(String[] argv) {
runClassifier(new SimpleLinearRegression(), argv);
}
}