com.enterprisemath.math.nn.FFSHLNetwork Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.nn;
import com.enterprisemath.utils.DomainUtils;
import com.enterprisemath.utils.ValidationUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.ToStringBuilder;
/**
* Feed forward network with single hidden layer.
*
* @author radek.hecl
*/
public class FFSHLNetwork implements Network {
/**
* Builder object.
*/
public static class Builder {
/**
* Input neurons.
*/
private List inputs = new ArrayList();
/**
* Hidden neurons.
*/
private List hiddens = new ArrayList();
/**
* Output neurons.
*/
private List outputs = new ArrayList();
/**
* Synapses between neurons.
*/
private List synapses = new ArrayList();
/**
* Sets input neurons.
*
* @param inputs input neurons
* @return this instance
*/
public Builder setInputs(List inputs) {
this.inputs = DomainUtils.softCopyList(inputs);
return this;
}
/**
* Adds input neuron.
*
* @param neuron input neuron
* @return this instance
*/
public Builder addInput(Neuron neuron) {
inputs.add(neuron);
return this;
}
/**
* Sets hidden neurons.
*
* @param hiddens hidden neurons
* @return this instance
*/
public Builder setHiddens(List hiddens) {
this.hiddens = DomainUtils.softCopyList(hiddens);
return this;
}
/**
* Adds hidden neuron.
*
* @param neuron hidden neuron
* @return this instance
*/
public Builder addHidden(Neuron neuron) {
hiddens.add(neuron);
return this;
}
/**
* Sets output neurons.
*
* @param outputs output neurons
* @return this instance
*/
public Builder setOutputs(List outputs) {
this.outputs = DomainUtils.softCopyList(outputs);
return this;
}
/**
* Adds output neuron.
*
* @param neuron output neuron
* @return this instance
*/
public Builder addOutput(Neuron neuron) {
outputs.add(neuron);
return this;
}
/**
* Sets synapses.
*
* @param synapses synapses
* @return this instance
*/
public Builder setSynapses(List synapses) {
this.synapses = DomainUtils.softCopyList(synapses);
return this;
}
/**
* Adds synapse.
*
* @param synapse synapse
* @return this instance
*/
public Builder addSynapse(Synapse synapse) {
synapses.add(synapse);
return this;
}
/**
* Builds the result object.
*
* @return created object
*/
public FFSHLNetwork build() {
return new FFSHLNetwork(this);
}
}
/**
* Input neurons.
*/
private List inputs;
/**
* Hidden neurons.
*/
private List hiddens;
/**
* Output neurons.
*/
private List outputs;
/**
* Synapses between neurons.
*/
private List synapses;
/**
* Creates new instance.
*
* @param builder builder object
*/
public FFSHLNetwork(Builder builder) {
inputs = DomainUtils.softCopyUnmodifiableList(builder.inputs);
hiddens = DomainUtils.softCopyUnmodifiableList(builder.hiddens);
outputs = DomainUtils.softCopyUnmodifiableList(builder.outputs);
synapses = DomainUtils.softCopyUnmodifiableList(builder.synapses);
guardInvariants();
}
/**
* Guards this object to be consistent. Throws exception if this is not the case.
*/
private void guardInvariants() {
ValidationUtils.guardNotNullCollection(inputs, "inputs cannot have null element");
ValidationUtils.guardNotNullCollection(hiddens, "hiddens cannot have null element");
ValidationUtils.guardNotNullCollection(outputs, "outputs cannot have null element");
ValidationUtils.guardNotNullCollection(synapses, "synapses cannot have null element");
Set ids = new HashSet();
for (Neuron neuron : inputs) {
if (ids.contains(neuron.getId())) {
throw new RuntimeException("duplicated neuron id: " + neuron.getId());
}
ids.add(neuron.getId());
}
for (Neuron neuron : hiddens) {
if (ids.contains(neuron.getId())) {
throw new RuntimeException("duplicated neuron id: " + neuron.getId());
}
ids.add(neuron.getId());
}
for (Neuron neuron : outputs) {
if (ids.contains(neuron.getId())) {
throw new RuntimeException("duplicated neuron id: " + neuron.getId());
}
ids.add(neuron.getId());
}
for (Synapse synapse : synapses) {
ValidationUtils.guardIn(synapse.getStartId(), ids, "startId doesnt match neuron");
ValidationUtils.guardIn(synapse.getEndId(), ids, "endId doesnt match neuron");
}
}
@Override
public Map process(Map input) {
Map inputValues = new HashMap();
Map hiddenValues = new HashMap();
Map outputValues = new HashMap();
// to input
for (Neuron neuron : inputs) {
String id = neuron.getId();
double val = input.get(id) == null ? 0 : input.get(id);
inputValues.put(id, neuron.getOutput(val));
}
// from input to hidden
for (Neuron neuron : hiddens) {
String id = neuron.getId();
double wsum = 0;
for (Synapse synapse : synapses) {
if (!id.equals(synapse.getEndId())) {
continue;
}
wsum = wsum + synapse.getWeight() * inputValues.get(synapse.getStartId());
}
hiddenValues.put(id, neuron.getOutput(wsum));
}
// from hidden to out
for (Neuron neuron : outputs) {
String id = neuron.getId();
double wsum = 0;
for (Synapse synapse : synapses) {
if (!id.equals(synapse.getEndId())) {
continue;
}
double val = 0;
if (hiddenValues.containsKey(synapse.getStartId())) {
val = hiddenValues.get(synapse.getStartId());
}
else {
val = inputValues.get(synapse.getStartId());
}
wsum = wsum + synapse.getWeight() * val;
}
outputValues.put(id, neuron.getOutput(wsum));
}
return outputValues;
}
/**
* Returns input nodes.
*
* @return input nodes
*/
public List getInputs() {
return inputs;
}
/**
* Returns hidden nodes.
*
* @return hidden nodes
*/
public List getHiddens() {
return hiddens;
}
/**
* Returns output nodes.
*
* @return output nodes
*/
public List getOutputs() {
return outputs;
}
/**
* Returns synapses.
*
* @return synapses
*/
public List getSynapses() {
return synapses;
}
@Override
public int hashCode() {
return HashCodeBuilder.reflectionHashCode(this);
}
@Override
public boolean equals(Object obj) {
return EqualsBuilder.reflectionEquals(this, obj);
}
@Override
public String toString() {
return ToStringBuilder.reflectionToString(this);
}
}