com.github.chen0040.mlp.ann.classifiers.MLPClassifier Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of java-ann-mlp Show documentation
Show all versions of java-ann-mlp Show documentation
Multi-Layer Perceptron with BP learning implemented in Java
package com.github.chen0040.mlp.ann.classifiers;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import lombok.Getter;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/**
* Created by xschen on 21/8/15.
*/
public class MLPClassifier implements Cloneable {
private MLPWithLabelOutput mlp;
private final Logger logger = LoggerFactory.getLogger(MLPClassifier.class);
public static final String HIDDEN_LAYER1 = "hiddenLayer1";
public static final String HIDDEN_LAYER2 = "hiddenLayer2";
public static final String HIDDEN_LAYER3 = "hiddenLayer3";
public static final String HIDDEN_LAYER4 = "hiddenLayer4";
public static final String HIDDEN_LAYER5 = "hiddenLayer5";
public static final String HIDDEN_LAYER6 = "hiddenLayer6";
public static final String HIDDEN_LAYER7 = "hiddenLayer7";
private List classLabels = new ArrayList<>();
@Getter
@Setter
private int epoches = 1000;
@Getter
@Setter
private double learningRate = 0.2;
private Map hiddenLayer = new HashMap<>();
public void copy(MLPClassifier rhs2) throws CloneNotSupportedException {
mlp = rhs2.mlp == null ? null : (MLPWithLabelOutput)rhs2.mlp.clone();
if(mlp != null){
mlp.classLabelsModel = this::getClassLabels;
}
}
public List getClassLabels(){
return classLabels;
}
@Override
public Object clone() throws CloneNotSupportedException {
MLPClassifier clone = (MLPClassifier)super.clone();
clone.copy(this);
return clone;
}
public MLPClassifier(){
setHiddenLayers(6);
}
public List getHiddenLayers() {
return parseHiddenLayers();
}
private String hiddenLayerName(int i){
String hiddenLayerName = HIDDEN_LAYER7;
switch(i){
case 0:
hiddenLayerName = HIDDEN_LAYER1;
break;
case 1:
hiddenLayerName = HIDDEN_LAYER2;
break;
case 2:
hiddenLayerName = HIDDEN_LAYER3;
break;
case 3:
hiddenLayerName = HIDDEN_LAYER4;
break;
case 4:
hiddenLayerName = HIDDEN_LAYER5;
break;
case 5:
hiddenLayerName = HIDDEN_LAYER6;
break;
case 6:
hiddenLayerName = HIDDEN_LAYER7;
break;
}
return hiddenLayerName;
}
public void setHiddenLayers(int... hiddenLayers) {
for(int i = 0; i < hiddenLayers.length; ++i){
hiddenLayer.put(hiddenLayerName(i), hiddenLayers[i]);
}
}
public String classify(DataRow tuple) {
double[] target = mlp.transform(tuple.toArray());
int selected_index = -1;
double maxValue = Double.NEGATIVE_INFINITY;
for(int i=0; i < target.length; ++i){
double value = target[i];
if(value > maxValue){
maxValue = value;
selected_index = i;
}
}
if(selected_index==-1){
logger.error("transform failed due to label not found");
}
return getClassLabels().get(selected_index);
}
private void scan4ClassLabels(DataFrame batch){
int m = batch.rowCount();
Set set = new HashSet<>();
for(int i=0; i < m; ++i){
DataRow tuple = batch.row(i);
if(!tuple.getCategoricalTargetColumnNames().isEmpty()) {
set.add(tuple.categoricalTarget());
}
}
List labels = new ArrayList<>();
for(String label : set){
labels.add(label);
}
classLabels.clear();
classLabels.addAll(labels);
}
private List parseHiddenLayers(){
List hiddenLayers = new ArrayList<>();
for(int i=0; i < 7; ++i){
int neuronCount = getAttribute(hiddenLayerName(i));
if(neuronCount > 0){
hiddenLayers.add(neuronCount);
}
}
return hiddenLayers;
}
private int getAttribute(String layerName) {
return hiddenLayer.getOrDefault(layerName, 0);
}
public void fit(DataFrame batch) {
if (getClassLabels().isEmpty()) {
scan4ClassLabels(batch);
}
logger.info("class labels: {}", classLabels.size());
mlp = new MLPWithLabelOutput();
mlp.classLabelsModel = () -> getClassLabels();
int dimension = batch.row(0).toArray().length;
List hiddenLayers = parseHiddenLayers();
mlp.setLearningRate(learningRate);
mlp.createInputLayer(dimension);
for (int hiddenLayerNeuronCount : hiddenLayers){
mlp.addHiddenLayer(hiddenLayerNeuronCount);
}
mlp.createOutputLayer(getClassLabels().size());
mlp.train(batch, epoches);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy