All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.core.pmml.PMMLFactory Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    PMMLFactory.java
 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.core.pmml;

import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.BufferedOutputStream;
import java.io.ObjectOutputStream;
import java.io.BufferedReader;
import java.io.FileReader;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.classifiers.Classifier;
import weka.classifiers.pmml.consumer.*;
import weka.core.Instances;
import weka.core.Instance;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Utils;
import weka.gui.Logger;

/**
 * This class is a factory class for reading/writing PMML models
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 5562 $
 */
public class PMMLFactory {

  /** for serialization */
  
  protected enum ModelType {
    UNKNOWN_MODEL ("unknown"),
    REGRESSION_MODEL ("Regression"),
    GENERAL_REGRESSION_MODEL ("GeneralRegression"),
    NEURAL_NETWORK_MODEL ("NeuralNetwork");
    
    private final String m_stringVal;
    
    ModelType(String name) {
      m_stringVal = name;
    }
    
    public String toString() {
      return m_stringVal;
    }
  }

  /**
   * Read and return a PMML model.
   *
   * @param filename the name of the file to read from
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(String filename) throws Exception {
    return getPMMLModel(filename, null);
  }
  
  /**
   * Read and return a PMML model.
   *
   * @param file a File to read from
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(File file) throws Exception {
    return getPMMLModel(file, null);
  }
  
  /**
   * Read and return a PMML model.
   *
   * @param stream the InputStream to read from
   * @return a PMML model
   * @throws Exception if there is a problem while reading from the stream
   */
  public static PMMLModel getPMMLModel(InputStream stream) throws Exception {
    return getPMMLModel(stream, null);
  }
  
  /**
   * Read and return a PMML model.
   *
   * @param filename the name of the file to read from
   * @param log the logging object to use (or null if none is to be used)
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(String filename, Logger log) throws Exception {
    return getPMMLModel(new File(filename), log);
  }

  /**
   * Read and return a PMML model.
   *
   * @param file a File to read from
   * @param log the logging object to use (or null if none is to be used)
   * @return a PMML model
   * @throws Exception if there is a problem while reading the file
   */
  public static PMMLModel getPMMLModel(File file, Logger log) throws Exception {
    return getPMMLModel(new BufferedInputStream(new FileInputStream(file)), log);
  }
  
  private static boolean isPMML(Document doc) {
    NodeList tempL = doc.getElementsByTagName("PMML");
    if (tempL.getLength() == 0) {
      return false;
    }
    
    return true;
  }

  /**
   * Read and return a PMML model.
   *
   * @param stream the InputStream to read from
   * @param log the logging object to use (or null if none is to be used)
   * @returns a PMML model
   * @throws Exception if there is a problem while reading from the stream
   */
  public static PMMLModel getPMMLModel(InputStream stream, Logger log) throws Exception {
    DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
    DocumentBuilder db = dbf.newDocumentBuilder();
    Document doc = db.parse(stream);
    stream.close();
    doc.getDocumentElement().normalize();
    if (!isPMML(doc)) {
      throw new IllegalArgumentException("[PMMLFactory] Source is not a PMML file!!");
    }
    
    //    System.out.println("Root element " + doc.getDocumentElement().getNodeName());

    Instances dataDictionary = getDataDictionaryAsInstances(doc);
    TransformationDictionary transDict = getTransformationDictionary(doc, dataDictionary);
    
    ModelType modelType = getModelType(doc);
    if (modelType == ModelType.UNKNOWN_MODEL) {
      throw new Exception("Unsupported PMML model type");
    }
    Element model = getModelElement(doc, modelType);

    // Construct mining schema and meta data
    MiningSchema ms = new MiningSchema(model, dataDictionary, transDict);

    //System.out.println(ms);
    //System.exit(1);
    //    Instances miningSchema = getMiningSchemaAsInstances(model, dataDictionary);
    PMMLModel theModel = getModelInstance(doc, modelType, model, dataDictionary, ms);
    if (log != null) {
      theModel.setLog(log);
    }
    return theModel;
  }
  
  /**
   * Get the transformation dictionary (if there is one).
   * 
   * @param doc the Document containing the PMML model
   * @param dataDictionary the data dictionary as an Instances object
   * @return the transformation dictionary or null if there is none defined in
   * the Document
   * @throws Exception if there is a problem getting the transformation
   * dictionary
   */
  protected static TransformationDictionary getTransformationDictionary(Document doc, 
      Instances dataDictionary) throws Exception {
    TransformationDictionary transDict = null;
    
    NodeList transL = doc.getElementsByTagName("TransformationDictionary");
    // should be of size 0 or 1
    if (transL.getLength() > 0) {
      Node transNode = transL.item(0);
      if (transNode.getNodeType() == Node.ELEMENT_NODE) {
        transDict = new TransformationDictionary((Element)transNode, dataDictionary);
      }
    }
    
    return transDict;
  }

  /**
   * Serialize a PMMLModel object that encapsulates a PMML model
   *
   * @param model the PMMLModel to serialize
   * @param filename the name of the file to save to
   * @throws Exception if something goes wrong during serialization
   */
  public static void serializePMMLModel(PMMLModel model, String filename) 
    throws Exception {
    serializePMMLModel(model, new File(filename));
  }

  /**
   * Serialize a PMMLModel object that encapsulates a PMML model
   *
   * @param model the PMMLModel to serialize
   * @param file the File to save to
   * @throws Exception if something goes wrong during serialization
   */
  public static void serializePMMLModel(PMMLModel model, File file)
    throws Exception {
    serializePMMLModel(model, new BufferedOutputStream(new FileOutputStream(file)));
  }

  /**
   * Serialize a PMMLModel object that encapsulates a PMML model
   *
   * @param model the PMMLModel to serialize
   * @param stream the OutputStream to serialize to
   * @throws Exception if something goes wrong during serialization
   */
  public static void serializePMMLModel(PMMLModel model, OutputStream stream)
    throws Exception {
    ObjectOutputStream oo = new ObjectOutputStream(stream);
    Instances header = model.getMiningSchema().getFieldsAsInstances();
    oo.writeObject(header);
    oo.writeObject(model);
    oo.flush();
    oo.close();
  }

  /**
   * Get an instance of a PMMLModel from the supplied Document
   *
   * @param doc the Document holding the pmml
   * @param modelType the type of model
   * @param model the Element encapsulating the model part of the Document
   * @param dataDictionary the data dictionary as an Instances object
   * @param miningSchema the mining schema
   * @return a PMMLModel object
   * @throws Exception if there is a problem constructing the model or
   * if the model type is not supported
   */
  protected static PMMLModel getModelInstance(Document doc, 
                                              ModelType modelType, Element model,
                                              Instances dataDictionary,
                                              MiningSchema miningSchema) throws Exception {
    PMMLModel pmmlM = null;
    switch (modelType) {
    case REGRESSION_MODEL:
      pmmlM = new Regression(model, dataDictionary, miningSchema);
      //System.out.println(pmmlM);
      break;
    case GENERAL_REGRESSION_MODEL:
      pmmlM = new GeneralRegression(model, dataDictionary, miningSchema);
      //System.out.println(pmmlM);
      break;
    case NEURAL_NETWORK_MODEL:
      pmmlM = new NeuralNetwork(model, dataDictionary, miningSchema);
      break;
    default:
      throw new Exception("[PMMLFactory] Unknown model type!!");
    }
    pmmlM.setPMMLVersion(doc);
    pmmlM.setCreatorApplication(doc);
    return pmmlM;
  }

  /**
   * Get the type of model
   *
   * @param doc the Document encapsulating the pmml
   * @return the type of model
   */
  protected static ModelType getModelType(Document doc) {
    NodeList temp = doc.getElementsByTagName("RegressionModel");
    if (temp.getLength() > 0) {
      return ModelType.REGRESSION_MODEL;
    }

    temp = doc.getElementsByTagName("GeneralRegressionModel");
    if (temp.getLength() > 0) {
      return ModelType.GENERAL_REGRESSION_MODEL;
    }
    
    temp = doc.getElementsByTagName("NeuralNetwork");
    if (temp.getLength() > 0) {
      return ModelType.NEURAL_NETWORK_MODEL;
    }

    return ModelType.UNKNOWN_MODEL;
  }

  /**
   * Get the Element that contains the pmml model
   *
   * @param doc the Document encapsulating the pmml
   * @param modelType the type of model
   * @throws Exception if the model type is unsupported/unknown
   */
  protected static Element getModelElement(Document doc, ModelType modelType) 
    throws Exception {
    NodeList temp = null;
    Element model = null;
    switch (modelType) {
    case REGRESSION_MODEL:
      temp = doc.getElementsByTagName("RegressionModel");
      break;
    case GENERAL_REGRESSION_MODEL:
      temp = doc.getElementsByTagName("GeneralRegressionModel");
      break;
    case NEURAL_NETWORK_MODEL:
      temp = doc.getElementsByTagName("NeuralNetwork");
      break;
    default:
      throw new Exception("[PMMLFactory] unknown/unsupported model type.");
    }

    if (temp != null && temp.getLength() > 0) {
      Node modelNode = temp.item(0);
      if (modelNode.getNodeType() == Node.ELEMENT_NODE) {
        model = (Element)modelNode;
      }
    }

    return model;
  }

  /**
   * Get the mining schema as an Instances object
   *
   * @param model the Element containing the pmml model
   * @param dataDictionary the data dictionary as an Instances object
   * @return the mining schema as an Instances object
   * @throws Exception if something goes wrong during reading the mining schema
   * @deprecated Use the MiningSchema class instead
   */
  protected static Instances getMiningSchemaAsInstances(Element model,
                                                        Instances dataDictionary) 
    throws Exception {
    FastVector attInfo = new FastVector();
    NodeList fieldList = model.getElementsByTagName("MiningField");
    int classIndex = -1;
    int addedCount = 0;
    for (int i = 0; i < fieldList.getLength(); i++) {
      Node miningField = fieldList.item(i);
      if (miningField.getNodeType() == Node.ELEMENT_NODE) {
        Element miningFieldEl = (Element)miningField;
        String name = miningFieldEl.getAttribute("name");
        String usage = miningFieldEl.getAttribute("usageType");
        // TO-DO: also missing value replacement etc.

        // find this attribute in the dataDictionary
        Attribute miningAtt = dataDictionary.attribute(name);
        if (miningAtt != null) {
          if (usage.length() == 0 || usage.equals("active") || usage.equals("predicted")) {
            attInfo.addElement(miningAtt);
            addedCount++;
          }
          if (usage.equals("predicted")) {
            classIndex = addedCount - 1;
          }
        } else {
          throw new Exception("Can't find mining field: " + name 
                              + " in the data dictionary.");
        }
      }
    }
    
    Instances insts = new Instances("miningSchema", attInfo, 0);
    //    System.out.println(insts);
    if (classIndex != -1) {
      insts.setClassIndex(classIndex);
    }


    return insts;
  }

  /**
   * Get the data dictionary as an Instances object
   *
   * @param doc the Document encapsulating the pmml
   * @return the data dictionary as an Instances object
   * @throws Exception if there are fields that are not continuous, 
   * ordinal or categorical in the data dictionary
   */
  protected static Instances getDataDictionaryAsInstances(Document doc) 
    throws Exception {

    // TO-DO: definition of missing values (see below)

    FastVector attInfo = new FastVector();
    NodeList dataDictionary = doc.getElementsByTagName("DataField");
    for (int i = 0; i < dataDictionary.getLength(); i++) {
      Node dataField = dataDictionary.item(i);
      if (dataField.getNodeType() == Node.ELEMENT_NODE) {
        Element dataFieldEl = (Element)dataField;
        String name = dataFieldEl.getAttribute("name");
        String type = dataFieldEl.getAttribute("optype");
        Attribute tempAtt = null;
        if (name != null && type != null) {
          if (type.equals("continuous")) {
            tempAtt = new Attribute(name);
          } else if (type.equals("categorical") || type.equals("ordinal")) {
            NodeList valueList = dataFieldEl.getElementsByTagName("Value");
            if (valueList == null || valueList.getLength() == 0) {
              // assume that categorical values will be revealed in the actual model.
              // Create a string attribute for now
              FastVector nullV = null;
              tempAtt = new Attribute(name, nullV);
            } else {
              // add the values (if defined as "valid")
              FastVector valueVector = new FastVector();
              for (int j = 0; j < valueList.getLength(); j++) {
                Node val = valueList.item(j);
                if (val.getNodeType() == Node.ELEMENT_NODE) {
                  // property is optional (default value is "valid")
                  String property = ((Element)val).getAttribute("property");
                  if (property == null || property.length() == 0 || property.equals("valid")) {
                    String value = ((Element)val).getAttribute("value");
                    valueVector.addElement(value);
                  } else {
                    // Just ignore invalid or missing value definitions for now...
                    // TO-DO: implement Value meta data with missing/invalid value defs.
                  }
                }
              }
              tempAtt = new Attribute(name, valueVector);
            }
          } else {
            throw new Exception("[PMMLFactory] can't handle " + type + "attributes.");
          }
          attInfo.addElement(tempAtt);
        }
      }
    }

    // TO-DO: check whether certain values are declared to represent
    // missing or invalid values (applies to both categorical and continuous
    // attributes

    // create the Instances structure
    Instances insts = new Instances("dataDictionary", attInfo, 0);
    //    System.out.println(insts);

    return insts;
  }

  public static String applyClassifier(PMMLModel model, Instances test) throws Exception {
    StringBuffer buff = new StringBuffer();
    if (!(model instanceof PMMLClassifier)) {
      throw new Exception("PMML model is not a classifier!");
    }

    double[] preds = null;
    PMMLClassifier classifier = (PMMLClassifier)model;
    for (int i = 0; i < test.numInstances(); i++) {
      buff.append("Actual: ");
      Instance temp = test.instance(i);
      if (temp.classAttribute().isNumeric()) {
        buff.append(temp.value(temp.classIndex()) + " ");
      } else {
        buff.append(temp.classAttribute().value((int)temp.value(temp.classIndex())) + " ");
      }
      preds = classifier.distributionForInstance(temp);
      buff.append(" Predicted: ");
      for (int j = 0; j < preds.length; j++) {
        buff.append("" + preds[j] + " ");
      }
      buff.append("\n");
    }
    return buff.toString();
  }
  
  private static class PMMLClassifierRunner extends Classifier {
    public double[] distributionForInstance(Instance test) throws Exception {
      throw new Exception("Don't call this method!!");
    }
    
    public void buildClassifier(Instances instances) throws Exception {
      throw new Exception("Don't call this method!!");
    }
    
    public String getRevision() {
      return weka.core.RevisionUtils.extract("$Revision: 5562 $");
    }
    
    public void evaluatePMMLClassifier(String[] options) {
      runClassifier(this, options);
    }
  }

  public static void main(String[] args) {
    try {
      String[] optionsTmp = new String[args.length];
      for (int i = 0; i < args.length; i++) {
        optionsTmp[i] = args[i];
      }
      String pmmlFile = Utils.getOption('l', optionsTmp); 
      if (pmmlFile.length() == 0) {
        throw new Exception("[PMMLFactory] must specify a PMML file using the -l option.");
      }
      // see if it is supported before going any further
      PMMLModel model = getPMMLModel(pmmlFile, null);
      
      PMMLClassifierRunner pcr = new PMMLClassifierRunner();
      pcr.evaluatePMMLClassifier(args);

/*      System.out.println(model);
      if (args.length == 2) {
        // load an arff file
        Instances testData = new Instances(new BufferedReader(new FileReader(args[1])));
        Instances miningSchemaI = model.getMiningSchema().getFieldsAsInstances();
        if (miningSchemaI.classIndex() >= 0) {
          String className = miningSchemaI.classAttribute().name();
          for (int i = 0; i < testData.numAttributes(); i++) {
            if (testData.attribute(i).name().equals(className)) {
              testData.setClassIndex(i);
              System.out.println("Found class " + className + " in test data.");
              break;
            }
          }
        }
        System.out.println(applyClassifier(model, testData));
      } */
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy