All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.pmml.producer.LogisticProducerHelper Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    LogisticProducerHelper.java
 *    Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.pmml.producer;

import java.io.StringWriter;
import java.math.BigInteger;

import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;

import weka.core.Attribute;
import weka.core.Instances;
import weka.core.pmml.jaxbbindings.DATATYPE;
import weka.core.pmml.jaxbbindings.DerivedField;
import weka.core.pmml.jaxbbindings.FIELDUSAGETYPE;
import weka.core.pmml.jaxbbindings.LocalTransformations;
import weka.core.pmml.jaxbbindings.MININGFUNCTION;
import weka.core.pmml.jaxbbindings.MISSINGVALUETREATMENTMETHOD;
import weka.core.pmml.jaxbbindings.MiningField;
import weka.core.pmml.jaxbbindings.MiningSchema;
import weka.core.pmml.jaxbbindings.NormDiscrete;
import weka.core.pmml.jaxbbindings.NumericPredictor;
import weka.core.pmml.jaxbbindings.OPTYPE;
import weka.core.pmml.jaxbbindings.Output;
import weka.core.pmml.jaxbbindings.OutputField;
import weka.core.pmml.jaxbbindings.PMML;
import weka.core.pmml.jaxbbindings.REGRESSIONNORMALIZATIONMETHOD;
import weka.core.pmml.jaxbbindings.RegressionModel;
import weka.core.pmml.jaxbbindings.RegressionTable;
import weka.core.pmml.jaxbbindings.TransformationDictionary;

/**
 * Helper class for producing PMML for a Logistic classifier. Not designed to be
 * used directly - you should call toPMML() on a trained Logistic classifier.
 * 
 * @author David Persons
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: $
 */
public class LogisticProducerHelper extends AbstractPMMLProducerHelper {

  /**
   * Produce the PMML for a Logistic classifier
   * 
   * @param train the training data used to build the Logistic model
   * @param structureAfterFiltering the structure of the training data after
   *          filtering
   * @param par the parameters of the function(s)
   * @param numClasses the number of classes in the data
   * @return the PMML for the classifier
   */
  public static String toPMML(Instances train,
    Instances structureAfterFiltering, double[][] par, int numClasses) {

    PMML pmml = initPMML();
    addDataDictionary(train, pmml);

    String currentAttrName = null;
    TransformationDictionary transformDict = null;
    LocalTransformations localTransforms = null;
    MiningSchema schema = new MiningSchema();
    for (int i = 0; i < structureAfterFiltering.numAttributes(); i++) {
      Attribute attr = structureAfterFiltering.attribute(i);
      Attribute originalAttr = train.attribute(attr.name());
      if (i == structureAfterFiltering.classIndex()) {
        schema.addMiningFields(new MiningField(attr.name(),
          FIELDUSAGETYPE.PREDICTED));
      }

      if (originalAttr == null) {
        // this must be a derived one
        if (localTransforms == null) {
          localTransforms = new LocalTransformations();
        }
        if (transformDict == null) {
          transformDict = new TransformationDictionary();
        }
        String[] nameAndValue = getNameAndValueFromUnsupervisedNominalToBinaryDerivedAttribute(
          train, attr);

        if (!nameAndValue[0].equals(currentAttrName)) {
          currentAttrName = nameAndValue[0];
          if (i != structureAfterFiltering.classIndex()) {
            // add a mining field
            int mode = (int) train.meanOrMode(train.attribute(nameAndValue[0]));
            schema.addMiningFields(new MiningField(nameAndValue[0],
              FIELDUSAGETYPE.ACTIVE, MISSINGVALUETREATMENTMETHOD.AS_MODE, train
                .attribute(nameAndValue[0]).value(mode)));
          }
        }
        DerivedField derivedfield = new DerivedField(attr.name(),
          DATATYPE.DOUBLE, OPTYPE.CONTINUOUS);
        NormDiscrete normDiscrete = new NormDiscrete(nameAndValue[0],
          nameAndValue[1]);
        derivedfield.setNormDiscrete(normDiscrete);
        transformDict.addDerivedField(derivedfield);
      } else {
        // its either already numeric or was a binary nominal one
        if (i != structureAfterFiltering.classIndex()) {
          if (originalAttr.isNumeric()) {
            String mean = "" + train.meanOrMode(originalAttr);
            schema
              .addMiningFields(new MiningField(originalAttr.name(),
                FIELDUSAGETYPE.ACTIVE, MISSINGVALUETREATMENTMETHOD.AS_MEAN,
                mean));
          } else {
            int mode = (int) train.meanOrMode(originalAttr);
            schema.addMiningFields(new MiningField(originalAttr.name(),
              FIELDUSAGETYPE.ACTIVE, MISSINGVALUETREATMENTMETHOD.AS_MODE,
              originalAttr.value(mode)));
          }
        }
      }
    }

    RegressionModel model = new RegressionModel();
    if (transformDict != null) {
      pmml.setTransformationDictionary(transformDict);
    }
    model.addContent(schema);

    model.setFunctionName(MININGFUNCTION.CLASSIFICATION);
    model.setAlgorithmName("logisticRegression");
    model.setModelType("logisticRegression");
    model.setNormalizationMethod(REGRESSIONNORMALIZATIONMETHOD.SOFTMAX);

    Output output = new Output();
    Attribute classAttribute = structureAfterFiltering.classAttribute();
    for (int i = 0; i < classAttribute.numValues(); i++) {
      OutputField outputField = new OutputField();
      outputField.setName(classAttribute.name());
      outputField.setValue(classAttribute.value(i));
      output.addOutputField(outputField);
    }
    model.addContent(output);

    for (int i = 0; i < numClasses - 1; i++) {
      RegressionTable table = new RegressionTable(structureAfterFiltering
        .classAttribute().value(i));
      // coefficients
      int j = 1;
      for (int k = 0; k < structureAfterFiltering.numAttributes(); k++) {
        if (k != structureAfterFiltering.classIndex()) {
          Attribute attr = structureAfterFiltering.attribute(k);
          table.addNumericPredictor(new NumericPredictor(attr.name(),
            BigInteger.valueOf(1), par[j][i]));
          j++;
        }
      }
      table.setIntercept(par[0][i]);
      model.addContent(table);
    }

    pmml.addAssociationModelOrBaselineModelOrClusteringModes(model);

    try {
      StringWriter sw = new StringWriter();
      JAXBContext jc = JAXBContext.newInstance(PMML.class);
      Marshaller marshaller = jc.createMarshaller();
      marshaller.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true);
      marshaller.marshal(pmml, sw);
      return sw.toString();
    } catch (JAXBException e) {
      e.printStackTrace();
    }
    return "";
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy