All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.enterprisemath.math.nn.FFSHLNetwork Maven / Gradle / Ivy

The newest version!
package com.enterprisemath.math.nn;

import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.ToStringBuilder;

/**
 * Feed forward network with single hidden layer.
 *
 * @author radek.hecl
 */
public class FFSHLNetwork implements Network {

    /**
     * Builder object.
     */
    public static class Builder {

        /**
         * Input neurons.
         */
        private List inputs = new ArrayList();

        /**
         * Hidden neurons.
         */
        private List hiddens = new ArrayList();

        /**
         * Output neurons.
         */
        private List outputs = new ArrayList();

        /**
         * Synapses between neurons.
         */
        private List synapses = new ArrayList();

        /**
         * Sets input neurons.
         *
         * @param inputs input neurons
         * @return this instance
         */
        public Builder setInputs(List inputs) {
            this.inputs = DomainUtils.softCopyList(inputs);
            return this;
        }

        /**
         * Adds input neuron.
         *
         * @param neuron input neuron
         * @return this instance
         */
        public Builder addInput(Neuron neuron) {
            inputs.add(neuron);
            return this;
        }

        /**
         * Sets hidden neurons.
         *
         * @param hiddens hidden neurons
         * @return this instance
         */
        public Builder setHiddens(List hiddens) {
            this.hiddens = DomainUtils.softCopyList(hiddens);
            return this;
        }

        /**
         * Adds hidden neuron.
         *
         * @param neuron hidden neuron
         * @return this instance
         */
        public Builder addHidden(Neuron neuron) {
            hiddens.add(neuron);
            return this;
        }

        /**
         * Sets output neurons.
         *
         * @param outputs output neurons
         * @return this instance
         */
        public Builder setOutputs(List outputs) {
            this.outputs = DomainUtils.softCopyList(outputs);
            return this;
        }

        /**
         * Adds output neuron.
         *
         * @param neuron output neuron
         * @return this instance
         */
        public Builder addOutput(Neuron neuron) {
            outputs.add(neuron);
            return this;
        }

        /**
         * Sets synapses.
         *
         * @param synapses synapses
         * @return this instance
         */
        public Builder setSynapses(List synapses) {
            this.synapses = DomainUtils.softCopyList(synapses);
            return this;
        }

        /**
         * Adds synapse.
         *
         * @param synapse synapse
         * @return this instance
         */
        public Builder addSynapse(Synapse synapse) {
            synapses.add(synapse);
            return this;
        }

        /**
         * Builds the result object.
         *
         * @return created object
         */
        public FFSHLNetwork build() {
            return new FFSHLNetwork(this);
        }
    }

    /**
     * Input neurons.
     */
    private List inputs;

    /**
     * Hidden neurons.
     */
    private List hiddens;

    /**
     * Output neurons.
     */
    private List outputs;

    /**
     * Synapses between neurons.
     */
    private List synapses;

    /**
     * Creates new instance.
     *
     * @param builder builder object
     */
    public FFSHLNetwork(Builder builder) {
        inputs = DomainUtils.softCopyUnmodifiableList(builder.inputs);
        hiddens = DomainUtils.softCopyUnmodifiableList(builder.hiddens);
        outputs = DomainUtils.softCopyUnmodifiableList(builder.outputs);
        synapses = DomainUtils.softCopyUnmodifiableList(builder.synapses);
        guardInvariants();
    }

    /**
     * Guards this object to be consistent. Throws exception if this is not the case.
     */
    private void guardInvariants() {
        ValidationUtils.guardNotNullCollection(inputs, "inputs cannot have null element");
        ValidationUtils.guardNotNullCollection(hiddens, "hiddens cannot have null element");
        ValidationUtils.guardNotNullCollection(outputs, "outputs cannot have null element");
        ValidationUtils.guardNotNullCollection(synapses, "synapses cannot have null element");
        Set ids = new HashSet();
        for (Neuron neuron : inputs) {
            if (ids.contains(neuron.getId())) {
                throw new RuntimeException("duplicated neuron id: " + neuron.getId());
            }
            ids.add(neuron.getId());
        }
        for (Neuron neuron : hiddens) {
            if (ids.contains(neuron.getId())) {
                throw new RuntimeException("duplicated neuron id: " + neuron.getId());
            }
            ids.add(neuron.getId());
        }
        for (Neuron neuron : outputs) {
            if (ids.contains(neuron.getId())) {
                throw new RuntimeException("duplicated neuron id: " + neuron.getId());
            }
            ids.add(neuron.getId());
        }
        for (Synapse synapse : synapses) {
            ValidationUtils.guardIn(synapse.getStartId(), ids, "startId doesnt match neuron");
            ValidationUtils.guardIn(synapse.getEndId(), ids, "endId doesnt match neuron");
        }
    }

    @Override
    public Map process(Map input) {
        Map inputValues = new HashMap();
        Map hiddenValues = new HashMap();
        Map outputValues = new HashMap();
        // to input
        for (Neuron neuron : inputs) {
            String id = neuron.getId();
            double val = input.get(id) == null ? 0 : input.get(id);
            inputValues.put(id, neuron.getOutput(val));
        }
        // from input to hidden
        for (Neuron neuron : hiddens) {
            String id = neuron.getId();
            double wsum = 0;
            for (Synapse synapse : synapses) {
                if (!id.equals(synapse.getEndId())) {
                    continue;
                }
                wsum = wsum + synapse.getWeight() * inputValues.get(synapse.getStartId());
            }
            hiddenValues.put(id, neuron.getOutput(wsum));
        }
        // from hidden to out
        for (Neuron neuron : outputs) {
            String id = neuron.getId();
            double wsum = 0;
            for (Synapse synapse : synapses) {
                if (!id.equals(synapse.getEndId())) {
                    continue;
                }
                double val = 0;
                if (hiddenValues.containsKey(synapse.getStartId())) {
                    val = hiddenValues.get(synapse.getStartId());
                }
                else {
                    val = inputValues.get(synapse.getStartId());
                }
                wsum = wsum + synapse.getWeight() * val;
            }
            outputValues.put(id, neuron.getOutput(wsum));
        }
        return outputValues;
    }

    /**
     * Returns input nodes.
     *
     * @return input nodes
     */
    public List getInputs() {
        return inputs;
    }

    /**
     * Returns hidden nodes.
     *
     * @return hidden nodes
     */
    public List getHiddens() {
        return hiddens;
    }

    /**
     * Returns output nodes.
     *
     * @return output nodes
     */
    public List getOutputs() {
        return outputs;
    }

    /**
     * Returns synapses.
     *
     * @return synapses
     */
    public List getSynapses() {
        return synapses;
    }

    @Override
    public int hashCode() {
        return HashCodeBuilder.reflectionHashCode(this);
    }

    @Override
    public boolean equals(Object obj) {
        return EqualsBuilder.reflectionEquals(this, obj);
    }

    @Override
    public String toString() {
        return ToStringBuilder.reflectionToString(this);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy