
org.numenta.nupic.network.Region Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of htm.java Show documentation
Show all versions of htm.java Show documentation
The Java version of Numenta's HTM technology
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.network;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.joda.time.DateTime;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.encoders.Encoder;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.network.sensor.Sensor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Observer;
import rx.Subscriber;
/**
*
* Regions are collections of {@link Layer}s, which are in turn collections
* of algorithmic components. Regions can be connected to each other to establish
* a hierarchy of processing. To connect one Region to another, typically one
* would do the following:
*
*
* Parameters p = Parameters.getDefaultParameters(); // May be altered as needed
* Network n = Network.create("Test Network", p);
* Region region1 = n.createRegion("r1"); // would typically add Layers to the Region after this
* Region region2 = n.createRegion("r2");
* region1.connect(region2);
*
* --OR--
*
* n.connect(region1, region2);
*
* --OR--
*
* Network.lookup("r1").connect(Network.lookup("r2"));
*
*
* @author cogmission
*
*/
public class Region implements Persistable {
private static final long serialVersionUID = 1L;
private static final Logger LOGGER = LoggerFactory.getLogger(Region.class);
private Network parentNetwork;
private Region upstreamRegion;
private Region downstreamRegion;
private Layer> tail;
private Layer> head;
private Map> layers = new HashMap<>();
private transient Observable regionObservable;
/** Marker flag to indicate that assembly is finished and Region initialized */
private boolean assemblyClosed;
/** stores the learn setting */
private boolean isLearn = true;
/** Temporary variables used to determine endpoints of observable chain */
private HashSet> sources;
private HashSet> sinks;
/** Stores the overlap of algorithms state for {@link Inference} sharing determination */
byte flagAccumulator = 0;
/**
* Indicates whether algorithms are repeated, if true then no, if false then yes
* (for {@link Inference} sharing determination) see {@link Region#configureConnection(Layer, Layer)}
* and {@link Layer#getMask()}
*/
boolean layersDistinct = true;
private Object input;
private String name;
/**
* Constructs a new {@code Region}
*
* Warning: name cannot be null or empty
*
* @param name A unique identifier for this Region (uniqueness is enforced)
* @param network The containing {@link Network}
*/
public Region(String name, Network network) {
if(name == null || name.isEmpty()) {
throw new IllegalArgumentException("Name may not be null or empty. " +
"...not that anyone here advocates name calling!");
}
this.name = name;
this.parentNetwork = network;
}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public Region preSerialize() {
layers.values().stream().forEach(l -> l.preSerialize());
return this;
}
/**
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public Region postDeSerialize() {
layers.values().stream().forEach(l -> l.postDeSerialize());
// Connect Layer Observable chains (which are transient so we must
// rebuild them and their subscribers)
if(isMultiLayer()) {
Layer curr = (Layer)head;
Layer prev = curr.getPrevious();
do {
connect(curr, prev);
} while((curr = prev) != null && (prev = prev.getPrevious()) != null);
}
return this;
}
/**
* Sets the parent {@link Network} of this {@code Region}
* @param network
*/
public void setNetwork(Network network) {
this.parentNetwork = network;
for(Layer> l : layers.values()) {
l.setNetwork(network);
// Set the sensor & encoder reference for global access.
if(l.hasSensor() && network != null) {
network.setSensor(l.getSensor());
network.setEncoder(l.getSensor().getEncoder());
}else if(network != null && l.getEncoder() != null) {
network.setEncoder(l.getEncoder());
}
}
}
/**
* Returns a flag indicating whether this {@code Region} contain multiple
* {@link Layer}s.
*
* @return true if so, false if not.
*/
public boolean isMultiLayer() {
return layers.size() > 1;
}
/**
* Closes the Region and completes the finalization of its assembly.
* After this call, any attempt to mutate the structure of a Region
* will result in an {@link IllegalStateException} being thrown.
*
* @return
*/
public Region close() {
if(layers.size() < 1) {
LOGGER.warn("Closing region: " + name + " before adding contents.");
return this;
}
completeAssembly();
Layer> l = tail;
do {
l.close();
}while((l = l.getNext()) != null);
return this;
}
/**
* Returns a flag indicating whether this {@code Region} has had
* its {@link #close} method called, or not.
*
* @return
*/
public boolean isClosed() {
return assemblyClosed;
}
/**
* Sets the learning mode.
* @param isLearn
*/
public void setLearn(boolean isLearn) {
this.isLearn = isLearn;
Layer> l = tail;
while(l != null) {
l.setLearn(isLearn);
l = l.getNext();
}
}
/**
* Returns the learning mode setting.
* @return
*/
public boolean isLearn() {
return isLearn;
}
/**
* Used to manually input data into a {@link Region}, the other way
* being the call to {@link Region#start()} for a Region that contains
* a {@link Layer} which in turn contains a {@link Sensor} -OR-
* subscribing a receiving Region to this Region's output Observable.
*
* @param input One of (int[], String[], {@link ManualInput}, or Map<String, Object>)
*/
@SuppressWarnings("unchecked")
public void compute(T input) {
if(!assemblyClosed) {
close();
}
this.input = input;
((Layer)tail).compute(input);
}
/**
* Returns the current input into the region. This value may change
* after every call to {@link Region#compute(Object)}.
*
* @return
*/
public Object getInput() {
return input;
}
/**
* Adds the specified {@link Layer} to this {@code Region}.
* @param l
* @return
* @throws IllegalStateException if Region is already closed
* @throws IllegalArgumentException if a Layer with the same name already exists.
*/
@SuppressWarnings("unchecked")
public Region add(Layer> l) {
if(assemblyClosed) {
throw new IllegalStateException("Cannot add Layers when Region has already been closed.");
}
if(sources == null) {
sources = new HashSet>();
sinks = new HashSet>();
}
// Set the sensor reference for global access.
if(l.hasSensor() && parentNetwork != null) {
parentNetwork.setSensor(l.getSensor());
parentNetwork.setEncoder(l.getSensor().getEncoder());
}
String layerName = name.concat(":").concat(l.getName());
if(layers.containsKey(layerName)) {
throw new IllegalArgumentException("A Layer with the name: " + l.getName() + " has already been added to this Region.");
}
l.name(layerName);
layers.put(l.getName(), (Layer)l);
l.setRegion(this);
l.setNetwork(parentNetwork);
return this;
}
/**
* Returns the String identifier for this {@code Region}
* @return
*/
public String getName() {
return name;
}
/**
* Returns an {@link Observable} which can be used to receive
* {@link Inference} emissions from this {@code Region}
* @return
*/
public Observable observe() {
if(regionObservable == null && !assemblyClosed) {
close();
}
if(head.isHalted() || regionObservable == null) {
regionObservable = head.observe();
}
return regionObservable;
}
/**
* Calls {@link Layer#start()} on this Region's input {@link Layer} if
* that layer contains a {@link Sensor}. If not, this method has no
* effect.
*
* @return flag indicating that thread was started
*/
public boolean start() {
if(!assemblyClosed) {
close();
}
if(tail.hasSensor()) {
LOGGER.info("Starting Region [" + getName() + "] input Layer thread.");
tail.start();
return true;
}else{
LOGGER.warn("Start called on Region [" + getName() + "] with no effect due to no Sensor present.");
}
return false;
}
/**
* Calls {@link Layer#restart(boolean)} on this Region's input {@link Layer} if
* that layer contains a {@link Sensor}. If not, this method has no effect. If
* "startAtIndex" is true, the Network will start at the last saved index as
* obtained from the serialized "recordNum" field; if false then the Network
* will restart from 0.
*
* @param startAtIndex flag indicating whether to start from the previous save
* point or not. If true, this region's Network will start
* at the previously stored index, if false then it will
* start with a recordNum of zero.
* @return flag indicating whether the call to restart had an effect or not.
*/
public boolean restart(boolean startAtIndex) {
if(!assemblyClosed) {
return start();
}
if(tail.hasSensor()) {
LOGGER.info("Re-Starting Region [" + getName() + "] input Layer thread.");
tail.restart(startAtIndex);
return true;
}else{
LOGGER.warn("Re-Start called on Region [" + getName() + "] with no effect due to no Sensor present.");
}
return false;
}
/**
* Returns an {@link rx.Observable} operator that when subscribed to, invokes an operation
* that stores the state of this {@code Network} while keeping the Network up and running.
* The Network will be stored at the pre-configured location (in binary form only, not JSON).
*
* @return the {@link CheckPointOp} operator
*/
CheckPointOp getCheckPointOperator() {
LOGGER.debug("Region [" + getName() + "] CheckPoint called at: " + (new DateTime()));
if(tail != null) {
return tail.getCheckPointOperator();
}else{
close();
return tail.getCheckPointOperator();
}
}
/**
* Stops each {@link Layer} contained within this {@code Region}
*/
public void halt() {
LOGGER.debug("Halt called on Region [" + getName() + "]");
if(tail != null) {
tail.halt();
}else{
close();
tail.halt();
}
LOGGER.debug("Region [" + getName() + "] halted.");
}
/**
* Returns a flag indicating whether this Region has a Layer
* whose Sensor thread is halted.
* @return true if so, false if not
*/
public boolean isHalted() {
if(tail != null) {
return tail.isHalted();
}
return false;
}
/**
* Finds any {@link Layer} containing a {@link TemporalMemory}
* and resets them.
*/
public void reset() {
for(Layer> l : layers.values()) {
if(l.hasTemporalMemory()) {
l.reset();
}
}
}
/**
* Resets the recordNum in all {@link Layer}s.
*/
public void resetRecordNum() {
for(Layer> l : layers.values()) {
l.resetRecordNum();
}
}
/**
* Connects the output of the specified {@code Region} to the
* input of this Region
*
* @param inputRegion the Region who's emissions will be observed by
* this Region.
* @return
*/
Region connect(Region inputRegion) {
inputRegion.observe().subscribe(new Observer() {
ManualInput localInf = new ManualInput();
@Override public void onCompleted() {
tail.notifyComplete();
}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@SuppressWarnings("unchecked")
@Override public void onNext(Inference i) {
localInf.sdr(i.getSDR()).recordNum(i.getRecordNum()).classifierInput(i.getClassifierInput()).layerInput(i.getSDR());
if(i.getSDR().length > 0) {
((Layer)tail).compute(localInf);
}
}
});
// Set the upstream region
this.upstreamRegion = inputRegion;
inputRegion.downstreamRegion = this;
return this;
}
/**
* Returns this {@code Region}'s upstream region,
* if it exists.
*
* @return
*/
public Region getUpstreamRegion() {
return upstreamRegion;
}
/**
* Returns the {@code Region} that receives this Region's
* output.
*
* @return
*/
public Region getDownstreamRegion() {
return downstreamRegion;
}
/**
* Returns the top-most (last in execution order from
* bottom to top) {@link Layer} in this {@code Region}
*
* @return
*/
public Layer> getHead() {
return this.head;
}
/**
* Returns the bottom-most (first in execution order from
* bottom to top) {@link Layer} in this {@code Region}
*
* @return
*/
public Layer> getTail() {
return this.tail;
}
/**
* Connects two layers to each other in a unidirectional fashion
* with "toLayerName" representing the receiver or "sink" and "fromLayerName"
* representing the sender or "source".
*
* This method also forwards shared constructs up the connection chain
* such as any {@link Encoder} which may exist, and the {@link Inference} result
* container which is shared among layers.
*
* @param toLayerName the name of the sink layer
* @param fromLayerName the name of the source layer
* @return
* @throws IllegalStateException if Region is already closed
*/
@SuppressWarnings("unchecked")
public Region connect(String toLayerName, String fromLayerName) {
if(assemblyClosed) {
throw new IllegalStateException("Cannot connect Layers when Region has already been closed.");
}
Layer in = (Layer)lookup(toLayerName);
Layer out = (Layer)lookup(fromLayerName);
if(in == null) {
throw new IllegalArgumentException("Could not lookup (to) Layer with name: " + toLayerName);
}else if(out == null){
throw new IllegalArgumentException("Could not lookup (from) Layer with name: " + fromLayerName);
}
// Set source's pointer to its next Layer --> (sink : going upward).
out.next(in);
// Set the sink's pointer to its previous Layer --> (source : going downward)
in.previous(out);
// Connect out to in
configureConnection(in, out);
connect(in, out);
return this;
}
/**
* Does a straight associative lookup by first creating a composite
* key containing this {@code Region}'s name concatenated with the specified
* {@link Layer}'s name, and returning the result.
*
* @param layerName
* @return
*/
public Layer> lookup(String layerName) {
if(layerName.indexOf(":") != -1) {
return layers.get(layerName);
}
return layers.get(name.concat(":").concat(layerName));
}
/**
* Called by {@link #start()}, {@link #observe()} and {@link #connect(Region)}
* to finalize the internal chain of {@link Layer}s contained by this {@code Region}.
* This method assigns the head and tail Layers and composes the {@link Observable}
* which offers this Region's emissions to any upstream {@link Region}s.
*/
private void completeAssembly() {
if(!assemblyClosed) {
if(layers.size() == 0) return;
if(layers.size() == 1) {
head = tail = layers.values().iterator().next();
}
if(tail == null) {
Set> temp = new HashSet>(sources);
temp.removeAll(sinks);
if(temp.size() != 1) {
throw new IllegalArgumentException("Detected misconfigured Region too many or too few sinks.");
}
tail = temp.iterator().next();
}
if(head == null) {
Set> temp = new HashSet>(sinks);
temp.removeAll(sources);
if(temp.size() != 1) {
throw new IllegalArgumentException("Detected misconfigured Region too many or too few sources.");
}
head = temp.iterator().next();
}
regionObservable = head.observe();
assemblyClosed = true;
}
}
/**
* Called internally to configure the connection between two {@link Layer}
* {@link Observable}s taking care of other connection details such as passing
* the inference up the chain and any possible encoder.
*
* @param in the sink end of the connection between two layers
* @param out the source end of the connection between two layers
* @throws IllegalStateException if Region is already closed
*/
, O extends Layer> void configureConnection(I in, O out) {
if(assemblyClosed) {
throw new IllegalStateException("Cannot add Layers when Region has already been closed.");
}
Set> all = new HashSet<>(sources);
all.addAll(sinks);
byte inMask = in.getMask();
byte outMask = out.getMask();
if(!all.contains(out)) {
layersDistinct = (flagAccumulator & outMask) < 1;
flagAccumulator |= outMask;
}
if(!all.contains(in)) {
layersDistinct = (flagAccumulator & inMask) < 1;
flagAccumulator |= inMask;
}
sources.add(out);
sinks.add(in);
}
/**
* Called internally to actually connect two {@link Layer}
* {@link Observable}s taking care of other connection details such as passing
* the inference up the chain and any possible encoder.
*
* @param in the sink end of the connection between two layers
* @param out the source end of the connection between two layers
* @throws IllegalStateException if Region is already closed
*/
, O extends Layer> void connect(I in, O out) {
out.subscribe(new Subscriber() {
ManualInput localInf = new ManualInput();
@Override public void onCompleted() { in.notifyComplete(); }
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override public void onNext(Inference i) {
if(layersDistinct) {
in.compute(i);
}else{
localInf.sdr(i.getSDR()).recordNum(i.getRecordNum()).layerInput(i.getSDR());
in.compute(localInf);
}
}
});
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + (assemblyClosed ? 1231 : 1237);
result = prime * result + (isLearn ? 1231 : 1237);
result = prime * result + ((layers == null) ? 0 : layers.size());
result = prime * result + ((name == null) ? 0 : name.hashCode());
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
Region other = (Region)obj;
if(assemblyClosed != other.assemblyClosed)
return false;
if(isLearn != other.isLearn)
return false;
if(layers == null) {
if(other.layers != null)
return false;
} else if(!layers.equals(other.layers))
return false;
if(name == null) {
if(other.name != null)
return false;
} else if(!name.equals(other.name))
return false;
return true;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy