weka.classifiers.CostMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* CostMatrix.java
* Copyright (C) 2006-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers;
import java.io.LineNumberReader;
import java.io.Reader;
import java.io.Serializable;
import java.io.StreamTokenizer;
import java.io.Writer;
import java.util.Random;
import java.util.StringTokenizer;
import weka.core.AttributeExpression;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
/**
* Class for storing and manipulating a misclassification cost matrix. The
* element at position i,j in the matrix is the penalty for classifying an
* instance of class j as class i. Cost values can be fixed or computed on a
* per-instance basis (cost sensitive evaluation only) from the value of an
* attribute or an expression involving attribute(s).
*
* @author Mark Hall
* @author Richard Kirkby ([email protected])
* @version $Revision: 10141 $
*/
public class CostMatrix implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = -1973792250544554965L;
private int m_size;
/** [rows][columns] */
protected Object[][] m_matrix;
/** The deafult file extension for cost matrix files */
public static String FILE_EXTENSION = ".cost";
/**
* Creates a default cost matrix of a particular size. All diagonal values
* will be 0 and all non-diagonal values 1.
*
* @param numOfClasses the number of classes that the cost matrix holds.
*/
public CostMatrix(int numOfClasses) {
m_size = numOfClasses;
initialize();
}
/**
* Creates a cost matrix that is a copy of another.
*
* @param toCopy the matrix to copy.
*/
public CostMatrix(CostMatrix toCopy) {
this(toCopy.size());
for (int i = 0; i < m_size; i++) {
for (int j = 0; j < m_size; j++) {
setCell(i, j, toCopy.getCell(i, j));
}
}
}
/**
* Initializes the matrix
*/
public void initialize() {
m_matrix = new Object[m_size][m_size];
for (int i = 0; i < m_size; i++) {
for (int j = 0; j < m_size; j++) {
setCell(i, j, i == j ? new Double(0.0) : new Double(1.0));
}
}
}
/**
* The number of rows (and columns)
*
* @return the size of the matrix
*/
public int size() {
return m_size;
}
/**
* Same as size
*
* @return the number of columns
*/
public int numColumns() {
return size();
}
/**
* Same as size
*
* @return the number of rows
*/
public int numRows() {
return size();
}
private boolean replaceStrings() throws Exception {
boolean nonDouble = false;
for (int i = 0; i < m_size; i++) {
for (int j = 0; j < m_size; j++) {
if (getCell(i, j) instanceof String) {
AttributeExpression temp = new AttributeExpression();
temp.convertInfixToPostfix((String) getCell(i, j));
setCell(i, j, temp);
nonDouble = true;
} else if (getCell(i, j) instanceof AttributeExpression) {
nonDouble = true;
}
}
}
return nonDouble;
}
/**
* Applies the cost matrix to a set of instances. If a random number generator
* is supplied the instances will be resampled, otherwise they will be
* rewighted. Adapted from code once sitting in Instances.java
*
* @param data the instances to reweight.
* @param random a random number generator for resampling, if null then
* instances are rewighted.
* @return a new dataset reflecting the cost of misclassification.
* @exception Exception if the data has no class or the matrix in
* inappropriate.
*/
public Instances applyCostMatrix(Instances data, Random random)
throws Exception {
double sumOfWeightFactors = 0, sumOfMissClassWeights, sumOfWeights;
double[] weightOfInstancesInClass, weightFactor, weightOfInstances;
if (data.classIndex() < 0) {
throw new Exception("Class index is not set!");
}
if (size() != data.numClasses()) {
throw new Exception("Misclassification cost matrix has wrong format!");
}
// are there any non-fixed, per-instance costs defined in the matrix?
if (replaceStrings()) {
// could reweight in the two class case
if (data.classAttribute().numValues() > 2) {
throw new Exception("Can't resample/reweight instances using "
+ "non-fixed cost values when there are more "
+ "than two classes!");
} else {
// Store new weights
weightOfInstances = new double[data.numInstances()];
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
int classValIndex = (int) inst.classValue();
double factor = 1.0;
Object element = (classValIndex == 0) ? getCell(classValIndex, 1)
: getCell(classValIndex, 0);
if (element instanceof Double) {
factor = ((Double) element).doubleValue();
} else {
factor = ((AttributeExpression) element).evaluateExpression(inst);
}
weightOfInstances[i] = inst.weight() * factor;
/*
* System.err.println("Multiplying " +
* inst.classAttribute().value((int)inst.classValue()) +" by factor "
* + factor);
*/
}
// Change instances weight or do resampling
if (random != null) {
return data.resampleWithWeights(random, weightOfInstances);
} else {
Instances instances = new Instances(data);
for (int i = 0; i < data.numInstances(); i++) {
instances.instance(i).setWeight(weightOfInstances[i]);
}
return instances;
}
}
}
weightFactor = new double[data.numClasses()];
weightOfInstancesInClass = new double[data.numClasses()];
for (int j = 0; j < data.numInstances(); j++) {
weightOfInstancesInClass[(int) data.instance(j).classValue()] += data
.instance(j).weight();
}
sumOfWeights = Utils.sum(weightOfInstancesInClass);
// normalize the matrix if not already
for (int i = 0; i < m_size; i++) {
if (!Utils.eq(((Double) getCell(i, i)).doubleValue(), 0)) {
CostMatrix normMatrix = new CostMatrix(this);
normMatrix.normalize();
return normMatrix.applyCostMatrix(data, random);
}
}
for (int i = 0; i < data.numClasses(); i++) {
// Using Kai Ming Ting's formula for deriving weights for
// the classes and Breiman's heuristic for multiclass
// problems.
sumOfMissClassWeights = 0;
for (int j = 0; j < data.numClasses(); j++) {
if (Utils.sm(((Double) getCell(i, j)).doubleValue(), 0)) {
throw new Exception("Neg. weights in misclassification "
+ "cost matrix!");
}
sumOfMissClassWeights += ((Double) getCell(i, j)).doubleValue();
}
weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
sumOfWeightFactors += sumOfMissClassWeights * weightOfInstancesInClass[i];
}
for (int i = 0; i < data.numClasses(); i++) {
weightFactor[i] /= sumOfWeightFactors;
}
// Store new weights
weightOfInstances = new double[data.numInstances()];
for (int i = 0; i < data.numInstances(); i++) {
weightOfInstances[i] = data.instance(i).weight()
* weightFactor[(int) data.instance(i).classValue()];
}
// Change instances weight or do resampling
if (random != null) {
return data.resampleWithWeights(random, weightOfInstances);
} else {
Instances instances = new Instances(data);
for (int i = 0; i < data.numInstances(); i++) {
instances.instance(i).setWeight(weightOfInstances[i]);
}
return instances;
}
}
/**
* Calculates the expected misclassification cost for each possible class
* value, given class probability estimates.
*
* @param classProbs the class probability estimates.
* @return the expected costs.
* @exception Exception if the wrong number of class probabilities is
* supplied.
*/
public double[] expectedCosts(double[] classProbs) throws Exception {
if (classProbs.length != m_size) {
throw new Exception("Length of probability estimates don't "
+ "match cost matrix");
}
double[] costs = new double[m_size];
for (int x = 0; x < m_size; x++) {
for (int y = 0; y < m_size; y++) {
Object element = getCell(y, x);
if (!(element instanceof Double)) {
throw new Exception("Can't use non-fixed costs in "
+ "computing expected costs.");
}
costs[x] += classProbs[y] * ((Double) element).doubleValue();
}
}
return costs;
}
/**
* Calculates the expected misclassification cost for each possible class
* value, given class probability estimates.
*
* @param classProbs the class probability estimates.
* @param inst the current instance for which the class probabilites apply. Is
* used for computing any non-fixed cost values.
* @return the expected costs.
* @exception Exception if something goes wrong
*/
public double[] expectedCosts(double[] classProbs, Instance inst)
throws Exception {
if (classProbs.length != m_size) {
throw new Exception("Length of probability estimates don't "
+ "match cost matrix");
}
if (!replaceStrings()) {
return expectedCosts(classProbs);
}
double[] costs = new double[m_size];
for (int x = 0; x < m_size; x++) {
for (int y = 0; y < m_size; y++) {
Object element = getCell(y, x);
double costVal;
if (!(element instanceof Double)) {
costVal = ((AttributeExpression) element).evaluateExpression(inst);
} else {
costVal = ((Double) element).doubleValue();
}
costs[x] += classProbs[y] * costVal;
}
}
return costs;
}
/**
* Gets the maximum cost for a particular class value.
*
* @param classVal the class value.
* @return the maximum cost.
* @exception Exception if cost matrix contains non-fixed costs
*/
public double getMaxCost(int classVal) throws Exception {
double maxCost = Double.NEGATIVE_INFINITY;
for (int i = 0; i < m_size; i++) {
Object element = getCell(classVal, i);
if (!(element instanceof Double)) {
throw new Exception("Can't use non-fixed costs when "
+ "getting max cost.");
}
double cost = ((Double) element).doubleValue();
if (cost > maxCost)
maxCost = cost;
}
return maxCost;
}
/**
* Gets the maximum cost for a particular class value.
*
* @param classVal the class value.
* @return the maximum cost.
* @exception Exception if cost matrix contains non-fixed costs
*/
public double getMaxCost(int classVal, Instance inst) throws Exception {
if (!replaceStrings()) {
return getMaxCost(classVal);
}
double maxCost = Double.NEGATIVE_INFINITY;
double cost;
for (int i = 0; i < m_size; i++) {
Object element = getCell(classVal, i);
if (!(element instanceof Double)) {
cost = ((AttributeExpression) element).evaluateExpression(inst);
} else {
cost = ((Double) element).doubleValue();
}
if (cost > maxCost)
maxCost = cost;
}
return maxCost;
}
/**
* Normalizes the matrix so that the diagonal contains zeros.
*
*/
public void normalize() {
for (int y = 0; y < m_size; y++) {
double diag = ((Double) getCell(y, y)).doubleValue();
for (int x = 0; x < m_size; x++) {
setCell(x, y, new Double(((Double) getCell(x, y)).doubleValue() - diag));
}
}
}
/**
* Loads a cost matrix in the old format from a reader. Adapted from code once
* sitting in Instances.java
*
* @param reader the reader to get the values from.
* @exception Exception if the matrix cannot be read correctly.
*/
public void readOldFormat(Reader reader) throws Exception {
StreamTokenizer tokenizer;
int currentToken;
double firstIndex, secondIndex, weight;
tokenizer = new StreamTokenizer(reader);
initialize();
tokenizer.commentChar('%');
tokenizer.eolIsSignificant(true);
while (StreamTokenizer.TT_EOF != (currentToken = tokenizer.nextToken())) {
// Skip empty lines
if (currentToken == StreamTokenizer.TT_EOL) {
continue;
}
// Get index of first class.
if (currentToken != StreamTokenizer.TT_NUMBER) {
throw new Exception("Only numbers and comments allowed "
+ "in cost file!");
}
firstIndex = tokenizer.nval;
if (!Utils.eq((int) firstIndex, firstIndex)) {
throw new Exception("First number in line has to be "
+ "index of a class!");
}
if ((int) firstIndex >= size()) {
throw new Exception("Class index out of range!");
}
// Get index of second class.
if (StreamTokenizer.TT_EOF == (currentToken = tokenizer.nextToken())) {
throw new Exception("Premature end of file!");
}
if (currentToken == StreamTokenizer.TT_EOL) {
throw new Exception("Premature end of line!");
}
if (currentToken != StreamTokenizer.TT_NUMBER) {
throw new Exception("Only numbers and comments allowed "
+ "in cost file!");
}
secondIndex = tokenizer.nval;
if (!Utils.eq((int) secondIndex, secondIndex)) {
throw new Exception("Second number in line has to be "
+ "index of a class!");
}
if ((int) secondIndex >= size()) {
throw new Exception("Class index out of range!");
}
if ((int) secondIndex == (int) firstIndex) {
throw new Exception("Diagonal of cost matrix non-zero!");
}
// Get cost factor.
if (StreamTokenizer.TT_EOF == (currentToken = tokenizer.nextToken())) {
throw new Exception("Premature end of file!");
}
if (currentToken == StreamTokenizer.TT_EOL) {
throw new Exception("Premature end of line!");
}
if (currentToken != StreamTokenizer.TT_NUMBER) {
throw new Exception("Only numbers and comments allowed "
+ "in cost file!");
}
weight = tokenizer.nval;
if (!Utils.gr(weight, 0)) {
throw new Exception("Only positive weights allowed!");
}
setCell((int) firstIndex, (int) secondIndex, new Double(weight));
}
}
/**
* Reads a matrix from a reader. The first line in the file should contain the
* number of rows and columns. Subsequent lines contain elements of the
* matrix. (FracPete: taken from old weka.core.Matrix class)
*
* @param reader the reader containing the matrix
* @throws Exception if an error occurs
* @see #write(Writer)
*/
public CostMatrix(Reader reader) throws Exception {
LineNumberReader lnr = new LineNumberReader(reader);
String line;
int currentRow = -1;
while ((line = lnr.readLine()) != null) {
// Comments
if (line.startsWith("%")) {
continue;
}
StringTokenizer st = new StringTokenizer(line);
// Ignore blank lines
if (!st.hasMoreTokens()) {
continue;
}
if (currentRow < 0) {
int rows = Integer.parseInt(st.nextToken());
if (!st.hasMoreTokens()) {
throw new Exception("Line " + lnr.getLineNumber()
+ ": expected number of columns");
}
int cols = Integer.parseInt(st.nextToken());
if (rows != cols) {
throw new Exception("Trying to create a non-square cost " + "matrix");
}
// m_matrix = new Object[rows][cols];
m_size = rows;
initialize();
currentRow++;
continue;
} else {
if (currentRow == m_size) {
throw new Exception("Line " + lnr.getLineNumber()
+ ": too many rows provided");
}
for (int i = 0; i < m_size; i++) {
if (!st.hasMoreTokens()) {
throw new Exception("Line " + lnr.getLineNumber()
+ ": too few matrix elements provided");
}
String nextTok = st.nextToken();
// try to parse as a double first
Double val = null;
try {
val = new Double(nextTok);
} catch (Exception ex) {
val = null;
}
if (val == null) {
setCell(currentRow, i, nextTok);
} else {
setCell(currentRow, i, val);
}
}
currentRow++;
}
}
if (currentRow == -1) {
throw new Exception("Line " + lnr.getLineNumber()
+ ": expected number of rows");
} else if (currentRow != m_size) {
throw new Exception("Line " + lnr.getLineNumber()
+ ": too few rows provided");
}
}
/**
* Writes out a matrix. The format can be read via the CostMatrix(Reader)
* constructor. (FracPete: taken from old weka.core.Matrix class)
*
* @param w the output Writer
* @throws Exception if an error occurs
*/
public void write(Writer w) throws Exception {
w.write("% Rows\tColumns\n");
w.write("" + m_size + "\t" + m_size + "\n");
w.write("% Matrix elements\n");
for (int i = 0; i < m_size; i++) {
for (int j = 0; j < m_size; j++) {
w.write("" + getCell(i, j) + "\t");
}
w.write("\n");
}
w.flush();
}
/**
* converts the Matrix into a single line Matlab string: matrix is enclosed by
* parentheses, rows are separated by semicolon and single cells by blanks,
* e.g., [1 2; 3 4].
*
* @return the matrix in Matlab single line format
*/
public String toMatlab() {
StringBuffer result;
int i;
int n;
result = new StringBuffer();
result.append("[");
for (i = 0; i < m_size; i++) {
if (i > 0) {
result.append("; ");
}
for (n = 0; n < m_size; n++) {
if (n > 0) {
result.append(" ");
}
result.append(getCell(i, n));
}
}
result.append("]");
return result.toString();
}
/**
* creates a matrix from the given Matlab string.
*
* @param matlab the matrix in matlab format
* @return the matrix represented by the given string
* @see #toMatlab()
*/
public static CostMatrix parseMatlab(String matlab) throws Exception {
StringTokenizer tokRow;
StringTokenizer tokCol;
int rows;
int cols;
CostMatrix result;
String cells;
// get content
cells = matlab.substring(matlab.indexOf("[") + 1, matlab.indexOf("]"))
.trim();
// determine dimenions
tokRow = new StringTokenizer(cells, ";");
rows = tokRow.countTokens();
tokCol = new StringTokenizer(tokRow.nextToken(), " ");
cols = tokCol.countTokens();
// fill matrix
result = new CostMatrix(rows);
tokRow = new StringTokenizer(cells, ";");
rows = 0;
while (tokRow.hasMoreTokens()) {
tokCol = new StringTokenizer(tokRow.nextToken(), " ");
cols = 0;
while (tokCol.hasMoreTokens()) {
// is it a number
String current = tokCol.nextToken();
try {
double val = Double.parseDouble(current);
result.setCell(rows, cols, new Double(val));
} catch (NumberFormatException e) {
// must be an expression
result.setCell(rows, cols, current);
}
cols++;
}
rows++;
}
return result;
}
/**
* Set the value of a particular cell in the matrix
*
* @param rowIndex the row
* @param columnIndex the column
* @param value the value to set
*/
public final void setCell(int rowIndex, int columnIndex, Object value) {
m_matrix[rowIndex][columnIndex] = value;
}
/**
* Return the contents of a particular cell. Note: this method returns the
* Object stored at a particular cell.
*
* @param rowIndex the row
* @param columnIndex the column
* @return the value at the cell
*/
public final Object getCell(int rowIndex, int columnIndex) {
return m_matrix[rowIndex][columnIndex];
}
/**
* Return the value of a cell as a double (for legacy code)
*
* @param rowIndex the row
* @param columnIndex the column
* @return the value at a particular cell as a double
* @exception Exception if the value is not a double
*/
public final double getElement(int rowIndex, int columnIndex)
throws Exception {
if (!(m_matrix[rowIndex][columnIndex] instanceof Double)) {
throw new Exception("Cost matrix contains non-fixed costs!");
}
return ((Double) m_matrix[rowIndex][columnIndex]).doubleValue();
}
/**
* Return the value of a cell as a double. Computes the value for non-fixed
* costs using the supplied Instance
*
* @param rowIndex the row
* @param columnIndex the column
* @return the value from a particular cell
* @exception Exception if something goes wrong
*/
public final double getElement(int rowIndex, int columnIndex, Instance inst)
throws Exception {
if (m_matrix[rowIndex][columnIndex] instanceof Double) {
return ((Double) m_matrix[rowIndex][columnIndex]).doubleValue();
} else if (m_matrix[rowIndex][columnIndex] instanceof String) {
replaceStrings();
}
return ((AttributeExpression) m_matrix[rowIndex][columnIndex])
.evaluateExpression(inst);
}
/**
* Set the value of a cell as a double
*
* @param rowIndex the row
* @param columnIndex the column
* @param value the value (double) to set
*/
public final void setElement(int rowIndex, int columnIndex, double value) {
m_matrix[rowIndex][columnIndex] = new Double(value);
}
/**
* Converts a matrix to a string. (FracPete: taken from old weka.core.Matrix
* class)
*
* @return the converted string
*/
@Override
public String toString() {
// Determine the width required for the maximum element,
// and check for fractional display requirement.
double maxval = 0;
boolean fractional = false;
Object element = null;
int widthNumber = 0;
int widthExpression = 0;
for (int i = 0; i < size(); i++) {
for (int j = 0; j < size(); j++) {
element = getCell(i, j);
if (element instanceof Double) {
double current = ((Double) element).doubleValue();
if (current < 0)
current *= -11;
if (current > maxval)
maxval = current;
double fract = Math.abs(current - Math.rint(current));
if (!fractional && ((Math.log(fract) / Math.log(10)) >= -2)) {
fractional = true;
}
} else {
if (element.toString().length() > widthExpression) {
widthExpression = element.toString().length();
}
}
}
}
if (maxval > 0) {
widthNumber = (int) (Math.log(maxval) / Math.log(10) + (fractional ? 4
: 1));
}
int width = (widthNumber > widthExpression) ? widthNumber : widthExpression;
StringBuffer text = new StringBuffer();
for (int i = 0; i < size(); i++) {
for (int j = 0; j < size(); j++) {
element = getCell(i, j);
if (element instanceof Double) {
text.append(" ").append(
Utils.doubleToString(((Double) element).doubleValue(), width,
(fractional ? 2 : 0)));
} else {
int diff = width - element.toString().length();
if (diff > 0) {
int left = diff % 2;
left += diff / 2;
String temp = Utils.padLeft(element.toString(), element.toString()
.length() + left);
temp = Utils.padRight(temp, width);
text.append(" ").append(temp);
} else {
text.append(" ").append(element.toString());
}
}
}
text.append("\n");
}
return text.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 10141 $");
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy