com.github.wihoho.training.LDA Maven / Gradle / Ivy
package com.github.wihoho.training;
import com.github.wihoho.jama.EigenvalueDecomposition;
import com.github.wihoho.jama.Matrix;
import java.util.*;
public class LDA extends FeatureExtraction {
public LDA(ArrayList trainingSet, ArrayList labels,
int numOfComponents) throws Exception {
int n = trainingSet.size(); // sample size
Set tempSet = new HashSet(labels);
int c = tempSet.size(); // class size
assert numOfComponents >= n - c : "the input components is smaller than n - c!";
assert n >= 2 * c : "n is smaller than 2c!";
// process in PCA
PCA pca = new PCA(trainingSet, labels, n - c);
// classify
Matrix meanTotal = new Matrix(n - c, 1);
HashMap> map = new HashMap>();
ArrayList pcaTrain = pca
.getProjectedTrainingSet();
for (int i = 0; i < pcaTrain.size(); i++) {
String key = pcaTrain.get(i).label;
meanTotal.plusEquals(pcaTrain.get(i).matrix);
if (!map.containsKey(key)) {
ArrayList temp = new ArrayList();
temp.add(pcaTrain.get(i).matrix);
map.put(key, temp);
} else {
ArrayList temp = map.get(key);
temp.add(pcaTrain.get(i).matrix);
map.put(key, temp);
}
}
meanTotal.times((double) 1 / n);
// calculate Sw, Sb
Matrix Sw = new Matrix(n - c, n - c);
Matrix Sb = new Matrix(n - c, n - c);
tempSet = map.keySet();
Iterator it = tempSet.iterator();
while (it.hasNext()) {
String s = (String) it.next();
ArrayList matrixWithinThatClass = map.get(s);
Matrix meanOfCurrentClass = getMean(matrixWithinThatClass);
for (int i = 0; i < matrixWithinThatClass.size(); i++) {
Matrix temp1 = matrixWithinThatClass.get(i).minus(
meanOfCurrentClass);
temp1 = temp1.times(temp1.transpose());
Sw.plusEquals(temp1);
}
Matrix temp = meanOfCurrentClass.minus(meanTotal);
temp = temp.times(temp.transpose()).times(
matrixWithinThatClass.size());
Sb.plusEquals(temp);
}
// calculate the eigenvalues and vectors of Sw^-1 * Sb
Matrix targetForEigen = Sw.inverse().times(Sb);
EigenvalueDecomposition feature = targetForEigen.eig();
double[] d = feature.getd();
assert d.length >= c - 1 : "Ensure that the number of eigenvalues is larger than c - 1";
int[] indexes = getIndexesOfKEigenvalues(d, c - 1);
Matrix eigenVectors = feature.getV();
Matrix selectedEigenVectors = eigenVectors.getMatrix(0,
eigenVectors.getRowDimension() - 1, indexes);
this.W = pca.getW().times(selectedEigenVectors);
// Construct projectedTrainingMatrix
this.projectedTrainingSet = new ArrayList();
for (int i = 0; i < trainingSet.size(); i++) {
ProjectedTrainingMatrix ptm = new ProjectedTrainingMatrix(this.W
.transpose()
.times(trainingSet.get(i).minus(pca.meanMatrix)),
labels.get(i));
this.projectedTrainingSet.add(ptm);
}
this.meanMatrix = pca.meanMatrix;
}
private class mix implements Comparable {
int index;
double value;
mix(int i, double v) {
index = i;
value = v;
}
public int compareTo(Object o) {
double target = ((mix) o).value;
if (value > target)
return -1;
else if (value < target)
return 1;
return 0;
}
}
private int[] getIndexesOfKEigenvalues(double[] d, int k) {
mix[] mixes = new mix[d.length];
int i;
for (i = 0; i < d.length; i++)
mixes[i] = new mix(i, d[i]);
Arrays.sort(mixes);
int[] result = new int[k];
for (i = 0; i < k; i++)
result[i] = mixes[i].index;
return result;
}
static Matrix getMean(ArrayList m) {
int num = m.size();
int row = m.get(0).getRowDimension();
int column = m.get(0).getColumnDimension();
assert column == 1 : "expected column does not equal to 1!";
Matrix mean = new Matrix(row, column);
for (int i = 0; i < num; i++) {
mean.plusEquals(m.get(i));
}
mean = mean.times((double) 1 / num);
return mean;
}
@Override
public Matrix getW() {
return this.W;
}
@Override
public ArrayList getProjectedTrainingSet() {
return this.projectedTrainingSet;
}
@Override
public Matrix getMeanMatrix() {
// TODO Auto-generated method stub
return meanMatrix;
}
}