All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.core.pmml.PMMLFactory Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    PMMLFactory.java
 *    Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.core.pmml;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;

import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.pmml.consumer.GeneralRegression;
import weka.classifiers.pmml.consumer.NeuralNetwork;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.classifiers.pmml.consumer.Regression;
import weka.classifiers.pmml.consumer.RuleSetModel;
import weka.classifiers.pmml.consumer.SupportVectorMachineModel;
import weka.classifiers.pmml.consumer.TreeModel;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.gui.Logger;

/**
 * This class is a factory class for reading/writing PMML models
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 10203 $
 */
public class PMMLFactory {

  /** for serialization */

  protected enum ModelType {
    UNKNOWN_MODEL("unknown"), REGRESSION_MODEL("Regression"), GENERAL_REGRESSION_MODEL(
      "GeneralRegression"), NEURAL_NETWORK_MODEL("NeuralNetwork"), TREE_MODEL(
      "TreeModel"), RULESET_MODEL("RuleSetModel"), SVM_MODEL(
      "SupportVectorMachineModel");

    private final String m_stringVal;

    ModelType(String name) {
      m_stringVal = name;
    }

    @Override
    public String toString() {
      return m_stringVal;
    }
  }

  /**
   * Read and return a PMML model.
   * 
   * @param filename the name of the file to read from
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(String filename) throws Exception {
    return getPMMLModel(filename, null);
  }

  /**
   * Read and return a PMML model.
   * 
   * @param file a File to read from
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(File file) throws Exception {
    return getPMMLModel(file, null);
  }

  /**
   * Read and return a PMML model.
   * 
   * @param stream the InputStream to read from
   * @return a PMML model
   * @throws Exception if there is a problem while reading from the stream
   */
  public static PMMLModel getPMMLModel(InputStream stream) throws Exception {
    return getPMMLModel(stream, null);
  }

  /**
   * Read and return a PMML model.
   * 
   * @param filename the name of the file to read from
   * @param log the logging object to use (or null if none is to be used)
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(String filename, Logger log)
    throws Exception {
    return getPMMLModel(new File(filename), log);
  }

  /**
   * Read and return a PMML model.
   * 
   * @param file a File to read from
   * @param log the logging object to use (or null if none is to be used)
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(File file, Logger log) throws Exception {
    return getPMMLModel(new BufferedInputStream(new FileInputStream(file)), log);
  }

  private static boolean isPMML(Document doc) {
    NodeList tempL = doc.getElementsByTagName("PMML");
    if (tempL.getLength() == 0) {
      return false;
    }

    return true;
  }

  /**
   * Read and return a PMML model.
   * 
   * @param stream the InputStream to read from
   * @param log the logging object to use (or null if none is to be used)
   * @return a PMML model
   * @throws Exception if there is a problem while reading from the stream
   */
  public static PMMLModel getPMMLModel(InputStream stream, Logger log)
    throws Exception {
    DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
    DocumentBuilder db = dbf.newDocumentBuilder();
    Document doc = db.parse(stream);
    stream.close();
    doc.getDocumentElement().normalize();
    if (!isPMML(doc)) {
      throw new IllegalArgumentException(
        "[PMMLFactory] Source is not a PMML file!!");
    }

    // System.out.println("Root element " +
    // doc.getDocumentElement().getNodeName());

    Instances dataDictionary = getDataDictionaryAsInstances(doc);
    TransformationDictionary transDict = getTransformationDictionary(doc,
      dataDictionary);

    ModelType modelType = getModelType(doc);
    if (modelType == ModelType.UNKNOWN_MODEL) {
      throw new Exception("Unsupported PMML model type");
    }
    Element model = getModelElement(doc, modelType);

    // Construct mining schema and meta data
    MiningSchema ms = new MiningSchema(model, dataDictionary, transDict);

    // System.out.println(ms);
    // System.exit(1);
    // Instances miningSchema = getMiningSchemaAsInstances(model,
    // dataDictionary);
    PMMLModel theModel = getModelInstance(doc, modelType, model,
      dataDictionary, ms);
    if (log != null) {
      theModel.setLog(log);
    }
    return theModel;
  }

  /**
   * Get the transformation dictionary (if there is one).
   * 
   * @param doc the Document containing the PMML model
   * @param dataDictionary the data dictionary as an Instances object
   * @return the transformation dictionary or null if there is none defined in
   *         the Document
   * @throws Exception if there is a problem getting the transformation
   *           dictionary
   */
  protected static TransformationDictionary getTransformationDictionary(
    Document doc, Instances dataDictionary) throws Exception {
    TransformationDictionary transDict = null;

    NodeList transL = doc.getElementsByTagName("TransformationDictionary");
    // should be of size 0 or 1
    if (transL.getLength() > 0) {
      Node transNode = transL.item(0);
      if (transNode.getNodeType() == Node.ELEMENT_NODE) {
        transDict = new TransformationDictionary((Element) transNode,
          dataDictionary);
      }
    }

    return transDict;
  }

  /**
   * Serialize a PMMLModel object that encapsulates a PMML model
   * 
   * @param model the PMMLModel to serialize
   * @param filename the name of the file to save to
   * @throws Exception if something goes wrong during serialization
   */
  public static void serializePMMLModel(PMMLModel model, String filename)
    throws Exception {
    serializePMMLModel(model, new File(filename));
  }

  /**
   * Serialize a PMMLModel object that encapsulates a PMML model
   * 
   * @param model the PMMLModel to serialize
   * @param file the File to save to
   * @throws Exception if something goes wrong during serialization
   */
  public static void serializePMMLModel(PMMLModel model, File file)
    throws Exception {
    serializePMMLModel(model, new BufferedOutputStream(new FileOutputStream(
      file)));
  }

  /**
   * Serialize a PMMLModel object that encapsulates a PMML model
   * 
   * @param model the PMMLModel to serialize
   * @param stream the OutputStream to serialize to
   * @throws Exception if something goes wrong during serialization
   */
  public static void serializePMMLModel(PMMLModel model, OutputStream stream)
    throws Exception {
    ObjectOutputStream oo = new ObjectOutputStream(stream);
    Instances header = model.getMiningSchema().getFieldsAsInstances();
    oo.writeObject(header);
    oo.writeObject(model);
    oo.flush();
    oo.close();
  }

  /**
   * Get an instance of a PMMLModel from the supplied Document
   * 
   * @param doc the Document holding the pmml
   * @param modelType the type of model
   * @param model the Element encapsulating the model part of the Document
   * @param dataDictionary the data dictionary as an Instances object
   * @param miningSchema the mining schema
   * @return a PMMLModel object
   * @throws Exception if there is a problem constructing the model or if the
   *           model type is not supported
   */
  protected static PMMLModel getModelInstance(Document doc,
    ModelType modelType, Element model, Instances dataDictionary,
    MiningSchema miningSchema) throws Exception {
    PMMLModel pmmlM = null;
    switch (modelType) {
    case REGRESSION_MODEL:
      pmmlM = new Regression(model, dataDictionary, miningSchema);
      // System.out.println(pmmlM);
      break;
    case GENERAL_REGRESSION_MODEL:
      pmmlM = new GeneralRegression(model, dataDictionary, miningSchema);
      // System.out.println(pmmlM);
      break;
    case NEURAL_NETWORK_MODEL:
      pmmlM = new NeuralNetwork(model, dataDictionary, miningSchema);
      break;
    case TREE_MODEL:
      pmmlM = new TreeModel(model, dataDictionary, miningSchema);
      break;
    case RULESET_MODEL:
      pmmlM = new RuleSetModel(model, dataDictionary, miningSchema);
      break;
    case SVM_MODEL:
      pmmlM = new SupportVectorMachineModel(model, dataDictionary, miningSchema);
      break;
    default:
      throw new Exception("[PMMLFactory] Unknown model type!!");
    }
    pmmlM.setPMMLVersion(doc);
    pmmlM.setCreatorApplication(doc);
    return pmmlM;
  }

  /**
   * Get the type of model
   * 
   * @param doc the Document encapsulating the pmml
   * @return the type of model
   */
  protected static ModelType getModelType(Document doc) {
    NodeList temp = doc.getElementsByTagName("RegressionModel");
    if (temp.getLength() > 0) {
      return ModelType.REGRESSION_MODEL;
    }

    temp = doc.getElementsByTagName("GeneralRegressionModel");
    if (temp.getLength() > 0) {
      return ModelType.GENERAL_REGRESSION_MODEL;
    }

    temp = doc.getElementsByTagName("NeuralNetwork");
    if (temp.getLength() > 0) {
      return ModelType.NEURAL_NETWORK_MODEL;
    }

    temp = doc.getElementsByTagName("TreeModel");
    if (temp.getLength() > 0) {
      return ModelType.TREE_MODEL;
    }

    temp = doc.getElementsByTagName("RuleSetModel");
    if (temp.getLength() > 0) {
      return ModelType.RULESET_MODEL;
    }

    temp = doc.getElementsByTagName("SupportVectorMachineModel");
    if (temp.getLength() > 0) {
      return ModelType.SVM_MODEL;
    }

    return ModelType.UNKNOWN_MODEL;
  }

  /**
   * Get the Element that contains the pmml model
   * 
   * @param doc the Document encapsulating the pmml
   * @param modelType the type of model
   * @throws Exception if the model type is unsupported/unknown
   */
  protected static Element getModelElement(Document doc, ModelType modelType)
    throws Exception {
    NodeList temp = null;
    Element model = null;
    switch (modelType) {
    case REGRESSION_MODEL:
      temp = doc.getElementsByTagName("RegressionModel");
      break;
    case GENERAL_REGRESSION_MODEL:
      temp = doc.getElementsByTagName("GeneralRegressionModel");
      break;
    case NEURAL_NETWORK_MODEL:
      temp = doc.getElementsByTagName("NeuralNetwork");
      break;
    case TREE_MODEL:
      temp = doc.getElementsByTagName("TreeModel");
      break;
    case RULESET_MODEL:
      temp = doc.getElementsByTagName("RuleSetModel");
      break;
    case SVM_MODEL:
      temp = doc.getElementsByTagName("SupportVectorMachineModel");
      break;
    default:
      throw new Exception("[PMMLFactory] unknown/unsupported model type.");
    }

    if (temp != null && temp.getLength() > 0) {
      Node modelNode = temp.item(0);
      if (modelNode.getNodeType() == Node.ELEMENT_NODE) {
        model = (Element) modelNode;
      }
    }

    return model;
  }

  /**
   * Get the mining schema as an Instances object
   * 
   * @param model the Element containing the pmml model
   * @param dataDictionary the data dictionary as an Instances object
   * @return the mining schema as an Instances object
   * @throws Exception if something goes wrong during reading the mining schema
   * @deprecated Use the MiningSchema class instead
   */
  @Deprecated
  protected static Instances getMiningSchemaAsInstances(Element model,
    Instances dataDictionary) throws Exception {
    ArrayList attInfo = new ArrayList();
    NodeList fieldList = model.getElementsByTagName("MiningField");
    int classIndex = -1;
    int addedCount = 0;
    for (int i = 0; i < fieldList.getLength(); i++) {
      Node miningField = fieldList.item(i);
      if (miningField.getNodeType() == Node.ELEMENT_NODE) {
        Element miningFieldEl = (Element) miningField;
        String name = miningFieldEl.getAttribute("name");
        String usage = miningFieldEl.getAttribute("usageType");
        // TO-DO: also missing value replacement etc.

        // find this attribute in the dataDictionary
        Attribute miningAtt = dataDictionary.attribute(name);
        if (miningAtt != null) {
          if (usage.length() == 0 || usage.equals("active")
            || usage.equals("predicted")) {
            attInfo.add(miningAtt);
            addedCount++;
          }
          if (usage.equals("predicted")) {
            classIndex = addedCount - 1;
          }
        } else {
          throw new Exception("Can't find mining field: " + name
            + " in the data dictionary.");
        }
      }
    }

    Instances insts = new Instances("miningSchema", attInfo, 0);
    // System.out.println(insts);
    if (classIndex != -1) {
      insts.setClassIndex(classIndex);
    }

    return insts;
  }

  /**
   * Get the data dictionary as an Instances object
   * 
   * @param doc the Document encapsulating the pmml
   * @return the data dictionary as an Instances object
   * @throws Exception if there are fields that are not continuous, ordinal or
   *           categorical in the data dictionary
   */
  protected static Instances getDataDictionaryAsInstances(Document doc)
    throws Exception {

    // TO-DO: definition of missing values (see below)

    ArrayList attInfo = new ArrayList();
    NodeList dataDictionary = doc.getElementsByTagName("DataField");
    for (int i = 0; i < dataDictionary.getLength(); i++) {
      Node dataField = dataDictionary.item(i);
      if (dataField.getNodeType() == Node.ELEMENT_NODE) {
        Element dataFieldEl = (Element) dataField;
        String name = dataFieldEl.getAttribute("name");
        String type = dataFieldEl.getAttribute("optype");
        Attribute tempAtt = null;
        if (name != null && type != null) {
          if (type.equals("continuous")) {
            tempAtt = new Attribute(name);
          } else if (type.equals("categorical") || type.equals("ordinal")) {
            NodeList valueList = dataFieldEl.getElementsByTagName("Value");
            if (valueList == null || valueList.getLength() == 0) {
              // assume that categorical values will be revealed in the actual
              // model.
              // Create a string attribute for now
              ArrayList nullV = null;
              tempAtt = new Attribute(name, nullV);
            } else {
              // add the values (if defined as "valid")
              ArrayList valueVector = new ArrayList();
              for (int j = 0; j < valueList.getLength(); j++) {
                Node val = valueList.item(j);
                if (val.getNodeType() == Node.ELEMENT_NODE) {
                  // property is optional (default value is "valid")
                  String property = ((Element) val).getAttribute("property");
                  if (property == null || property.length() == 0
                    || property.equals("valid")) {
                    String value = ((Element) val).getAttribute("value");
                    valueVector.add(value);
                  } else {
                    // Just ignore invalid or missing value definitions for
                    // now...
                    // TO-DO: implement Value meta data with missing/invalid
                    // value defs.
                  }
                }
              }
              tempAtt = new Attribute(name, valueVector);
            }
          } else {
            throw new Exception("[PMMLFactory] can't handle " + type
              + "attributes.");
          }
          attInfo.add(tempAtt);
        }
      }
    }

    // TO-DO: check whether certain values are declared to represent
    // missing or invalid values (applies to both categorical and continuous
    // attributes

    // create the Instances structure
    Instances insts = new Instances("dataDictionary", attInfo, 0);
    // System.out.println(insts);

    return insts;
  }

  public static String applyClassifier(PMMLModel model, Instances test)
    throws Exception {
    StringBuffer buff = new StringBuffer();
    if (!(model instanceof PMMLClassifier)) {
      throw new Exception("PMML model is not a classifier!");
    }

    double[] preds = null;
    PMMLClassifier classifier = (PMMLClassifier) model;
    for (int i = 0; i < test.numInstances(); i++) {
      buff.append("Actual: ");
      Instance temp = test.instance(i);
      if (temp.classAttribute().isNumeric()) {
        buff.append(temp.value(temp.classIndex()) + " ");
      } else {
        buff.append(temp.classAttribute().value(
          (int) temp.value(temp.classIndex()))
          + " ");
      }
      preds = classifier.distributionForInstance(temp);
      buff.append(" Predicted: ");
      for (double pred : preds) {
        buff.append("" + pred + " ");
      }
      buff.append("\n");
    }
    return buff.toString();
  }

  private static class PMMLClassifierRunner extends AbstractClassifier {

    /** ID added to avoid warning */
    private static final long serialVersionUID = -3742334356788083347L;

    @Override
    public double[] distributionForInstance(Instance test) throws Exception {
      throw new Exception("Don't call this method!!");
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
      throw new Exception("Don't call this method!!");
    }

    @Override
    public String getRevision() {
      return weka.core.RevisionUtils.extract("$Revision: 10203 $");
    }

    public void evaluatePMMLClassifier(String[] options) {
      runClassifier(this, options);
    }
  }

  public static void main(String[] args) {
    try {
      String[] optionsTmp = new String[args.length];
      for (int i = 0; i < args.length; i++) {
        optionsTmp[i] = args[i];
      }
      String pmmlFile = Utils.getOption('l', optionsTmp);
      if (pmmlFile.length() == 0) {
        throw new Exception(
          "[PMMLFactory] must specify a PMML file using the -l option.");
      }
      // see if it is supported before going any further
      getPMMLModel(pmmlFile, null);

      PMMLClassifierRunner pcr = new PMMLClassifierRunner();
      pcr.evaluatePMMLClassifier(args);

      /*
       * PMMLModel model = getPMMLModel(args[0], null);
       * System.out.println(model); if (args.length == 2) { // load an arff file
       * Instances testData = new Instances(new java.io.BufferedReader(new
       * java.io.FileReader(args[1]))); Instances miningSchemaI =
       * model.getMiningSchema().getFieldsAsInstances(); if
       * (miningSchemaI.classIndex() >= 0) { String className =
       * miningSchemaI.classAttribute().name(); for (int i = 0; i <
       * testData.numAttributes(); i++) { if
       * (testData.attribute(i).name().equals(className)) {
       * testData.setClassIndex(i); System.out.println("Found class " +
       * className + " in test data."); break; } } }
       * System.out.println(applyClassifier(model, testData)); }
       */
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy