
org.numenta.nupic.network.Region Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.network;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.joda.time.DateTime;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.network.sensor.Sensor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Observer;
import rx.Subscriber;
/**
*
* Regions are collections of {@link Layer}s, which are in turn collections
* of algorithmic components. Regions can be connected to each other to establish
* a hierarchy of processing. To connect one Region to another, typically one
* would do the following:
*
*
* Parameters p = Parameters.getDefaultParameters(); // May be altered as needed
* Network n = Network.create("Test Network", p);
* Region region1 = n.createRegion("r1"); // would typically add Layers to the Region after this
* Region region2 = n.createRegion("r2");
* region1.connect(region2);
*
* --OR--
*
* n.connect(region1, region2);
*
* --OR--
*
* Network.lookup("r1").connect(Network.lookup("r2"));
*
*
* @author cogmission
*
*/
public class Region implements Persistable {
private static final long serialVersionUID = 1L;
private static final Logger LOGGER = LoggerFactory.getLogger(Region.class);
private Network parentNetwork;
private Region upstreamRegion;
private Region downstreamRegion;
private Layer> tail;
private Layer> head;
private Map> layers = new HashMap<>();
private transient Observable regionObservable;
/** Marker flag to indicate that assembly is finished and Region initialized */
private boolean assemblyClosed;
/** stores the learn setting */
private boolean isLearn = true;
/** Temporary variables used to determine endpoints of observable chain */
private HashSet> sources;
private HashSet> sinks;
/** Stores the overlap of algorithms state for {@link Inference} sharing determination */
byte flagAccumulator = 0;
/**
* Indicates whether algorithms are repeated, if true then no, if false then yes
* (for {@link Inference} sharing determination) see {@link Region#configureConnection(Layer, Layer)}
* and {@link Layer#getMask()}
*/
boolean layersDistinct = true;
private Object input;
private String name;
/**
* Constructs a new {@code Region}
*
* Warning: name cannot be null or empty
*
* @param name A unique identifier for this Region (uniqueness is enforced)
* @param network The containing {@link Network}
*/
public Region(String name, Network network) {
if(name == null || name.isEmpty()) {
throw new IllegalArgumentException("Name may not be null or empty. " +
"...not that anyone here advocates name calling!");
}
this.name = name;
this.parentNetwork = network;
}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public Region preSerialize() {
layers.values().stream().forEach(l -> l.preSerialize());
return this;
}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public Region postDeSerialize() {
layers.values().stream().forEach(l -> l.postDeSerialize());
// Connect Layer Observable chains (which are transient so we must
// rebuild them and their subscribers)
if(isMultiLayer()) {
Layer curr = (Layer)head;
Layer prev = curr.getPrevious();
do {
connect(curr, prev);
} while((curr = prev) != null && (prev = prev.getPrevious()) != null);
}
return this;
}
/**
* Sets the parent {@link Network} of this {@code Region}
* @param network
*/
public void setNetwork(Network network) {
this.parentNetwork = network;
for(Layer> l : layers.values()) {
l.setNetwork(network);
// Set the sensor & encoder reference for global access.
if(l.hasSensor() && network != null) {
network.setSensor(l.getSensor());
network.setEncoder(l.getSensor().getEncoder());
}else if(network != null && l.getEncoder() != null) {
network.setEncoder(l.getEncoder());
}
}
}
/**
* Returns a flag indicating whether this {@code Region} contain multiple
* {@link Layer}s.
*
* @return true if so, false if not.
*/
public boolean isMultiLayer() {
return layers.size() > 1;
}
/**
* Closes the Region and completes the finalization of its assembly.
* After this call, any attempt to mutate the structure of a Region
* will result in an {@link IllegalStateException} being thrown.
*
* @return
*/
public Region close() {
if(layers.size() < 1) {
LOGGER.warn("Closing region: " + name + " before adding contents.");
return this;
}
completeAssembly();
Layer> l = tail;
do {
l.close();
}while((l = l.getNext()) != null);
return this;
}
/**
* Returns a flag indicating whether this {@code Region} has had
* its {@link #close} method called, or not.
*
* @return
*/
public boolean isClosed() {
return assemblyClosed;
}
/**
* Sets the learning mode.
* @param isLearn
*/
public void setLearn(boolean isLearn) {
this.isLearn = isLearn;
Layer> l = tail;
while(l != null) {
l.setLearn(isLearn);
l = l.getNext();
}
}
/**
* Returns the learning mode setting.
* @return
*/
public boolean isLearn() {
return isLearn;
}
/**
* Used to manually input data into a {@link Region}, the other way
* being the call to {@link Region#start()} for a Region that contains
* a {@link Layer} which in turn contains a {@link Sensor} -OR-
* subscribing a receiving Region to this Region's output Observable.
*
* @param input One of (int[], String[], {@link ManualInput}, or Map<String, Object>)
*/
@SuppressWarnings("unchecked")
public void compute(T input) {
if(!assemblyClosed) {
close();
}
this.input = input;
((Layer)tail).compute(input);
}
/**
* Returns the current input into the region. This value may change
* after every call to {@link Region#compute(Object)}.
*
* @return
*/
public Object getInput() {
return input;
}
/**
* Adds the specified {@link Layer} to this {@code Region}.
* @param l
* @return
* @throws IllegalStateException if Region is already closed
* @throws IllegalArgumentException if a Layer with the same name already exists.
*/
@SuppressWarnings("unchecked")
public Region add(Layer> l) {
if(assemblyClosed) {
throw new IllegalStateException("Cannot add Layers when Region has already been closed.");
}
if(sources == null) {
sources = new HashSet>();
sinks = new HashSet>();
}
// Set the sensor reference for global access.
if(l.hasSensor() && parentNetwork != null) {
parentNetwork.setSensor(l.getSensor());
parentNetwork.setEncoder(l.getSensor().getEncoder());
}
String layerName = name.concat(":").concat(l.getName());
if(layers.containsKey(layerName)) {
throw new IllegalArgumentException("A Layer with the name: " + l.getName() + " has already been added to this Region.");
}
l.name(layerName);
layers.put(l.getName(), (Layer)l);
l.setRegion(this);
l.setNetwork(parentNetwork);
return this;
}
/**
* Returns the String identifier for this {@code Region}
* @return
*/
public String getName() {
return name;
}
/**
* Returns an {@link Observable} which can be used to receive
* {@link Inference} emissions from this {@code Region}
* @return
*/
public Observable observe() {
if(regionObservable == null && !assemblyClosed) {
close();
}
if(head.isHalted() || regionObservable == null) {
regionObservable = head.observe();
}
return regionObservable;
}
/**
* Calls {@link Layer#start()} on this Region's input {@link Layer} if
* that layer contains a {@link Sensor}. If not, this method has no
* effect.
*
* @return flag indicating that thread was started
*/
public boolean start() {
if(!assemblyClosed) {
close();
}
if(tail.hasSensor()) {
LOGGER.info("Starting Region [" + getName() + "] input Layer thread.");
tail.start();
return true;
}else{
LOGGER.warn("Start called on Region [" + getName() + "] with no effect due to no Sensor present.");
}
return false;
}
/**
* Calls {@link Layer#restart(boolean)} on this Region's input {@link Layer} if
* that layer contains a {@link Sensor}. If not, this method has no effect. If
* "startAtIndex" is true, the Network will start at the last saved index as
* obtained from the serialized "recordNum" field; if false then the Network
* will restart from 0.
*
* @param startAtIndex flag indicating whether to start from the previous save
* point or not. If true, this region's Network will start
* at the previously stored index, if false then it will
* start with a recordNum of zero.
* @return flag indicating whether the call to restart had an effect or not.
*/
public boolean restart(boolean startAtIndex) {
if(!assemblyClosed) {
return start();
}
if(tail.hasSensor()) {
LOGGER.info("Re-Starting Region [" + getName() + "] input Layer thread.");
tail.restart(startAtIndex);
return true;
}else{
LOGGER.warn("Re-Start called on Region [" + getName() + "] with no effect due to no Sensor present.");
}
return false;
}
/**
* Returns an {@link rx.Observable} operator that when subscribed to, invokes an operation
* that stores the state of this {@code Network} while keeping the Network up and running.
* The Network will be stored at the pre-configured location (in binary form only, not JSON).
*
* @return the {@link CheckPointOp} operator
*/
CheckPointOp getCheckPointOperator() {
LOGGER.debug("Region [" + getName() + "] CheckPoint called at: " + (new DateTime()));
if(tail != null) {
return tail.getCheckPointOperator();
}else{
close();
return tail.getCheckPointOperator();
}
}
/**
* Stops each {@link Layer} contained within this {@code Region}
*/
public void halt() {
LOGGER.debug("Halt called on Region [" + getName() + "]");
if(tail != null) {
tail.halt();
}else{
close();
tail.halt();
}
LOGGER.debug("Region [" + getName() + "] halted.");
}
/**
* Returns a flag indicating whether this Region has a Layer
* whose Sensor thread is halted.
* @return true if so, false if not
*/
public boolean isHalted() {
if(tail != null) {
return tail.isHalted();
}
return false;
}
/**
* Finds any {@link Layer} containing a {@link TemporalMemory}
* and resets them.
*/
public void reset() {
for(Layer> l : layers.values()) {
if(l.hasTemporalMemory()) {
l.reset();
}
}
}
/**
* Resets the recordNum in all {@link Layer}s.
*/
public void resetRecordNum() {
for(Layer> l : layers.values()) {
l.resetRecordNum();
}
}
/**
* Connects the output of the specified {@code Region} to the
* input of this Region
*
* @param inputRegion the Region who's emissions will be observed by
* this Region.
* @return
*/
Region connect(Region inputRegion) {
inputRegion.observe().subscribe(new Observer() {
ManualInput localInf = new ManualInput();
@Override public void onCompleted() {
tail.notifyComplete();
}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@SuppressWarnings("unchecked")
@Override public void onNext(Inference i) {
localInf.sdr(i.getSDR()).recordNum(i.getRecordNum()).classifierInput(i.getClassifierInput()).layerInput(i.getSDR());
if(i.getSDR().length > 0) {
((Layer)tail).compute(localInf);
}
}
});
// Set the upstream region
this.upstreamRegion = inputRegion;
inputRegion.downstreamRegion = this;
return this;
}
/**
* Returns this {@code Region}'s upstream region,
* if it exists.
*
* @return
*/
public Region getUpstreamRegion() {
return upstreamRegion;
}
/**
* Returns the {@code Region} that receives this Region's
* output.
*
* @return
*/
public Region getDownstreamRegion() {
return downstreamRegion;
}
/**
* Returns the top-most (last in execution order from
* bottom to top) {@link Layer} in this {@code Region}
*
* @return
*/
public Layer> getHead() {
return this.head;
}
/**
* Returns the bottom-most (first in execution order from
* bottom to top) {@link Layer} in this {@code Region}
*
* @return
*/
public Layer> getTail() {
return this.tail;
}
/**
* Connects two layers to each other in a unidirectional fashion
* with "toLayerName" representing the receiver or "sink" and "fromLayerName"
* representing the sender or "source".
*
* This method also forwards shared constructs up the connection chain
* such as any {@link Encoder} which may exist, and the {@link Inference} result
* container which is shared among layers.
*
* @param toLayerName the name of the sink layer
* @param fromLayerName the name of the source layer
* @return
* @throws IllegalStateException if Region is already closed
*/
@SuppressWarnings("unchecked")
public Region connect(String toLayerName, String fromLayerName) {
if(assemblyClosed) {
throw new IllegalStateException("Cannot connect Layers when Region has already been closed.");
}
Layer in = (Layer)lookup(toLayerName);
Layer out = (Layer)lookup(fromLayerName);
if(in == null) {
throw new IllegalArgumentException("Could not lookup (to) Layer with name: " + toLayerName);
}else if(out == null){
throw new IllegalArgumentException("Could not lookup (from) Layer with name: " + fromLayerName);
}
// Set source's pointer to its next Layer --> (sink : going upward).
out.next(in);
// Set the sink's pointer to its previous Layer --> (source : going downward)
in.previous(out);
// Connect out to in
configureConnection(in, out);
connect(in, out);
return this;
}
/**
* Does a straight associative lookup by first creating a composite
* key containing this {@code Region}'s name concatenated with the specified
* {@link Layer}'s name, and returning the result.
*
* @param layerName
* @return
*/
public Layer> lookup(String layerName) {
if(layerName.indexOf(":") != -1) {
return layers.get(layerName);
}
return layers.get(name.concat(":").concat(layerName));
}
/**
* Called by {@link #start()}, {@link #observe()} and {@link #connect(Region)}
* to finalize the internal chain of {@link Layer}s contained by this {@code Region}.
* This method assigns the head and tail Layers and composes the {@link Observable}
* which offers this Region's emissions to any upstream {@link Region}s.
*/
private void completeAssembly() {
if(!assemblyClosed) {
if(layers.size() == 0) return;
if(layers.size() == 1) {
head = tail = layers.values().iterator().next();
}
if(tail == null) {
Set> temp = new HashSet>(sources);
temp.removeAll(sinks);
if(temp.size() != 1) {
throw new IllegalArgumentException("Detected misconfigured Region too many or too few sinks.");
}
tail = temp.iterator().next();
}
if(head == null) {
Set> temp = new HashSet>(sinks);
temp.removeAll(sources);
if(temp.size() != 1) {
throw new IllegalArgumentException("Detected misconfigured Region too many or too few sources.");
}
head = temp.iterator().next();
}
regionObservable = head.observe();
assemblyClosed = true;
}
}
/**
* Called internally to configure the connection between two {@link Layer}
* {@link Observable}s taking care of other connection details such as passing
* the inference up the chain and any possible encoder.
*
* @param in the sink end of the connection between two layers
* @param out the source end of the connection between two layers
* @throws IllegalStateException if Region is already closed
*/
, O extends Layer> void configureConnection(I in, O out) {
if(assemblyClosed) {
throw new IllegalStateException("Cannot add Layers when Region has already been closed.");
}
Set> all = new HashSet<>(sources);
all.addAll(sinks);
byte inMask = in.getMask();
byte outMask = out.getMask();
if(!all.contains(out)) {
layersDistinct = (flagAccumulator & outMask) < 1;
flagAccumulator |= outMask;
}
if(!all.contains(in)) {
layersDistinct = (flagAccumulator & inMask) < 1;
flagAccumulator |= inMask;
}
sources.add(out);
sinks.add(in);
}
/**
* Called internally to actually connect two {@link Layer}
* {@link Observable}s taking care of other connection details such as passing
* the inference up the chain and any possible encoder.
*
* @param in the sink end of the connection between two layers
* @param out the source end of the connection between two layers
* @throws IllegalStateException if Region is already closed
*/
, O extends Layer> void connect(I in, O out) {
out.subscribe(new Subscriber() {
ManualInput localInf = new ManualInput();
@Override public void onCompleted() { in.notifyComplete(); }
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override public void onNext(Inference i) {
if(layersDistinct) {
in.compute(i);
}else{
localInf.sdr(i.getSDR()).recordNum(i.getRecordNum()).layerInput(i.getSDR());
in.compute(localInf);
}
}
});
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + (assemblyClosed ? 1231 : 1237);
result = prime * result + (isLearn ? 1231 : 1237);
result = prime * result + ((layers == null) ? 0 : layers.size());
result = prime * result + ((name == null) ? 0 : name.hashCode());
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
Region other = (Region)obj;
if(assemblyClosed != other.assemblyClosed)
return false;
if(isLearn != other.isLearn)
return false;
if(layers == null) {
if(other.layers != null)
return false;
} else if(!layers.equals(other.layers))
return false;
if(name == null) {
if(other.name != null)
return false;
} else if(!name.equals(other.name))
return false;
return true;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy