com.github.wihoho.Trainer Maven / Gradle / Ivy
package com.github.wihoho;
import com.github.wihoho.constant.FeatureType;
import com.github.wihoho.jama.Matrix;
import com.github.wihoho.training.*;
import com.google.common.base.Preconditions;
import lombok.experimental.Builder;
import java.util.ArrayList;
import java.util.Objects;
@Builder
public class Trainer {
Metric metric;
FeatureType featureType;
FeatureExtraction featureExtraction;
int numberOfComponents;
int k; // k specifies the number of neighbour to consider
ArrayList trainingSet;
ArrayList trainingLabels;
ArrayList model;
public void add(Matrix matrix, String label) {
if (Objects.isNull(trainingSet)) {
trainingSet = new ArrayList<>();
trainingLabels = new ArrayList<>();
}
trainingSet.add(matrix);
trainingLabels.add(label);
}
public void train() throws Exception {
Preconditions.checkNotNull(metric);
Preconditions.checkNotNull(featureType);
Preconditions.checkNotNull(numberOfComponents);
Preconditions.checkNotNull(trainingSet);
Preconditions.checkNotNull(trainingLabels);
switch (featureType) {
case PCA:
featureExtraction = new PCA(trainingSet, trainingLabels, numberOfComponents);
break;
case LDA:
featureExtraction = new LDA(trainingSet, trainingLabels, numberOfComponents);
break;
case LPP:
featureExtraction = new LPP(trainingSet, trainingLabels, numberOfComponents);
break;
}
model = featureExtraction.getProjectedTrainingSet();
}
public String recognize(Matrix matrix) {
Matrix testCase = featureExtraction.getW().transpose().times(matrix.minus(featureExtraction.getMeanMatrix()));
String result = KNN.assignLabel(model.toArray(new ProjectedTrainingMatrix[0]), testCase, k, metric);
return result;
}
}