
org.numenta.nupic.network.Network Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.network;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.joda.time.DateTime;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.CLAClassifier;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.encoders.MultiEncoder;
import org.numenta.nupic.model.Connections;
import org.numenta.nupic.model.Persistable;
import org.numenta.nupic.network.sensor.HTMSensor;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rx.Observable;
import rx.Subscriber;
/**
*
* A {@code Network} is the fundamental component of the HTM.java Network API.
* It is comprised of {@link Region}s which are in turn comprised of {@link Layer}s;
* each Layer directly containing one or more algorithm or computational components
* such (i.e. {@link Sensor}, {@link MultiEncoder}, {@link SpatialPooler},
* {@link TemporalMemory}, {@link CLAClassifier} etc.)
*
*
* Networks in HTM.java are extremely easy to compose. For instance, here is an example
* of a network which contains everything:
*
* Parameters p = NetworkTestHarness.getParameters();
* p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
*
* Network network = Network.create("test network", p)
* .add(Network.createRegion("r1")
* .add(Network.createLayer("1", p)
* .alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
* .add(Anomaly.create())
* .add(new TemporalMemory())
* .add(new SpatialPooler())
* .add(Sensor.create(FileSensor::create, SensorParams.create(
* Keys::path, "", ResourceLocator.path("rec-center-hourly.csv"))))));
*
*
*
* As you can see, {@code Networks} can be composed in "fluent" style making their
* declaration much more concise.
*
* While the above Network contains only 1 Region and only 1 Layer, Networks with many Regions
* and Layers within them; may be composed. For example:
*
* Connections cons = new Connections();
*
* Network network = Network.create("test network", p)
* .add(Network.createRegion("r1")
* .add(Network.createLayer("1", p)
* .add(Anomaly.create())
* .add(new TemporalMemory())
* .add(new SpatialPooler())))
* .add(Network.createRegion("r2")
* .add(Network.createLayer("1", p)
* .using(cons)
* .alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
* .add(new TemporalMemory()))
* .add(Network.createLayer("2", p)
* .using(cons)
* .add(new SpatialPooler())
* .add(Sensor.create(FileSensor::create, SensorParams.create(
* Keys::path, "", ResourceLocator.path("rec-center-hourly.csv")))))
* .connect("1", "2")) // Tell the Region to connect the two layers {@link Region#connect(String, String)}.
* .connect("r1", "r2"); // Tell the Network to connect the two regions {@link Network#connect(String, String)}.
*
* As you can see above, a {@link Connections} object is being shared among Region "r2"'s inner layers
* via the use of the {@link Layer#using(Connections)} method. Additionally, {@link Parameters} may be
* altered in place without changing the values in the original Parameters object, by calling {@link Layer#alterParameter(KEY, Object)}
* as seen in the above examples.
*
*
* Networks can be "observed", meaning they can return an {@link Observable} object
* which can be operated on in the map-reduce pattern of usage; and they can be
* subscribed to, to receive {@link Inference} objects as output from every process
* cycle. For instance:
*
* network.observe().subscribe(new Subscriber<Inference>() {
* public void onCompleted() { System.out.println("Input Completed!"); }
* public void onError(Throwable e) { e.printStackTrace(); }
* public void onNext(Inference i) {
* System.out.println(String.format("%d: %s", i.getRecordNum(), Arrays.toString(i.getSDR())));
* }
* });
*
*
*
* Likewise, each Region within a Network may be "observed" and/or subscribed to; and each
* Layer may be observed and subscribed to as well - for those instances where you would like
* to obtain the processing from an individual component within the network for use outside
* the network in some other area of your application. To find and subscribe to individual
* components try:
*
* Network network = ...
*
* Region region = network.lookup("<region name>");
* region.observe().subscribe(new Subscriber<Inference>() {
* public void onCompleted() { System.out.println("Input Completed!"); }
* public void onError(Throwable e) { e.printStackTrace(); }
* public void onNext(Inference i) {
* int[] sdr = i.getSDR();
* do something...
* }
* }
*
* Layer l2_3 = region.lookup("<layer name>");
* l2_3.observe().subscribe(new Subscriber<Inference>() {
* public void onCompleted() { System.out.println("Input Completed!"); }
* public void onError(Throwable e) { e.printStackTrace(); }
* public void onNext(Inference i) {
* int[] sdr = i.getSDR();
* do something...
* }
* }
*
*
* In addition there are many usage examples to be found in the {@link org.numenta.nupic.examples.napi.hotgym} package
* where there are tests which may be examined for details, and the {@link NetworkAPIDemo}
*
* @author David Ray
* @see Region
* @see Layer
* @see Inference
* @see ManualInput
* @see NetworkAPIDemo
*/
public class Network implements Persistable {
private static final long serialVersionUID = 1L;
private static final Logger LOGGER = LoggerFactory.getLogger(Network.class);
private String name;
private Parameters parameters;
private HTMSensor> sensor;
private MultiEncoder encoder;
private Region head;
private Region tail;
private Region sensorRegion;
private volatile Publisher publisher;
private volatile boolean isLearn = true;
private volatile boolean isThreadRunning;
private List regions = new ArrayList<>();
/** Stored check pointer function */
private transient Function checkPointFunction;
boolean shouldDoHalt = true;
Network() {}
/**
* Creates a new {@link Network}
*
* Warning: name cannot be null or empty
*
* @param name
* @param parameters
*/
public Network(String name, Parameters parameters) {
if(name == null || name.isEmpty()) {
throw new IllegalStateException("All Networks must have a name. " +
"Increases digestion, and overall happiness!");
}
this.name = name;
this.parameters = parameters;
if(parameters == null) {
throw new IllegalArgumentException("Network Parameters were null.");
}
}
/**
* Creates and returns an implementation of {@link Network}
*
* Warning: name cannot be null or empty
*
* @param name
* @param parameters
* @return
*/
public static Network create(String name, Parameters parameters) {
return new Network(name, parameters);
}
/**
* Creates and returns a child {@link Region} of this {@code Network}
*
* @param name The String identifier for the specified {@link Region}
* @return
*/
public static Region createRegion(String name) {
Network.checkName(name);
Region r = new Region(name, null);
return r;
}
/**
* Creates a {@link Layer} to hold algorithmic components and returns
* it.
*
* @param name the String identifier for the specified {@link Layer}
* @param p the {@link Parameters} to use for the specified {@link Layer}
* @return
*/
@SuppressWarnings("rawtypes")
public static Layer> createLayer(String name, Parameters p) {
Network.checkName(name);
return new Layer(name, null, p);
}
/**
* DO NOT CALL THIS METHOD! FOR INTERNAL USE ONLY!
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public Network preSerialize() {
if(shouldDoHalt && isThreadRunning) {
halt();
}else{ // Make sure "close()" has been called on the Network
if(regions.size() == 1) {
this.tail = regions.get(0);
}
tail.close();
}
regions.stream().forEach(r -> r.preSerialize());
return this;
}
/**
* DO NOT CALL THIS METHOD! FOR INTERNAL USE ONLY!
* {@inheritDoc}
*/
@SuppressWarnings("unchecked")
@Override
public Network postDeSerialize() {
regions.stream().forEach(r -> r.setNetwork(this));
regions.stream().forEach(r -> r.postDeSerialize());
// Connect Layer Observable chains (which are transient so we must
// rebuild them and their subscribers)
if(isMultiRegion()) {
Region curr = head;
Region nxt = curr.getUpstreamRegion();
do {
curr.connect(nxt);
} while((curr = nxt) != null && (nxt = nxt.getUpstreamRegion()) != null);
}
return this;
}
/**
* INTERNAL METHOD: DO NOT CALL
*
* Called from {@link Layer} to execute a check point from within the scope of
* this {@link Network}
* checkPointFunction
* @return the serialized {@code Network} in byte array form.
*/
byte[] internalCheckPointOp() {
shouldDoHalt = false;
byte[] serializedBytes = (byte[])checkPointFunction.apply(this);
shouldDoHalt = true;
return serializedBytes;
}
/**
* Sets the reference to the check point function.
* @param f function which executes check point logic.
*/
@SuppressWarnings("unchecked")
public void setCheckPointFunction(Function f) {
this.checkPointFunction = (Function)f;
}
/**
* USED INTERNALLY, DO NOT CALL
* Returns an {@link rx.Observable} operator that when subscribed to, invokes an operation
* that stores the state of this {@code Network} while keeping the Network up and running.
* The Network will be stored at the pre-configured location (in binary form only, not JSON).
*
* @return the {@link CheckPointOp} operator
*/
CheckPointOp getCheckPointOperator() {
LOGGER.debug("Network [" + getName() + "] called checkPoint() at: " + (new DateTime()));
if(regions.size() == 1) {
this.tail = regions.get(0);
}
return tail.getCheckPointOperator();
}
/**
* Restarts this {@code Network}. The network will run from the previous save point
* of the stored Network.
*
* @see {@link #restart(boolean)} for a start at "saved-index" behavior explanation.
*/
public void restart() {
restart(true);
}
/**
* Restarts this {@code Network}. If the "startAtIndex" flag is true, the Network
* will start from the last record number (plus 1) at which the Network was saved -
* continuing on from where it left off. The Network will achieve this by rebuilding
* the underlying Stream (if necessary, i.e. not for {@link ObservableSensor}s) and skipping
* the number of records equal to the stored record number plus one, continuing from where it left off.
*
* @param startAtIndex flag indicating whether to start this {@code Network} from
* its previous save point.
*/
public void restart(boolean startAtIndex) {
if(regions.size() < 1) {
throw new IllegalStateException("Nothing to start - 0 regions");
}
Region tail = regions.get(0);
Region upstream = tail;
while((upstream = upstream.getUpstreamRegion()) != null) {
tail = upstream;
}
// Record thread start
this.isThreadRunning = tail.restart(startAtIndex);
}
/**
*
* DO NOT CALL THIS METHOD!
*
* Called internally by an {@link ObservableSensor}'s factory method's creation of a new
* {@code ObservableSensor}. This would usually happen following a halt or
* deserialization.
*
* @param p the new Publisher created upon reconstitution of a new ObservableSensor
*/
void setPublisher(Publisher p) {
this.publisher = p;
publisher.setNetwork(this);
}
/**
* Returns the new {@link Publisher} created after halt or deserialization
* of this {@code Network}, when a new Publisher must be created.
*
* @return the new Publisher created after deserialization or halt.
* @see #getPublisherSupplier()
*/
public Publisher getPublisher() {
if(publisher == null) {
throw new NullPointerException("A Supplier must be built first. " +
"please see Network.getPublisherSupplier()");
}
return publisher;
}
/**
* Returns a flag indicating whether this {@code Network} contain multiple
* {@link Region}s.
*
* @return true if so, false if not.
*/
public boolean isMultiRegion() {
return regions.size() > 1;
}
/**
* Returns the String identifier for this {@code Network}
* @return
*/
public String getName() {
return name;
}
/**
* Calling this
* method will start the main engine thread which pulls in data
* from the connected {@link Sensor}(s).
*/
public void start() {
if(regions.size() < 1) {
throw new IllegalStateException("Nothing to start - 0 regions");
}
Region tail = regions.get(0);
Region upstream = tail;
while((upstream = upstream.getUpstreamRegion()) != null) {
tail = upstream;
}
// Record thread start
this.isThreadRunning = tail.start();
}
/**
* Returns a flag indicating that the {@code Network} has an {@link Observable}
* running on a thread.
*
* @return a flag indicating if threaded.
*/
public boolean isThreadedOperation() {
return this.isThreadRunning;
}
/**
* Halts this {@code Network}, stopping all threads and closing
* all {@link SensorFactory} connections to incoming data, freeing up
* any resources associated with the input connections.
*/
public void halt() {
// Call onComplete if using an ObservableSensor to complete the stream output.
if(publisher != null) {
publisher.onComplete();
}
if(regions.size() == 1) {
this.tail = regions.get(0);
}
tail.halt();
}
/**
* Returns a flag indicating whether this Network has a Region
* whose tail (input {@link Layer}) is halted.
* @return true if so, false if not
*/
public boolean isHalted() {
if(regions.size() == 1) {
this.tail = regions.get(0);
}
return tail.isHalted();
}
/**
* Returns the index of the last record processed.
*
* @return the last recordNum processed
*/
public int getRecordNum() {
if(regions.size() == 1) {
this.tail = regions.get(0);
}
return tail.getTail().getRecordNum();
}
/**
* Sets the learning mode.
* @param isLearn
*/
public void setLearn(boolean isLearn) {
this.isLearn = isLearn;
for(Region r : regions) {
r.setLearn(isLearn);
}
}
/**
* Returns the learning mode setting.
* @return
*/
public boolean isLearn() {
return isLearn;
}
/**
* Finds any {@link Region} containing a {@link Layer} which contains a {@link TemporalMemory}
* and resets them.
*/
public void reset() {
for(Region r : regions) {
r.reset();
}
}
/**
* Resets the recordNum in all {@link Region}s.
*/
public void resetRecordNum() {
for(Region r : regions) {
r.resetRecordNum();
}
}
/**
* Returns an {@link Observable} capable of emitting {@link Inference}s
* which contain the results of this {@code Network}'s processing chain.
* @return
*/
public Observable observe() {
if(regions.size() == 1) {
this.head = regions.get(0);
}
return head.observe();
}
/**
* Returns the top-most (last in execution order from
* bottom to top) {@link Region} in this {@code Network}
*
* @return
*/
public Region getHead() {
if(regions.size() == 1) {
this.head = regions.get(0);
}
return this.head;
}
/**
* Returns the bottom-most (first in execution order from
* bottom to top) {@link Region} in this {@code Network}
*
* @return
*/
public Region getTail() {
if(regions.size() == 1) {
this.tail = regions.get(0);
}
return this.tail;
}
/**
* For internal Use: Returns a boolean flag indicating whether
* the specified {@link Layer} is the tail of the Network.
* @param l the layer to test
* @return true if so, false if not
*/
boolean isTail(Layer> l) {
if(regions.size() == 1) {
this.tail = regions.get(0);
}
return tail.getTail() == l;
}
/**
* Returns a {@link Iterator} capable of walking the tree of regions
* from the root {@link Region} down through all the child Regions. In turn,
* a {@link Region} may be queried for a {@link Iterator} which will return
* an iterator capable of traversing the Region's contained {@link Layer}s.
*
* @return
*/
public Iterator iterator() {
return getRegions().iterator();
}
/**
* Used to manually input data into a {@link Network}, the other way
* being the call to {@link Network#start()} for a Network that contains a
* Region that contains a {@link Layer} which in turn contains a {@link Sensor} -OR-
* subscribing a receiving Region to this Region's output Observable.
*
* @param input One of (int[], String[], {@link ManualInput}, or Map<String, Object>)
*/
public void compute(T input) {
if(tail == null && regions.size() == 1) {
this.tail = regions.get(0);
}
if(head == null) {
addDummySubscriber();
}
tail.compute(input);
}
/**
* Used to manually input data into a {@link Network} in a synchronous way, the other way
* being the call to {@link Network#start()} for a Network that contains a
* Region that contains a {@link Layer} which in turn contains a {@link Sensor} -OR-
* subscribing a receiving Region to this Region's output Observable.
*
* @param input One of (int[], String[], {@link ManualInput}, or Map<String, Object>)
*/
public Inference computeImmediate(T input) {
if(isThreadRunning) {
throw new IllegalStateException("Cannot call computeImmediate() when Network has been started.");
}
if(tail == null && regions.size() == 1) {
this.tail = regions.get(0);
}
if(head == null) {
addDummySubscriber();
}
tail.compute(input);
return head.getHead().getInference();
}
/**
* Added when a synchronous call is made and there is no subscriber. No
* Subscriber leads to the observable chain not being constructed, therefore
* we must always have at least one subscriber.
*/
void addDummySubscriber() {
observe().subscribe(new Subscriber() {
@Override public void onCompleted() {}
@Override public void onError(Throwable e) { e.printStackTrace(); }
@Override public void onNext(Inference i) {}
});
}
/**
* Connects the specified source to the specified sink (the order of
* processing flows from source to sink, or lower level region to higher
* level region).
* @param regionSink the receiving end of the connection
* @param regionSource the source end of the connection
*
* @return this {@code Network}
*/
public Network connect(String regionSink, String regionSource) {
Region source = lookup(regionSource);
if(source == null) {
throw new IllegalArgumentException("Region with name: " + regionSource + " not added to Network.");
}
Region sink = lookup(regionSink);
if(sink == null) {
throw new IllegalArgumentException("Region with name: " + regionSink + " not added to Network.");
}
sink.connect(source);
tail = tail == null ? source : tail;
head = head == null ? sink : head;
Region bottom = source;
while((bottom = bottom.getUpstreamRegion()) != null) {
tail = bottom;
}
Region top = sink;
while((top = top.getDownstreamRegion()) != null) {
head = top;
}
return this;
}
/**
* Adds a {@link Region} to this {@code Network}
* @param region
* @return
*/
public Network add(Region region) {
regions.add(region);
region.setNetwork(this);
return this;
}
/**
* Closes all the {@link Region} objects, in this {@link Network}
*/
public Network close() {
regions.forEach(region -> region.close());
return this;
}
/**
* Returns a {@link List} view of the contained {@link Region}s.
* @return
*/
public List getRegions() {
return new ArrayList(regions);
}
/**
* Sets a reference to the {@link Region} which contains the {@link Sensor}
* (if any).
*
* @param r
*/
public void setSensorRegion(Region r) {
this.sensorRegion = r;
}
/**
* Returns a reference to the {@link Region} which contains the {@link Sensor}
*
* @return the Region which contains the Sensor
*/
public Region getSensorRegion() {
return sensorRegion;
}
/**
* Returns the {@link Region} with the specified name
* or null if it doesn't exist within this {@code Network}
* @param regionName
* @return
*/
public Region lookup(String regionName) {
for(Region r : regions) {
if(r.getName().equals(regionName)) {
return r;
}
}
return null;
}
/**
* Returns the network-level {@link Parameters}.
* @return
*/
public Parameters getParameters() {
return parameters;
}
/**
* Sets the reference to this {@code Network}'s Sensor
* @param sensor
*/
public void setSensor(HTMSensor> sensor) {
this.sensor = sensor;
this.sensor.initEncoder(this.parameters);
}
/**
* Returns the encoder present in one of this {@code Network}'s
* {@link Sensor}s
*
* @return
*/
public HTMSensor> getSensor() {
return sensor;
}
/**
* Sets the {@link MultiEncoder} on this Network
* @param e
*/
public void setEncoder(MultiEncoder e) {
this.encoder = e;
}
/**
* Returns the {@link MultiEncoder} with which this Network is configured.
* @return
*/
public MultiEncoder getEncoder() {
return this.encoder;
}
/**
* Checks the name for suitability within a given network,
* checking for reserved characters and such.
*
* @param name
*/
private static void checkName(String name) {
if(name.indexOf(":") != -1) {
throw new IllegalArgumentException("\":\" is a reserved character.");
}
}
/* (non-Javadoc)
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + (isLearn ? 1231 : 1237);
result = prime * result + ((name == null) ? 0 : name.hashCode());
result = prime * result + ((parameters == null) ? 0 : parameters.hashCode());
result = prime * result + ((regions == null) ? 0 : regions.hashCode());
result = prime * result + ((sensor == null) ? 0 : sensor.hashCode());
return result;
}
/* (non-Javadoc)
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
Network other = (Network)obj;
if(isLearn != other.isLearn)
return false;
if(name == null) {
if(other.name != null)
return false;
} else if(!name.equals(other.name))
return false;
if(parameters == null) {
if(other.parameters != null)
return false;
} else if(!parameters.equals(other.parameters))
return false;
if(regions == null) {
if(other.regions != null)
return false;
} else if(!regions.equals(other.regions))
return false;
if(sensor == null) {
if(other.sensor != null)
return false;
} else if(!sensor.equals(other.sensor))
return false;
return true;
}
}