weka.classifiers.pmml.consumer.Regression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* Regression.java
* Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.pmml.consumer;
import java.io.Serializable;
import java.util.ArrayList;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;
/**
* Class implementing import of PMML Regression model. Can be
* used as a Weka classifier for prediction (buildClassifier()
* raises an Exception).
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com
* @version $Revision: 8034 $
*/
public class Regression extends PMMLClassifier
implements Serializable {
/** For serialization */
private static final long serialVersionUID = -5551125528409488634L;
/**
* Inner class for encapsulating a regression table
*/
static class RegressionTable implements Serializable {
/** For serialization */
private static final long serialVersionUID = -5259866093996338995L;
/**
* Abstract inner base class for different predictor types.
*/
abstract static class Predictor implements Serializable {
/** For serialization */
private static final long serialVersionUID = 7043831847273383618L;
/** Name of this predictor */
protected String m_name;
/**
* Index of the attribute in the mining schema that corresponds to this
* predictor
*/
protected int m_miningSchemaAttIndex = -1;
/** Coefficient for this predictor */
protected double m_coefficient = 1.0;
/**
* Constructs a new Predictor.
*
* @param predictor the Element
encapsulating this predictor
* @param miningSchema the mining schema as an Instances object
* @throws Exception if there is a problem constructing this Predictor
*/
protected Predictor(Element predictor, Instances miningSchema) throws Exception {
m_name = predictor.getAttribute("name");
for (int i = 0; i < miningSchema.numAttributes(); i++) {
Attribute temp = miningSchema.attribute(i);
if (temp.name().equals(m_name)) {
m_miningSchemaAttIndex = i;
}
}
if (m_miningSchemaAttIndex == -1) {
throw new Exception("[Predictor] unable to find matching attribute for "
+ "predictor " + m_name);
}
String coeff = predictor.getAttribute("coefficient");
if (coeff.length() > 0) {
m_coefficient = Double.parseDouble(coeff);
}
}
/**
* Returns a textual description of this predictor applicable
* to all sub classes.
*/
public String toString() {
return Utils.doubleToString(m_coefficient, 12, 4) + " * ";
}
/**
* Abstract add method. Adds this predictor into the sum for the
* current prediction.
*
* @param preds the prediction computed so far. For regression, it is a
* single element array; for classification it is a multi-element array
* @param input the input instance's values
*/
public abstract void add(double[] preds, double[] input);
}
/**
* Inner class for a numeric predictor
*/
protected class NumericPredictor extends Predictor {
/**
* For serialization
*/
private static final long serialVersionUID = -4335075205696648273L;
/** The exponent*/
protected double m_exponent = 1.0;
/**
* Constructs a NumericPredictor.
*
* @param predictor the Element
holding the predictor
* @param miningSchema the mining schema as an Instances object
* @throws Exception if something goes wrong while constructing this
* predictor
*/
protected NumericPredictor(Element predictor,
Instances miningSchema) throws Exception {
super(predictor, miningSchema);
String exponent = predictor.getAttribute("exponent");
if (exponent.length() > 0) {
m_exponent = Double.parseDouble(exponent);
}
}
/**
* Return a textual description of this predictor.
*/
public String toString() {
String output = super.toString();
output += m_name;
if (m_exponent > 1.0 || m_exponent < 1.0) {
output += "^" + Utils.doubleToString(m_exponent, 4);
}
return output;
}
/**
* Adds this predictor into the sum for the
* current prediction.
*
* @param preds the prediction computed so far. For regression, it is a
* single element array; for classification it is a multi-element array
* @param input the input instance's values
*/
public void add(double[] preds, double[] input) {
if (m_targetCategory == -1) {
preds[0] += m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent);
} else {
preds[m_targetCategory] +=
m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent);
}
}
}
/**
* Inner class encapsulating a categorical predictor.
*/
protected class CategoricalPredictor extends Predictor {
/**For serialization */
private static final long serialVersionUID = 3077920125549906819L;
/** The attribute value for this predictor */
protected String m_valueName;
/** The index of the attribute value for this predictor */
protected int m_valueIndex = -1;
/**
* Constructs a CategoricalPredictor.
*
* @param predictor the Element
containing the predictor
* @param miningSchema the mining schema as an Instances object
* @throws Exception if something goes wrong while constructing
* this predictor
*/
protected CategoricalPredictor(Element predictor,
Instances miningSchema) throws Exception {
super(predictor, miningSchema);
String valName = predictor.getAttribute("value");
if (valName.length() == 0) {
throw new Exception("[CategoricalPredictor] attribute value not specified!");
}
m_valueName = valName;
Attribute att = miningSchema.attribute(m_miningSchemaAttIndex);
if (att.isString()) {
// means that there were no Value elements defined in the
// data dictionary (and hence the mining schema).
// We add our value here.
att.addStringValue(m_valueName);
}
m_valueIndex = att.indexOfValue(m_valueName);
/* for (int i = 0; i < att.numValues(); i++) {
if (att.value(i).equals(m_valueName)) {
m_valueIndex = i;
}
}*/
if (m_valueIndex == -1) {
throw new Exception("[CategoricalPredictor] unable to find value "
+ m_valueName + " in mining schema attribute "
+ att.name());
}
}
/**
* Return a textual description of this predictor.
*/
public String toString() {
String output = super.toString();
output += m_name + "=" + m_valueName;
return output;
}
/**
* Adds this predictor into the sum for the
* current prediction.
*
* @param preds the prediction computed so far. For regression, it is a
* single element array; for classification it is a multi-element array
* @param input the input instance's values
*/
public void add(double[] preds, double[] input) {
// if the value is equal to the one in the input then add the coefficient
if (m_valueIndex == (int)input[m_miningSchemaAttIndex]) {
if (m_targetCategory == -1) {
preds[0] += m_coefficient;
} else {
preds[m_targetCategory] += m_coefficient;
}
}
}
}
/**
* Inner class to handle PredictorTerms.
*/
protected class PredictorTerm implements Serializable {
/** For serialization */
private static final long serialVersionUID = 5493100145890252757L;
/** The coefficient for this predictor term */
protected double m_coefficient = 1.0;
/** the indexes of the terms to be multiplied */
protected int[] m_indexes;
/** The names of the terms (attributes) to be multiplied */
protected String[] m_fieldNames;
/**
* Construct a new PredictorTerm.
*
* @param predictorTerm the Element
describing the predictor term
* @param miningSchema the mining schema as an Instances object
* @throws Exception if something goes wrong while constructing this
* predictor term
*/
protected PredictorTerm(Element predictorTerm,
Instances miningSchema) throws Exception {
String coeff = predictorTerm.getAttribute("coefficient");
if (coeff != null && coeff.length() > 0) {
try {
m_coefficient = Double.parseDouble(coeff);
} catch (IllegalArgumentException ex) {
throw new Exception("[PredictorTerm] unable to parse coefficient");
}
}
NodeList fields = predictorTerm.getElementsByTagName("FieldRef");
if (fields.getLength() > 0) {
m_indexes = new int[fields.getLength()];
m_fieldNames = new String[fields.getLength()];
for (int i = 0; i < fields.getLength(); i++) {
Node fieldRef = fields.item(i);
if (fieldRef.getNodeType() == Node.ELEMENT_NODE) {
String fieldName = ((Element)fieldRef).getAttribute("field");
if (fieldName != null && fieldName.length() > 0) {
boolean found = false;
// look for this field in the mining schema
for (int j = 0; j < miningSchema.numAttributes(); j++) {
if (miningSchema.attribute(j).name().equals(fieldName)) {
// all referenced fields MUST be numeric
if (!miningSchema.attribute(j).isNumeric()) {
throw new Exception("[PredictorTerm] field is not continuous: "
+ fieldName);
}
found = true;
m_indexes[i] = j;
m_fieldNames[i] = fieldName;
break;
}
}
if (!found) {
throw new Exception("[PredictorTerm] Unable to find field "
+ fieldName + " in mining schema!");
}
}
}
}
}
}
/**
* Return a textual description of this predictor term.
*/
public String toString() {
StringBuffer result = new StringBuffer();
result.append("(" + Utils.doubleToString(m_coefficient, 12, 4));
for (int i = 0; i < m_fieldNames.length; i++) {
result.append(" * " + m_fieldNames[i]);
}
result.append(")");
return result.toString();
}
/**
* Adds this predictor term into the sum for the
* current prediction.
*
* @param preds the prediction computed so far. For regression, it is a
* single element array; for classification it is a multi-element array
* @param input the input instance's values
*/
public void add(double[] preds, double[] input) {
int indx = 0;
if (m_targetCategory != -1) {
indx = m_targetCategory;
}
double result = m_coefficient;
for (int i = 0; i < m_indexes.length; i++) {
result *= input[m_indexes[i]];
}
preds[indx] += result;
}
}
/** Constant for regression model type */
public static final int REGRESSION = 0;
/** Constant for classification model type */
public static final int CLASSIFICATION = 1;
/** The type of function - regression or classification */
protected int m_functionType = REGRESSION;
/** The mining schema */
protected MiningSchema m_miningSchema;
/** The intercept */
protected double m_intercept = 0.0;
/** classification only */
protected int m_targetCategory = -1;
/** Numeric and categorical predictors */
protected ArrayList m_predictors =
new ArrayList();
/** Interaction terms */
protected ArrayList m_predictorTerms =
new ArrayList();
/**
* Return a textual description of this RegressionTable.
*/
public String toString() {
Instances miningSchema = m_miningSchema.getFieldsAsInstances();
StringBuffer temp = new StringBuffer();
temp.append("Regression table:\n");
temp.append(miningSchema.classAttribute().name());
if (m_functionType == CLASSIFICATION) {
temp.append("=" + miningSchema.
classAttribute().value(m_targetCategory));
}
temp.append(" =\n\n");
// do the predictors
for (int i = 0; i < m_predictors.size(); i++) {
temp.append(m_predictors.get(i).toString() + " +\n");
}
// do the predictor terms
for (int i = 0; i < m_predictorTerms.size(); i++) {
temp.append(m_predictorTerms.get(i).toString() + " +\n");
}
temp.append(Utils.doubleToString(m_intercept, 12, 4));
temp.append("\n\n");
return temp.toString();
}
/**
* Construct a regression table from an Element
*
* @param table the table to encapsulate
* @param functionType the type of function
* (regression or classification)
* to use
* @param mSchema the mining schema
* @throws Exception if there is a problem while constructing
* this regression table
*/
protected RegressionTable(Element table,
int functionType,
MiningSchema mSchema) throws Exception {
m_miningSchema = mSchema;
m_functionType = functionType;
Instances miningSchema = m_miningSchema.getFieldsAsInstances();
// get the intercept
String intercept = table.getAttribute("intercept");
if (intercept.length() > 0) {
m_intercept = Double.parseDouble(intercept);
}
// get the target category (if classification)
if (m_functionType == CLASSIFICATION) {
// target category MUST be defined
String targetCat = table.getAttribute("targetCategory");
if (targetCat.length() > 0) {
Attribute classA = miningSchema.classAttribute();
for (int i = 0; i < classA.numValues(); i++) {
if (classA.value(i).equals(targetCat)) {
m_targetCategory = i;
}
}
}
if (m_targetCategory == -1) {
throw new Exception("[RegressionTable] No target categories defined for classification");
}
}
// read all the numeric predictors
NodeList numericPs = table.getElementsByTagName("NumericPredictor");
for (int i = 0; i < numericPs.getLength(); i++) {
Node nP = numericPs.item(i);
if (nP.getNodeType() == Node.ELEMENT_NODE) {
NumericPredictor numP = new NumericPredictor((Element)nP, miningSchema);
m_predictors.add(numP);
}
}
// read all the categorical predictors
NodeList categoricalPs = table.getElementsByTagName("CategoricalPredictor");
for (int i = 0; i < categoricalPs.getLength(); i++) {
Node cP = categoricalPs.item(i);
if (cP.getNodeType() == Node.ELEMENT_NODE) {
CategoricalPredictor catP = new CategoricalPredictor((Element)cP, miningSchema);
m_predictors.add(catP);
}
}
// read all the PredictorTerms
NodeList predictorTerms = table.getElementsByTagName("PredictorTerm");
for (int i = 0; i < predictorTerms.getLength(); i++) {
Node pT = predictorTerms.item(i);
PredictorTerm predT = new PredictorTerm((Element)pT, miningSchema);
m_predictorTerms.add(predT);
}
}
public void predict(double[] preds, double[] input) {
if (m_targetCategory == -1) {
preds[0] = m_intercept;
} else {
preds[m_targetCategory] = m_intercept;
}
// add the predictors
for (int i = 0; i < m_predictors.size(); i++) {
Predictor p = m_predictors.get(i);
p.add(preds, input);
}
// add the PredictorTerms
for (int i = 0; i < m_predictorTerms.size(); i++) {
PredictorTerm pt = m_predictorTerms.get(i);
pt.add(preds, input);
}
}
}
/** Description of the algorithm */
protected String m_algorithmName;
/** The regression tables for this regression */
protected RegressionTable[] m_regressionTables;
/**
* Enum for the normalization methods.
*/
enum Normalization {
NONE, SIMPLEMAX, SOFTMAX, LOGIT, PROBIT, CLOGLOG,
EXP, LOGLOG, CAUCHIT}
/** The normalization to use */
protected Normalization m_normalizationMethod = Normalization.NONE;
/**
* Constructs a new PMML Regression.
*
* @param model the Element
containing the regression model
* @param dataDictionary the data dictionary as an Instances object
* @param miningSchema the mining schema
* @throws Exception if there is a problem constructing this Regression
*/
public Regression(Element model, Instances dataDictionary,
MiningSchema miningSchema) throws Exception {
super(dataDictionary, miningSchema);
int functionType = RegressionTable.REGRESSION;
// determine function name first
String fName = model.getAttribute("functionName");
if (fName.equals("regression")) {
functionType = RegressionTable.REGRESSION;
} else if (fName.equals("classification")) {
functionType = RegressionTable.CLASSIFICATION;
} else {
throw new Exception("[PMML Regression] Function name not defined in pmml!");
}
// do we have an algorithm name?
String algName = model.getAttribute("algorithmName");
if (algName != null && algName.length() > 0) {
m_algorithmName = algName;
}
// determine normalization method (if any)
m_normalizationMethod = determineNormalization(model);
setUpRegressionTables(model, functionType);
// convert any string attributes in the mining schema
//miningSchema.convertStringAttsToNominal();
}
/**
* Create all the RegressionTables for this model.
*
* @param model the Element
holding this regression model
* @param functionType the type of function (regression or
* classification)
* @throws Exception if there is a problem setting up the regression
* tables
*/
private void setUpRegressionTables(Element model,
int functionType) throws Exception {
NodeList tableList = model.getElementsByTagName("RegressionTable");
if (tableList.getLength() == 0) {
throw new Exception("[Regression] no regression tables defined!");
}
m_regressionTables = new RegressionTable[tableList.getLength()];
for (int i = 0; i < tableList.getLength(); i++) {
Node table = tableList.item(i);
if (table.getNodeType() == Node.ELEMENT_NODE) {
RegressionTable tempRTable =
new RegressionTable((Element)table,
functionType,
m_miningSchema);
m_regressionTables[i] = tempRTable;
}
}
}
/**
* Return the type of normalization used for this regression
*
* @param model the Element
holding the model
* @return the normalization used in this regression
*/
private static Normalization determineNormalization(Element model) {
Normalization normMethod = Normalization.NONE;
String normName = model.getAttribute("normalizationMethod");
if (normName.equals("simplemax")) {
normMethod = Normalization.SIMPLEMAX;
} else if (normName.equals("softmax")) {
normMethod = Normalization.SOFTMAX;
} else if (normName.equals("logit")) {
normMethod = Normalization.LOGIT;
} else if (normName.equals("probit")) {
normMethod = Normalization.PROBIT;
} else if (normName.equals("cloglog")) {
normMethod = Normalization.CLOGLOG;
} else if (normName.equals("exp")) {
normMethod = Normalization.EXP;
} else if (normName.equals("loglog")) {
normMethod = Normalization.LOGLOG;
} else if (normName.equals("cauchit")) {
normMethod = Normalization.CAUCHIT;
}
return normMethod;
}
/**
* Return a textual description of this Regression model.
*/
public String toString() {
StringBuffer temp = new StringBuffer();
temp.append("PMML version " + getPMMLVersion());
if (!getCreatorApplication().equals("?")) {
temp.append("\nApplication: " + getCreatorApplication());
}
if (m_algorithmName != null) {
temp.append("\nPMML Model: " + m_algorithmName);
}
temp.append("\n\n");
temp.append(m_miningSchema);
for (RegressionTable table : m_regressionTables) {
temp.append(table);
}
if (m_normalizationMethod != Normalization.NONE) {
temp.append("Normalization: " + m_normalizationMethod);
}
temp.append("\n");
return temp.toString();
}
/**
* Classifies the given test instance. The instance has to belong to a
* dataset when it's being classified.
*
* @param inst the instance to be classified
* @return the predicted most likely class for the instance or
* Utils.missingValue() if no prediction is made
* @exception Exception if an error occurred during the prediction
*/
public double[] distributionForInstance(Instance inst) throws Exception {
if (!m_initialized) {
mapToMiningSchema(inst.dataset());
}
double[] preds = null;
if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
preds = new double[1];
} else {
preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
}
// create an array of doubles that holds values from the incoming
// instance; in order of the fields in the mining schema. We will
// also handle missing values and outliers here.
// System.err.println(inst);
double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
// scan for missing values. If there are still missing values after instanceToSchema(),
// then missing value handling has been deferred to the PMML scheme. The specification
// (Regression PMML 3.2) seems to contradict itself with regards to classification and categorical
// variables. In one place it states that if a categorical variable is missing then
// variable_name=value is 0 for any value. Further down in the document it states: "if
// one or more of the y_j cannot be evaluated because the value in one of the referenced
// fields is missing, then the following formulas (for computing p_j) do not apply. In
// that case the predictions are defined by the priorProbability values in the Target
// element".
// In this implementation we will default to information in the Target element (default
// value for numeric prediction and prior probabilities for classification). If there is
// no Target element defined, then an Exception is thrown.
boolean hasMissing = false;
for (int i = 0; i < incoming.length; i++) {
if (i != m_miningSchema.getFieldsAsInstances().classIndex() &&
Utils.isMissingValue(incoming[i])) {
hasMissing = true;
break;
}
}
if (hasMissing) {
if (!m_miningSchema.hasTargetMetaData()) {
String message = "[Regression] WARNING: Instance to predict has missing value(s) but "
+ "there is no missing value handling meta data and no "
+ "prior probabilities/default value to fall back to. No "
+ "prediction will be made ("
+ ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() ||
m_miningSchema.getFieldsAsInstances().classAttribute().isString())
? "zero probabilities output)."
: "NaN output).");
if (m_log == null) {
System.err.println(message);
} else {
m_log.logMessage(message);
}
if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
preds[0] = Utils.missingValue();
}
return preds;
} else {
// use prior probablilities/default value
TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
preds[0] = targetData.getDefaultValue();
} else {
Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
}
}
return preds;
}
} else {
// loop through the RegressionTables
for (int i = 0; i < m_regressionTables.length; i++) {
m_regressionTables[i].predict(preds, incoming);
}
// Now apply the normalization
switch (m_normalizationMethod) {
case NONE:
// nothing to be done
break;
case SIMPLEMAX:
Utils.normalize(preds);
break;
case SOFTMAX:
for (int i = 0; i < preds.length; i++) {
preds[i] = Math.exp(preds[i]);
}
if (preds.length == 1) {
// hack for those models that do binary logistic regression as
// a numeric prediction model
preds[0] = preds[0] / (preds[0] + 1.0);
} else {
Utils.normalize(preds);
}
break;
case LOGIT:
for (int i = 0; i < preds.length; i++) {
preds[i] = 1.0 / (1.0 + Math.exp(-preds[i]));
}
Utils.normalize(preds);
break;
case PROBIT:
for (int i = 0; i < preds.length; i++) {
preds[i] = weka.core.matrix.Maths.pnorm(preds[i]);
}
Utils.normalize(preds);
break;
case CLOGLOG:
// note this is supposed to be illegal for regression
for (int i = 0; i < preds.length; i++) {
preds[i] = 1.0 - Math.exp(-Math.exp(-preds[i]));
}
Utils.normalize(preds);
break;
case EXP:
for (int i = 0; i < preds.length; i++) {
preds[i] = Math.exp(preds[i]);
}
Utils.normalize(preds);
break;
case LOGLOG:
// note this is supposed to be illegal for regression
for (int i = 0; i < preds.length; i++) {
preds[i] = Math.exp(-Math.exp(-preds[i]));
}
Utils.normalize(preds);
break;
case CAUCHIT:
for (int i = 0; i < preds.length; i++) {
preds[i] = 0.5 + (1.0 / Math.PI) * Math.atan(preds[i]);
}
Utils.normalize(preds);
break;
default:
throw new Exception("[Regression] unknown normalization method");
}
// If there is a Target defined, and this is a numeric prediction problem,
// then apply any min, max, rescaling etc.
if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()
&& m_miningSchema.hasTargetMetaData()) {
TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
preds[0] = targetData.applyMinMaxRescaleCast(preds[0]);
}
}
return preds;
}
/* (non-Javadoc)
* @see weka.core.RevisionHandler#getRevision()
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 8034 $");
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy