![JAR search and dependency download from the Maven repository](/logo.png)
com.enterprisemath.math.prediction.NNCaterogyPredictorEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.prediction;
import com.enterprisemath.math.nn.Network;
import com.enterprisemath.math.nn.SupervisedFFSHLNetworkEstimator;
import com.enterprisemath.math.nn.SupervisedTrainingRecord;
import com.enterprisemath.math.statistics.Estimator;
import com.enterprisemath.math.statistics.observation.ListObservationProvider;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Estimator for neural net category predictor.
*
* @author radek.hecl
*/
public class NNCaterogyPredictorEstimator implements Estimator {
/**
* Builder object.
*/
public static class Builder {
/**
* Names of the columns.
*/
private List columns = new ArrayList();
/**
* Sets columns names.
*
* @param columns column names
* @return this instance
*/
public Builder setColumns(List columns) {
this.columns = DomainUtils.softCopyList(columns);
return this;
}
/**
* Ads column name.
*
* @param column column name
* @return this instance
*/
public Builder addColumn(String column) {
columns.add(column);
return this;
}
/**
* Builds the result object.
*
* @return estimator
*/
public NNCaterogyPredictorEstimator build() {
return new NNCaterogyPredictorEstimator(this);
}
}
/**
* Name of the columns.
*/
private List columns;
/**
* Creates new new object.
*
* @param builder builder object
*/
public NNCaterogyPredictorEstimator(Builder builder) {
this.columns = DomainUtils.softCopyUnmodifiableList(builder.columns);
guardInvariants();
}
/**
* Guards this object to be consistent. Throws exception if this is not the case.
*/
private void guardInvariants() {
ValidationUtils.guardNotEmptyStringInCollection(columns, "columns cannot have empty string");
}
@Override
public CategoryPredictor estimate(ObservationProvider data) {
ObservationIterator dataIt = null;
// build the bank for input, outputs and observation transformer
Map> categoryValues = new HashMap>();
Set outputs = new HashSet();
dataIt = data.getIterator();
while (dataIt.isNextAvailable()) {
DataCategoyPair rec = dataIt.getNext();
outputs.add(rec.getCategory());
for (int i = 0; i < rec.getData().size(); ++i) {
Object obs = rec.getData().get(i);
if (obs instanceof String) {
if (!categoryValues.containsKey(columns.get(i))) {
categoryValues.put(columns.get(i), new HashSet());
}
categoryValues.get(columns.get(i)).add((String) obs);
}
}
}
NNCategoryPredictorObesrationTrasformer transformer = new NNCategoryPredictorObesrationTrasformer.Builder().
setColumns(columns).
setCategoryValues(categoryValues).
build();
// transform records
List nnrecords = new ArrayList();
Map> nnOutputBank = createPossibleNNOutputBank(outputs);
dataIt = data.getIterator();
while (dataIt.isNextAvailable()) {
DataCategoyPair rec = dataIt.getNext();
Map input = transformer.transformInput(rec.getData());
Map output = nnOutputBank.get(rec.getCategory());
SupervisedTrainingRecord nnrec = SupervisedTrainingRecord.create(input, output);
nnrecords.add(nnrec);
}
// estimate the network
Estimator estimator = SupervisedFFSHLNetworkEstimator.create();
Network network = estimator.estimate(ListObservationProvider.create(nnrecords));
// build up the result
return new NNCategoryPredictor.Builder().
setNetwork(network).
setObservationTransformer(transformer).
build();
}
/**
* Creates bank of possible outputs from neural network.
*
* @param outputs possible outputs
* @return bank of outputs from NN, key is normal output, value is output returned by nn
*/
private Map> createPossibleNNOutputBank(Set outputs) {
Map> res = new HashMap>();
for (String output : outputs) {
res.put(output, new HashMap());
for (String ops : outputs) {
if (ops.equals(output)) {
res.get(output).put(ops, 1d);
}
else {
res.get(output).put(ops, 0d);
}
}
}
return res;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy