All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.pmml.consumer.GeneralRegression Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    GeneralRegression.java
 *    Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;

import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.PMMLUtils;
import weka.core.pmml.TargetMetaInfo;

/**
 * Class implementing import of PMML General Regression model. Can be
 * used as a Weka classifier for prediction (buildClassifier()
 * raises an Exception).
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 8034 $
 */
public class GeneralRegression extends PMMLClassifier
  implements Serializable {

  /**
   * For serialization
   */
  private static final long serialVersionUID = 2583880411828388959L;

  /**
   * Enumerated type for the model type.
   */
  enum ModelType {

    // same type of model
    REGRESSION ("regression"), 
      GENERALLINEAR ("generalLinear"), 
      MULTINOMIALLOGISTIC ("multinomialLogistic"),
      ORDINALMULTINOMIAL ("ordinalMultinomial"), 
      GENERALIZEDLINEAR ("generalizedLinear");

    private final String m_stringVal;
    ModelType(String name) {
      m_stringVal = name;
    }
    
    public String toString() {
      return m_stringVal;
    }
  }
  
  // the model type
  protected ModelType m_modelType = ModelType.REGRESSION;

  // the model name (if defined)
  protected String m_modelName;
    
  // the algorithm name (if defined)
  protected String m_algorithmName;

  // the function type (regression or classification)
  protected int m_functionType = Regression.RegressionTable.REGRESSION;

  /**
   * Enumerated type for the cumulative link function
   * (ordinal multinomial model type only).
   */
  enum CumulativeLinkFunction {
    NONE ("none") {
      double eval(double value, double offset) {
        return Double.NaN; // no evaluation defined in this case!
      }
    },
    LOGIT ("logit") {
      double eval(double value, double offset) {
        return 1.0 / (1.0 + Math.exp(-(value + offset)));
      }
    },
    PROBIT ("probit") {
      double eval(double value, double offset) {
        return weka.core.matrix.Maths.pnorm(value + offset); 
      }
    },
    CLOGLOG ("cloglog") {
      double eval(double value, double offset) {
        return 1.0 - Math.exp(-Math.exp(value + offset));
      }
    },
    LOGLOG ("loglog") {
      double eval(double value, double offset) {
        return Math.exp(-Math.exp(-(value + offset))); 
      }
    },
    CAUCHIT ("cauchit") {
      double eval(double value, double offset) {
        return 0.5 + (1.0 / Math.PI) * Math.atan(value + offset);
      }
    };

    /**
     * Evaluation function.
     * 
     * @param value the raw response value
     * @param offset the offset to add to the raw value 
     * @return the result of the link function
     */
    abstract double eval(double value, double offset);
    
    private final String m_stringVal;
    
    /**
     * Constructor
     * 
     * @param name textual name for this enum
     */
    CumulativeLinkFunction(String name) {
      m_stringVal = name;
    }
    
    /* (non-Javadoc)
     * @see java.lang.Enum#toString()
     */
    public String toString() {
      return m_stringVal;
    }
  }
  
  // cumulative link function (ordinal multinomial only)
  protected CumulativeLinkFunction m_cumulativeLinkFunction 
    = CumulativeLinkFunction.NONE;


  /**
   * Enumerated type for the link function (general linear and
   * generalized linear model types only).
   */
  enum LinkFunction {
    NONE ("none") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return Double.NaN; // no evaluation defined in this case!
      }
    },
    CLOGLOG ("cloglog") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (1.0 - Math.exp(-Math.exp(value + offset))) * trials;
      }
    },
    IDENTITY ("identity") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (value + offset) * trials;
      }
    },
    LOG ("log") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return Math.exp(value + offset) * trials;
      }
    },
    LOGC ("logc") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (1.0 - Math.exp(value + offset)) * trials;
      }
    },
    LOGIT ("logit") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (1.0 / (1.0 + Math.exp(-(value + offset)))) * trials;
      }
    },
    LOGLOG ("loglog") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return Math.exp(-Math.exp(-(value + offset))) * trials;
      }
    },
    NEGBIN ("negbin") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (1.0 / (distParam * (Math.exp(-(value + offset)) - 1.0))) * trials;
      }
    },
    ODDSPOWER ("oddspower") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (linkParam < 0.0 || linkParam > 0.0)
        ? (1.0 / (1.0 + Math.pow(1.0 + linkParam * (value + offset), (-1.0 / linkParam)))) * trials
        : (1.0 / (1.0 + Math.exp(-(value + offset)))) * trials;
      }
    },
    POWER ("power") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return (linkParam < 0.0 || linkParam > 0.0)
        ? Math.pow(value + offset, (1.0 / linkParam)) * trials
            : Math.exp(value + offset) * trials;
      }
    },
    PROBIT ("probit") {
      double eval(double value, double offset, double trials,
                  double distParam, double linkParam) {
        return weka.core.matrix.Maths.pnorm(value + offset) * trials;
      }
    };

    /**
     * Evaluation function.
     * 
     * @param value the raw response value
     * @param offset the offset to add to the raw value
     * @param trials the trials value to multiply the result by
     * @param distParam the distribution parameter (negbin only)
     * @param linkParam the link parameter (power and oddspower only) 
     * @return the result of the link function
     */
    abstract double eval(double value, double offset, double trials, 
                         double distParam, double linkParam);
    
    private final String m_stringVal;
    
    /**
     * Constructor.
     * 
     * @param name the textual name of this link function
     */
    LinkFunction(String name) {
      m_stringVal = name;
    }

    /* (non-Javadoc)
     * @see java.lang.Enum#toString()
     */
    public String toString() {
      return m_stringVal;
    }
  }
  
  // link function (generalLinear model type only)
  protected LinkFunction m_linkFunction = LinkFunction.NONE;
  protected double m_linkParameter = Double.NaN;
  protected String m_trialsVariable;
  protected double m_trialsValue = Double.NaN;

  /**
   * Enumerated type for the distribution (general linear
   * and generalized linear model types only).
   */
  enum Distribution {
    NONE ("none"),
    NORMAL ("normal"),
    BINOMIAL ("binomial"),
    GAMMA ("gamma"),
    INVGAUSSIAN ("igauss"),
    NEGBINOMIAL ("negbin"),
    POISSON ("poisson");

    private final String m_stringVal;
    Distribution(String name) {
      m_stringVal = name;
    }

    /* (non-Javadoc)
     * @see java.lang.Enum#toString()
     */
    public String toString() {
      return m_stringVal;
    }
  }
  
  // generalLinear and generalizedLinear model type only
  protected Distribution m_distribution = Distribution.NORMAL;

  // ancillary parameter value for the negative binomial distribution
  protected double m_distParameter = Double.NaN;

  // if present, this variable is used during scoring generalizedLinear/generalLinear or
  // ordinalMultinomial models
  protected String m_offsetVariable;

  // if present, this variable is used during scoring generalizedLinear/generalLinear or
  // ordinalMultinomial models. It works like a user-specified intercept.
  // At most, only one of offsetVariable or offsetValue may be specified.
  protected double m_offsetValue = Double.NaN;

  /**
   * Small inner class to hold the name of a parameter plus
   * its optional descriptive label
   */
  static class Parameter implements Serializable {
    // ESCA-JAVA0096:
    /** For serialization */
    // CHECK ME WITH serialver
    private static final long serialVersionUID = 6502780192411755341L;

    protected String m_name = null;
    protected String m_label = null;
  }

  // List of model parameters
  protected ArrayList m_parameterList = new ArrayList();

  /**
   * Small inner class to hold the name of a factor or covariate,
   * plus the index of the attribute it corresponds to in the
   * mining schema.
   */
  static class Predictor implements Serializable {
    /** For serialization */
    // CHECK ME WITH serialver
    private static final long serialVersionUID = 6502780192411755341L;

    protected String m_name = null;
    protected int m_miningSchemaIndex = -1;
    
    public String toString() {
      return m_name;
    }
  }
  
  // FactorList
  protected ArrayList m_factorList = new ArrayList();

  // CovariateList
  protected ArrayList m_covariateList = new ArrayList();

  /**
   * Small inner class to hold details on a predictor-to-parameter
   * correlation.
   */
  static class PPCell implements Serializable {
    /** For serialization */
    // CHECK ME WITH serialver
    private static final long serialVersionUID = 6502780192411755341L;
    
    protected String m_predictorName = null;
    protected String m_parameterName = null;

    // either the exponent of a numeric attribute or the index of
    // a discrete value
    protected double m_value = 0;

    // optional. The default is for all target categories to
    // share the same PPMatrix.
    // TO-DO: implement multiple PPMatrixes 
    protected String m_targetCategory = null;
    
  }
  
  // PPMatrix (predictor-to-parameter matrix)
  // rows = parameters, columns = predictors (attributes)
  protected PPCell[][] m_ppMatrix;

  /**
   * Small inner class to hold a single entry in the 
   * ParamMatrix (parameter matrix).
   */
  static class PCell implements Serializable {
    
    /** For serialization */
    // CHECK ME WITH serialver
    private static final long serialVersionUID = 6502780192411755341L;

    // may be null for numeric target. May also be null if this coefficent
    // applies to all target categories.
    protected String m_targetCategory = null;
    protected String m_parameterName = null;
    // coefficient
    protected double m_beta = 0.0;
    // optional degrees of freedom
    protected int m_df = -1;
  }
  
  // ParamMatrix. rows = target categories (only one if target is numeric),
  // columns = parameters (in order that they occur in the parameter list).
  protected PCell[][] m_paramMatrix;

  /**
   * Constructs a GeneralRegression classifier.
   * 
   * @param model the Element that holds the model definition
   * @param dataDictionary the data dictionary as a set of Instances
   * @param miningSchema the mining schema
   * @throws Exception if there is a problem constructing the general regression
   * object from the PMML.
   */
  public GeneralRegression(Element model, Instances dataDictionary,
                           MiningSchema miningSchema) throws Exception {

    super(dataDictionary, miningSchema);
 
    // get the model type
    String mType = model.getAttribute("modelType");
    boolean found = false;
    for (ModelType m : ModelType.values()) {
      if (m.toString().equals(mType)) {
        m_modelType = m;
        found = true;
        break;
      }      
    }
    if (!found) {
      throw new Exception("[GeneralRegression] unknown model type: " + mType);
    }

    if (m_modelType == ModelType.ORDINALMULTINOMIAL) {
      // get the cumulative link function
      String cLink = model.getAttribute("cumulativeLink");
      found = false;
      for (CumulativeLinkFunction c : CumulativeLinkFunction.values()) {
        if (c.toString().equals(cLink)) {
          m_cumulativeLinkFunction = c;
          found = true;
          break;
        }
      }
      if (!found) {
        throw new Exception("[GeneralRegression] cumulative link function " + cLink);
      }
    } else if (m_modelType == ModelType.GENERALIZEDLINEAR || 
                m_modelType == ModelType.GENERALLINEAR) {
      // get the link function
      String link = model.getAttribute("linkFunction");
      found = false;
      for (LinkFunction l : LinkFunction.values()) {
        if (l.toString().equals(link)) {
          m_linkFunction = l;
          found = true;
          break;
        }
      }
      if (!found) {
        throw new Exception("[GeneralRegression] unknown link function " + link);
      }

      // get the link parameter
      String linkP = model.getAttribute("linkParameter");
      if (linkP != null && linkP.length() > 0) {
        try {
          m_linkParameter = Double.parseDouble(linkP);
        } catch (IllegalArgumentException ex) {
          throw new Exception("[GeneralRegression] unable to parse the link parameter");
        }
      }

      // get the trials variable
      String trials = model.getAttribute("trialsVariable");
      if (trials != null && trials.length() > 0) {
        m_trialsVariable = trials;
      }

      // get the trials value
      String trialsV = model.getAttribute("trialsValue");
      if (trialsV != null && trialsV.length() > 0) {
        try {
          m_trialsValue = Double.parseDouble(trialsV);
        } catch (IllegalArgumentException ex) {
          throw new Exception("[GeneralRegression] unable to parse the trials value"); 
        }
      }
    }
  
    String mName = model.getAttribute("modelName");
    if (mName != null && mName.length() > 0) {
      m_modelName = mName;
    }

    String fName = model.getAttribute("functionName");
    if (fName.equals("classification")) {
      m_functionType = Regression.RegressionTable.CLASSIFICATION;
    }

    String algName = model.getAttribute("algorithmName");
    if (algName != null && algName.length() > 0) {
      m_algorithmName = algName;
    }

    String distribution = model.getAttribute("distribution");
    if (distribution != null && distribution.length() > 0) {
      found = false;
      for (Distribution d : Distribution.values()) {
        if (d.toString().equals(distribution)) {
          m_distribution = d;
          found = true;
          break;
        }
      }
      if (!found) {
        throw new Exception("[GeneralRegression] unknown distribution type " + distribution);
      }
    }

    String distP = model.getAttribute("distParameter");
    if (distP != null && distP.length() > 0) {
      try {
        m_distParameter = Double.parseDouble(distP);
      } catch (IllegalArgumentException ex) {
        throw new Exception("[GeneralRegression] unable to parse the distribution parameter");
      }
    }

    String offsetV = model.getAttribute("offsetVariable");
    if (offsetV != null && offsetV.length() > 0) {
       m_offsetVariable = offsetV;
    }

    String offsetVal = model.getAttribute("offsetValue");
    if (offsetVal != null && offsetVal.length() > 0) {
      try {
        m_offsetValue = Double.parseDouble(offsetVal);
      } catch (IllegalArgumentException ex) {
        throw new Exception("[GeneralRegression] unable to parse the offset value");
      }
    }

    // get the parameter list
    readParameterList(model);
    
    // get the factors and covariates
    readFactorsAndCovariates(model, "FactorList");
    readFactorsAndCovariates(model, "CovariateList");

    // read the PPMatrix
    readPPMatrix(model);

    // read the parameter estimates
    readParamMatrix(model);
  }

  /**
   * Read the list of parameters.
   *
   * @param model the Element that contains the model
   * @throws Exception if there is some problem with extracting the
   * parameters.
   */
  protected void readParameterList(Element model) throws Exception {
    NodeList paramL = model.getElementsByTagName("ParameterList");

    // should be just one parameter list
    if (paramL.getLength() == 1) {
      Node paramN = paramL.item(0);
      if (paramN.getNodeType() == Node.ELEMENT_NODE) {
        NodeList parameterList = ((Element)paramN).getElementsByTagName("Parameter");
        for (int i = 0; i < parameterList.getLength(); i++) {
          Node parameter = parameterList.item(i);
          if (parameter.getNodeType() == Node.ELEMENT_NODE) {
            Parameter p = new Parameter();
            p.m_name = ((Element)parameter).getAttribute("name");
            String label = ((Element)parameter).getAttribute("label");
            if (label != null && label.length() > 0) {
              p.m_label = label;
            }
            m_parameterList.add(p);
          }
        }
      }
    } else {
      throw new Exception("[GeneralRegression] more than one parameter list!");
    }
  }

  /**
   * Read the lists of factors and covariates.
   *
   * @param model the Element that contains the model
   * @param factorOrCovariate holds the String "FactorList" or
   * "CovariateList"
   * @throws Exception if there is a factor or covariate listed
   * that isn't in the mining schema
   */
  protected void readFactorsAndCovariates(Element model, 
                                          String factorOrCovariate) 
    throws Exception {
    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();

    NodeList factorL = model.getElementsByTagName(factorOrCovariate);
    if (factorL.getLength() == 1) { // should be 0 or 1 FactorList element
      Node factor = factorL.item(0);
      if (factor.getNodeType() == Node.ELEMENT_NODE) {
        NodeList predL = ((Element)factor).getElementsByTagName("Predictor");
        for (int i = 0; i < predL.getLength(); i++) {
          Node pred = predL.item(i);
          if (pred.getNodeType() == Node.ELEMENT_NODE) {
            Predictor p = new Predictor();
            p.m_name = ((Element)pred).getAttribute("name");
            // find the index of this predictor in the mining schema
            boolean found = false;
            for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
              if (miningSchemaI.attribute(j).name().equals(p.m_name)) {
                found = true;
                p.m_miningSchemaIndex = j;
                break;
              }
            }
            if (found) {
              if (factorOrCovariate.equals("FactorList")) {
                m_factorList.add(p);
              } else {
                m_covariateList.add(p);
              }
            } else {
              throw new Exception("[GeneralRegression] reading factors and covariates - "
                                  + "unable to find predictor " +
                                  p.m_name + " in the mining schema");
            }
          }
        }
      }
    } else if (factorL.getLength() > 1){
      throw new Exception("[GeneralRegression] more than one " + factorOrCovariate
                          + "! ");
    }
  }

  /**
   * Read the PPMatrix from the xml. Does not handle multiple PPMatrixes yet.
   *
   * @param model the Element that contains the model
   * @throws Exception if there is a problem parsing cell values.
   */
  protected void readPPMatrix(Element model) throws Exception {
    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
    
    NodeList matrixL = model.getElementsByTagName("PPMatrix");

    // should be exactly one PPMatrix
    if (matrixL.getLength() == 1) {
      // allocate space for the matrix
      // column that corresponds to the class will be empty (and will be missed out
      // when printing the model).
      m_ppMatrix = new PPCell[m_parameterList.size()][miningSchemaI.numAttributes()];

      Node ppM = matrixL.item(0);
      if (ppM.getNodeType() == Node.ELEMENT_NODE) {
        NodeList cellL = ((Element)ppM).getElementsByTagName("PPCell");
        for (int i = 0; i < cellL.getLength(); i++) {
          Node cell = cellL.item(i);
          if (cell.getNodeType() == Node.ELEMENT_NODE) {
            String predictorName = ((Element)cell).getAttribute("predictorName");
            String parameterName = ((Element)cell).getAttribute("parameterName");
            String value = ((Element)cell).getAttribute("value");
            double expOrIndex = -1;
            int predictorIndex = -1;
            int parameterIndex = -1;
            for (int j = 0; j < m_parameterList.size(); j++) {
              if (m_parameterList.get(j).m_name.equals(parameterName)) {
                parameterIndex = j;
                break;
              }
            }
            if (parameterIndex == -1) {
              throw new Exception("[GeneralRegression] unable to find parameter name "
                                  + parameterName + " in parameter list");
            }

            Predictor p = getCovariate(predictorName);
            if (p != null) {
              try {
                expOrIndex = Double.parseDouble(value);
                predictorIndex = p.m_miningSchemaIndex;
              } catch (IllegalArgumentException ex) {
                throw new Exception("[GeneralRegression] unable to parse PPCell value: "
                                    + value);
              }
            } else {
              // try as a factor
              p = getFactor(predictorName);
              if (p != null) {
                // An example pmml file from DMG seems to suggest that it
                // is possible for a continuous variable in the mining schema
                // to be treated as a factor, so we have to check for this
                if (miningSchemaI.attribute(p.m_miningSchemaIndex).isNumeric()) {
                  // parse this value as a double. It will be treated as a value
                  // to match rather than an exponent since we are dealing with
                  // a factor here
                  try {
                    expOrIndex = Double.parseDouble(value);
                  } catch (IllegalArgumentException ex) {
                    throw new Exception("[GeneralRegresion] unable to parse PPCell value: "
                                        + value);
                  }
                } else {
                  // it is a nominal attribute in the mining schema so find
                  // the index that correponds to this value
                  Attribute att = miningSchemaI.attribute(p.m_miningSchemaIndex); 
                  expOrIndex = att.indexOfValue(value);
                  if (expOrIndex == -1) {
                    throw new Exception("[GeneralRegression] unable to find PPCell value "
                                        + value + " in mining schema attribute "
                                        + att.name());
                  }
                }
              } else {
                throw new Exception("[GeneralRegression] cant find predictor "
                                    + predictorName + "in either the factors list "
                                    + "or the covariates list");
              }
              predictorIndex = p.m_miningSchemaIndex;
            }

            // fill in cell value
            PPCell ppc = new PPCell();
            ppc.m_predictorName = predictorName; ppc.m_parameterName = parameterName;
            ppc.m_value = expOrIndex;

            // TO-DO: ppc.m_targetCategory (when handling for multiple PPMatrixes is implemented)
            m_ppMatrix[parameterIndex][predictorIndex] = ppc;
          }
        }
      }
    } else {
      throw new Exception("[GeneralRegression] more than one PPMatrix!");
    }
  }

  private Predictor getCovariate(String predictorName) {
    for (int i = 0; i < m_covariateList.size(); i++) {
      if (predictorName.equals(m_covariateList.get(i).m_name)) {
        return m_covariateList.get(i);
      }
    }
    return null;
  }

  private Predictor getFactor(String predictorName) {
    for (int i = 0; i < m_factorList.size(); i++) {
      if (predictorName.equals(m_factorList.get(i).m_name)) {
        return m_factorList.get(i);
      }
    }
    return null;
  }

  /**
   * Read the parameter matrix from the xml.
   * 
   * @param model Element that holds the model
   * @throws Exception if a problem is encountered during extraction of
   * the parameter matrix
   */
  private void readParamMatrix(Element model) throws Exception {

    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
    Attribute classAtt = miningSchemaI.classAttribute();
    // used when function type is classification but class attribute is numeric
    // in the mining schema. We will assume that there is a Target specified in
    // the pmml that defines the legal values for this class.
    ArrayList targetVals = null;

    NodeList matrixL = model.getElementsByTagName("ParamMatrix");
    if (matrixL.getLength() != 1) {
      throw new Exception("[GeneralRegression] more than one ParamMatrix!");
    }
    Element matrix = (Element)matrixL.item(0);


    // check for the case where the class in the mining schema is numeric,
    // but this attribute is treated as discrete
    if (m_functionType == Regression.RegressionTable.CLASSIFICATION &&
        classAtt.isNumeric()) {
      // try and convert the class attribute to nominal. For this to succeed
      // there has to be a Target element defined in the PMML.
      if (!m_miningSchema.hasTargetMetaData()) {
        throw new Exception("[GeneralRegression] function type is classification and "
                            + "class attribute in mining schema is numeric, however, "
                            + "there is no Target element "
                            + "specifying legal discrete values for the target!");

      }

      if (m_miningSchema.getTargetMetaData().getOptype() 
          != TargetMetaInfo.Optype.CATEGORICAL) {
        throw new Exception("[GeneralRegression] function type is classification and "
                            + "class attribute in mining schema is numeric, however "
                            + "Target element in PMML does not have optype categorical!");
      }

      // OK now get legal values
      targetVals = m_miningSchema.getTargetMetaData().getValues();
      if (targetVals.size() == 0) {
        throw new Exception("[GeneralRegression] function type is classification and "
                            + "class attribute in mining schema is numeric, however "
                            + "Target element in PMML does not have any discrete values "
                            + "defined!");
      }

      // Finally, convert the class in the mining schema to nominal
      m_miningSchema.convertNumericAttToNominal(miningSchemaI.classIndex(), targetVals);
    }
    
    // allocate space for the matrix 
    m_paramMatrix = 
        new PCell[(classAtt.isNumeric())
                  ? 1
                  : classAtt.numValues()][m_parameterList.size()];

    NodeList pcellL = matrix.getElementsByTagName("PCell");
    for (int i = 0; i < pcellL.getLength(); i++) {
      // indicates that that this beta applies to all target categories
      // or target is numeric
      int targetCategoryIndex = -1;
      int parameterIndex = -1;
      Node pcell = pcellL.item(i);
      if (pcell.getNodeType() == Node.ELEMENT_NODE) {
        String paramName = ((Element)pcell).getAttribute("parameterName");
        String targetCatName = ((Element)pcell).getAttribute("targetCategory");
        String coefficient = ((Element)pcell).getAttribute("beta");
        String df = ((Element)pcell).getAttribute("df");

        for (int j = 0; j < m_parameterList.size(); j++) {
          if (m_parameterList.get(j).m_name.equals(paramName)) {
            parameterIndex = j;
            // use the label if defined
            if (m_parameterList.get(j).m_label != null) {
              paramName = m_parameterList.get(j).m_label;
            }
            break;
          }
        }
        if (parameterIndex == -1) {
          throw new Exception("[GeneralRegression] unable to find parameter name "
                              + paramName + " in parameter list");
        }

        if (targetCatName != null && targetCatName.length() > 0) {
          if (classAtt.isNominal() || classAtt.isString()) {
            targetCategoryIndex = classAtt.indexOfValue(targetCatName);
          } else {
            throw new Exception("[GeneralRegression] found a PCell with a named "
                                + "target category: " + targetCatName
                                + " but class attribute is numeric in "
                                + "mining schema");
          }
        }

        PCell p = new PCell();
        if (targetCategoryIndex != -1) {
          p.m_targetCategory = targetCatName;
        }
        p.m_parameterName = paramName;
        try {
          p.m_beta = Double.parseDouble(coefficient);
        } catch (IllegalArgumentException ex) {
          throw new Exception("[GeneralRegression] unable to parse beta value "
                              + coefficient + " as a double from PCell");
        }
        if (df != null && df.length() > 0) {
          try {
            p.m_df = Integer.parseInt(df);
          } catch (IllegalArgumentException ex) {
            throw new Exception("[GeneralRegression] unable to parse df value "
                              + df + " as an int from PCell");
          }
        }
        
        if (targetCategoryIndex != -1) {
          m_paramMatrix[targetCategoryIndex][parameterIndex] = p;
        } else {
          // this PCell to all target categories (covers numeric class, in
          // which case there will be only one row in the matrix anyway)
          for (int j = 0; j < m_paramMatrix.length; j++) {
            m_paramMatrix[j][parameterIndex] = p;
          }
        }
      }
    }
  }

  /**
   * Return a textual description of this general regression.
   * 
   * @return a description of this general regression
   */
  public String toString() {
    StringBuffer temp = new StringBuffer();
    temp.append("PMML version " + getPMMLVersion());
    if (!getCreatorApplication().equals("?")) {
      temp.append("\nApplication: " + getCreatorApplication());
    }
    temp.append("\nPMML Model: " + m_modelType);
    temp.append("\n\n");
    temp.append(m_miningSchema);

    if (m_factorList.size() > 0) {
      temp.append("Factors:\n");
      for (Predictor p : m_factorList) {
        temp.append("\t" + p + "\n");
      }
    }
    temp.append("\n");
    if (m_covariateList.size() > 0) {
      temp.append("Covariates:\n");
      for (Predictor p : m_covariateList) {
        temp.append("\t" + p + "\n");
      }
    }
    temp.append("\n");
    
    printPPMatrix(temp);
    temp.append("\n");
    printParameterMatrix(temp);
    
    // do the link function stuff
    temp.append("\n");
    
    if (m_linkFunction != LinkFunction.NONE) {
      temp.append("Link function: " + m_linkFunction);
      if (m_offsetVariable != null) {
        temp.append("\n\tOffset variable " + m_offsetVariable);
      } else if (!Double.isNaN(m_offsetValue)) {
        temp.append("\n\tOffset value " + m_offsetValue);
      }
      
      if (m_trialsVariable != null) {
        temp.append("\n\tTrials variable " + m_trialsVariable);
      } else if (!Double.isNaN(m_trialsValue)) {
        temp.append("\n\tTrials value " + m_trialsValue);
      }
      
      if (m_distribution != Distribution.NONE) {
        temp.append("\nDistribution: " + m_distribution);
      }
      
      if (m_linkFunction == LinkFunction.NEGBIN &&
          m_distribution == Distribution.NEGBINOMIAL &&
          !Double.isNaN(m_distParameter)) {
        temp.append("\n\tDistribution parameter " + m_distParameter);
      }
      
      if (m_linkFunction == LinkFunction.POWER ||
          m_linkFunction == LinkFunction.ODDSPOWER) {
        if (!Double.isNaN(m_linkParameter)) {
          temp.append("\n\nLink parameter " + m_linkParameter);
        }
      }
    }
    
    if (m_cumulativeLinkFunction != CumulativeLinkFunction.NONE) {
      temp.append("Cumulative link function: " + m_cumulativeLinkFunction);
      
      if (m_offsetVariable != null) {
        temp.append("\n\tOffset variable " + m_offsetVariable);
      } else if (!Double.isNaN(m_offsetValue)) {
        temp.append("\n\tOffset value " + m_offsetValue);
      }
    }
    temp.append("\n");
    
    return temp.toString();
  }
  
  /**
   * Format and print the PPMatrix to the supplied StringBuffer.
   * 
   * @param buff the StringBuffer to append to
   */
  protected void printPPMatrix(StringBuffer buff) {
    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
    int maxAttWidth = 0;
    for (int i = 0; i < miningSchemaI.numAttributes(); i++) {
      Attribute a = miningSchemaI.attribute(i);
      if (a.name().length() > maxAttWidth) {
        maxAttWidth = a.name().length();
      }
    }

    // check the width of the values
    for (int i = 0; i < m_parameterList.size(); i++) {
      for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
        if (m_ppMatrix[i][j] != null) {
          double width = Math.log(Math.abs(m_ppMatrix[i][j].m_value)) /
            Math.log(10.0);
          if (width < 0) {
            width = 1;
          }
          // decimal + # decimal places + 1
          width += 2.0;
          if ((int)width > maxAttWidth) {
            maxAttWidth = (int)width;
          }
          if (miningSchemaI.attribute(j).isNominal() || 
              miningSchemaI.attribute(j).isString()) {
            // check the width of this value
            String val = miningSchemaI.attribute(j).value((int)m_ppMatrix[i][j].m_value) + " ";
            if (val.length() > maxAttWidth) {
              maxAttWidth = val.length();
            }
          }
        }
      }
    }

    // get the max parameter width
    int maxParamWidth = "Parameter  ".length();
    for (Parameter p : m_parameterList) {
      String temp = (p.m_label != null)
        ? p.m_label + " "
        : p.m_name + " ";

      if (temp.length() > maxParamWidth) {
        maxParamWidth = temp.length();
      }
    }

    buff.append("Predictor-to-Parameter matrix:\n");
    buff.append(PMMLUtils.pad("Predictor", " ", (maxParamWidth + (maxAttWidth * 2 + 2))
                              - "Predictor".length(), true));
    buff.append("\n" + PMMLUtils.pad("Parameter", " ", maxParamWidth - "Parameter".length(), false));
    // attribute names
    for (int i = 0; i < miningSchemaI.numAttributes(); i++) {
      if (i != miningSchemaI.classIndex()) {
        String attName = miningSchemaI.attribute(i).name();
        buff.append(PMMLUtils.pad(attName, " ", maxAttWidth + 1 - attName.length(), true));
      }
    }
    buff.append("\n");

    for (int i = 0; i < m_parameterList.size(); i++) {
      Parameter param = m_parameterList.get(i);
      String paramS = (param.m_label != null)
        ? param.m_label
        : param.m_name;
      buff.append(PMMLUtils.pad(paramS, " ", 
                                maxParamWidth - paramS.length(), false));
      for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
        if (j != miningSchemaI.classIndex()) {
          PPCell p = m_ppMatrix[i][j];
          String val = " ";
          if (p != null) {
            if (miningSchemaI.attribute(j).isNominal() ||
                miningSchemaI.attribute(j).isString()) {
              val = miningSchemaI.attribute(j).value((int)p.m_value);
            } else {
              val = "" + Utils.doubleToString(p.m_value, maxAttWidth, 4).trim();
            }
          }
          buff.append(PMMLUtils.pad(val, " ", maxAttWidth + 1 - val.length(), true));
        }
      }
      buff.append("\n");
    }
  }

  /**
   * Format and print the parameter matrix to the supplied StringBuffer.
   * 
   * @param buff the StringBuffer to append to
   */
  protected void printParameterMatrix(StringBuffer buff) {
    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();

    // get the maximum class value width (nominal)
    int maxClassWidth = miningSchemaI.classAttribute().name().length();
    if (miningSchemaI.classAttribute().isNominal()
        || miningSchemaI.classAttribute().isString()) {
      for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
        if (miningSchemaI.classAttribute().value(i).length() > maxClassWidth) {
          maxClassWidth = miningSchemaI.classAttribute().value(i).length();
        }
      }
    }

    // get the maximum parameter name/label width
    int maxParamWidth = 0;
    for (int i = 0; i < m_parameterList.size(); i++) {
      Parameter p = m_parameterList.get(i);
      String val = (p.m_label != null)
        ? p.m_label + " "
        : p.m_name + " ";
      if (val.length() > maxParamWidth) {
        maxParamWidth = val.length();
      }
    }

    // get the max beta value width
    int maxBetaWidth = "Coeff.".length();
    for (int i = 0; i < m_paramMatrix.length; i++) {
      for (int j = 0; j < m_parameterList.size(); j++) {
        PCell p = m_paramMatrix[i][j];
        if (p != null) {
          double width = Math.log(Math.abs(p.m_beta)) / Math.log(10);
          if (width < 0) {
            width = 1;
          }
          // decimal + # decimal places + 1
          width += 7.0;
          if ((int)width > maxBetaWidth) {
            maxBetaWidth = (int)width;
          }
        }
      }
    }

    buff.append("Parameter estimates:\n");
    buff.append(PMMLUtils.pad(miningSchemaI.classAttribute().name(), " ", 
                              maxClassWidth + maxParamWidth + 2 - 
                              miningSchemaI.classAttribute().name().length(), false));
    buff.append(PMMLUtils.pad("Coeff.", " ", maxBetaWidth + 1 - "Coeff.".length(), true));
    buff.append(PMMLUtils.pad("df", " ", maxBetaWidth - "df".length(), true));
    buff.append("\n");
    for (int i = 0; i < m_paramMatrix.length; i++) {
      // scan for non-null entry for this class value
      boolean ok = false;
      for (int j = 0; j < m_parameterList.size(); j++) {
        if (m_paramMatrix[i][j] != null) {
          ok = true;
        }
      }
      if (!ok) {
        continue;
      }
      // first the class value (if nominal)
      String cVal = (miningSchemaI.classAttribute().isNominal() || 
          miningSchemaI.classAttribute().isString())
        ? miningSchemaI.classAttribute().value(i)
        : " ";
      buff.append(PMMLUtils.pad(cVal, " ", maxClassWidth - cVal.length(), false));     
      buff.append("\n");
      for (int j = 0; j < m_parameterList.size(); j++) {
        PCell p = m_paramMatrix[i][j];
        if (p != null) {
          String label = p.m_parameterName;
          buff.append(PMMLUtils.pad(label, " ", maxClassWidth + maxParamWidth + 2 -
                                    label.length(), true));
          String betaS = Utils.doubleToString(p.m_beta, maxBetaWidth, 4).trim();
          buff.append(PMMLUtils.pad(betaS, " ", maxBetaWidth + 1 - betaS.length(), true));
          String dfS = Utils.doubleToString(p.m_df, maxBetaWidth, 4).trim();
          buff.append(PMMLUtils.pad(dfS, " ", maxBetaWidth - dfS.length(), true));
          buff.append("\n");
        }
      }
    }
  }
  
  /**
   * Construct the incoming parameter vector based on the values
   * in the incoming test instance.
   * 
   * @param incomingInst the values of the incoming test instance
   * @return the populated parameter vector ready to be multiplied against
   * the vector of coefficients.
   * @throws Exception if there is some problem whilst constructing the
   * parameter vector
   */
  private double[] incomingParamVector(double[] incomingInst) throws Exception {
    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
    double[] incomingPV = new double[m_parameterList.size()];
    
    for (int i = 0; i < m_parameterList.size(); i++) {
      //
      // default is that this row represents the intercept.
      // this will be the case if there are all null entries in this row
      incomingPV[i] = 1.0;

      // loop over the attributes (predictors)
      for (int j = 0; j < miningSchemaI.numAttributes(); j++) {        
        PPCell cellEntry = m_ppMatrix[i][j];
        Predictor p = null;
        if (cellEntry != null) {
          if ((p = getFactor(cellEntry.m_predictorName)) != null) {
            if ((int)incomingInst[p.m_miningSchemaIndex] == (int)cellEntry.m_value) {
              incomingPV[i] *= 1.0; // we have a match
            } else {
              incomingPV[i] *= 0.0;
            }
          } else if ((p = getCovariate(cellEntry.m_predictorName)) != null) {
              incomingPV[i] *= Math.pow(incomingInst[p.m_miningSchemaIndex], cellEntry.m_value);
          } else {
            throw new Exception("[GeneralRegression] can't find predictor "
                + cellEntry.m_predictorName + " in either the list of factors or covariates");
          }
        }
      }
    }
    
    return incomingPV;
  }

  /**                                                                                                             
   * Classifies the given test instance. The instance has to belong to a                                          
   * dataset when it's being classified.                                                          
   *                                                                                                              
   * @param inst the instance to be classified                                                                
   * @return the predicted most likely class for the instance or                                                  
   * Utils.missingValue() if no prediction is made                                                             
   * @exception Exception if an error occurred during the prediction                                              
   */
  public double[] distributionForInstance(Instance inst) throws Exception {
    if (!m_initialized) {
      mapToMiningSchema(inst.dataset());
    }
    double[] preds = null;
    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
      preds = new double[1];
    } else {
      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
    }
    
    // create an array of doubles that holds values from the incoming
    // instance; in order of the fields in the mining schema. We will
    // also handle missing values and outliers here.
    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
    
    // In this implementation we will default to information in the Target element (default
    // value for numeric prediction and prior probabilities for classification). If there is
    // no Target element defined, then an Exception is thrown.

    boolean hasMissing = false;
    for (int i = 0; i < incoming.length; i++) {
      if (i != m_miningSchema.getFieldsAsInstances().classIndex() && 
          Double.isNaN(incoming[i])) {
        hasMissing = true;
        break;
      }
    }
    
    if (hasMissing) {
      if (!m_miningSchema.hasTargetMetaData()) {
        String message = "[GeneralRegression] WARNING: Instance to predict has missing value(s) but "
          + "there is no missing value handling meta data and no "
          + "prior probabilities/default value to fall back to. No "
          + "prediction will be made (" 
          + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()
              || m_miningSchema.getFieldsAsInstances().classAttribute().isString())
              ? "zero probabilities output)."
              : "NaN output).");
        if (m_log == null) {
          System.err.println(message);
        } else {
          m_log.logMessage(message);
        }
        
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
          preds[0] = Utils.missingValue();
        }
        return preds;
      } else {
        // use prior probablilities/default value
        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
          preds[0] = targetData.getDefaultValue();
        } else {
          Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
          for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
            preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
          }
        }
        return preds;
      }
    } else {
      // construct input parameter vector here
      double[] inputParamVector = incomingParamVector(incoming);
      computeResponses(incoming, inputParamVector, preds);
    }
    
    return preds;
  }
  
  /**
   * Compute the responses for the function given the parameter values corresponding
   * to the current incoming instance.
   * 
   * @param incomingInst raw incoming instance values (after missing value
   * replacement and outlier treatment)
   * @param incomingParamVector incoming instance values mapped to parameters
   * @param responses will contain the responses computed by the function
   * @throws Exception if something goes wrong
   */
  private void computeResponses(double[] incomingInst, 
                                double[] incomingParamVector,
                                double[] responses) throws Exception {
    for (int i = 0; i < responses.length; i++) {
      for (int j = 0; j < m_parameterList.size(); j++) {
        // a row of the parameter matrix should have all non-null entries
        // except for the last class (in the case of classification) which
        // should have just an intercept of 0. Need to handle the case where
        // no intercept has been defined in the pmml file for the last class
        PCell p = m_paramMatrix[i][j];
        if (p == null) {
          responses[i] += 0.0 * incomingParamVector[j];
        } else {
          responses[i] += incomingParamVector[j] * p.m_beta;
        }
      }
    }
    
    switch(m_modelType) {
    case MULTINOMIALLOGISTIC:
      computeProbabilitiesMultinomialLogistic(responses);
      break;
    case REGRESSION:
      // nothing to be done
      break;
    case GENERALLINEAR:
    case GENERALIZEDLINEAR:
      if (m_linkFunction != LinkFunction.NONE) {
        computeResponseGeneralizedLinear(incomingInst, responses);
      } else {
        throw new Exception("[GeneralRegression] no link function specified!");
      }
      break;
    case ORDINALMULTINOMIAL:
      if (m_cumulativeLinkFunction != CumulativeLinkFunction.NONE) {
        computeResponseOrdinalMultinomial(incomingInst, responses);
      } else {
        throw new Exception("[GeneralRegression] no cumulative link function specified!");
      }
      break;
      default:
        throw new Exception("[GeneralRegression] unknown model type");
    }
  }
  
  /**
   * Computes probabilities for the multinomial logistic model type.
   * 
   * @param responses will hold the responses computed by the function.
   */
  private static void computeProbabilitiesMultinomialLogistic(double[] responses) {
    double[] r = responses.clone();
    for (int j = 0; j < r.length; j++) {
      double sum = 0;
      boolean overflow = false;
      for (int k = 0; k < r.length; k++) {
        if (r[k] - r[j] > 700) {
          overflow = true;
          break;
        }
        sum += Math.exp(r[k] - r[j]);
      }
      if (overflow) {
        responses[j] = 0.0;
      } else {
        responses[j] = 1.0 / sum;
      }
    }
  }
  
  /**
   * Computes responses for the general linear and generalized linear model
   * types.
   * 
   * @param incomingInst the raw incoming instance values (after missing value
   * replacement and outlier treatment etc).
   * @param responses will hold the responses computed by the function
   * @throws Exception if a problem occurs. 
   */
  private void computeResponseGeneralizedLinear(double[] incomingInst, 
                                                double[] responses) 
    throws Exception {
    double[] r = responses.clone();
    
    double offset = 0;
    if (m_offsetVariable != null) {
      Attribute offsetAtt = 
        m_miningSchema.getFieldsAsInstances().attribute(m_offsetVariable);
      if (offsetAtt == null) {
        throw new Exception("[GeneralRegression] unable to find offset variable "
            + m_offsetVariable + " in the mining schema!");
      }
      offset = incomingInst[offsetAtt.index()];
    } else if (!Double.isNaN(m_offsetValue)) {
      offset = m_offsetValue;
    }
    
    double trials = 1;
    if (m_trialsVariable != null) {
      Attribute trialsAtt = m_miningSchema.getFieldsAsInstances().attribute(m_trialsVariable);
      if (trialsAtt == null) {
        throw new Exception("[GeneralRegression] unable to find trials variable "
            + m_trialsVariable + " in the mining schema!");
      }
      trials = incomingInst[trialsAtt.index()];
    } else if (!Double.isNaN(m_trialsValue)) {
      trials = m_trialsValue;
    }
    
    double distParam = 0;
    if (m_linkFunction == LinkFunction.NEGBIN && 
        m_distribution == Distribution.NEGBINOMIAL) {
      if (Double.isNaN(m_distParameter)) {
        throw new Exception("[GeneralRegression] no distribution parameter defined!");
      }
      distParam = m_distParameter;
    }
    
    double linkParam = 0;
    if (m_linkFunction == LinkFunction.POWER || 
        m_linkFunction == LinkFunction.ODDSPOWER) {
      if (Double.isNaN(m_linkParameter)) {
        throw new Exception("[GeneralRegression] no link parameter defined!");
      }
      linkParam = m_linkParameter;
    }
   
    for (int i = 0; i < r.length; i++) {
      responses[i] = m_linkFunction.eval(r[i], offset, trials, distParam, linkParam);
    }
  }
    
  /**
   * Computes responses for the ordinal multinomial model type.
   * 
   * @param incomingInst the raw incoming instance values (after missing value
   * replacement and outlier treatment etc).
   * @param responses will hold the responses computed by the function
   * @throws Exception if a problem occurs. 
   */
  private void computeResponseOrdinalMultinomial(double[] incomingInst, 
                                                  double[] responses) throws Exception {
    
    double[] r = responses.clone();
    
    double offset = 0;
    if (m_offsetVariable != null) {
      Attribute offsetAtt = 
        m_miningSchema.getFieldsAsInstances().attribute(m_offsetVariable);
      if (offsetAtt == null) {
        throw new Exception("[GeneralRegression] unable to find offset variable "
            + m_offsetVariable + " in the mining schema!");
      }
      offset = incomingInst[offsetAtt.index()];
    } else if (!Double.isNaN(m_offsetValue)) {
      offset = m_offsetValue;
    }
    
    for (int i = 0; i < r.length; i++) {
      if (i == 0) {
        responses[i] = m_cumulativeLinkFunction.eval(r[i], offset);
   
      } else if (i == (r.length - 1)) {
        responses[i] = 1.0 - responses[i - 1];
      } else {
        responses[i] = m_cumulativeLinkFunction.eval(r[i], offset) - responses[i - 1];
      }
    }
  }

  /* (non-Javadoc)
   * @see weka.core.RevisionHandler#getRevision()
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 8034 $");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy