weka.core.pmml.TargetMetaInfo Maven / Gradle / Ivy
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* TargetMetaInfo.java
* Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.core.pmml;
import java.io.Serializable;
import java.util.ArrayList;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.core.Attribute;
import weka.core.Utils;
/**
* Class to encapsulate information about a Target.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision 1.0 $
*/
public class TargetMetaInfo extends FieldMetaInfo implements Serializable {
/** For serialization */
private static final long serialVersionUID = 863500462237904927L;
/** min and max */
protected double m_min = Double.NaN;
protected double m_max = Double.NaN;
/** re-scaling of target value (if defined) */
protected double m_rescaleConstant = 0;
protected double m_rescaleFactor = 1.0;
/** cast integers (default no casting) */
protected String m_castInteger = "";
// -------------------------------------------------------
/** default value (numeric) or prior distribution (categorical) */
protected double[] m_defaultValueOrPriorProbs;
/** for categorical values. Actual values */
protected ArrayList m_values = new ArrayList();
/** corresponding display values */
protected ArrayList m_displayValues = new ArrayList();
// TODO: toString method.
/**
* Constructor.
*
* @param target the Element
encapsulating a Target
* @throws Exception if there is a problem reading the Target
*/
protected TargetMetaInfo(Element target) throws Exception {
super(target);
// check for an OPTYPE
/*String op = target.getAttribute("optype");
if (op != null && op.length() > 0) {
for (int i = 0; i < Optype.values().length; i++) {
if (op.equals(Optype.values()[i].toString())) {
m_optype = Optype.values()[i];
break;
}
}
}*/
// min and max (if defined)
String min = target.getAttribute("min");
if (min != null && min.length() > 0) {
try {
m_min = Double.parseDouble(min);
} catch (IllegalArgumentException ex) {
throw new Exception("[TargetMetaInfo] can't parse min value for target field "
+ m_fieldName);
}
}
String max = target.getAttribute("max");
if (max != null && max.length() > 0) {
try {
m_max = Double.parseDouble(max);
} catch (IllegalArgumentException ex) {
throw new Exception("[TargetMetaInfo] can't parse max value for target field "
+ m_fieldName);
}
}
// Re-scaling (if any)
String rsc = target.getAttribute("rescaleConstant");
if (rsc != null && rsc.length() > 0) {
try {
m_rescaleConstant = Double.parseDouble(rsc);
} catch (IllegalArgumentException ex) {
throw new Exception("[TargetMetaInfo] can't parse rescale constant value for "
+ "target field " + m_fieldName);
}
}
String rsf = target.getAttribute("rescaleFactor");
if (rsf != null && rsf.length() > 0) {
try {
m_rescaleFactor = Double.parseDouble(rsf);
} catch (IllegalArgumentException ex) {
throw new Exception("[TargetMetaInfo] can't parse rescale factor value for "
+ "target field " + m_fieldName);
}
}
// Cast integers
String cstI = target.getAttribute("castInteger");
if (cstI != null && cstI.length() > 0) {
m_castInteger = cstI;
}
// Get the target value(s). Apparently, there doesn't have to
// be any target values defined.
NodeList vals = target.getElementsByTagName("TargetValue");
if (vals.getLength() > 0) {
m_defaultValueOrPriorProbs = new double[vals.getLength()];
for (int i = 0; i < vals.getLength(); i++) {
Node value = vals.item(i);
if (value.getNodeType() == Node.ELEMENT_NODE) {
Element valueE = (Element)value;
String valueName = valueE.getAttribute("value");
if (valueName != null && valueName.length() > 0) {
// we have a categorical value - set optype if it's not
// already set
if (m_optype != Optype.CATEGORICAL &&
m_optype != Optype.NONE) {
throw new Exception("[TargetMetaInfo] TargetValue element has categorical value but "
+ "optype is not categorical!");
}
if (m_optype == Optype.NONE) {
m_optype = Optype.CATEGORICAL;
}
m_values.add(valueName);
// get display value (if any)
String displayValue = valueE.getAttribute("displayValue");
if (displayValue != null && displayValue.length() > 0) {
m_displayValues.add(displayValue);
} else {
// use the value as the display value
m_displayValues.add(valueName);
}
// get prior probability (should be defined!!)
String prior = valueE.getAttribute("priorProbability");
if (prior != null && prior.length() > 0) {
try {
m_defaultValueOrPriorProbs[i] = Double.parseDouble(prior);
} catch (IllegalArgumentException ex) {
throw new Exception("[TargetMetaInfo] Can't parse probability from "
+ "TargetValue element.");
}
} else {
throw new Exception("[TargetMetaInfo] No prior probability defined for value "
+ valueName);
}
} else {
// we have a numeric field
// check the optype
if (m_optype != Optype.CONTINUOUS &&
m_optype != Optype.NONE) {
throw new Exception("[TargetMetaInfo] TargetValue element has continuous value but "
+ "optype is not continuous!");
}
if (m_optype == Optype.NONE) {
m_optype = Optype.CONTINUOUS;
}
// get the default value
String defaultV = valueE.getAttribute("defaultValue");
if (defaultV != null && defaultV.length() > 0) {
try {
m_defaultValueOrPriorProbs[i] = Double.parseDouble(defaultV);
} catch (IllegalArgumentException ex) {
throw new Exception("[TargetMetaInfo] Can't parse default value from "
+ "TargetValue element.");
}
} else {
throw new Exception("[TargetMetaInfo] No default value defined for target "
+ m_fieldName);
}
}
}
}
}
}
/**
* Get the prior probability for the supplied value.
*
* @param value the value to get the probability for
* @return the probability
* @throws Exception if there are no TargetValues defined or
* if the supplied value is not in the list of TargetValues
*/
public double getPriorProbability(String value) throws Exception {
if (m_defaultValueOrPriorProbs == null) {
throw new Exception("[TargetMetaInfo] no TargetValues defined (getPriorProbability)");
}
double result = Double.NaN;
boolean found = false;
for (int i = 0; i < m_values.size(); i++) {
if (value.equals(m_values.get(i))) {
found = true;
result = m_defaultValueOrPriorProbs[i];
break;
}
}
if (!found) {
throw new Exception("[TargetMetaInfo] couldn't find value " + value
+ "(getPriorProbability)");
}
return result;
}
/**
* Get the default value (numeric target)
*
* @return the default value
* @throws Exception if there is no TargetValue defined
*/
public double getDefaultValue() throws Exception {
if (m_defaultValueOrPriorProbs == null) {
throw new Exception("[TargetMetaInfo] no TargetValues defined (getPriorProbability)");
}
return m_defaultValueOrPriorProbs[0];
}
/**
* Get the values (discrete case only) for this Target. Note: the
* list may be empty if the pmml doesn't specify any values.
*
* @return the values of this Target
*/
public ArrayList getValues() {
return new ArrayList(m_values);
}
/**
* Apply min and max, rescaleFactor, rescaleConstant and castInteger - in
* that order (where defined).
*
* @param prediction the prediction to apply these modification to
* @return the modified prediction
* @throws Exception if this target is not a continuous one
*/
public double applyMinMaxRescaleCast(double prediction) throws Exception {
if (m_optype != Optype.CONTINUOUS) {
throw new Exception("[TargetMetaInfo] target must be continuous!");
}
if (!Utils.isMissingValue(m_min) && prediction < m_min) {
prediction = m_min;
}
if (!Utils.isMissingValue(m_max) && prediction > m_max) {
prediction = m_max;
}
prediction *= m_rescaleFactor;
prediction += m_rescaleConstant;
if (m_castInteger.length() > 0) {
if (m_castInteger.equals("round")) {
prediction = Math.round(prediction);
} else if (m_castInteger.equals("ceiling")) {
prediction = Math.ceil(prediction);
} else if (m_castInteger.equals("floor")) {
prediction = Math.floor(prediction);
} else {
throw new Exception("[TargetMetaInfo] unknown castInteger value "
+ m_castInteger);
}
}
return prediction;
}
/**
* Return this field as an Attribute.
*
* @return an Attribute for this field.
*/
public Attribute getFieldAsAttribute() {
if (m_optype == Optype.CONTINUOUS) {
return new Attribute(m_fieldName);
}
if (m_values.size() == 0) {
// return a String attribute
return new Attribute(m_fieldName, (ArrayList)null);
}
ArrayList values = new ArrayList();
for (String val : m_values) {
values.add(val);
}
return new Attribute(m_fieldName, values);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy