All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.pmml.consumer.SupportVectorMachineModel Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This version represents the developer version, the "bleeding edge" of development, you could say. New functionality gets added to this version.

There is a newer version: 3.9.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    SupportVectorMachineModel.java
 *    Copyright (C) 2010-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.classifiers.pmml.consumer.NeuralNetwork.MiningFunction;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;
import weka.core.pmml.VectorDictionary;
import weka.core.pmml.VectorInstance;
import weka.gui.Logger;

/**
 * Implements a PMML SupportVectorMachineModel
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 8034 $
 */
public class SupportVectorMachineModel extends PMMLClassifier 
  implements Serializable {
  
  /** For serialization */
  private static final long serialVersionUID = 6225095165118374296L;

  /**
   * Abstract base class for kernels
   */
  static abstract class Kernel implements Serializable {
    
    /** The log object to use */
    protected Logger m_log = null;
    
    protected Kernel(Logger log) {
      m_log = log;
    }
    
    /** For serialization */
    private static final long serialVersionUID = -6696443459968934767L;

    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector instance
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public abstract double evaluate(VectorInstance x, VectorInstance y) 
      throws Exception;
    
    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector (as an array of values)
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public abstract double evaluate(VectorInstance x, double[] y) 
      throws Exception;
    
    /**
     * Factory method returning a new Kernel constructed from 
     * the supplied XML element
     * 
     * @param svmMachineModelElement the XML element containing the kernel
     * @param log the logging object to use
     * @return a new Kernel
     * @throws Exception if something goes wrong
     */
    public static Kernel getKernel(Element svmMachineModelElement,
        Logger log) throws Exception {

      
      NodeList kList = svmMachineModelElement.getElementsByTagName("LinearKernelType");
      if (kList.getLength() > 0) {
        return new LinearKernel(log);
      }
      
      kList = svmMachineModelElement.getElementsByTagName("PolynomialKernelType");
      if (kList.getLength() > 0) {
        return new PolynomialKernel((Element)kList.item(0), log);
      }
      
      kList = svmMachineModelElement.getElementsByTagName("RadialBasisKernelType");
      if (kList.getLength() > 0) {
        return new RadialBasisKernel((Element)kList.item(0), log);
      }
      
      kList = svmMachineModelElement.getElementsByTagName("SigmoidKernelType");
      if (kList.getLength() > 0) {
        return new SigmoidKernel((Element)kList.item(0), log);
      }
      
      throw new Exception("[Kernel] Can't find a kernel that I recognize!");
    }
  }
  
  /**
   * Subclass of Kernel implementing a simple linear (dot product) kernel
   */
  static class LinearKernel extends Kernel implements Serializable {
    
    public LinearKernel(Logger log) {
      super(log);
    }
    
    public LinearKernel() {
      super(null);
    }
    
    /** For serialization */
    private static final long serialVersionUID = 8991716708484953837L;

    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector instance
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, VectorInstance y) 
      throws Exception {
      return x.dotProduct(y);
    }
    
    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector (as an array of values)
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, double[] y)
      throws Exception {
      return x.dotProduct(y);
    }
    
    /**
     * Return a textual description of this kernel
     * 
     * @return a string describing this kernel
     */
    public String toString() {
      return "Linear kernel: K(x,y) = ";
    }
  }
  
  /**
   * Subclass of Kernel implementing a polynomial kernel
   */
  static class PolynomialKernel extends Kernel implements Serializable {
    
    /** For serialization */
    private static final long serialVersionUID = -616176630397865281L;
    
    protected double m_gamma = 1;
    protected double m_coef0 = 1;
    protected double m_degree = 1;
    
    public PolynomialKernel(Element polyNode) {
      this(polyNode, null);
    }
    
    public PolynomialKernel(Element polyNode, Logger log) {
      super(log);
      
      String gammaString = polyNode.getAttribute("gamma");
      if (gammaString != null && gammaString.length() > 0) {
        try {
          m_gamma = Double.parseDouble(gammaString);
        } catch (NumberFormatException e) {
          String message = "[PolynomialKernel] : WARNING, can't parse "
            + "gamma attribute. Using default value of 1.";
          if (m_log == null) {
            System.err.println(message);
          } else {
            m_log.logMessage(message);
          }
        }
      }
      
      String coefString = polyNode.getAttribute("coef0");
      if (coefString != null && coefString.length() > 0) {
        try {
          m_coef0 = Double.parseDouble(coefString);
        } catch (NumberFormatException e) {
          String message = "[PolynomialKernel] : WARNING, can't parse "
            + "coef0 attribute. Using default value of 1.";
          if (m_log == null) {
            System.err.println(message);
          } else {
            m_log.logMessage(message);
          }
        }
      }
      
      String degreeString = polyNode.getAttribute("degree");
      if (degreeString != null && degreeString.length() > 0) {
        try {
          m_degree = Double.parseDouble(degreeString);
        } catch (NumberFormatException e) {
          String message = "[PolynomialKernel] : WARNING, can't parse "
            + "degree attribute. Using default value of 1.";
          if (m_log == null) {
            System.err.println(message);
          } else {
            m_log.logMessage(message);
          }
        }
      }
    }
    
    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector instance
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, VectorInstance y)
      throws Exception {
      double dotProd = x.dotProduct(y);
      return Math.pow(m_gamma * dotProd + m_coef0, m_degree);      
    }
    
    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector (as an array of values)
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, double[] y)
      throws Exception {
      double dotProd = x.dotProduct(y);
      return Math.pow(m_gamma * dotProd + m_coef0, m_degree);      
    }
    
    /**
     * Return a textual description of this kernel
     * 
     * @return a string describing this kernel
     */
    public String toString() {
      return "Polynomial kernel: K(x,y) = (" + m_gamma + " *  + " 
        + m_coef0 +")^" + m_degree;
    }
  }
  
  /**
   * Subclass of Kernel implementing a radial basis function kernel
   */
  static class RadialBasisKernel extends Kernel implements Serializable {
    
    /** For serialization */
    private static final long serialVersionUID = -3834238621822239042L;
    
    protected double m_gamma = 1;
    
    public RadialBasisKernel(Element radialElement) {
      this(radialElement, null);
    }
    
    public RadialBasisKernel(Element radialElement, Logger log) {
      super(log);
      
      String gammaString = radialElement.getAttribute("gamma");
      if (gammaString != null && gammaString.length() > 0) {
        try {
          m_gamma = Double.parseDouble(gammaString);
        } catch (NumberFormatException e) {
          String message = "[RadialBasisKernel] : WARNING, can't parse "
            + "gamma attribute. Using default value of 1.";
          if (m_log == null) {
            System.err.println(message);
          } else {
            m_log.logMessage(message);
          }
        }
      }
    }
   
    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector instance
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, VectorInstance y)
    throws Exception {
      VectorInstance diff = x.subtract(y);
      double result = -m_gamma * diff.dotProduct(diff);

      return Math.exp(result);    
    }

    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector (as an array of values)
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, double[] y)
      throws Exception {
      VectorInstance diff = x.subtract(y);
//      System.err.println("diff: " + diff.getValues());
      double result = -m_gamma * diff.dotProduct(diff);
//      System.err.println("Result: " + result);
      return Math.exp(result);
    }
    
    /**
     * Return a textual description of this kernel
     * 
     * @return a string describing this kernel
     */
    public String toString() {
      return "Radial kernel: K(x,y) = exp(-" + m_gamma + " * ||x - y||^2)"; 
    }    
  }
  
  /**
   * Subclass of Kernel implementing a sigmoid function
   */
  static class SigmoidKernel extends Kernel implements Serializable {
    
    /** For serialization */
    private static final long serialVersionUID = 8713475894705750117L;

    protected double m_gamma = 1;
    
    protected double m_coef0 = 1;
    
    public SigmoidKernel(Element sigElement) {
      this(sigElement, null);
    }
    
    public SigmoidKernel(Element sigElement, Logger log) {
      super(log);
      
      String gammaString = sigElement.getAttribute("gamma");
      if (gammaString != null && gammaString.length() > 0) {
        try {
          m_gamma = Double.parseDouble(gammaString);
        } catch (NumberFormatException e) {
          String message = "[SigmoidKernel] : WARNING, can't parse "
            + "gamma attribute. Using default value of 1.";
          if (m_log == null) {
            System.err.println(message);
          } else {
            m_log.logMessage(message);
          }
        }
      }
      
      String coefString = sigElement.getAttribute("coef0");
      if (coefString != null && coefString.length() > 0) {
        try {
          m_coef0 = Double.parseDouble(coefString);
        } catch (NumberFormatException e) {
          String message = "[SigmoidKernel] : WARNING, can't parse "
            + "coef0 attribute. Using default value of 1.";
          if (m_log == null) {
            System.err.println(message);
          } else {
            m_log.logMessage(message);
          }
        }
      }
    }
    
    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector instance
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, VectorInstance y)
      throws Exception {
      
      double dotProd = x.dotProduct(y);
      double z = m_gamma * dotProd + m_coef0;
      double a = Math.exp(z);
      double b = Math.exp(-z);
      return ((a - b) / (a + b));      
    }

    /**
     * Compute the result of the kernel evaluation on the supplied vectors
     * 
     * @param x the first vector instance
     * @param y the second vector (as an array of values)
     * @return the result of the kernel evaluation
     * @throws Exception if something goes wrong
     */
    public double evaluate(VectorInstance x, double[] y)
      throws Exception {
      
      double dotProd = x.dotProduct(y);
      double z = m_gamma * dotProd + m_coef0;
      double a = Math.exp(z);
      double b = Math.exp(-z);
      return ((a - b) / (a + b));
    }
    
    /**
     * Return a textual description of this kernel
     * 
     * @return a string describing this kernel
     */
    public String toString() {
      return "Sigmoid kernel: K(x,y) = tanh(" + m_gamma + " *  + " 
        + m_coef0 +")";
    }
  }
  
  /**
   * Inner class implementing a single binary (classification) SVM
   */
  static class SupportVectorMachine implements Serializable {
    
    /** For serialization */
    private static final long serialVersionUID = -7650496802836815608L;

    /** The target label for classification problems */
    protected String m_targetCategory;
    
    /** The index of the global alternate target label for classification */
    protected int m_globalAlternateTargetCategoryIndex = -1;
    
    /** The index of the target label for classification problems */
    protected int m_targetCategoryIndex = -1;
        
    /** PMML 4.0 - index of the alternate category for one-vs-one */
    protected int m_localAlternateTargetCategoryIndex = -1;
    
    /** PMML 4.0 - local threshold (overrides the global one, if set) */
    protected double m_localThreshold = Double.MAX_VALUE; // default - not set
        
    /** The mining schema */
    protected MiningSchema m_miningSchema;
    
    /** The log object to use (if supplied) */
    protected Logger m_log;
    
    /** 
     * True if this is a linear machine expressed in terms of the original
     * attributes
     */
    protected boolean m_coeffsOnly = false;
    
    /** The support vectors used by this machine (if not linear with coeffs only */
    protected List m_supportVectors = 
      new ArrayList();
    
    /** The constant term - b */
    protected double m_intercept = 0;
    
    /** The coefficients for the vectors */
    protected double[] m_coefficients;
    
    /**
     * Computes the prediction from this support vector machine. For classification,
     * it fills in the appropriate element in the preds array with either a 0 or a 1.
     * If the output of the machine is < 0, then a 1 is entered into the array
     * element corresponding to this machine's targetCategory, otherwise a 0 is entered.
     * Note that this is different to the scoring procedure in the 3.2 spec (see the comments
     * in the source code of this method for more information about this).
     * 
     * @param input the test instance to predict
     * @param kernel the kernel to use
     * @param vecDict the vector dictionary (null if the machine is linear and expressed
     * in terms of attribute weights)
     * @param preds the prediction array to fill in
     * @param cMethod the classification method to use (for classification problems)
     * @param globalThreshold the global threshold (used if there is no local threshold
     * for this machine)
     * @throws Exception if something goes wrong
     */
    public void distributionForInstance(double[] input, Kernel kernel, 
        VectorDictionary vecDict, double[] preds, classificationMethod cMethod,
        double globalThreshold) 
      throws Exception {
      int targetIndex = 0;
      
      if (!m_coeffsOnly) {
        // get an array that only holds the values that are referenced
        // by the support vectors
        input = vecDict.incomingInstanceToVectorFieldVals(input);
      }
      
      if (m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
        targetIndex = m_targetCategoryIndex;
      }
      
      double result = 0;
      for (int i = 0; i < m_coefficients.length; i++) {
   //     System.err.println("X : " + m_supportVectors.get(i).getValues());
        double val = 0;
        if (!m_coeffsOnly) {
          val = kernel.evaluate(m_supportVectors.get(i), input);
        } else {
          val = input[i];
        }
        val *= m_coefficients[i];
//        System.err.println("val " + val);
        result += val;
      }
      result += m_intercept;
      
/*      if (result < 0 && m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
          System.err.println("[SupportVectorMachine] result (" 
            + result + ") is less than zero" +
        		" for a nominal class value!");
        result = 0;
      } else if (result > 1 && m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
        System.err.println("[SupportVectorMachine] result (" 
            + result + ") is greater than one" +
                        " for a nominal class value!");
        result = 1;
      }
      System.err.println("result " + result); */
      
      // TODO revisit this when I actually find out what is going on with
      // the Zementis model (the only easily available SVM model out there
      // at this time).
      
      if (cMethod == classificationMethod.NONE || 
          m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
        // ----------------------------------------------------------------------
        // The PMML 3.2 spec states that the output of the machine should lie
        // between 0 and 1 for a binary classification case with 1 corresponding
        // to the machine's targetCategory and 0 corresponding to the
        // alternateBinaryTargetCategory (implying a threshold of 0.5). This seems kind of
        // non standard, and indeed, in the 4.0 spec, it has changed to output < threshold
        // corresponding to the machine's targetCategory and >= threshold corresponding
        // to the alternateBinaryTargetCategory. What has been implemented here is
        // the later with a default threshold of 0 (since the 3.2 spec doesn't have
        // a way to specify a threshold). The example SVM PMML model from Zementis,
        // which is PMML version 3.2, produces output between -1 and 1 (give or take).
        // Implementing the 3.2 scoring mechanism as described in the spec (and truncating 
        // output at 0 and 1) results in all the predicted labels getting flipped on the data
        // used to construct the Zementis model!
        //
        // April 2010 - the Zementis guys have emailed me to say that their model
        // has been prepared for PMML 4.0, even though it states it is 3.2 in the
        // XML.
        // ----------------------------------------------------------------------

        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
          if (result < 0) {
            preds[targetIndex] = 1;
          } else {
            preds[targetIndex] = 0;
          }
        } else {
          preds[targetIndex] = result;
        }
      } else {
        // PMML 4.0
        if (cMethod == classificationMethod.ONE_AGAINST_ALL) {
          // smallest value output by a machine is the predicted class
          preds[targetIndex] = result;
        } else {
          // one-vs-one
          double threshold = (m_localThreshold < Double.MAX_VALUE)
            ? m_localThreshold
            : globalThreshold;
          
          // vote
          if (result < threshold) {
            preds[targetIndex]++;
          } else {
            int altCat = (m_localAlternateTargetCategoryIndex != -1)
              ? m_localAlternateTargetCategoryIndex
              : m_globalAlternateTargetCategoryIndex;
            
            preds[altCat]++;
          }
        }
      }
//      preds[targetIndex] = result;
    }
        
    /**
     * Constructs a new SupportVectorMachine from the supplied XML element
     * 
     * @param machineElement the XML element containing the SVM
     * @param miningSchema the mining schema for the PMML model
     * @param dictionary the VectorDictionary from which to look up the support
     * vectors used by this machine (may be null if the machine is linear and
     * expressed in terms of attribute weights)
     * @param svmRep the representation of this SVM (uses support vectors or is linear
     * an uses attribute weights)
     * @param altCategoryInd the index of the global alternateBinaryTarget (if classification)
     * @param log the log object to use
     * @throws Exception if something goes wrong
     */
    public SupportVectorMachine(Element machineElement, 
        MiningSchema miningSchema, VectorDictionary dictionary, SVM_representation svmRep, 
        int altCategoryInd, Logger log) throws Exception {

      m_miningSchema = miningSchema;
      m_log = log;
      
      String targetCat = machineElement.getAttribute("targetCategory");
      if (targetCat != null && targetCat.length() > 0) {
        m_targetCategory = targetCat;
        Attribute classAtt = m_miningSchema.getFieldsAsInstances().classAttribute(); 
        if (classAtt.isNominal()) {
          int index = classAtt.indexOfValue(m_targetCategory);
          
          if (index < 0) {
            throw new Exception("[SupportVectorMachine] : can't find target category: "
                + m_targetCategory + " in the class attribute!");
          }
          
          m_targetCategoryIndex = index;
          
          // now check for the PMML 4.0 alternateTargetCategory
          String altTargetCat = machineElement.getAttribute("alternateTargetCategory");
          if (altTargetCat != null && altTargetCat.length() > 0) {
            index = classAtt.indexOfValue(altTargetCat);
            if (index < 0) {
              throw new Exception("[SupportVectorMachine] : can't find alternate target category: "
                  + altTargetCat + " in the class attribute!");
            }            
            m_localAlternateTargetCategoryIndex = index;
          } else {
            // set the global one
            m_globalAlternateTargetCategoryIndex = altCategoryInd;
          }
          
        } else {
          throw new Exception("[SupportVectorMachine] : target category supplied " +
          		"but class attribute is numeric!");
        }
      } else {
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
          m_targetCategoryIndex = (altCategoryInd == 0)
            ? 1
            : 0;
          
          m_globalAlternateTargetCategoryIndex = altCategoryInd;
          System.err.println("Setting target index for machine to " + m_targetCategoryIndex);
        }
      }
      
      if (svmRep == SVM_representation.SUPPORT_VECTORS) {
        // get the vectors
        NodeList vectorsL = 
          machineElement.getElementsByTagName("SupportVectors");
        if (vectorsL.getLength() > 0) {
          Element vectors = (Element)vectorsL.item(0);
          NodeList allTheVectorsL = 
            vectors.getElementsByTagName("SupportVector");
          for (int i = 0; i < allTheVectorsL.getLength(); i++) {
            Node vec = allTheVectorsL.item(i);
            String vecId = ((Element)vec).getAttribute("vectorId");
            VectorInstance suppV = dictionary.getVector(vecId);
            if (suppV == null) {
              throw new Exception("[SupportVectorMachine] : can't find " +
                  "vector with ID: " + vecId + " in the " +
              "vector dictionary!");
            }
            m_supportVectors.add(suppV);
          }
        }
      } else {
        m_coeffsOnly = true;
      }
      
      // get the coefficients
      NodeList coefficientsL = machineElement.getElementsByTagName("Coefficients");
      // should be just one list of coefficients
      if (coefficientsL.getLength() != 1) {
        throw new Exception("[SupportVectorMachine] Should be just one list of " +
        		"coefficients per binary SVM!");
      }
      Element cL = (Element)coefficientsL.item(0);
      String intercept = cL.getAttribute("absoluteValue");
      if (intercept != null && intercept.length() > 0) {
        m_intercept = Double.parseDouble(intercept);
      }
            
      // now get the individual coefficient elements
      NodeList coeffL = cL.getElementsByTagName("Coefficient");
      if (coeffL.getLength() == 0) {
        throw new Exception("[SupportVectorMachine] No coefficients defined!");
      }
      
      m_coefficients = new double[coeffL.getLength()];
      
      for (int i = 0; i < coeffL.getLength(); i++) {
        Element coeff = ((Element) coeffL.item(i));
        String val = coeff.getAttribute("value");
        m_coefficients[i] = Double.parseDouble(val);
      }
    }
    
    /**
     * Get a textual description of this binary SVM
     * 
     * @return a description of this SVM as a string.
     */
    public String toString() {
      StringBuffer temp = new StringBuffer();
      
      temp.append("Binary SVM");
      if (m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
        temp.append(" (target category = " + m_targetCategory + ")");
        if (m_localAlternateTargetCategoryIndex != -1) {
          temp.append("\n (alternate category = " 
              + m_miningSchema.getFieldsAsInstances().classAttribute().
                value(m_localAlternateTargetCategoryIndex) + ")");
        }
      }
      temp.append("\n\n");
      
      for (int i = 0; i < m_supportVectors.size(); i++) {
        temp.append("\n" + m_coefficients[i] + " * [" 
            + m_supportVectors.get(i).getValues() + " * X]");
      }
      
      if (m_intercept >= 0) {
        temp.append("\n +" + m_intercept);
      } else {
        temp.append("\n " + m_intercept);
      }
      return temp.toString();
    }
  }
  
  static enum SVM_representation {
    SUPPORT_VECTORS,
    COEFFICIENTS; // for the inputs if machine is linear and expressed in terms of the attributes
  }
  
  static enum classificationMethod {
    NONE, // PMML 3.x
    ONE_AGAINST_ALL, // PMML 4.0 default
    ONE_AGAINST_ONE;
  }
  
  /** The mining function **/
  protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION;
  
  /** The classification method (PMML 4.0) */
  protected classificationMethod m_classificationMethod = 
    classificationMethod.NONE; // PMML 3.x (only handles binary problems)
  
  /** The model name (if defined) */
  protected String m_modelName;
  
  /** The algorithm name (if defined) */
  protected String m_algorithmName;
  
  /** The dictionary of support vectors */
  protected VectorDictionary m_vectorDictionary;
  
  /** The kernel function to use */
  protected Kernel m_kernel;
  
  /** The individual binary SVMs */
  protected List m_machines = 
    new ArrayList();
  
  /** The other class index (in the case of a single binary SVM - PMML 3.2). */
  protected int m_alternateBinaryTargetCategory = -1;
  
  /** Do we have support vectors, or just attribute coefficients for a linear machine? */
  protected SVM_representation m_svmRepresentation = SVM_representation.SUPPORT_VECTORS;
  
  /** PMML 4.0 threshold value */
  protected double m_threshold = 0;
  
  /**
   * Construct a new SupportVectorMachineModel encapsulating the information provided
   * in the PMML document.
   * 
   * @param model the SVM element from the PMML document
   * @param dataDictionary the data dictionary
   * @param miningSchema the mining schema
   * @throws Exception if the model can't be constructed from the PMML
   */
  public SupportVectorMachineModel(Element model, Instances dataDictionary,
      MiningSchema miningSchema) throws Exception {
    
    super(dataDictionary, miningSchema);
    
    if (!getPMMLVersion().equals("3.2")) {
      //TODO might have to throw an exception and only support 3.2
    }
    
    String fn = model.getAttribute("functionName");
    if (fn.equals("regression")) {
      m_functionType = MiningFunction.REGRESSION;
    }
    
    String modelName = model.getAttribute("modelName");
    if (modelName != null && modelName.length() > 0) {
      m_modelName = modelName;
    }
    
    String algoName = model.getAttribute("algorithmName");
    if (algoName != null && algoName.length() > 0) {
      m_algorithmName = algoName;
    }
    
    String svmRep = model.getAttribute("svmRepresentation");
    if (svmRep != null && svmRep.length() > 0) {
      if (svmRep.equals("Coefficients")) {
        m_svmRepresentation = SVM_representation.COEFFICIENTS;
      }
    }
    
    String altTargetCat = model.getAttribute("alternateBinaryTargetCategory");
    if (altTargetCat != null && altTargetCat.length() > 0) {
      int altTargetInd = 
        m_miningSchema.getFieldsAsInstances().classAttribute().indexOfValue(altTargetCat);
      
      if (altTargetInd < 0) {
        throw new Exception("[SupportVectorMachineModel] can't find alternate " +
        		"target value " + altTargetCat);
      }
      m_alternateBinaryTargetCategory = altTargetInd;
    }
    
    // PMML 4.0
    String thresholdS = model.getAttribute("threshold");
    if (thresholdS != null && thresholdS.length() > 0) {
      m_threshold = Double.parseDouble(thresholdS);
    }
    
    // PMML 4.0
    if (getPMMLVersion().startsWith("4.")) {
      m_classificationMethod = classificationMethod.ONE_AGAINST_ALL; // default for PMML 4.0
    }
    
    String classificationMethodS = model.getAttribute("classificationMethod");
    if (classificationMethodS != null && classificationMethodS.length()> 0) {
      if (classificationMethodS.equals("OneAgainstOne")) {
        m_classificationMethod = classificationMethod.ONE_AGAINST_ONE;
      }
    }
    
    if (m_svmRepresentation == SVM_representation.SUPPORT_VECTORS) {
      m_vectorDictionary = VectorDictionary.getVectorDictionary(model, miningSchema);
    }
    
    m_kernel = Kernel.getKernel(model, m_log);
    if (m_svmRepresentation == SVM_representation.COEFFICIENTS && 
        !(m_kernel instanceof LinearKernel)) {
      throw new Exception("[SupportVectorMachineModel] representation is " +
      		"coefficients, but kernel is not linear!");
    }
    
    // Get the individual machines
    NodeList machineL = model.getElementsByTagName("SupportVectorMachine");
    if (machineL.getLength() == 0) {
      throw new Exception("[SupportVectorMachineModel] No binary SVMs" +
      		" defined in model file!");
    }
    for (int i = 0; i < machineL.getLength(); i++) {
      Node machine = machineL.item(i);
      SupportVectorMachine newMach = 
        new SupportVectorMachine((Element)machine, 
            m_miningSchema, m_vectorDictionary, m_svmRepresentation,
            m_alternateBinaryTargetCategory, m_log);
      m_machines.add(newMach);
    }
  }
  
  /**                                                                                                             
   * Classifies the given test instance. The instance has to belong to a                                          
   * dataset when it's being classified.                                                          
   *                                                                                                              
   * @param inst the instance to be classified                                                                
   * @return the predicted most likely class for the instance or                                                  
   * Utils.missingValue() if no prediction is made                                                             
   * @exception Exception if an error occurred during the prediction                                              
   */
  public double[] distributionForInstance(Instance inst) throws Exception {
    if (!m_initialized) {
      mapToMiningSchema(inst.dataset());
    }
    double[] preds = null;
    
    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
      preds = new double[1];
    } else {
      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
      for (int i = 0; i < preds.length; i++) {
        preds[i] = -1; // mark all entries as not calculated to begin with
      }
    }
    
    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
    
    boolean hasMissing = false;
    for (int i = 0; i < incoming.length; i++) {
      if (i != m_miningSchema.getFieldsAsInstances().classIndex() && 
          Double.isNaN(incoming[i])) {
        hasMissing = true;
        //System.err.println("Missing value for att : " + m_miningSchema.getFieldsAsInstances().attribute(i).name());
        break;
      }
    }
    
    if (hasMissing) {
      if (!m_miningSchema.hasTargetMetaData()) {
        String message = "[SupportVectorMachineModel] WARNING: Instance to predict has missing value(s) but "
          + "there is no missing value handling meta data and no "
          + "prior probabilities/default value to fall back to. No "
          + "prediction will be made (" 
          + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()
              || m_miningSchema.getFieldsAsInstances().classAttribute().isString())
              ? "zero probabilities output)."
              : "NaN output).");
        if (m_log == null) {
          System.err.println(message);
        } else {
          m_log.logMessage(message);
        }
        
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
          preds[0] = Utils.missingValue();
        }
        return preds;
      } else {
        // use prior probablilities/default value
        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
          preds[0] = targetData.getDefaultValue();
        } else {
          Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
          for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
            preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
          }
        }
        return preds;
      }
    } else {
      for (SupportVectorMachine m : m_machines) {
        m.distributionForInstance(incoming, m_kernel, m_vectorDictionary, 
            preds, m_classificationMethod, m_threshold);
      }
    }
    
    if (m_classificationMethod != classificationMethod.NONE &&
        m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
      // PMML 4.0
      if (m_classificationMethod == classificationMethod.ONE_AGAINST_ALL) {
        // find the minimum value
        int minI = Utils.minIndex(preds);
        preds = new double[preds.length];
        preds[minI] = 1.0;
      } else {
        // nothing to do for one-against-one - just normalize the
        // votes
      }
    }
    
    if (m_machines.size() == preds.length - 1) {
      double total = 0;
      int unset = -1;
      for (int i = 0; i < preds.length; i++) {
        if (preds[i] != -1) {
          total += preds[i];
        } else {
          unset = i;
        }
      }
      
      if (total > 1.0) {
        throw new Exception("[SupportVectorMachineModel] total of probabilities" +
        		" is greater than 1!");
      }
      
      preds[unset] = 1.0 - total;
    }
    
    if (preds.length > 1) {
      Utils.normalize(preds);
    }
    
    return preds;
  }
  
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 8034 $");
  }
  
  /**
   * Get a textual description of this SupportVectorMachineModel
   * 
   * @return a description of this SupportVectorMachineModel as a string
   */
  public String toString() {
    StringBuffer temp = new StringBuffer();
    
    temp.append("PMML version " + getPMMLVersion());
    if (!getCreatorApplication().equals("?")) {
      temp.append("\nApplication: " + getCreatorApplication());
    }
    temp.append("\nPMML Model: Support Vector Machine Model");
    
    temp.append("\n\n");
    temp.append(m_miningSchema);
    
    temp.append("Kernel: \n\t");
    temp.append(m_kernel);
    temp.append("\n");
    
    if (m_classificationMethod != classificationMethod.NONE) {
      temp.append("Multi-class classifcation using ");
      if (m_classificationMethod == classificationMethod.ONE_AGAINST_ALL) {
        temp.append("one-against-all");
      } else {
        temp.append("one-against-one");
      }
      temp.append("\n\n");
    }
    
    for (SupportVectorMachine v : m_machines) {
      temp.append("\n" + v);
    }
    
    return temp.toString();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy