All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.nn.SupervisedFFSHLNetworkEstimator Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.nn;

import com.enterprisemath.math.statistics.EmptyEstimatorStepListener;
import com.enterprisemath.math.statistics.Estimator;
import com.enterprisemath.math.statistics.EstimatorStepListener;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Supervised estimator for feed forward single hidden layer network.
 *
 * @author radek.hecl
 */
public class SupervisedFFSHLNetworkEstimator implements Estimator {

    /**
     * Builder object.
     */
    public static class Builder {

        /**
         * Step listener.
         */
        private EstimatorStepListener stepListener = EmptyEstimatorStepListener.create();

        /**
         * Sets step listener.
         *
         * @param stepListener step listener
         * @return this instance
         */
        public Builder setStepListener(EstimatorStepListener stepListener) {
            this.stepListener = stepListener;
            return this;
        }

        /**
         * Builds the result object.
         *
         * @return created object
         */
        public SupervisedFFSHLNetworkEstimator build() {
            return new SupervisedFFSHLNetworkEstimator(this);
        }
    }

    /**
     * Step listener.
     */
    private EstimatorStepListener stepListener;

    /**
     * Creates new instance.
     *
     * @param builder builder object
     */
    public SupervisedFFSHLNetworkEstimator(Builder builder) {
        stepListener = builder.stepListener;
        guardInvariants();
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        ValidationUtils.guardNotNull(stepListener, "stepListener cannot be null");
    }

    @Override
    public Network estimate(ObservationProvider data) {
        List records = new ArrayList();
        Set inputs = new HashSet();
        Set outputs = new HashSet();
        ObservationIterator iterator = data.getIterator();
        while (iterator.isNextAvailable()) {
            SupervisedTrainingRecord rec = iterator.getNext();
            inputs.addAll(rec.getInputs().keySet());
            outputs.addAll(rec.getOutputs().keySet());
            for (Double val : rec.getOutputs().values()) {
                ValidationUtils.guardNotNegativeDouble(val, "output values must be in interval [0, 1]");
                ValidationUtils.guardGreaterOrEqualDouble(1d, val, "output values must be in interval [0, 1]");
            }
            records.add(rec);
        }
        List inputList = DomainUtils.softCopyList(inputs);
        List outputList = DomainUtils.softCopyList(outputs);
        List numHiddenNodesOptions = Arrays.asList(inputs.size() + outputs.size(),
                (inputs.size() + outputs.size()) * 2, (inputs.size() + outputs.size()) * 3, (inputs.size() + outputs.size()) * 4,
                (inputs.size() + outputs.size()) * 5);
        Network res = null;
        double maxErr = Math.min(0.1, outputs.size() * 0.01);
        double bestErr = Double.POSITIVE_INFINITY;
        for (int numHiddenNodes : numHiddenNodesOptions) {
            Network network = BackpropagationUtils.train(inputList, numHiddenNodes, outputList, records, 100000, 0.001, maxErr);
            double err = 0;
            for (SupervisedTrainingRecord rec : records) {
                Map networkRes = network.process(rec.getInputs());
                for (String key : networkRes.keySet()) {
                    double absErr = networkRes.get(key) - rec.getOutputs().get(key);
                    err += absErr * absErr;
                }
            }
            if (err < bestErr) {
                bestErr = err;
                res = network;
                stepListener.stepDone(res);
            }
        }
        stepListener.stepDone(res);
        return res;
    }

    /**
     * Creates new instance.
     *
     * @return created instance
     */
    public static SupervisedFFSHLNetworkEstimator create() {
        return new SupervisedFFSHLNetworkEstimator.Builder().
                build();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy