All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.prediction.NNCaterogyPredictorEstimator Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.prediction;

import com.enterprisemath.math.nn.Network;
import com.enterprisemath.math.nn.SupervisedFFSHLNetworkEstimator;
import com.enterprisemath.math.nn.SupervisedTrainingRecord;
import com.enterprisemath.math.statistics.Estimator;
import com.enterprisemath.math.statistics.observation.ListObservationProvider;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Estimator for neural net category predictor.
 *
 * @author radek.hecl
 */
public class NNCaterogyPredictorEstimator implements Estimator {

    /**
     * Builder object.
     */
    public static class Builder {

        /**
         * Names of the columns.
         */
        private List columns = new ArrayList();

        /**
         * Sets columns names.
         *
         * @param columns column names
         * @return this instance
         */
        public Builder setColumns(List columns) {
            this.columns = DomainUtils.softCopyList(columns);
            return this;
        }

        /**
         * Ads column name.
         *
         * @param column column name
         * @return this instance
         */
        public Builder addColumn(String column) {
            columns.add(column);
            return this;
        }

        /**
         * Builds the result object.
         *
         * @return estimator
         */
        public NNCaterogyPredictorEstimator build() {
            return new NNCaterogyPredictorEstimator(this);
        }
    }

    /**
     * Name of the columns.
     */
    private List columns;

    /**
     * Creates new new object.
     *
     * @param builder builder object
     */
    public NNCaterogyPredictorEstimator(Builder builder) {
        this.columns = DomainUtils.softCopyUnmodifiableList(builder.columns);
        guardInvariants();
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        ValidationUtils.guardNotEmptyStringInCollection(columns, "columns cannot have empty string");
    }

    @Override
    public CategoryPredictor estimate(ObservationProvider data) {
        ObservationIterator dataIt = null;

        // build the bank for input, outputs and observation transformer
        Map> categoryValues = new HashMap>();
        Set outputs = new HashSet();
        dataIt = data.getIterator();
        while (dataIt.isNextAvailable()) {
            DataCategoyPair rec = dataIt.getNext();
            outputs.add(rec.getCategory());
            for (int i = 0; i < rec.getData().size(); ++i) {
                Object obs = rec.getData().get(i);
                if (obs instanceof String) {
                    if (!categoryValues.containsKey(columns.get(i))) {
                        categoryValues.put(columns.get(i), new HashSet());
                    }
                    categoryValues.get(columns.get(i)).add((String) obs);
                }
            }
        }
        NNCategoryPredictorObesrationTrasformer transformer = new NNCategoryPredictorObesrationTrasformer.Builder().
                setColumns(columns).
                setCategoryValues(categoryValues).
                build();

        // transform records
        List nnrecords = new ArrayList();
        Map> nnOutputBank = createPossibleNNOutputBank(outputs);
        dataIt = data.getIterator();
        while (dataIt.isNextAvailable()) {
            DataCategoyPair rec = dataIt.getNext();
            Map input = transformer.transformInput(rec.getData());
            Map output = nnOutputBank.get(rec.getCategory());
            SupervisedTrainingRecord nnrec = SupervisedTrainingRecord.create(input, output);
            nnrecords.add(nnrec);
        }

        // estimate the network
        Estimator estimator = SupervisedFFSHLNetworkEstimator.create();
        Network network = estimator.estimate(ListObservationProvider.create(nnrecords));

        // build up the result
        return new NNCategoryPredictor.Builder().
                setNetwork(network).
                setObservationTransformer(transformer).
                build();
    }

    /**
     * Creates bank of possible outputs from neural network.
     *
     * @param outputs possible outputs
     * @return bank of outputs from NN, key is normal output, value is output returned by nn
     */
    private Map> createPossibleNNOutputBank(Set outputs) {
        Map> res = new HashMap>();
        for (String output : outputs) {
            res.put(output, new HashMap());
            for (String ops : outputs) {
                if (ops.equals(output)) {
                    res.get(output).put(ops, 1d);
                }
                else {
                    res.get(output).put(ops, 0d);
                }
            }
        }
        return res;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy