weka.classifiers.immune.airs.algorithm.AIRS2Trainer Maven / Gradle / Ivy
Show all versions of wekaclassalgos Show documentation
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* Created on 30/12/2004
*
*/
package weka.classifiers.immune.airs.algorithm;
import weka.classifiers.immune.airs.algorithm.classification.MajorityVote;
import weka.classifiers.immune.airs.algorithm.initialisation.RandomInstancesInitialisation;
import weka.classifiers.immune.airs.algorithm.samplegeneration.StimulationProportionalMutation;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import java.text.NumberFormat;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
/**
* Type: AIRS1Trainer
* File: AIRS1Trainer.java
* Date: 30/12/2004
*
* Description:
*
* @author Jason Brownlee
*/
public class AIRS2Trainer implements AISTrainer {
protected final double affinityThresholdScalar;
protected final double clonalRate;
protected final double hyperMutationRate;
protected final double totalResources;
protected final double stimulationThreshold;
protected final int affinityThresholdNumInstances;
protected final Random rand;
protected final int memoryCellPoolInitialSize;
protected final int kNN;
protected AffinityFunction affinityFunction;
protected SampleGenerator arbSampleGeneration;
protected double affinityThreshold;
protected CellPool memoryCellPool;
// stats
protected double meanClonesArb;
protected double meanClonesMemCell;
protected double meanAllocatedResources;
protected double meanArbPoolSize;
protected double meanArbRefinementIterations;
protected long totalArbDeletions;
protected long totalMemoryCellReplacements;
protected long totalArbRefinementIterations;
protected long totalTrainingInstances;
public AIRS2Trainer(
double aAffinityThresholdScalar,
double aClonalRate,
double aHyperMutationRate,
double aTotalResources,
double aStimulationValue,
int aNumInstancesAffinityThreshold,
Random aRand,
int aMemoryCellPoolInitialSize,
int aKNN) {
affinityThresholdScalar = aAffinityThresholdScalar;
clonalRate = aClonalRate;
hyperMutationRate = aHyperMutationRate;
totalResources = aTotalResources;
stimulationThreshold = aStimulationValue;
affinityThresholdNumInstances = aNumInstancesAffinityThreshold;
rand = aRand;
memoryCellPoolInitialSize = aMemoryCellPoolInitialSize;
kNN = aKNN;
}
public void algorithmPreperation(Instances aInstances) {
affinityFunction = new AffinityFunction(aInstances);
arbSampleGeneration = prepareSampleGeneration(aInstances);
}
protected SampleGenerator prepareSampleGeneration(Instances aInstances) {
return new StimulationProportionalMutation(rand);
}
public String getTrainingSummary() {
StringBuilder buffer = new StringBuilder();
NumberFormat f = Utils.format;
buffer.append(" - Training Summary - \n");
buffer.append("Affinity Threshold:.............................." + f.format(affinityThreshold) + "\n");
buffer.append("Total training instances:........................" + f.format(totalTrainingInstances) + "\n");
buffer.append("Total memory cell replacements:.................." + f.format(totalMemoryCellReplacements) + "\n");
buffer.append("Mean ARB clones per refinement iteration:........" + f.format(meanClonesArb) + "\n");
buffer.append("Mean total resources per refinement iteration:..." + f.format(meanAllocatedResources) + "\n");
buffer.append("Mean pool size per refinement iteration:........." + f.format(meanArbPoolSize) + "\n");
buffer.append("Mean memory cell clones per antigen:............." + f.format(meanClonesMemCell) + "\n");
buffer.append("Mean ARB refinement iterations per antigen:......" + f.format(meanArbRefinementIterations) + "\n");
buffer.append("Mean ARB prunings per refinement iteration:......" + f.format((double) totalArbDeletions / (double) totalArbRefinementIterations) + "\n");
return buffer.toString();
}
public AISModelClassifier train(Instances instances)
throws Exception {
// normalise the dataset
Normalize normalise = new Normalize();
normalise.setInputFormat(instances);
Instances trainingSet = Filter.useFilter(instances, normalise);
// prepare the algorithm
algorithmPreperation(trainingSet);
// calculate affinity threshold
affinityThreshold = Utils.calculateAffinityThreshold(trainingSet, affinityThresholdNumInstances, rand, affinityFunction);
// perform the training
return internalTrain(trainingSet, normalise);
}
public void setAffinityThreshold(double a) {
affinityThreshold = a;
}
protected AISModelClassifier internalTrain(
Instances trainingSet,
Normalize normalise)
throws Exception {
// initialise model
initialise(trainingSet);
// train model on each instance
for (int i = 0; i < trainingSet.numInstances(); i++) {
Instance current = trainingSet.instance(i);
CellPool arbCellPool = new CellPool(new LinkedList());
// identify best match from memory pool
Cell bestMatch = identifyMemoryPoolBestMatch(current);
if (bestMatch == null) {
bestMatch = addNewMemoryCell(current);
}
// check for an identical match
else if (bestMatch.getStimulation() == 1.0) {
// do nothing
}
else {
// generate arbs and add to arb pool
generateARBs(arbCellPool, bestMatch, current);
// perform ARB refinement
Cell candidate = runARBRefinement(arbCellPool, current);
// respond to candidate
respondToCandidateMemoryCell(bestMatch, candidate, current);
}
// System.out.println("Finished "+(i+1)+"/"+trainingSet.numInstances());
}
// prepare statistics
prepareStatistics(trainingSet.numInstances());
// prepare the classifier
AISModelClassifier classifier = getClassifier(normalise);
return classifier;
}
protected void prepareStatistics(int aNumTrainingInstances) {
totalTrainingInstances = aNumTrainingInstances;
meanClonesArb /= totalArbRefinementIterations;
meanClonesMemCell /= totalTrainingInstances;
meanAllocatedResources /= totalArbRefinementIterations;
meanArbPoolSize /= totalArbRefinementIterations;
meanArbRefinementIterations = ((double) totalArbRefinementIterations / (double) totalTrainingInstances);
}
protected Cell runARBRefinement(
CellPool aArbCellPool,
Instance aInstance) {
boolean stopCondition = false;
Cell candidateMemoryCell = null;
do {
// perform competition for resources
candidateMemoryCell = performARBCompetitionForResources(aArbCellPool, aInstance);
// calculate if stop condition has been met
stopCondition = isStoppingCriterion(aArbCellPool, aInstance);
if (!stopCondition) {
LinkedList arbs = new LinkedList();
// 3c. variation (mutated clones)
for (Cell c : aArbCellPool.getCells()) {
arbs.addAll(generateARBVarients(aInstance, c));
}
aArbCellPool.add(arbs);
}
// stats
meanArbPoolSize += aArbCellPool.size();
meanArbRefinementIterations++;
totalArbRefinementIterations++;
}
while (!stopCondition);
return candidateMemoryCell;
}
protected AISModelClassifier getClassifier(Normalize aNormalise) {
MajorityVote classifier = new MajorityVote(
kNN,
aNormalise,
memoryCellPool,
affinityFunction);
return classifier;
}
protected void respondToCandidateMemoryCell(
Cell bestMatchMemoryCell,
Cell candidateMemoryCell,
Instance aInstance) {
// recalculate candidate stimulation
double candidateStimulation = stimulation(candidateMemoryCell, aInstance);
// check if candidate is better
if (candidateStimulation > bestMatchMemoryCell.getStimulation()) {
// add candidate to memory pool
memoryCellPool.add(candidateMemoryCell);
// check previous best can be removed
double affinity = affinityFunction.affinityNormalised(bestMatchMemoryCell, candidateMemoryCell);
if (affinity < getMemoryCellReplacementCutoff()) {
// remove previous best
memoryCellPool.delete(bestMatchMemoryCell);
totalMemoryCellReplacements++;
}
}
}
protected LinkedList generateARBVarients(Instance aInstance, Cell aArb) {
LinkedList newARBs = new LinkedList();
// determine the number of clones to produce
int numClones = arbNumClones(aArb);
// generate clones
for (int i = 0; i < numClones; i++) {
// generate mutated clone
Cell mutatedClone = arbSampleGeneration.generateSample(aArb, aInstance);
// add to arb pool
newARBs.add(mutatedClone);
}
meanClonesArb += numClones;
return newARBs;
}
protected boolean isStoppingCriterion(
CellPool aArbCellPool,
Instance aInstance) {
double meanStimulation = 0.0;
// sum stimulation values
for (Iterator iter = aArbCellPool.iterator(); iter.hasNext(); ) {
Cell c = iter.next();
meanStimulation += c.getStimulation();
}
meanStimulation = (meanStimulation / aArbCellPool.size());
// check if the stopping condition has been met
// that is the mean is >= the stimulation threshold
if (meanStimulation >= stimulationThreshold) {
return true;
}
// safety
if (Double.isNaN(meanStimulation)) {
throw new RuntimeException("Infinite loop condition detected, mean stimulation is NaN.");
}
// condition is not met
return false;
}
protected Cell performARBCompetitionForResources(
CellPool aArbCellPool,
Instance aInstance) {
Cell mostStimulatedSameClass = null;
double numResAllowed = totalResources;
// stimulate arbs, normalise stimulation, order by stimulation
stimulationNormalisation(aArbCellPool, aInstance);
// allocate resources to arbs based on stimulation
double resources = calculateResourceAllocations(aArbCellPool, aInstance);
// continue until the resources for this class is below a threshold
LinkedList cells = aArbCellPool.getCells();
while (resources > numResAllowed) {
double numResourceToRemove = (resources - numResAllowed);
Cell last = cells.getLast();
// check if element can be removed
if (last.getNumResources() <= numResourceToRemove) {
// remove from everywhere
cells.removeLast();
totalArbDeletions++;
resources -= last.getNumResources();
}
else {
// decrement resources
double res = last.getNumResources() - numResourceToRemove;
last.setNumResources(res);
resources -= numResourceToRemove;
}
}
// best ARB will always have the most resources
mostStimulatedSameClass = cells.getFirst();
// stats
meanAllocatedResources += resources;
return mostStimulatedSameClass;
}
protected double calculateResourceAllocations(
CellPool cellPool,
Instance aInstance) {
double resources = 0.0;
for (Iterator iter = cellPool.iterator(); iter.hasNext(); ) {
Cell c = iter.next();
double r = (c.getStimulation() * clonalRate);
c.setNumResources(r);
resources += r;
}
// order by allocated resources
cellPool.orderByResources();
return resources;
}
protected void generateARBs(
CellPool arbCellPool,
Cell aBestMatchMemoryCell,
Instance aInstance) {
// add best match to the arb pool
arbCellPool.add(new Cell(aBestMatchMemoryCell));
// determine the number of clones to produce
int numClones = memoryCellNumClones(aBestMatchMemoryCell);
// generate clones
for (int i = 0; i < numClones; i++) {
// generate mutated clone
Cell mutatedClone = arbSampleGeneration.generateSample(aBestMatchMemoryCell, aInstance);
// add to arb pool
arbCellPool.add(mutatedClone);
}
meanClonesMemCell += numClones;
}
protected Cell identifyMemoryPoolBestMatch(Instance aInstance) {
// get memory pool sorted by stimulation
LinkedList stimulatedSorted = stimulation(memoryCellPool.getCells(), aInstance);
// process list until a member of the same class is located
for (Cell c : stimulatedSorted) {
if (Utils.isSameClass(aInstance, c)) {
return c;
}
}
return null;
}
protected Cell addNewMemoryCell(Instance aInstance) {
// no match, therefore create one
Cell c = new Cell(aInstance);
// add to memory cell pool
memoryCellPool.add(c);
double s = stimulation(c, aInstance);
c.setStimulation(s);
return c;
}
protected void initialise(Instances aTrainingSet) {
ModelInitialisation init = getModelInitialisation();
memoryCellPool = new CellPool(init.generateCellsList(aTrainingSet, memoryCellPoolInitialSize));
}
protected ModelInitialisation getModelInitialisation() {
return new RandomInstancesInitialisation(rand);
}
/**
* The number of clones that an ARB can produce
*
* @param aArb
* @return
*/
protected int arbNumClones(Cell aArb) {
return (int) Math.round(aArb.getStimulation() * clonalRate);
}
/**
* The numberof clones that a memory cell can produce
*
* @param aArb
* @return
*/
protected int memoryCellNumClones(Cell aArb) {
return (int) Math.round(aArb.getStimulation() * clonalRate * hyperMutationRate);
}
protected double getMemoryCellReplacementCutoff() {
return (affinityThreshold * affinityThresholdScalar);
}
protected void stimulationNormalisation(
CellPool cells,
Instance aInstance) {
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
// determine min and max
for (Iterator iter = cells.iterator(); iter.hasNext(); ) {
Cell c = iter.next();
double s = stimulation(c, aInstance);
if (s < min) {
min = s;
}
if (s > max) {
max = s;
}
}
// normalise
double range = (max - min);
if (range == 0) {
throw new RuntimeException("Infinite loop condition detected: range of stimulation values is zero.");
}
for (Iterator iter = cells.iterator(); iter.hasNext(); ) {
Cell c = iter.next();
double normalised = (c.getStimulation() - min) / range;
c.setStimulation(normalised);
// validation
if (normalised < 0 || normalised > 1) {
throw new RuntimeException("Normalised stimulation outside range!");
}
}
}
protected LinkedList stimulation(LinkedList| cells, Instance aInstance) {
// calculate stimulation for all the cells
for (Cell c : cells) {
stimulation(c, aInstance);
}
// order the population by stimulation
Collections.sort(cells, CellPool.stimulationComparator);
return cells;
}
protected double stimulation(Cell aCell, Instance aInstance) {
// calculate normalised affinity [0,1]
double affinity = affinityFunction.affinityNormalised(aInstance, aCell);
// convert to stimulation
double stimulation = 1.0 - affinity;
// store
aCell.setStimulation(stimulation);
// return it in case its needed
return stimulation;
}
}
| | | | | | | | | | | | | |