All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.pmml.consumer.TreeModel Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    TreeModel.java
 *    Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;

import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.core.Attribute;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.Array;
import weka.core.pmml.MiningSchema;

/**
 * Class implementing import of PMML TreeModel. Can be used as a Weka classifier
 * for prediction (buildClassifier() raises and Exception).
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 10153 $;
 */
public class TreeModel extends PMMLClassifier implements Drawable {

  /**
   * For serialization
   */
  private static final long serialVersionUID = -2065158088298753129L;

  /**
   * Inner class representing the ScoreDistribution element
   */
  static class ScoreDistribution implements Serializable {

    /**
     * For serialization
     */
    private static final long serialVersionUID = -123506262094299933L;

    /** The class label for this distribution element */
    private final String m_classLabel;

    /** The index of the class label */
    private int m_classLabelIndex = -1;

    /** The count for this label */
    private final double m_recordCount;

    /** The optional confidence value */
    private double m_confidence = Utils.missingValue();

    /**
     * Construct a ScoreDistribution entry
     * 
     * @param scoreE the node containing the distribution
     * @param miningSchema the mining schema
     * @param baseCount the number of records at the node that owns this
     *          distribution entry
     * @throws Exception if something goes wrong
     */
    protected ScoreDistribution(Element scoreE, MiningSchema miningSchema,
      double baseCount) throws Exception {
      // get the label
      m_classLabel = scoreE.getAttribute("value");
      Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute();
      if (classAtt == null || classAtt.indexOfValue(m_classLabel) < 0) {
        throw new Exception(
          "[ScoreDistribution] class attribute not set or class value "
            + m_classLabel + " not found!");
      }

      m_classLabelIndex = classAtt.indexOfValue(m_classLabel);

      // get the frequency
      String recordC = scoreE.getAttribute("recordCount");
      m_recordCount = Double.parseDouble(recordC);

      // get the optional confidence
      String confidence = scoreE.getAttribute("confidence");
      if (confidence != null && confidence.length() > 0) {
        m_confidence = Double.parseDouble(confidence);
      } else if (!Utils.isMissingValue(baseCount) && baseCount > 0) {
        m_confidence = m_recordCount / baseCount;
      }
    }

    /**
     * Backfit confidence value (does nothing if the confidence value is already
     * set).
     * 
     * @param baseCount the total number of records (supplied either explicitly
     *          from the node that owns this distribution entry or most likely
     *          computed from summing the recordCounts of all the distribution
     *          entries in the distribution that owns this entry).
     */
    void deriveConfidenceValue(double baseCount) {
      if (Utils.isMissingValue(m_confidence)
        && !Utils.isMissingValue(baseCount) && baseCount > 0) {
        m_confidence = m_recordCount / baseCount;
      }
    }

    String getClassLabel() {
      return m_classLabel;
    }

    int getClassLabelIndex() {
      return m_classLabelIndex;
    }

    double getRecordCount() {
      return m_recordCount;
    }

    double getConfidence() {
      return m_confidence;
    }

    @Override
    public String toString() {
      return m_classLabel + ": " + m_recordCount + " ("
        + Utils.doubleToString(m_confidence, 2) + ") ";
    }
  }

  /**
   * Base class for Predicates
   */
  static abstract class Predicate implements Serializable {

    /**
     * For serialization
     */
    private static final long serialVersionUID = 1035344165452733887L;

    enum Eval {
      TRUE, FALSE, UNKNOWN;
    }

    /**
     * Evaluate this predicate.
     * 
     * @param input the input vector of attribute and derived field values.
     * 
     * @return the evaluation status of this predicate.
     */
    abstract Eval evaluate(double[] input);

    protected String toString(int level, boolean cr) {
      return toString(level);
    }

    protected String toString(int level) {
      StringBuffer text = new StringBuffer();
      for (int j = 0; j < level; j++) {
        text.append("|   ");
      }

      return text.append(toString()).toString();
    }

    static Eval booleanToEval(boolean missing, boolean result) {
      if (missing) {
        return Eval.UNKNOWN;
      } else if (result) {
        return Eval.TRUE;
      } else {
        return Eval.FALSE;
      }
    }

    /**
     * Factory method to return the appropriate predicate for a given node in
     * the tree.
     * 
     * @param nodeE the XML node encapsulating the tree node.
     * @param miningSchema the mining schema in use
     * @return a Predicate
     * @throws Exception of something goes wrong.
     */
    static Predicate getPredicate(Element nodeE, MiningSchema miningSchema)
      throws Exception {

      Predicate result = null;
      NodeList children = nodeE.getChildNodes();
      for (int i = 0; i < children.getLength(); i++) {
        Node child = children.item(i);
        if (child.getNodeType() == Node.ELEMENT_NODE) {
          String tagName = ((Element) child).getTagName();
          if (tagName.equals("True")) {
            result = new True();
            break;
          } else if (tagName.equals("False")) {
            result = new False();
            break;
          } else if (tagName.equals("SimplePredicate")) {
            result = new SimplePredicate((Element) child, miningSchema);
            break;
          } else if (tagName.equals("CompoundPredicate")) {
            result = new CompoundPredicate((Element) child, miningSchema);
            break;
          } else if (tagName.equals("SimpleSetPredicate")) {
            result = new SimpleSetPredicate((Element) child, miningSchema);
            break;
          }
        }
      }

      if (result == null) {
        throw new Exception(
          "[Predicate] unknown or missing predicate type in node");
      }

      return result;
    }
  }

  /**
   * Simple True Predicate
   */
  static class True extends Predicate {

    /**
     * For serialization
     */
    private static final long serialVersionUID = 1817942234610531627L;

    @Override
    public Predicate.Eval evaluate(double[] input) {
      return Predicate.Eval.TRUE;
    }

    @Override
    public String toString() {
      return "True: ";
    }
  }

  /**
   * Simple False Predicate
   */
  static class False extends Predicate {

    /**
     * For serialization
     */
    private static final long serialVersionUID = -3647261386442860365L;

    @Override
    public Predicate.Eval evaluate(double[] input) {
      return Predicate.Eval.FALSE;
    }

    @Override
    public String toString() {
      return "False: ";
    }
  }

  /**
   * Class representing the SimplePredicate
   */
  static class SimplePredicate extends Predicate {

    /**
     * For serialization
     */
    private static final long serialVersionUID = -6156684285069327400L;

    enum Operator {
      EQUAL("equal") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            weka.core.Utils.eq(input[fieldIndex], value));
        }

        @Override
        String shortName() {
          return "==";
        }
      },
      NOTEQUAL("notEqual") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            (input[fieldIndex] != value));
        }

        @Override
        String shortName() {
          return "!=";
        }
      },
      LESSTHAN("lessThan") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            (input[fieldIndex] < value));
        }

        @Override
        String shortName() {
          return "<";
        }
      },
      LESSOREQUAL("lessOrEqual") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            (input[fieldIndex] <= value));
        }

        @Override
        String shortName() {
          return "<=";
        }
      },
      GREATERTHAN("greaterThan") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            (input[fieldIndex] > value));
        }

        @Override
        String shortName() {
          return ">";
        }
      },
      GREATEROREQUAL("greaterOrEqual") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            (input[fieldIndex] >= value));
        }

        @Override
        String shortName() {
          return ">=";
        }
      },
      ISMISSING("isMissing") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(false,
            Utils.isMissingValue(input[fieldIndex]));
        }

        @Override
        String shortName() {
          return toString();
        }
      },
      ISNOTMISSING("isNotMissing") {
        @Override
        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
          return Predicate.booleanToEval(false,
            !Utils.isMissingValue(input[fieldIndex]));
        }

        @Override
        String shortName() {
          return toString();
        }
      };

      abstract Predicate.Eval evaluate(double[] input, double value,
        int fieldIndex);

      abstract String shortName();

      private final String m_stringVal;

      Operator(String name) {
        m_stringVal = name;
      }

      @Override
      public String toString() {
        return m_stringVal;
      }
    }

    /** the field that we are comparing against */
    int m_fieldIndex = -1;

    /** the name of the field */
    String m_fieldName;

    /** true if the field is nominal */
    boolean m_isNominal;

    /** the value as a string (if nominal) */
    String m_nominalValue;

    /**
     * the value to compare against (if nominal it holds the index of the value)
     */
    double m_value;

    /** the operator to use */
    Operator m_operator;

    public SimplePredicate(Element simpleP, MiningSchema miningSchema) throws Exception {
      Instances totalStructure = miningSchema.getFieldsAsInstances();

      // get the field name and set up the index
      String fieldS = simpleP.getAttribute("field");
      Attribute att = totalStructure.attribute(fieldS);
      if (att == null) {
        throw new Exception("[SimplePredicate] unable to find field " + fieldS
          + " in the incoming instance structure!");
      }

      // find the index
      int index = -1;
      for (int i = 0; i < totalStructure.numAttributes(); i++) {
        if (totalStructure.attribute(i).name().equals(fieldS)) {
          index = i;
          m_fieldName = totalStructure.attribute(i).name();
          break;
        }
      }
      m_fieldIndex = index;
      if (att.isNominal()) {
        m_isNominal = true;
      }

      // get the operator
      String oppS = simpleP.getAttribute("operator");
      for (Operator o : Operator.values()) {
        if (o.toString().equals(oppS)) {
          m_operator = o;
          break;
        }
      }

      if (m_operator != Operator.ISMISSING
        && m_operator != Operator.ISNOTMISSING) {
        String valueS = simpleP.getAttribute("value");
        if (att.isNumeric()) {
          m_value = Double.parseDouble(valueS);
        } else {
          m_nominalValue = valueS;
          m_value = att.indexOfValue(valueS);
          if (m_value < 0) {
            throw new Exception("[SimplePredicate] can't find value " + valueS
              + " in nominal " + "attribute " + att.name());
          }
        }
      }
    }

    @Override
    public Predicate.Eval evaluate(double[] input) {
      return m_operator.evaluate(input, m_value, m_fieldIndex);
    }

    @Override
    public String toString() {
      StringBuffer temp = new StringBuffer();

      temp.append(m_fieldName + " " + m_operator.shortName());
      if (m_operator != Operator.ISMISSING
        && m_operator != Operator.ISNOTMISSING) {
        temp.append(" " + ((m_isNominal) ? m_nominalValue : "" + m_value));
      }

      return temp.toString();
    }
  }

  /**
   * Class representing the CompoundPredicate
   */
  static class CompoundPredicate extends Predicate {

    /**
     * For serialization
     */
    private static final long serialVersionUID = -3332091529764559077L;

    enum BooleanOperator {
      OR("or") {
        @Override
        Predicate.Eval evaluate(ArrayList constituents,
          double[] input) {
          Predicate.Eval currentStatus = Predicate.Eval.FALSE;
          for (Predicate p : constituents) {
            Predicate.Eval temp = p.evaluate(input);
            if (temp == Predicate.Eval.TRUE) {
              currentStatus = temp;
              break;
            } else if (temp == Predicate.Eval.UNKNOWN) {
              currentStatus = temp;
            }
          }
          return currentStatus;
        }
      },
      AND("and") {
        @Override
        Predicate.Eval evaluate(ArrayList constituents,
          double[] input) {
          Predicate.Eval currentStatus = Predicate.Eval.TRUE;
          for (Predicate p : constituents) {
            Predicate.Eval temp = p.evaluate(input);
            if (temp == Predicate.Eval.FALSE) {
              currentStatus = temp;
              break;
            } else if (temp == Predicate.Eval.UNKNOWN) {
              currentStatus = temp;
            }
          }
          return currentStatus;
        }
      },
      XOR("xor") {
        @Override
        Predicate.Eval evaluate(ArrayList constituents,
          double[] input) {
          Predicate.Eval currentStatus = constituents.get(0).evaluate(input);
          if (currentStatus != Predicate.Eval.UNKNOWN) {
            for (int i = 1; i < constituents.size(); i++) {
              Predicate.Eval temp = constituents.get(i).evaluate(input);
              if (temp == Predicate.Eval.UNKNOWN) {
                currentStatus = temp;
                break;
              } else {
                if (currentStatus != temp) {
                  currentStatus = Predicate.Eval.TRUE;
                } else {
                  currentStatus = Predicate.Eval.FALSE;
                }
              }
            }
          }
          return currentStatus;
        }
      },
      SURROGATE("surrogate") {
        @Override
        Predicate.Eval evaluate(ArrayList constituents,
          double[] input) {
          Predicate.Eval currentStatus = constituents.get(0).evaluate(input);

          int i = 1;
          while (currentStatus == Predicate.Eval.UNKNOWN) {
            currentStatus = constituents.get(i).evaluate(input);
          }

          // return false if all our surrogates evaluate to unknown.
          if (currentStatus == Predicate.Eval.UNKNOWN) {
            currentStatus = Predicate.Eval.FALSE;
          }

          return currentStatus;
        }
      };

      abstract Predicate.Eval evaluate(ArrayList constituents,
        double[] input);

      private final String m_stringVal;

      BooleanOperator(String name) {
        m_stringVal = name;
      }

      @Override
      public String toString() {
        return m_stringVal;
      }
    }

    /** the constituent Predicates */
    ArrayList m_components = new ArrayList();

    /** the boolean operator */
    BooleanOperator m_booleanOperator;

    public CompoundPredicate(Element compoundP, MiningSchema miningSchema) throws Exception {
      // Instances totalStructure = miningSchema.getFieldsAsInstances();

      String booleanOpp = compoundP.getAttribute("booleanOperator");
      for (BooleanOperator b : BooleanOperator.values()) {
        if (b.toString().equals(booleanOpp)) {
          m_booleanOperator = b;
        }
      }

      // now get all the encapsulated operators
      NodeList children = compoundP.getChildNodes();
      for (int i = 0; i < children.getLength(); i++) {
        Node child = children.item(i);
        if (child.getNodeType() == Node.ELEMENT_NODE) {
          String tagName = ((Element) child).getTagName();
          if (tagName.equals("True")) {
            m_components.add(new True());
          } else if (tagName.equals("False")) {
            m_components.add(new False());
          } else if (tagName.equals("SimplePredicate")) {
            m_components
              .add(new SimplePredicate((Element) child, miningSchema));
          } else if (tagName.equals("CompoundPredicate")) {
            m_components.add(new CompoundPredicate((Element) child,
              miningSchema));
          } else {
            m_components.add(new SimpleSetPredicate((Element) child,
              miningSchema));
          }
        }
      }
    }

    @Override
    public Predicate.Eval evaluate(double[] input) {
      return m_booleanOperator.evaluate(m_components, input);
    }

    @Override
    public String toString() {
      return toString(0, false);
    }

    @Override
    public String toString(int level, boolean cr) {
      StringBuffer text = new StringBuffer();
      for (int j = 0; j < level; j++) {
        text.append("|   ");
      }

      text.append("Compound [" + m_booleanOperator.toString() + "]");
      if (cr) {
        text.append("\\n");
      } else {
        text.append("\n");
      }
      for (int i = 0; i < m_components.size(); i++) {
        text.append(m_components.get(i).toString(level, cr).replace(":", ""));
        if (i != m_components.size() - 1) {
          if (cr) {
            text.append("\\n");
          } else {
            text.append("\n");
          }
        }
      }

      return text.toString();
    }
  }

  /**
   * Class representing the SimpleSetPredicate
   */
  static class SimpleSetPredicate extends Predicate {

    /**
     * For serialization
     */
    private static final long serialVersionUID = -2711995401345708486L;

    enum BooleanOperator {
      IS_IN("isIn") {
        @Override
        Predicate.Eval evaluate(double[] input, int fieldIndex, Array set,
          Attribute nominalLookup) {
          if (set.getType() == Array.ArrayType.STRING) {
            String value = "";
            if (!Utils.isMissingValue(input[fieldIndex])) {
              value = nominalLookup.value((int) input[fieldIndex]);
            }
            return Predicate.booleanToEval(
              Utils.isMissingValue(input[fieldIndex]), set.contains(value));
          } else if (set.getType() == Array.ArrayType.NUM
            || set.getType() == Array.ArrayType.REAL) {
            return Predicate.booleanToEval(
              Utils.isMissingValue(input[fieldIndex]),
              set.contains(input[fieldIndex]));
          }
          return Predicate.booleanToEval(
            Utils.isMissingValue(input[fieldIndex]),
            set.contains((int) input[fieldIndex]));
        }
      },
      IS_NOT_IN("isNotIn") {
        @Override
        Predicate.Eval evaluate(double[] input, int fieldIndex, Array set,
          Attribute nominalLookup) {
          Predicate.Eval result = IS_IN.evaluate(input, fieldIndex, set,
            nominalLookup);
          if (result == Predicate.Eval.FALSE) {
            result = Predicate.Eval.TRUE;
          } else if (result == Predicate.Eval.TRUE) {
            result = Predicate.Eval.FALSE;
          }

          return result;
        }
      };

      abstract Predicate.Eval evaluate(double[] input, int fieldIndex,
        Array set, Attribute nominalLookup);

      private final String m_stringVal;

      BooleanOperator(String name) {
        m_stringVal = name;
      }

      @Override
      public String toString() {
        return m_stringVal;
      }
    }

    /** the field to reference */
    int m_fieldIndex = -1;

    /** the name of the field */
    String m_fieldName;

    /** is the referenced field nominal? */
    boolean m_isNominal = false;

    /** the attribute to lookup nominal values from */
    Attribute m_nominalLookup;

    /** the boolean operator */
    BooleanOperator m_operator = BooleanOperator.IS_IN;

    /** the array holding the set of values */
    Array m_set;

    public SimpleSetPredicate(Element setP, MiningSchema miningSchema) throws Exception {
      Instances totalStructure = miningSchema.getFieldsAsInstances();

      // get the field name and set up the index
      String fieldS = setP.getAttribute("field");
      Attribute att = totalStructure.attribute(fieldS);
      if (att == null) {
        throw new Exception("[SimplePredicate] unable to find field " + fieldS
          + " in the incoming instance structure!");
      }

      // find the index
      int index = -1;
      for (int i = 0; i < totalStructure.numAttributes(); i++) {
        if (totalStructure.attribute(i).name().equals(fieldS)) {
          index = i;
          m_fieldName = totalStructure.attribute(i).name();
          break;
        }
      }
      m_fieldIndex = index;
      if (att.isNominal()) {
        m_isNominal = true;
        m_nominalLookup = att;
      }

      // need to scan the children looking for an array type
      NodeList children = setP.getChildNodes();
      for (int i = 0; i < children.getLength(); i++) {
        Node child = children.item(i);
        if (child.getNodeType() == Node.ELEMENT_NODE) {
          if (Array.isArray((Element) child)) {
            // found the array
            m_set = Array.create((Element) child);
            break;
          }
        }
      }

      if (m_set == null) {
        throw new Exception("[SimpleSetPredictate] couldn't find an "
          + "array containing the set values!");
      }

      // check array type against field type
      if (m_set.getType() == Array.ArrayType.STRING && !m_isNominal) {
        throw new Exception("[SimpleSetPredicate] referenced field "
          + totalStructure.attribute(m_fieldIndex).name()
          + " is numeric but array type is string!");
      } else if (m_set.getType() != Array.ArrayType.STRING && m_isNominal) {
        throw new Exception("[SimpleSetPredicate] referenced field "
          + totalStructure.attribute(m_fieldIndex).name()
          + " is nominal but array type is numeric!");
      }
    }

    @Override
    public Predicate.Eval evaluate(double[] input) {
      return m_operator.evaluate(input, m_fieldIndex, m_set, m_nominalLookup);
    }

    @Override
    public String toString() {
      StringBuffer temp = new StringBuffer();

      temp.append(m_fieldName + " " + m_operator.toString() + " ");
      temp.append(m_set.toString());

      return temp.toString();
    }
  }

  /**
   * Class for handling a Node in the tree
   */
  class TreeNode implements Serializable {
    // TODO: perhaps implement a class called Statistics that contains
    // Partitions?

    /**
     * For serialization
     */
    private static final long serialVersionUID = 3011062274167063699L;

    /** ID for this node */
    private String m_ID = "" + this.hashCode();

    /** The score as a string */
    private String m_scoreString;

    /** The index of this predicted value (if class is nominal) */
    private int m_scoreIndex = -1;

    /** The score as a number (if target is numeric) */
    private double m_scoreNumeric = Utils.missingValue();

    /** The record count at this node (if defined) */
    private double m_recordCount = Utils.missingValue();

    /** The ID of the default child (if applicable) */
    private String m_defaultChildID;

    /** Holds the node of the default child (if defined) */
    private TreeNode m_defaultChild;

    /** The distribution for labels (classification) */
    private final ArrayList m_scoreDistributions = new ArrayList();

    /** The predicate for this node */
    private final Predicate m_predicate;

    /** The children of this node */
    private final ArrayList m_childNodes = new ArrayList();

    protected TreeNode(Element nodeE, MiningSchema miningSchema) throws Exception {
      Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute();

      // get the ID
      String id = nodeE.getAttribute("id");
      if (id != null && id.length() > 0) {
        m_ID = id;
      }

      // get the score for this node
      String scoreS = nodeE.getAttribute("score");
      if (scoreS != null && scoreS.length() > 0) {
        m_scoreString = scoreS;

        // try to parse as a number in case we
        // are part of a regression tree
        if (classAtt.isNumeric()) {
          try {
            m_scoreNumeric = Double.parseDouble(scoreS);
          } catch (NumberFormatException ex) {
            throw new Exception(
              "[TreeNode] class is numeric but unable to parse score "
                + m_scoreString + " as a number!");
          }
        } else {
          // store the index of this class value
          m_scoreIndex = classAtt.indexOfValue(m_scoreString);

          if (m_scoreIndex < 0) {
            throw new Exception(
              "[TreeNode] can't find match for predicted value "
                + m_scoreString + " in class attribute!");
          }
        }
      }

      // get the record count if defined
      String recordC = nodeE.getAttribute("recordCount");
      if (recordC != null && recordC.length() > 0) {
        m_recordCount = Double.parseDouble(recordC);
      }

      // get the default child (if applicable)
      String defaultC = nodeE.getAttribute("defaultChild");
      if (defaultC != null && defaultC.length() > 0) {
        m_defaultChildID = defaultC;
      }

      // TODO: Embedded model (once we support model composition)

      // Now get the ScoreDistributions (if any and mining function
      // is classification) at this level
      if (m_functionType == MiningFunction.CLASSIFICATION) {
        getScoreDistributions(nodeE, miningSchema);
      }

      // Now get the Predicate
      m_predicate = Predicate.getPredicate(nodeE, miningSchema);

      // Now get the child Node(s)
      getChildNodes(nodeE, miningSchema);

      // If we have a default child specified, find it now
      if (m_defaultChildID != null) {
        for (TreeNode t : m_childNodes) {
          if (t.getID().equals(m_defaultChildID)) {
            m_defaultChild = t;
            break;
          }
        }
      }
    }

    private void getChildNodes(Element nodeE, MiningSchema miningSchema)
      throws Exception {
      NodeList children = nodeE.getChildNodes();

      for (int i = 0; i < children.getLength(); i++) {
        Node child = children.item(i);
        if (child.getNodeType() == Node.ELEMENT_NODE) {
          String tagName = ((Element) child).getTagName();
          if (tagName.equals("Node")) {
            TreeNode tempN = new TreeNode((Element) child, miningSchema);
            m_childNodes.add(tempN);
          }
        }
      }
    }

    private void getScoreDistributions(Element nodeE, MiningSchema miningSchema)
      throws Exception {

      NodeList scoreChildren = nodeE.getChildNodes();
      for (int i = 0; i < scoreChildren.getLength(); i++) {
        Node child = scoreChildren.item(i);
        if (child.getNodeType() == Node.ELEMENT_NODE) {
          String tagName = ((Element) child).getTagName();
          if (tagName.equals("ScoreDistribution")) {
            ScoreDistribution newDist = new ScoreDistribution((Element) child,
              miningSchema, m_recordCount);
            m_scoreDistributions.add(newDist);
          }
        }
      }

      // backfit the confidence values
      if (Utils.isMissingValue(m_recordCount)) {
        double baseCount = 0;
        for (ScoreDistribution s : m_scoreDistributions) {
          baseCount += s.getRecordCount();
        }

        for (ScoreDistribution s : m_scoreDistributions) {
          s.deriveConfidenceValue(baseCount);
        }
      }
    }

    /**
     * Get the score value as a string.
     * 
     * @return the score value as a String.
     */
    protected String getScore() {
      return m_scoreString;
    }

    /**
     * Get the score value as a number (regression trees only).
     * 
     * @return the score as a number
     */
    protected double getScoreNumeric() {
      return m_scoreNumeric;
    }

    /**
     * Get the ID of this node.
     * 
     * @return the ID of this node.
     */
    protected String getID() {
      return m_ID;
    }

    /**
     * Get the Predicate at this node.
     * 
     * @return the predicate at this node.
     */
    protected Predicate getPredicate() {
      return m_predicate;
    }

    /**
     * Get the record count at this node.
     * 
     * @return the record count at this node.
     */
    protected double getRecordCount() {
      return m_recordCount;
    }

    protected void dumpGraph(StringBuffer text) throws Exception {
      text.append("N" + m_ID + " ");
      if (m_scoreString != null) {
        text.append("[label=\"score=" + m_scoreString);
      }

      if (m_scoreDistributions.size() > 0 && m_childNodes.size() == 0) {
        text.append("\\n");
        for (ScoreDistribution s : m_scoreDistributions) {
          text.append(s + "\\n");
        }
      }

      text.append("\"");

      if (m_childNodes.size() == 0) {
        text.append(" shape=box style=filled");

      }

      text.append("]\n");

      for (TreeNode c : m_childNodes) {
        text.append("N" + m_ID + "->" + "N" + c.getID());
        text.append(" [label=\"" + c.getPredicate().toString(0, true));
        text.append("\"]\n");
        c.dumpGraph(text);
      }
    }

    @Override
    public String toString() {
      StringBuffer text = new StringBuffer();

      // print out the root
      dumpTree(0, text);

      return text.toString();
    }

    protected void dumpTree(int level, StringBuffer text) {
      if (m_childNodes.size() > 0) {

        for (int i = 0; i < m_childNodes.size(); i++) {
          text.append("\n");

          /*
           * for (int j = 0; j < level; j++) { text.append("|   "); }
           */

          // output the predicate for this child node
          TreeNode child = m_childNodes.get(i);
          text.append(child.getPredicate().toString(level, false));

          // process recursively
          child.dumpTree(level + 1, text);
        }
      } else {
        // leaf
        text.append(": ");
        if (!Utils.isMissingValue(m_scoreNumeric)) {
          text.append(m_scoreNumeric);
        } else {
          text.append(m_scoreString + " ");
          if (m_scoreDistributions.size() > 0) {
            text.append("[");
            for (ScoreDistribution s : m_scoreDistributions) {
              text.append(s);
            }
            text.append("]");
          } else {
            text.append(m_scoreString);
          }
        }
      }
    }

    /**
     * Score an incoming instance. Invokes a missing value handling strategy.
     * 
     * @param instance a vector of incoming attribute and derived field values.
     * @param classAtt the class attribute
     * @return a predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] score(double[] instance, Attribute classAtt)
      throws Exception {
      double[] preds = null;

      if (classAtt.isNumeric()) {
        preds = new double[1];
      } else {
        preds = new double[classAtt.numValues()];
      }

      // leaf?
      if (m_childNodes.size() == 0) {
        doLeaf(classAtt, preds);
      } else {
        // process the children
        switch (TreeModel.this.m_missingValueStrategy) {
        case NONE:
          preds = missingValueStrategyNone(instance, classAtt);
          break;
        case LASTPREDICTION:
          preds = missingValueStrategyLastPrediction(instance, classAtt);
          break;
        case DEFAULTCHILD:
          preds = missingValueStrategyDefaultChild(instance, classAtt);
          break;
        default:
          throw new Exception("[TreeModel] not implemented!");
        }
      }

      return preds;
    }

    /**
     * Compute the predictions for a leaf.
     * 
     * @param classAtt the class attribute
     * @param preds an array to hold the predicted probabilities.
     * @throws Exception if something goes wrong.
     */
    protected void doLeaf(Attribute classAtt, double[] preds) throws Exception {
      if (classAtt.isNumeric()) {
        preds[0] = m_scoreNumeric;
      } else {
        if (m_scoreDistributions.size() == 0) {
          preds[m_scoreIndex] = 1.0;
        } else {
          // collect confidences from the score distributions
          for (ScoreDistribution s : m_scoreDistributions) {
            preds[s.getClassLabelIndex()] = s.getConfidence();
          }
        }
      }
    }

    /**
     * Evaluate on the basis of the no true child strategy.
     * 
     * @param classAtt the class attribute.
     * @param preds an array to hold the predicted probabilities.
     * @throws Exception if something goes wrong.
     */
    protected void doNoTrueChild(Attribute classAtt, double[] preds)
      throws Exception {
      if (TreeModel.this.m_noTrueChildStrategy == NoTrueChildStrategy.RETURNNULLPREDICTION) {
        for (int i = 0; i < classAtt.numValues(); i++) {
          preds[i] = Utils.missingValue();
        }
      } else {
        // return the predictions at this node
        doLeaf(classAtt, preds);
      }
    }

    /**
     * Compute predictions and optionally invoke the weighted confidence missing
     * value handling strategy.
     * 
     * @param instance the incoming vector of attribute and derived field
     *          values.
     * @param classAtt the class attribute.
     * @return the predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] missingValueStrategyWeightedConfidence(
      double[] instance, Attribute classAtt) throws Exception {

      if (classAtt.isNumeric()) {
        throw new Exception(
          "[TreeNode] missing value strategy weighted confidence, "
            + "but class is numeric!");
      }

      double[] preds = null;
      TreeNode trueNode = null;
      boolean strategyInvoked = false;
      int nodeCount = 0;

      // look at the evaluation of the child predicates
      for (TreeNode c : m_childNodes) {
        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
          // note the first child to evaluate to true
          if (trueNode == null) {
            trueNode = c;
          }
          nodeCount++;
        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
          strategyInvoked = true;
          nodeCount++;
        }
      }

      if (strategyInvoked) {
        // we expect to combine nodeCount distributions
        double[][] dists = new double[nodeCount][];
        double[] weights = new double[nodeCount];

        // collect the distributions and weights
        int count = 0;
        for (TreeNode c : m_childNodes) {
          if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE
            || c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {

            weights[count] = c.getRecordCount();
            if (Utils.isMissingValue(weights[count])) {
              throw new Exception(
                "[TreeNode] weighted confidence missing value "
                  + "strategy invoked, but no record count defined for node "
                  + c.getID());
            }
            dists[count++] = c.score(instance, classAtt);
          }
        }

        // do the combination
        preds = new double[classAtt.numValues()];
        for (int i = 0; i < classAtt.numValues(); i++) {
          for (int j = 0; j < nodeCount; j++) {
            preds[i] += ((weights[j] / m_recordCount) * dists[j][i]);
          }
        }
      } else {
        if (trueNode != null) {
          preds = trueNode.score(instance, classAtt);
        } else {
          doNoTrueChild(classAtt, preds);
        }
      }

      return preds;
    }

    protected double[] freqCountsForAggNodesStrategy(double[] instance,
      Attribute classAtt) throws Exception {

      double[] counts = new double[classAtt.numValues()];

      if (m_childNodes.size() > 0) {
        // collect the counts
        for (TreeNode c : m_childNodes) {
          if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE
            || c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {

            double[] temp = c.freqCountsForAggNodesStrategy(instance, classAtt);
            for (int i = 0; i < classAtt.numValues(); i++) {
              counts[i] += temp[i];
            }
          }
        }
      } else {
        // process the score distributions
        if (m_scoreDistributions.size() == 0) {
          throw new Exception(
            "[TreeModel] missing value strategy aggregate nodes:"
              + " no score distributions at leaf " + m_ID);
        }
        for (ScoreDistribution s : m_scoreDistributions) {
          counts[s.getClassLabelIndex()] = s.getRecordCount();
        }
      }

      return counts;
    }

    /**
     * Compute predictions and optionally invoke the aggregate nodes missing
     * value handling strategy.
     * 
     * @param instance the incoming vector of attribute and derived field
     *          values.
     * @param classAtt the class attribute.
     * @return the predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] missingValueStrategyAggregateNodes(double[] instance,
      Attribute classAtt) throws Exception {

      if (classAtt.isNumeric()) {
        throw new Exception(
          "[TreeNode] missing value strategy aggregate nodes, "
            + "but class is numeric!");
      }

      double[] preds = null;
      TreeNode trueNode = null;
      boolean strategyInvoked = false;
      // look at the evaluation of the child predicates
      for (TreeNode c : m_childNodes) {
        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
          // note the first child to evaluate to true
          if (trueNode == null) {
            trueNode = c;
          }
        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
          strategyInvoked = true;
        }
      }

      if (strategyInvoked) {
        double[] aggregatedCounts = freqCountsForAggNodesStrategy(instance,
          classAtt);

        // normalize
        Utils.normalize(aggregatedCounts);
        preds = aggregatedCounts;
      } else {
        if (trueNode != null) {
          preds = trueNode.score(instance, classAtt);
        } else {
          doNoTrueChild(classAtt, preds);
        }
      }

      return preds;
    }

    /**
     * Compute predictions and optionally invoke the default child missing value
     * handling strategy.
     * 
     * @param instance the incoming vector of attribute and derived field
     *          values.
     * @param classAtt the class attribute.
     * @return the predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] missingValueStrategyDefaultChild(double[] instance,
      Attribute classAtt) throws Exception {

      double[] preds = null;
      boolean strategyInvoked = false;

      // look for a child whose predicate evaluates to TRUE
      for (TreeNode c : m_childNodes) {
        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
          preds = c.score(instance, classAtt);
          break;
        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
          strategyInvoked = true;
        }
      }

      // no true child found
      if (preds == null) {
        if (!strategyInvoked) {
          doNoTrueChild(classAtt, preds);
        } else {
          // do the strategy

          // NOTE: we don't actually implement the missing value penalty since
          // we always return a full probability distribution.
          if (m_defaultChild != null) {
            preds = m_defaultChild.score(instance, classAtt);
          } else {
            throw new Exception(
              "[TreeNode] missing value strategy is defaultChild, but "
                + "no default child has been specified in node " + m_ID);
          }
        }
      }

      return preds;
    }

    /**
     * Compute predictions and optionally invoke the last prediction missing
     * value handling strategy.
     * 
     * @param instance the incoming vector of attribute and derived field
     *          values.
     * @param classAtt the class attribute.
     * @return the predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] missingValueStrategyLastPrediction(double[] instance,
      Attribute classAtt) throws Exception {

      double[] preds = null;
      boolean strategyInvoked = false;

      // look for a child whose predicate evaluates to TRUE
      for (TreeNode c : m_childNodes) {
        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
          preds = c.score(instance, classAtt);
          break;
        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
          strategyInvoked = true;
        }
      }

      // no true child found
      if (preds == null) {
        preds = new double[classAtt.numValues()];
        if (!strategyInvoked) {
          // no true child
          doNoTrueChild(classAtt, preds);
        } else {
          // do the strategy
          doLeaf(classAtt, preds);
        }
      }

      return preds;
    }

    /**
     * Compute predictions and optionally invoke the null prediction missing
     * value handling strategy.
     * 
     * @param instance the incoming vector of attribute and derived field
     *          values.
     * @param classAtt the class attribute.
     * @return the predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] missingValueStrategyNullPrediction(double[] instance,
      Attribute classAtt) throws Exception {

      double[] preds = null;
      boolean strategyInvoked = false;

      // look for a child whose predicate evaluates to TRUE
      for (TreeNode c : m_childNodes) {
        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
          preds = c.score(instance, classAtt);
          break;
        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
          strategyInvoked = true;
        }
      }

      // no true child found
      if (preds == null) {
        preds = new double[classAtt.numValues()];
        if (!strategyInvoked) {
          doNoTrueChild(classAtt, preds);
        } else {
          // do the strategy
          for (int i = 0; i < classAtt.numValues(); i++) {
            preds[i] = Utils.missingValue();
          }
        }
      }

      return preds;
    }

    /**
     * Compute predictions and optionally invoke the "none" missing value
     * handling strategy (invokes no true child).
     * 
     * @param instance the incoming vector of attribute and derived field
     *          values.
     * @param classAtt the class attribute.
     * @return the predicted probability distribution.
     * @throws Exception if something goes wrong.
     */
    protected double[] missingValueStrategyNone(double[] instance,
      Attribute classAtt) throws Exception {

      double[] preds = null;

      // look for a child whose predicate evaluates to TRUE
      for (TreeNode c : m_childNodes) {
        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
          preds = c.score(instance, classAtt);
          break;
        }
      }

      if (preds == null) {
        preds = new double[classAtt.numValues()];

        // no true child strategy
        doNoTrueChild(classAtt, preds);
      }

      return preds;
    }
  }

  /**
   * Enumerated type for the mining function
   */
  enum MiningFunction {
    CLASSIFICATION, REGRESSION;
  }

  enum MissingValueStrategy {
    LASTPREDICTION("lastPrediction"), NULLPREDICTION("nullPrediction"), DEFAULTCHILD(
      "defaultChild"), WEIGHTEDCONFIDENCE("weightedConfidence"), AGGREGATENODES(
      "aggregateNodes"), NONE("none");

    private final String m_stringVal;

    MissingValueStrategy(String name) {
      m_stringVal = name;
    }

    @Override
    public String toString() {
      return m_stringVal;
    }
  }

  enum NoTrueChildStrategy {
    RETURNNULLPREDICTION("returnNullPrediction"), RETURNLASTPREDICTION(
      "returnLastPrediction");

    private final String m_stringVal;

    NoTrueChildStrategy(String name) {
      m_stringVal = name;
    }

    @Override
    public String toString() {
      return m_stringVal;
    }
  }

  enum SplitCharacteristic {
    BINARYSPLIT("binarySplit"), MULTISPLIT("multiSplit");

    private final String m_stringVal;

    SplitCharacteristic(String name) {
      m_stringVal = name;
    }

    @Override
    public String toString() {
      return m_stringVal;
    }
  }

  /** The mining function */
  protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION;

  /** The missing value strategy */
  protected MissingValueStrategy m_missingValueStrategy = MissingValueStrategy.NONE;

  /**
   * The missing value penalty (if defined). We don't actually make use of this
   * since we always return full probability distributions.
   */
  protected double m_missingValuePenalty = Utils.missingValue();

  /** The no true child strategy to use */
  protected NoTrueChildStrategy m_noTrueChildStrategy = NoTrueChildStrategy.RETURNNULLPREDICTION;

  /** The splitting type */
  protected SplitCharacteristic m_splitCharacteristic = SplitCharacteristic.MULTISPLIT;

  /** The root of the tree */
  protected TreeNode m_root;

  public TreeModel(Element model, Instances dataDictionary,
    MiningSchema miningSchema) throws Exception {

    super(dataDictionary, miningSchema);

    if (!getPMMLVersion().equals("3.2")) {
      // TODO: might have to throw an exception and only support 3.2
    }

    String fn = model.getAttribute("functionName");
    if (fn.equals("regression")) {
      m_functionType = MiningFunction.REGRESSION;
    }

    // get the missing value strategy (if any)
    String missingVS = model.getAttribute("missingValueStrategy");
    if (missingVS != null && missingVS.length() > 0) {
      for (MissingValueStrategy m : MissingValueStrategy.values()) {
        if (m.toString().equals(missingVS)) {
          m_missingValueStrategy = m;
          break;
        }
      }
    }

    // get the missing value penalty (if any)
    String missingP = model.getAttribute("missingValuePenalty");
    if (missingP != null && missingP.length() > 0) {
      // try to parse as a number
      try {
        m_missingValuePenalty = Double.parseDouble(missingP);
      } catch (NumberFormatException ex) {
        System.err.println("[TreeModel] WARNING: "
          + "couldn't parse supplied missingValuePenalty as a number");
      }
    }

    String splitC = model.getAttribute("splitCharacteristic");

    if (splitC != null && splitC.length() > 0) {
      for (SplitCharacteristic s : SplitCharacteristic.values()) {
        if (s.toString().equals(splitC)) {
          m_splitCharacteristic = s;
          break;
        }
      }
    }

    // find the root node of the tree
    NodeList children = model.getChildNodes();
    for (int i = 0; i < children.getLength(); i++) {
      Node child = children.item(i);
      if (child.getNodeType() == Node.ELEMENT_NODE) {
        String tagName = ((Element) child).getTagName();
        if (tagName.equals("Node")) {
          m_root = new TreeNode((Element) child, miningSchema);
          break;
        }
      }
    }
  }

  /**
   * Classifies the given test instance. The instance has to belong to a dataset
   * when it's being classified.
   * 
   * @param inst the instance to be classified
   * @return the predicted most likely class for the instance or
   *         Utils.missingValue() if no prediction is made
   * @exception Exception if an error occurred during the prediction
   */
  @Override
  public double[] distributionForInstance(Instance inst) throws Exception {
    if (!m_initialized) {
      mapToMiningSchema(inst.dataset());
    }
    double[] preds = null;

    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
      preds = new double[1];
    } else {
      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute()
        .numValues()];
    }

    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);

    preds = m_root.score(incoming, m_miningSchema.getFieldsAsInstances()
      .classAttribute());

    return preds;
  }

  @Override
  public String toString() {
    StringBuffer temp = new StringBuffer();

    temp.append("PMML version " + getPMMLVersion());
    if (!getCreatorApplication().equals("?")) {
      temp.append("\nApplication: " + getCreatorApplication());
    }
    temp.append("\nPMML Model: TreeModel");
    temp.append("\n\n");
    temp.append(m_miningSchema);

    temp.append("Split-type: " + m_splitCharacteristic + "\n");
    temp.append("No true child strategy: " + m_noTrueChildStrategy + "\n");
    temp.append("Missing value strategy: " + m_missingValueStrategy + "\n");

    temp.append(m_root.toString());

    return temp.toString();
  }

  @Override
  public String graph() throws Exception {
    StringBuffer text = new StringBuffer();
    text.append("digraph PMMTree {\n");

    m_root.dumpGraph(text);

    text.append("}\n");

    return text.toString();
  }

  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 10153 $");
  }

  @Override
  public int graphType() {
    return Drawable.TREE;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy