com.enterprisemath.math.nn.SupervisedFFSHLNetworkEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.nn;
import com.enterprisemath.math.statistics.EmptyEstimatorStepListener;
import com.enterprisemath.math.statistics.Estimator;
import com.enterprisemath.math.statistics.EstimatorStepListener;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Supervised estimator for feed forward single hidden layer network.
*
* @author radek.hecl
*/
public class SupervisedFFSHLNetworkEstimator implements Estimator {
/**
* Builder object.
*/
public static class Builder {
/**
* Step listener.
*/
private EstimatorStepListener stepListener = EmptyEstimatorStepListener.create();
/**
* Sets step listener.
*
* @param stepListener step listener
* @return this instance
*/
public Builder setStepListener(EstimatorStepListener stepListener) {
this.stepListener = stepListener;
return this;
}
/**
* Builds the result object.
*
* @return created object
*/
public SupervisedFFSHLNetworkEstimator build() {
return new SupervisedFFSHLNetworkEstimator(this);
}
}
/**
* Step listener.
*/
private EstimatorStepListener stepListener;
/**
* Creates new instance.
*
* @param builder builder object
*/
public SupervisedFFSHLNetworkEstimator(Builder builder) {
stepListener = builder.stepListener;
guardInvariants();
}
/**
* Guards this object to be consistent. Throws exception if this is not the case.
*/
private void guardInvariants() {
ValidationUtils.guardNotNull(stepListener, "stepListener cannot be null");
}
@Override
public Network estimate(ObservationProvider data) {
List records = new ArrayList();
Set inputs = new HashSet();
Set outputs = new HashSet();
ObservationIterator iterator = data.getIterator();
while (iterator.isNextAvailable()) {
SupervisedTrainingRecord rec = iterator.getNext();
inputs.addAll(rec.getInputs().keySet());
outputs.addAll(rec.getOutputs().keySet());
for (Double val : rec.getOutputs().values()) {
ValidationUtils.guardNotNegativeDouble(val, "output values must be in interval [0, 1]");
ValidationUtils.guardGreaterOrEqualDouble(1d, val, "output values must be in interval [0, 1]");
}
records.add(rec);
}
List inputList = DomainUtils.softCopyList(inputs);
List outputList = DomainUtils.softCopyList(outputs);
List numHiddenNodesOptions = Arrays.asList(inputs.size() + outputs.size(),
(inputs.size() + outputs.size()) * 2, (inputs.size() + outputs.size()) * 3, (inputs.size() + outputs.size()) * 4,
(inputs.size() + outputs.size()) * 5);
Network res = null;
double maxErr = Math.min(0.1, outputs.size() * 0.01);
double bestErr = Double.POSITIVE_INFINITY;
for (int numHiddenNodes : numHiddenNodesOptions) {
Network network = BackpropagationUtils.train(inputList, numHiddenNodes, outputList, records, 100000, 0.001, maxErr);
double err = 0;
for (SupervisedTrainingRecord rec : records) {
Map networkRes = network.process(rec.getInputs());
for (String key : networkRes.keySet()) {
double absErr = networkRes.get(key) - rec.getOutputs().get(key);
err += absErr * absErr;
}
}
if (err < bestErr) {
bestErr = err;
res = network;
stepListener.stepDone(res);
}
}
stepListener.stepDone(res);
return res;
}
/**
* Creates new instance.
*
* @return created instance
*/
public static SupervisedFFSHLNetworkEstimator create() {
return new SupervisedFFSHLNetworkEstimator.Builder().
build();
}
}