All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.lwhite1.tablesaw.api.ml.regression.LeastSquares Maven / Gradle / Ivy

package com.github.lwhite1.tablesaw.api.ml.regression;

import com.github.lwhite1.tablesaw.api.NumericColumn;
import com.github.lwhite1.tablesaw.util.DoubleArrays;
import com.google.common.base.Strings;
import smile.regression.OLS;

/**
 *
 */
public class LeastSquares {

  private final OLS model;
  private final double[][] explanatoryVariables;
  private final int explanatoryVariableCount;
  private final double[] responseVarArray;
  private final String[] explanatoryVariableNames;


  public static LeastSquares train(NumericColumn responseVar, NumericColumn... explanatoryVars) {
    return new LeastSquares(responseVar, explanatoryVars);
  }

  public LeastSquares(NumericColumn responseVariable, NumericColumn ... explanatoryVars) {
    this.explanatoryVariables = DoubleArrays.to2dArray(explanatoryVars);

    this.responseVarArray = responseVariable.toDoubleArray();
    this.model = new OLS(explanatoryVariables, responseVarArray);
    this.explanatoryVariableCount = explanatoryVars.length;
    this.explanatoryVariableNames = new String[explanatoryVariableCount];

    for (int i = 0; i < explanatoryVariableCount; i++) {
      explanatoryVariableNames[i] = explanatoryVars[i].name();
    }
  }

  @Override
  public String toString() {
    String result = model.toString();
    result = result.replace("Intercept", "(Intercept)");

    // TODO(lwhite): This hack needed because Smile doesn't name the vars in it's output; we do, by string replacement.
    int maxNameLength = "(intercept)".length() - 1;
    for (int i = 0; i < explanatoryVariableCount; i++) {
      String replacement = explanatoryVariableNames[i];
      if (replacement.length() >= maxNameLength) {
        replacement = replacement.substring(0, maxNameLength);
      } else {
        replacement = Strings.padEnd(replacement, maxNameLength, ' ');
      }
      result = result.replaceFirst("Var " + (i + 1) + '\t', replacement);
    }
    return result;
  }

  public double[] residuals() {
    return model.residuals();
  }

  public double[] fitted() {
    double[] fitted = new double[explanatoryVariables.length];
    for(int i = 0; i < explanatoryVariables.length; i++) {
      double[] input = explanatoryVariables[i];
      fitted[i] = predict(input);
    }
    return fitted;
  }

  public double adjustedRSquared() {
    return model.adjustedRSquared();
  }

  public double df() {
    return model.df();
  }

  public double error() {
    return model.error();
  }

  public double ftest() {
    return model.ftest();
  }

  public double pValue() {
    return model.pvalue();
  }

  public double intercept() {
    return model.intercept();
  }

  public double RSquared() {
    return model.RSquared();
  }

  public double RSS() {
    return model.RSS();
  }

  public double[][] ttest() {
    return model.ttest();
  }

  public double predict(double[] x) {
    return model.predict(x);
  }

  public double[] coefficients() {
    return model.coefficients();
  }

  public double[] actuals() {
    return responseVarArray;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy