All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.pmml.consumer.Regression Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    Regression.java
 *    Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;

import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;

/**
 * Class implementing import of PMML Regression model. Can be
 * used as a Weka classifier for prediction (buildClassifier()
 * raises an Exception).
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com
 * @version $Revision: 8034 $
 */
public class Regression extends PMMLClassifier
  implements Serializable {

  /** For serialization */
  private static final long serialVersionUID = -5551125528409488634L;

  /**
   * Inner class for encapsulating a regression table
   */
  static class RegressionTable implements Serializable {

    /** For serialization */
    private static final long serialVersionUID = -5259866093996338995L;

    /**
     * Abstract inner base class for different predictor types.
     */
    abstract static class Predictor implements Serializable {

      /** For serialization */
      private static final long serialVersionUID = 7043831847273383618L;
      
      /** Name of this predictor */
      protected String m_name;
      
      /** 
       * Index of the attribute in the mining schema that corresponds to this
       * predictor
       */
      protected int m_miningSchemaAttIndex = -1;
      
      /** Coefficient for this predictor */
      protected double m_coefficient = 1.0;
      
      /**
       * Constructs a new Predictor.
       * 
       * @param predictor the Element encapsulating this predictor
       * @param miningSchema the mining schema as an Instances object
       * @throws Exception if there is a problem constructing this Predictor
       */
      protected Predictor(Element predictor, Instances miningSchema) throws Exception {
        m_name = predictor.getAttribute("name");
        for (int i = 0; i < miningSchema.numAttributes(); i++) {
          Attribute temp = miningSchema.attribute(i);
          if (temp.name().equals(m_name)) {
            m_miningSchemaAttIndex = i;
          }
        }
        
        if (m_miningSchemaAttIndex == -1) {
          throw new Exception("[Predictor] unable to find matching attribute for "
                              + "predictor " + m_name);
        }

        String coeff = predictor.getAttribute("coefficient");
        if (coeff.length() > 0) {
          m_coefficient = Double.parseDouble(coeff);
        }
      }

      /**
       * Returns a textual description of this predictor applicable
       * to all sub classes.
       */
      public String toString() {
        return Utils.doubleToString(m_coefficient, 12, 4) + " * ";
      }

      /**
       * Abstract add method. Adds this predictor into the sum for the
       * current prediction.
       * 
       * @param preds the prediction computed so far. For regression, it is a
       * single element array; for classification it is a multi-element array
       * @param input the input instance's values
       */
      public abstract void add(double[] preds, double[] input);
    }

    /**
     * Inner class for a numeric predictor
     */
    protected class NumericPredictor extends Predictor {
      /**
       * For serialization
       */
      private static final long serialVersionUID = -4335075205696648273L;
      
      /** The exponent*/
      protected double m_exponent = 1.0;

      /**
       * Constructs a NumericPredictor.
       * 
       * @param predictor the Element holding the predictor
       * @param miningSchema the mining schema as an Instances object
       * @throws Exception if something goes wrong while constructing this
       * predictor
       */
      protected NumericPredictor(Element predictor, 
                              Instances miningSchema) throws Exception {
        super(predictor, miningSchema);
        
        String exponent = predictor.getAttribute("exponent");
        if (exponent.length() > 0) {
          m_exponent = Double.parseDouble(exponent);
        }
      }

      /**
       * Return a textual description of this predictor.
       */
      public String toString() {
        String output = super.toString();
        output += m_name;
        if (m_exponent > 1.0 || m_exponent < 1.0) {
          output += "^" + Utils.doubleToString(m_exponent, 4);
        }
        return output;
      }

      /**
       * Adds this predictor into the sum for the
       * current prediction.
       * 
       * @param preds the prediction computed so far. For regression, it is a
       * single element array; for classification it is a multi-element array
       * @param input the input instance's values
       */
      public void add(double[] preds, double[] input) {
        if (m_targetCategory == -1) {
          preds[0] += m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent);
        } else {
          preds[m_targetCategory] += 
            m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent);
        }
      }
    }

    /**
     * Inner class encapsulating a categorical predictor.
     */
    protected class CategoricalPredictor extends Predictor {
      
      /**For serialization */
      private static final long serialVersionUID = 3077920125549906819L;
      
      /** The attribute value for this predictor */
      protected String m_valueName;
      
      /** The index of the attribute value for this predictor */
      protected int m_valueIndex = -1;

      /**
       * Constructs a CategoricalPredictor.
       * 
       * @param predictor the Element containing the predictor
       * @param miningSchema the mining schema as an Instances object
       * @throws Exception if something goes wrong while constructing
       * this predictor
       */
      protected CategoricalPredictor(Element predictor,
                                  Instances miningSchema) throws Exception {
        super(predictor, miningSchema);
        
        String valName = predictor.getAttribute("value");
        if (valName.length() == 0) {
          throw new Exception("[CategoricalPredictor] attribute value not specified!");
        }
        
        m_valueName = valName;

        Attribute att = miningSchema.attribute(m_miningSchemaAttIndex);
        if (att.isString()) {
          // means that there were no Value elements defined in the 
          // data dictionary (and hence the mining schema).
          // We add our value here.
          att.addStringValue(m_valueName);
        }
        m_valueIndex = att.indexOfValue(m_valueName);
        /*        for (int i = 0; i < att.numValues(); i++) {
          if (att.value(i).equals(m_valueName)) {
            m_valueIndex = i;
          }
          }*/

        if (m_valueIndex == -1) {
          throw new Exception("[CategoricalPredictor] unable to find value "
                              + m_valueName + " in mining schema attribute "
                              + att.name());
        }
      }

      /**
       * Return a textual description of this predictor.
       */
      public String toString() {
        String output = super.toString();
        output += m_name + "=" + m_valueName;
        return output;
      }

      /**
       * Adds this predictor into the sum for the
       * current prediction.
       * 
       * @param preds the prediction computed so far. For regression, it is a
       * single element array; for classification it is a multi-element array
       * @param input the input instance's values
       */
      public void add(double[] preds, double[] input) {
        
        // if the value is equal to the one in the input then add the coefficient
        if (m_valueIndex == (int)input[m_miningSchemaAttIndex]) {
          if (m_targetCategory == -1) {
            preds[0] += m_coefficient;
          } else {
            preds[m_targetCategory] += m_coefficient;
          }
        }
      }
    }

    /**
     * Inner class to handle PredictorTerms.
     */
    protected class PredictorTerm implements Serializable {

      /** For serialization */
      private static final long serialVersionUID = 5493100145890252757L;

      /** The coefficient for this predictor term */
      protected double m_coefficient = 1.0;

      /** the indexes of the terms to be multiplied */
      protected int[] m_indexes;

      /** The names of the terms (attributes) to be multiplied */
      protected String[] m_fieldNames;

      /**
       * Construct a new PredictorTerm.
       * 
       * @param predictorTerm the Element describing the predictor term
       * @param miningSchema the mining schema as an Instances object
       * @throws Exception if something goes wrong while constructing this
       * predictor term
       */
      protected PredictorTerm(Element predictorTerm, 
                              Instances miningSchema) throws Exception {

        String coeff = predictorTerm.getAttribute("coefficient");
        if (coeff != null && coeff.length() > 0) {
          try {
            m_coefficient = Double.parseDouble(coeff);
          } catch (IllegalArgumentException ex) {
            throw new Exception("[PredictorTerm] unable to parse coefficient");
          }
        }
        
        NodeList fields = predictorTerm.getElementsByTagName("FieldRef");
        if (fields.getLength() > 0) {
          m_indexes = new int[fields.getLength()];
          m_fieldNames = new String[fields.getLength()];

          for (int i = 0; i < fields.getLength(); i++) {
            Node fieldRef = fields.item(i);
            if (fieldRef.getNodeType() == Node.ELEMENT_NODE) {
              String fieldName = ((Element)fieldRef).getAttribute("field");
              if (fieldName != null && fieldName.length() > 0) {
                boolean found = false;
                // look for this field in the mining schema
                for (int j = 0; j < miningSchema.numAttributes(); j++) {
                  if (miningSchema.attribute(j).name().equals(fieldName)) {
                    
                    // all referenced fields MUST be numeric
                    if (!miningSchema.attribute(j).isNumeric()) {
                      throw new Exception("[PredictorTerm] field is not continuous: "
                                          + fieldName);
                    }
                    found = true;
                    m_indexes[i] = j;
                    m_fieldNames[i] = fieldName;
                    break;
                  }
                }
                if (!found) {
                  throw new Exception("[PredictorTerm] Unable to find field "
                                      + fieldName + " in mining schema!");
                }
              }
            }
          }
        }
      }

      /**
       * Return a textual description of this predictor term.
       */
      public String toString() {
        StringBuffer result = new StringBuffer();
        result.append("(" + Utils.doubleToString(m_coefficient, 12, 4));
        for (int i = 0; i < m_fieldNames.length; i++) {
          result.append(" * " + m_fieldNames[i]);
        }
        result.append(")");
        return result.toString();
      }

      /**
       * Adds this predictor term into the sum for the
       * current prediction.
       * 
       * @param preds the prediction computed so far. For regression, it is a
       * single element array; for classification it is a multi-element array
       * @param input the input instance's values
       */
      public void add(double[] preds, double[] input) {
        int indx = 0;
        if (m_targetCategory != -1) {
          indx = m_targetCategory;
        }

        double result = m_coefficient;
        for (int i = 0; i < m_indexes.length; i++) {
          result *= input[m_indexes[i]];
        }
        preds[indx] += result;
      }
    }
    
    /** Constant for regression model type */
    public static final int REGRESSION = 0;
    
    /** Constant for classification model type */
    public static final int CLASSIFICATION = 1;

    /** The type of function - regression or classification */
    protected int m_functionType = REGRESSION;
    
    /** The mining schema */
    protected MiningSchema m_miningSchema;
        
    /** The intercept */
    protected double m_intercept = 0.0;
    
    /** classification only */
    protected int m_targetCategory = -1;

    /** Numeric and categorical predictors */
    protected ArrayList m_predictors = 
      new ArrayList();

    /** Interaction terms */
    protected ArrayList m_predictorTerms =
      new ArrayList();

    /**
     * Return a textual description of this RegressionTable.
     */
    public String toString() {
      Instances miningSchema = m_miningSchema.getFieldsAsInstances();
      StringBuffer temp = new StringBuffer();
      temp.append("Regression table:\n");
      temp.append(miningSchema.classAttribute().name());
      if (m_functionType == CLASSIFICATION) {
        temp.append("=" + miningSchema.
                    classAttribute().value(m_targetCategory));
      }

      temp.append(" =\n\n");
      
      // do the predictors
      for (int i = 0; i < m_predictors.size(); i++) {
        temp.append(m_predictors.get(i).toString() + " +\n");
      }
      
      // do the predictor terms
      for (int i = 0; i < m_predictorTerms.size(); i++) {
        temp.append(m_predictorTerms.get(i).toString() + " +\n");
      }

      temp.append(Utils.doubleToString(m_intercept, 12, 4));
      temp.append("\n\n");

      return temp.toString();
    }

    /**
     * Construct a regression table from an Element
     *
     * @param table the table to encapsulate
     * @param functionType the type of function 
     * (regression or classification)
     * to use
     * @param mSchema the mining schema
     * @throws Exception if there is a problem while constructing
     * this regression table
     */
    protected RegressionTable(Element table, 
                           int functionType,
                           MiningSchema mSchema) throws Exception {

      m_miningSchema = mSchema;
      m_functionType = functionType;

      Instances miningSchema = m_miningSchema.getFieldsAsInstances();

      // get the intercept
      String intercept = table.getAttribute("intercept");
      if (intercept.length() > 0) {
        m_intercept = Double.parseDouble(intercept);
      }

      // get the target category (if classification)
      if (m_functionType == CLASSIFICATION) {
        // target category MUST be defined
        String targetCat = table.getAttribute("targetCategory");
        if (targetCat.length() > 0) {
          Attribute classA = miningSchema.classAttribute();
          for (int i = 0; i < classA.numValues(); i++) {
            if (classA.value(i).equals(targetCat)) {
              m_targetCategory = i;
            }
          }
        } 
        if (m_targetCategory == -1) {
          throw new Exception("[RegressionTable] No target categories defined for classification");
        }
      }

      // read all the numeric predictors
      NodeList numericPs = table.getElementsByTagName("NumericPredictor");
      for (int i = 0; i < numericPs.getLength(); i++) {
        Node nP = numericPs.item(i);
        if (nP.getNodeType() == Node.ELEMENT_NODE) {
          NumericPredictor numP = new NumericPredictor((Element)nP, miningSchema);
          m_predictors.add(numP);
        }
      }

      // read all the categorical predictors
      NodeList categoricalPs = table.getElementsByTagName("CategoricalPredictor");
      for (int i = 0; i < categoricalPs.getLength(); i++) {
        Node cP = categoricalPs.item(i);
        if (cP.getNodeType() == Node.ELEMENT_NODE) {
          CategoricalPredictor catP = new CategoricalPredictor((Element)cP, miningSchema);
          m_predictors.add(catP);
        }
      }

      // read all the PredictorTerms
      NodeList predictorTerms = table.getElementsByTagName("PredictorTerm");
      for (int i = 0; i < predictorTerms.getLength(); i++) {
        Node pT = predictorTerms.item(i);
        PredictorTerm predT = new PredictorTerm((Element)pT, miningSchema);
        m_predictorTerms.add(predT);
      }
    }

    public void predict(double[] preds, double[] input) {
      if (m_targetCategory == -1) {
        preds[0] = m_intercept;
      } else {
        preds[m_targetCategory] = m_intercept;
      }
      
      // add the predictors
      for (int i = 0; i < m_predictors.size(); i++) {
        Predictor p = m_predictors.get(i);
        p.add(preds, input);
      }

      // add the PredictorTerms
      for (int i = 0; i < m_predictorTerms.size(); i++) {
        PredictorTerm pt = m_predictorTerms.get(i);
        pt.add(preds, input);
      }
    }
  }

  /** Description of the algorithm */
  protected String m_algorithmName;

  /** The regression tables for this regression */
  protected RegressionTable[] m_regressionTables;

  /**
   * Enum for the normalization methods.
   */
  enum Normalization {
    NONE, SIMPLEMAX, SOFTMAX, LOGIT, PROBIT, CLOGLOG,
      EXP, LOGLOG, CAUCHIT}

  /** The normalization to use */
  protected Normalization m_normalizationMethod = Normalization.NONE;

  /**
   * Constructs a new PMML Regression.
   * 
   * @param model the Element containing the regression model
   * @param dataDictionary the data dictionary as an Instances object
   * @param miningSchema the mining schema
   * @throws Exception if there is a problem constructing this Regression
   */
  public Regression(Element model, Instances dataDictionary,
                    MiningSchema miningSchema) throws Exception {
    super(dataDictionary, miningSchema);
    
    int functionType = RegressionTable.REGRESSION;

    // determine function name first
    String fName = model.getAttribute("functionName");
    
    if (fName.equals("regression")) {
      functionType = RegressionTable.REGRESSION;
    } else if (fName.equals("classification")) {
      functionType = RegressionTable.CLASSIFICATION;
    } else {
      throw new Exception("[PMML Regression] Function name not defined in pmml!");
    }

    // do we have an algorithm name?
    String algName = model.getAttribute("algorithmName");
    if (algName != null && algName.length() > 0) {
      m_algorithmName = algName;
    }

    // determine normalization method (if any)
    m_normalizationMethod = determineNormalization(model);

    setUpRegressionTables(model, functionType);

    // convert any string attributes in the mining schema
    //miningSchema.convertStringAttsToNominal();
  }

  /**
   * Create all the RegressionTables for this model.
   * 
   * @param model the Element holding this regression model
   * @param functionType the type of function (regression or
   * classification)
   * @throws Exception if there is a problem setting up the regression
   * tables
   */
  private void setUpRegressionTables(Element model,
                                     int functionType) throws Exception {
    NodeList tableList = model.getElementsByTagName("RegressionTable");
    
    if (tableList.getLength() == 0) {
      throw new Exception("[Regression] no regression tables defined!");
    }

    m_regressionTables = new RegressionTable[tableList.getLength()];
    
    for (int i = 0; i < tableList.getLength(); i++) {
      Node table = tableList.item(i);
      if (table.getNodeType() == Node.ELEMENT_NODE) {
        RegressionTable tempRTable = 
          new RegressionTable((Element)table, 
                              functionType, 
                              m_miningSchema);
        m_regressionTables[i] = tempRTable;
      }
    }
  }

  /**
   * Return the type of normalization used for this regression
   * 
   * @param model the Element holding the model
   * @return the normalization used in this regression
   */
  private static Normalization determineNormalization(Element model) {
    
    Normalization normMethod = Normalization.NONE;

    String normName = model.getAttribute("normalizationMethod");
    if (normName.equals("simplemax")) {
      normMethod = Normalization.SIMPLEMAX;
    } else if (normName.equals("softmax")) {
      normMethod = Normalization.SOFTMAX;
    } else if (normName.equals("logit")) {
      normMethod = Normalization.LOGIT;
    } else if (normName.equals("probit")) {
      normMethod = Normalization.PROBIT;
    } else if (normName.equals("cloglog")) {
      normMethod = Normalization.CLOGLOG;
    } else if (normName.equals("exp")) {
      normMethod = Normalization.EXP;
    } else if (normName.equals("loglog")) {
      normMethod = Normalization.LOGLOG;
    } else if (normName.equals("cauchit")) {
      normMethod = Normalization.CAUCHIT;
    } 
    return normMethod;
  }

  /**
   * Return a textual description of this Regression model.
   */
  public String toString() {
    StringBuffer temp = new StringBuffer();
    temp.append("PMML version " + getPMMLVersion());
    if (!getCreatorApplication().equals("?")) {
      temp.append("\nApplication: " + getCreatorApplication());
    }
    if (m_algorithmName != null) {
      temp.append("\nPMML Model: " + m_algorithmName);
    }
    temp.append("\n\n");
    temp.append(m_miningSchema);

    for (RegressionTable table : m_regressionTables) {
      temp.append(table);
    }
    
    if (m_normalizationMethod != Normalization.NONE) {
      temp.append("Normalization: " + m_normalizationMethod);
    }
    temp.append("\n");

    return temp.toString();
  }

  /**                                                                                                             
   * Classifies the given test instance. The instance has to belong to a                                          
   * dataset when it's being classified.                                                          
   *                                                                                                              
   * @param inst the instance to be classified                                                                
   * @return the predicted most likely class for the instance or                                                  
   * Utils.missingValue() if no prediction is made                                                             
   * @exception Exception if an error occurred during the prediction                                              
   */
  public double[] distributionForInstance(Instance inst) throws Exception {
    if (!m_initialized) {
      mapToMiningSchema(inst.dataset());
    }
    double[] preds = null;
    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
      preds = new double[1];
    } else {
      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
    }

    // create an array of doubles that holds values from the incoming
    // instance; in order of the fields in the mining schema. We will
    // also handle missing values and outliers here.
    //    System.err.println(inst);
    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);

    // scan for missing values. If there are still missing values after instanceToSchema(),
    // then missing value handling has been deferred to the PMML scheme. The specification
    // (Regression PMML 3.2) seems to contradict itself with regards to classification and categorical
    // variables. In one place it states that if a categorical variable is missing then 
    // variable_name=value is 0 for any value. Further down in the document it states: "if
    // one or more of the y_j cannot be evaluated because the value in one of the referenced
    // fields is missing, then the following formulas (for computing p_j) do not apply. In
    // that case the predictions are defined by the priorProbability values in the Target
    // element".

    // In this implementation we will default to information in the Target element (default
    // value for numeric prediction and prior probabilities for classification). If there is
    // no Target element defined, then an Exception is thrown.

    boolean hasMissing = false;
    for (int i = 0; i < incoming.length; i++) {
      if (i != m_miningSchema.getFieldsAsInstances().classIndex() && 
          Utils.isMissingValue(incoming[i])) {
        hasMissing = true;
        break;
      }
    }

    if (hasMissing) {
      if (!m_miningSchema.hasTargetMetaData()) {
        String message = "[Regression] WARNING: Instance to predict has missing value(s) but "
          + "there is no missing value handling meta data and no "
          + "prior probabilities/default value to fall back to. No "
          + "prediction will be made (" 
          + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() ||
              m_miningSchema.getFieldsAsInstances().classAttribute().isString())
              ? "zero probabilities output)."
              : "NaN output).");
        if (m_log == null) {
          System.err.println(message);
        } else {
          m_log.logMessage(message);
        }
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
          preds[0] = Utils.missingValue();
        }
        return preds;
      } else {
        // use prior probablilities/default value
        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
          preds[0] = targetData.getDefaultValue();
        } else {
          Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
          for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
            preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
          }
        }
        return preds;
      }
    } else {
      // loop through the RegressionTables
      for (int i = 0; i < m_regressionTables.length; i++) {
        m_regressionTables[i].predict(preds, incoming);
      }
 
      // Now apply the normalization
      switch (m_normalizationMethod) {
      case NONE:
        // nothing to be done
        break;
      case SIMPLEMAX:
        Utils.normalize(preds);
        break;
      case SOFTMAX:
        for (int i = 0; i < preds.length; i++) {
          preds[i] = Math.exp(preds[i]);
        }
        if (preds.length == 1) {
          // hack for those models that do binary logistic regression as
          // a numeric prediction model
          preds[0] = preds[0] / (preds[0] + 1.0);
        } else {
          Utils.normalize(preds);
        }
        break;
      case LOGIT:
        for (int i = 0; i < preds.length; i++) {
          preds[i] = 1.0 / (1.0 + Math.exp(-preds[i]));
        }
        Utils.normalize(preds);
        break;
      case PROBIT:
        for (int i = 0; i < preds.length; i++) {
          preds[i] = weka.core.matrix.Maths.pnorm(preds[i]);
        }
        Utils.normalize(preds);
        break;
      case CLOGLOG:
        // note this is supposed to be illegal for regression
        for (int i = 0; i < preds.length; i++) {
          preds[i] = 1.0 - Math.exp(-Math.exp(-preds[i]));
        }
        Utils.normalize(preds);
        break;
      case EXP:
        for (int i = 0; i < preds.length; i++) {
          preds[i] = Math.exp(preds[i]);
        }
        Utils.normalize(preds);
        break;
      case LOGLOG:
        // note this is supposed to be illegal for regression
        for (int i = 0; i < preds.length; i++) {
          preds[i] = Math.exp(-Math.exp(-preds[i]));
        }
        Utils.normalize(preds);
        break;
      case CAUCHIT:
        for (int i = 0; i < preds.length; i++) {
          preds[i] = 0.5 + (1.0 / Math.PI) * Math.atan(preds[i]);
        }
        Utils.normalize(preds);
        break;
      default:
          throw new Exception("[Regression] unknown normalization method");
      }

      // If there is a Target defined, and this is a numeric prediction problem,
      // then apply any min, max, rescaling etc.
      if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()
          && m_miningSchema.hasTargetMetaData()) {
        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
        preds[0] = targetData.applyMinMaxRescaleCast(preds[0]);
      }
    }
    
    return preds;
  }

  /* (non-Javadoc)
   * @see weka.core.RevisionHandler#getRevision()
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 8034 $");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy