All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.estimators.EstimatorUtils Maven / Gradle / Ivy

Go to download

The Waikato Environment for Knowledge Analysis (WEKA), a machine learning workbench. This is the stable version. Apart from bugfixes, this version does not receive any other updates.

There is a newer version: 3.8.6
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    EstimatorUtils.java
 *    Copyright (C) 2004-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.estimators;

import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.util.Enumeration;
import java.util.Vector;

import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;

/**
 * Contains static utility functions for Estimators.
 * 

* * @author Gabi Schmidberger ([email protected]) * @version $Revision: 10153 $ */ public class EstimatorUtils implements RevisionHandler { /** * Find the minimum distance between values * * @param inst sorted instances, sorted * @param attrIndex index of the attribute, they are sorted after * @return the minimal distance */ public static double findMinDistance(Instances inst, int attrIndex) { double min = Double.MAX_VALUE; int numInst = inst.numInstances(); double diff; if (numInst < 2) { return min; } int begin = -1; Instance instance = null; do { begin++; if (begin < numInst) { instance = inst.instance(begin); } } while (begin < numInst && instance.isMissing(attrIndex)); double secondValue = inst.instance(begin).value(attrIndex); for (int i = begin; i < numInst && !inst.instance(i).isMissing(attrIndex); i++) { double firstValue = secondValue; secondValue = inst.instance(i).value(attrIndex); if (secondValue != firstValue) { diff = secondValue - firstValue; if (diff < min && diff > 0.0) { min = diff; } } } return min; } /** * Find the minimum and the maximum of the attribute and return it in the last * parameter.. * * @param inst instances used to build the estimator * @param attrIndex index of the attribute * @param minMax the array to return minimum and maximum in * @return number of not missing values * @exception Exception if parameter minMax wasn't initialized properly */ public static int getMinMax(Instances inst, int attrIndex, double[] minMax) throws Exception { double min = Double.NaN; double max = Double.NaN; Instance instance = null; int numNotMissing = 0; if ((minMax == null) || (minMax.length < 2)) { throw new Exception("Error in Program, privat method getMinMax"); } Enumeration enumInst = inst.enumerateInstances(); if (enumInst.hasMoreElements()) { do { instance = enumInst.nextElement(); } while (instance.isMissing(attrIndex) && (enumInst.hasMoreElements())); // add values if not missing if (!instance.isMissing(attrIndex)) { numNotMissing++; min = instance.value(attrIndex); max = instance.value(attrIndex); } while (enumInst.hasMoreElements()) { instance = enumInst.nextElement(); if (!instance.isMissing(attrIndex)) { numNotMissing++; if (instance.value(attrIndex) < min) { min = (instance.value(attrIndex)); } else { if (instance.value(attrIndex) > max) { max = (instance.value(attrIndex)); } } } } } minMax[0] = min; minMax[1] = max; return numNotMissing; } /** * Returns a dataset that contains all instances of a certain class value. * * @param data dataset to select the instances from * @param attrIndex index of the relevant attribute * @param classIndex index of the class attribute * @param classValue the relevant class value * @return a dataset with only */ public static Vector getInstancesFromClass(Instances data, int attrIndex, int classIndex, double classValue, Instances workData) { // Oops.pln("getInstancesFromClass classValue"+classValue+" workData"+data.numInstances()); Vector dataPlusInfo = new Vector(0); int num = 0; int numClassValue = 0; // workData = new Instances(data, 0); for (int i = 0; i < data.numInstances(); i++) { if (!data.instance(i).isMissing(attrIndex)) { num++; if (data.instance(i).value(classIndex) == classValue) { workData.add(data.instance(i)); numClassValue++; } } } Double alphaFactor = new Double((double) numClassValue / (double) num); dataPlusInfo.add(workData); dataPlusInfo.add(alphaFactor); return dataPlusInfo; } /** * Returns a dataset that contains of all instances of a certain class value. * * @param data dataset to select the instances from * @param classIndex index of the class attribute * @param classValue the class value * @return a dataset with only instances of one class value */ public static Instances getInstancesFromClass(Instances data, int classIndex, double classValue) { Instances workData = new Instances(data, 0); for (int i = 0; i < data.numInstances(); i++) { if (data.instance(i).value(classIndex) == classValue) { workData.add(data.instance(i)); } } return workData; } /** * Output of an n points of a density curve. Filename is parameter f + * ".curv". * * @param f string to build filename * @param est * @param min * @param max * @param numPoints * @throws Exception if something goes wrong */ public static void writeCurve(String f, Estimator est, double min, double max, int numPoints) throws Exception { PrintWriter output = null; StringBuffer text = new StringBuffer(""); if (f.length() != 0) { // add attribute indexnumber to filename and extension .hist String name = f + ".curv"; output = new PrintWriter(new FileOutputStream(name)); } else { return; } double diff = (max - min) / (numPoints - 1.0); try { text.append("" + min + " " + est.getProbability(min) + " \n"); for (double value = min + diff; value < max; value += diff) { text.append("" + value + " " + est.getProbability(value) + " \n"); } text.append("" + max + " " + est.getProbability(max) + " \n"); } catch (Exception ex) { ex.printStackTrace(); System.out.println(ex.getMessage()); } output.println(text.toString()); // close output if (output != null) { output.close(); } } /** * Output of an n points of a density curve. Filename is parameter f + * ".curv". * * @param f string to build filename * @param est * @param classEst * @param classIndex * @param min * @param max * @param numPoints * @throws Exception if something goes wrong */ public static void writeCurve(String f, Estimator est, Estimator classEst, double classIndex, double min, double max, int numPoints) throws Exception { PrintWriter output = null; StringBuffer text = new StringBuffer(""); if (f.length() != 0) { // add attribute indexnumber to filename and extension .hist String name = f + ".curv"; output = new PrintWriter(new FileOutputStream(name)); } else { return; } double diff = (max - min) / (numPoints - 1.0); try { text.append("" + min + " " + est.getProbability(min) * classEst.getProbability(classIndex) + " \n"); for (double value = min + diff; value < max; value += diff) { text.append("" + value + " " + est.getProbability(value) * classEst.getProbability(classIndex) + " \n"); } text.append("" + max + " " + est.getProbability(max) * classEst.getProbability(classIndex) + " \n"); } catch (Exception ex) { ex.printStackTrace(); System.out.println(ex.getMessage()); } output.println(text.toString()); // close output if (output != null) { output.close(); } } /** * Returns a dataset that contains of all instances of a certain value for the * given attribute. * * @param data dataset to select the instances from * @param index the index of the attribute * @param v the value * @return a subdataset with only instances of one value for the attribute */ public static Instances getInstancesFromValue(Instances data, int index, double v) { Instances workData = new Instances(data, 0); for (int i = 0; i < data.numInstances(); i++) { if (data.instance(i).value(index) == v) { workData.add(data.instance(i)); } } return workData; } /** * Returns a string representing the cutpoints */ public static String cutpointsToString(double[] cutPoints, boolean[] cutAndLeft) { StringBuffer text = new StringBuffer(""); if (cutPoints == null) { text.append("\n# no cutpoints found - attribute \n"); } else { text.append("\n#* " + cutPoints.length + " cutpoint(s) -\n"); for (int i = 0; i < cutPoints.length; i++) { text.append("# " + cutPoints[i] + " "); text.append("" + cutAndLeft[i] + "\n"); } text.append("# end\n"); } return text.toString(); } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 10153 $"); } }