
weka.classifiers.neural.lvq.HierarchalLvq Maven / Gradle / Ivy
Go to download
Fork of the following defunct sourceforge.net project: https://sourceforge.net/projects/wekaclassalgos/
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package weka.classifiers.neural.lvq;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.neural.lvq.initialise.InitialisationFactory;
import weka.classifiers.neural.lvq.model.CodebookVector;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import java.text.NumberFormat;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Vector;
/**
* Date: 22/05/2004
* File: HierarchalLVQ.java
*
* @author Jason Brownlee
*/
public class HierarchalLvq extends AbstractClassifier
implements WeightedInstancesHandler {
/**
* Index of the initialsiation parameter
*/
private final static int PARAM_BASE_ALGORITHM = 0;
/**
* Index of the code book parameter
*/
private final static int PARAM_SUB_CLASSIFIER = 1;
/**
* Index of the training iterations parameter
*/
private final static int PARAM_ERROR_PERCENTAGE = 2;
/**
* Index of the learning function parameter
*/
private final static int PARAM_HIT_PERCENTAGE = 3;
/**
* Common LVQ algorithm parameters
*/
private final static String[] PARAMETERS =
{
"B", // base lvq algorithm
"S", // sub algorithm
"E", // error percentage
"H" // hit percentage
};
private final static String[] PARAMETER_NOTES =
{
" ", // base lvq algorithm
"", // sub algorithm
"", // error percentage
"" // hit percentage
};
/**
* Descriptions of common LVQ algorithm parameters
*/
private final static String[] PARAM_DESCRIPTIONS =
{
"LVQ algorthm to construt the base LVQ model.",
"Algorithm to use to construct the bmu sub models.",
"Percentage of training error a bmu must achieve to be considered a candidate for a sub model.",
"Percentage of total training hits a bmu must achieve to be considered a candidate for a sub model."
};
/**
* Total number of classes in dataset
*/
protected int numClasses;
/**
* Total number of attributes in dataset
*/
protected int numAttributes;
/**
* Base LVQ model used for clustering
*/
private AlgorithmAncestor baseAlgorithm;
/**
* Type and configuration of classifier to use for sub models
*/
private Classifier subModelType;
/**
* Percentage of data running through a bmu for it to be considered a cluster
*/
private double hitPercentage;
/**
* Percentage of error a bmu must exibit to be considered a cluster
*/
private double errorPercentage;
/**
* Collection of sub models used instead of BMU's, indexed on bmu id
*/
private Classifier[] subModels;
/**
* Training data used for training sub models, indexed on bmu id
*/
private Instances[] subModelTrainingData;
/**
* Whether or not a sub model is used for each bmu id
*/
private boolean[] subModelUsed;
/**
* Matrix of bmu usage calculated after base model construction, using training data
*/
private int[][] trainingBmuUsage;
/**
* Total number of bmu hits (sum of training bmu matrix)
*/
private long totalTrainingBmuHits;
/**
* Accuracy of sub models on training data
*/
private double[] subModelAccuracy;
/**
* Formatter used for producing useful information to the user
*/
private final static NumberFormat format = NumberFormat.getPercentInstance();
public HierarchalLvq() {
// prepare defaults
baseAlgorithm = new Olvq1();
baseAlgorithm.setInitialisationMode(new SelectedTag(InitialisationFactory.INITALISE_TRAINING_EVEN, InitialisationFactory.TAGS_MODEL_INITALISATION));
subModelType = new MultipassLvq();
hitPercentage = 1.0;
errorPercentage = 10.0;
}
private void evaluateAndPruneClassifiers()
throws Exception {
for (int i = 0; i < subModelUsed.length; i++) {
if (subModelUsed[i]) {
// determine bmu's quality
double bmuQuality = calculateBmuQuality(i);
// determine the sub-model's quality
subModelAccuracy[i] = calculateSubModelQuality(i);
// check if the quality of the sub model is worse than the bmu
if (bmuQuality >= subModelAccuracy[i]) {
// prune the sub-model
subModelUsed[i] = false;
}
// else keep the sub-model
}
}
}
private String prepareSubModelAccuracyReport() {
StringBuffer buffer = new StringBuffer(1024);
buffer.append("-- BMU Sub-Model Accuracy --\n");
buffer.append("bmu,\t%bmu,\t%model,\t%better,\tpruned\n");
for (int i = 0; i < subModelUsed.length; i++) {
if (isBmuIdCandidate(i)) {
// determine bmu's quality
double bmuQuality = calculateBmuQuality(i);
double improvement = subModelAccuracy[i] - bmuQuality;
buffer.append(i);
buffer.append(",\t");
buffer.append(format.format(bmuQuality / 100.0));
buffer.append(",\t");
buffer.append(format.format(subModelAccuracy[i] / 100.0));
buffer.append(",\t");
buffer.append(format.format(improvement / 100.0));
buffer.append(",\t\t");
if (!subModelUsed[i]) {
buffer.append("true");
}
buffer.append("\n");
}
}
return buffer.toString();
}
private double calculateSubModelQuality(int aBmuId)
throws Exception {
Evaluation eval = new Evaluation(subModelTrainingData[aBmuId]);
eval.evaluateModel(subModels[aBmuId], subModelTrainingData[aBmuId]);
return eval.pctCorrect();
}
private double calculateBmuQuality(int aBmuId) {
int totalHits = (trainingBmuUsage[aBmuId][0] + trainingBmuUsage[aBmuId][1]);
// check for a sum hits of zero
if (totalHits == 0) {
return 0.0;
}
// total / total possible
double percentCorrect = ((double) trainingBmuUsage[aBmuId][0] / (double) totalHits);
// make useable
percentCorrect *= 100.0;
return percentCorrect;
}
private void prepareClassifiersForCandidateClusters(Instances trainingDataset)
throws Exception {
for (int i = 0; i < subModelUsed.length; i++) {
if (subModelUsed[i]) {
// initialise and train the classifier
subModels[i] = prepareClusterClassifier(subModelTrainingData[i], trainingDataset, i);
}
}
}
private Classifier prepareClusterClassifier(Instances aClusterTrainingSet, Instances aTrainingSet, int aClusterNumber)
throws Exception {
// clone the selected model type
Classifier clusterClassifier = AbstractClassifier.makeCopies(subModelType, 1)[0];
// train the model
clusterClassifier.buildClassifier(aClusterTrainingSet);
return clusterClassifier;
}
private void prepareDataForCandidateClusters(Instances trainingDataset) {
// sort all training data by bmu
LinkedList[] tmpList = new LinkedList[subModelTrainingData.length];
for (int i = 0; i < trainingDataset.numInstances(); i++) {
Instance instance = trainingDataset.instance(i);
CodebookVector codebook = baseAlgorithm.getModel().getBmu(instance);
int id = codebook.getId();
if (tmpList[id] == null) {
tmpList[id] = new LinkedList();
}
tmpList[id].add(instance);
}
// convert datasets for known clusters into usable training data
for (int i = 0; i < subModelUsed.length; i++) {
if (subModelUsed[i]) {
// check for no data in cluster
if (tmpList[i] == null || tmpList[i].isEmpty()) {
subModelUsed[i] = false;
}
else {
subModelTrainingData[i] = linkedListToInstances(tmpList[i], trainingDataset);
}
}
tmpList[i] = null; // reduce memory on the fly
}
}
private Instances linkedListToInstances(LinkedList aListOfInstances, Instances aInstances) {
Instances instances = new Instances(aInstances, aListOfInstances.size());
for (Iterator iter = aListOfInstances.iterator(); iter.hasNext(); ) {
Instance element = (Instance) iter.next();
instances.add(element);
}
return instances;
}
private void selectBMUCandidateClusters() {
for (int i = 0; i < trainingBmuUsage.length; i++) {
if (isBmuIdCandidate(i)) {
subModelUsed[i] = true;
}
}
}
private boolean isBmuIdCandidate(int aBmuId) {
double error = getBmusPercentageError(aBmuId);
double hits = getBmusHitPercentage(aBmuId);
return isCandidate(error, hits);
}
private boolean isCandidate(double error, double hits) {
// must have > n% of total hits
if (hits >= hitPercentage) {
// must have >= n% of hits are error
if (error >= errorPercentage) {
return true;
}
}
return false;
}
private String prepareSubModelSelectionReport() {
StringBuffer buffer = new StringBuffer(1024);
buffer.append("-- BMU Sub-Model Selection Report --\n");
buffer.append("bmu,\t%error,\t%hits,\tpruned\n");
for (int i = 0; i < trainingBmuUsage.length; i++) {
double error = getBmusPercentageError(i);
double hits = getBmusHitPercentage(i);
if (isCandidate(error, hits)) {
buffer.append(i);
buffer.append(",\t");
buffer.append(format.format(error / 100.0));
buffer.append(",\t");
buffer.append(format.format(hits / 100.0));
buffer.append(",\t");
// model no longer exists
if (!subModelUsed[i]) {
buffer.append("true");
}
buffer.append("\n");
}
}
return buffer.toString();
}
private double getBmusHitPercentage(int aBmuId) {
int totalHits = (trainingBmuUsage[aBmuId][0] + trainingBmuUsage[aBmuId][1]);
// check for a sum hits of zero
if (totalHits == 0) {
return 0.0;
}
// total / total possible
double percentHits = ((double) totalHits / (double) totalTrainingBmuHits);
// make useable
percentHits *= 100.0;
return percentHits;
}
private double getBmusPercentageError(int aBmuId) {
int totalHits = (trainingBmuUsage[aBmuId][0] + trainingBmuUsage[aBmuId][1]);
// check for a sum hits of zero
if (totalHits == 0) {
return 0.0;
}
// error / bmu hits
double error = ((double) trainingBmuUsage[aBmuId][1] / (double) totalHits);
// make useable
error *= 100.0;
return error;
}
private void getBmuHits(Instances trainingDataset)
throws Exception {
trainingBmuUsage = baseAlgorithm.getTrainingBmuUsage();
totalTrainingBmuHits = baseAlgorithm.getTotalTrainingBmuHits();
}
private void cleanup() {
// training data is no longer needed
subModelTrainingData = null;
}
private void initialiseBaseModel() {
int totalCodebookVectors = baseAlgorithm.getTotalCodebookVectors();
baseAlgorithm.setDebug(m_Debug);
// prepare other bits
subModelTrainingData = new Instances[totalCodebookVectors];
subModels = new Classifier[totalCodebookVectors];
subModelUsed = new boolean[totalCodebookVectors];
subModelAccuracy = new double[totalCodebookVectors];
}
/**
* Calcualte the class distribution for the provided instance
*
* @param instance - an instance to calculate the class distribution for
* @return double [] - class distribution for instance - all values are 0, exception for the
* index of the predicted class, which has the value of 1
* @throws Exception
*/
public double[] distributionForInstance(Instance instance)
throws Exception {
if (baseAlgorithm == null) {
throw new Exception("Model has not been prepared");
}
// verify number of classes
else if (instance.numClasses() != numClasses) {
throw new Exception("Number of classes in instance (" + instance.numClasses() + ") does not match expected (" + numClasses + ").");
}
// verify the number of attributes
else if (instance.numAttributes() != numAttributes) {
throw new Exception("Number of attributes in instance (" + instance.numAttributes() + ") does not match expected (" + numAttributes + ").");
}
// get the bmu
double[] classDistribution = null;
CodebookVector bmu = baseAlgorithm.getModel().getBmu(instance);
// check if the bmu is used for classification
if (!subModelUsed[bmu.getId()]) {
classDistribution = new double[numClasses];
// there is no class distribution, only the predicted class
if (baseAlgorithm.getUseVoting()) {
// return the class distribution
int[] distribution = bmu.getClassHitDistribution();
int total = 0;
// calculate the total hits
for (int i = 0; i < distribution.length; i++) {
total += distribution[i];
}
// calculate percentages for each class
for (int i = 0; i < classDistribution.length; i++) {
classDistribution[i] = ((double) distribution[i] / (double) total);
}
}
else {
int index = (int) bmu.getClassification();
classDistribution[index] = 1.0;
}
}
// use the sub model
else {
classDistribution = subModels[bmu.getId()].distributionForInstance(instance);
}
return classDistribution;
}
/**
* Returns the Capabilities of this classifier.
*
* @return the capabilities of this object
* @see Capabilities
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.NOMINAL_ATTRIBUTES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
result.setMinimumNumberInstances(1);
return result;
}
/**
* Build a model of the provided training dataset using the specific LVQ
* algorithm implementation. The model is constructed (if not already provided),
* it is initialised, then the model is trained (constructed) using
* the specific implementation of the LVQ algorithm by calling
* prepareLVQClassifier()
*
* @param instances - training dataset.
* @throws Exception
*/
public void buildClassifier(Instances instances)
throws Exception {
// validate user provided arguments
validateAlgorithmArguments();
// prepare the dataset
Instances trainingDataset = prepareDataset(instances);
// prepare elements based on base model
initialiseBaseModel();
// build the base model
baseAlgorithm.buildClassifier(trainingDataset);
// extract bmu usage
getBmuHits(trainingDataset);
// select bmu candidate clusters
selectBMUCandidateClusters();
// prepare data for candidate clusters
prepareDataForCandidateClusters(trainingDataset);
// prepre classifiers for candidate clusters
prepareClassifiersForCandidateClusters(trainingDataset);
// evaluate candidate clusters, prune models that perform worse than BMU
evaluateAndPruneClassifiers();
// clean up data no longer needed
cleanup();
}
/**
* Verify the dataset can be used with the LVQ algorithm and store details about
* the nature of the data.
* Rules:
*
* - Class must be assigned
* - Class must be nominal
* - Must be atleast 1 training instance
* - Must have attributes besides the class attribute
*
*
* @param instances - training dataset
* @return - all instances that can be used for training
* @throws Exception
*/
protected Instances prepareDataset(final Instances instances)
throws Exception {
Instances trainingInstances = new Instances(instances);
trainingInstances.deleteWithMissingClass();
getCapabilities().testWithFail(trainingInstances);
numClasses = trainingInstances.numClasses();
numAttributes = trainingInstances.numAttributes();
// return training instances
return trainingInstances;
}
protected void validateAlgorithmArguments() throws Exception {
if (baseAlgorithm == null) {
throw new Exception("An LVQ algorithm used to construct the base model must be specified.");
}
else if (subModelType == null) {
throw new Exception("A algorithm used to construct sub models for bmu's must be specified.");
}
else if (errorPercentage < 0.0 || errorPercentage > 100.0) {
throw new Exception("Error percentage must be in the range of 0.0 to 100.0.");
}
else if (hitPercentage < 0.0 || hitPercentage > 100.0) {
throw new Exception("Hit percentage must be in the range of 0.0 to 100.0.");
}
}
public Enumeration listOptions() {
Vector list = new Vector(PARAMETERS.length);
for (int i = 0; i < PARAMETERS.length; i++) {
String param = "-" + PARAMETERS[i] + " " + PARAMETER_NOTES[i];
list.add(new Option("\t" + PARAM_DESCRIPTIONS[i], PARAMETERS[i], 1, param));
}
return list.elements();
}
public void setOptions(String[] options)
throws Exception {
for (int i = 0; i < PARAMETERS.length; i++) {
String data = Utils.getOption(PARAMETERS[i].charAt(0), options);
if (data == null || data.length() == 0) {
continue;
}
switch (i) {
case PARAM_BASE_ALGORITHM: {
setBaseLVQAlgorithm(prepareClassifierFromParameterString(data));
break;
}
case PARAM_SUB_CLASSIFIER: {
setSubModelAlgorithm(prepareClassifierFromParameterString(data));
break;
}
case PARAM_ERROR_PERCENTAGE: {
errorPercentage = Double.parseDouble(data);
break;
}
case PARAM_HIT_PERCENTAGE: {
hitPercentage = Double.parseDouble(data);
break;
}
default: {
throw new Exception("Invalid option offset: " + i);
}
}
}
}
protected boolean hasValue(String aString) {
return (aString != null && aString.length() != 0);
}
public String[] getOptions() {
LinkedList list = new LinkedList();
list.add("-" + PARAMETERS[PARAM_BASE_ALGORITHM]);
list.add(getClassifierSpec(baseAlgorithm));
list.add("-" + PARAMETERS[PARAM_SUB_CLASSIFIER]);
list.add(getClassifierSpec(subModelType));
list.add("-" + PARAMETERS[PARAM_ERROR_PERCENTAGE]);
list.add(Double.toString(errorPercentage));
list.add("-" + PARAMETERS[PARAM_HIT_PERCENTAGE]);
list.add(Double.toString(hitPercentage));
return (String[]) list.toArray(new String[list.size()]);
}
protected String getClassifierSpec(Classifier c) {
String name = c.getClass().getName();
String params = "";
if (c instanceof OptionHandler)
params = Utils.joinOptions(((OptionHandler) c).getOptions());
return (name + " " + params).trim();
}
/**
* Prepares the provided classifier from a string
*
* @param s
* @return - classifier instance
* @throws Exception
*/
private Classifier prepareClassifierFromParameterString(String s)
throws Exception {
String[] classifierSpec = null;
String classifierName = null;
// split the string into its componenets
classifierSpec = Utils.splitOptions(s);
// verify some componets were specified
if (classifierSpec.length == 0) {
throw new Exception("Invalid classifier specification string");
}
// copy the name, then clear it from the list (it will not be a valid param for itself)
classifierName = classifierSpec[0];
classifierSpec[0] = "";
// consrtuct the classifier with its params
return AbstractClassifier.forName(classifierName, classifierSpec);
}
public String globalInfo() {
StringBuffer buffer = new StringBuffer();
buffer.append("Hierarchal version of the LVQ/SOM algorithm where per-bmu models are used in some cases. ");
buffer.append("For those bmu's that perform poorly, a sub-model is created to handle all classification ");
buffer.append("tasks for data instances that match onto that bmu. Firstly the base LVQ/SOM model is constructed ");
buffer.append("then evaluated, bmu's that are candidates for sub models are identifed based on their hit percentages. ");
buffer.append("Finally sub models for candidate bmu's are created and evaluated. Those sub-models that ");
buffer.append("out-perform their parent bmu (on the training data) are kept. \nUnlimited nesting of LVQ models ");
buffer.append("can be achieved by selecting the HierarchalLVQ algorthm as the sub model implementation.");
return buffer.toString();
}
public String toString() {
StringBuffer buffer = new StringBuffer();
if (super.m_Debug) {
// bmu hits report
if (baseAlgorithm.prepareBmuStatistis) {
buffer.append(baseAlgorithm.prepareTrainingBMUReport());
buffer.append("\n");
}
// class distributions for each codebook vector
buffer.append(baseAlgorithm.prepareIndividualClassDistributionReport());
buffer.append("\n");
// quantisation error
buffer.append(baseAlgorithm.quantisationErrorReport());
buffer.append("\n");
// codebook vectors
buffer.append(baseAlgorithm.prepareCodebookVectorReport());
buffer.append("\n");
}
// sub model selections
buffer.append(prepareSubModelSelectionReport());
buffer.append("\n");
// sub model accuracy
buffer.append(prepareSubModelAccuracyReport());
buffer.append("\n");
// build times
buffer.append(baseAlgorithm.prepareBuildTimeReport());
buffer.append("\n");
// distribution report
buffer.append(baseAlgorithm.prepareClassDistributionReport("-- Cass Distribution --"));
buffer.append("\n");
return buffer.toString();
}
public void setBaseLVQAlgorithm(Classifier aClassifier) {
if (aClassifier instanceof AlgorithmAncestor) {
baseAlgorithm = (AlgorithmAncestor) aClassifier;
}
else {
throw new IllegalArgumentException("Base algorithm must be a single pass LVQ or single pass SOM algorithm.");
}
}
public Classifier getBaseLVQAlgorithm() {
return baseAlgorithm;
}
public void setSubModelAlgorithm(Classifier aClassifier) {
subModelType = aClassifier;
}
public Classifier getSubModelAlgorithm() {
return subModelType;
}
public void setErrorPercentage(double aPercentage) {
errorPercentage = aPercentage;
}
public double getErrorPercentage() {
return errorPercentage;
}
public void setHitPercentage(double aPercentage) {
hitPercentage = aPercentage;
}
public double getHitPercentage() {
return hitPercentage;
}
public String baseLVQAlgorithmTipText() {
return PARAM_DESCRIPTIONS[PARAM_BASE_ALGORITHM];
}
public String subModelAlgorithmTipText() {
return PARAM_DESCRIPTIONS[PARAM_SUB_CLASSIFIER];
}
public String errorPercentageTipText() {
return PARAM_DESCRIPTIONS[PARAM_ERROR_PERCENTAGE];
}
public String hitPercentageTipText() {
return PARAM_DESCRIPTIONS[PARAM_HIT_PERCENTAGE];
}
/**
* Entry point into the algorithm for direct usage
*
* @param args
*/
public static void main(String[] args) {
runClassifier(new HierarchalLvq(), args);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy