Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.numenta.nupic.encoders.MultiEncoder Maven / Gradle / Ivy
/* ---------------------------------------------------------------------
* Numenta Platform for Intelligent Computing (NuPIC)
* Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
* with Numenta, Inc., for a separate license for this software code, the
* following terms and conditions apply:
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero Public License version 3 as
* published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
* See the GNU Affero Public License for more details.
*
* You should have received a copy of the GNU Affero Public License
* along with this program. If not, see http://www.gnu.org/licenses.
*
* http://numenta.org/licenses/
* ---------------------------------------------------------------------
*/
package org.numenta.nupic.encoders;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.numenta.nupic.FieldMetaType;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.util.Tuple;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntObjectHashMap;
/**
* A MultiEncoder encodes a dictionary or object with
* multiple components. A MultiEncode contains a number
* of sub-encoders, each of which encodes a separate component.
*
* @see Encoder
* @see Encoding
* @see Parameters
*
* @author wlmiller
*/
public class MultiEncoder extends Encoder {
private static final long serialVersionUID = 1L;
protected TIntObjectMap indexToCategory = new TIntObjectHashMap();
protected List categoryList;
protected int width;
protected static final String CATEGORY_DELIMITER = ";";
/**
* Constructs a new {@code MultiEncoder}
*/
private MultiEncoder() {}
/**
* Returns a builder for building MultiEncoders.
* This builder may be reused to produce multiple builders
*
* @return a {@code MultiEncoder.Builder}
*/
public static Encoder.Builder builder() {
return new MultiEncoder.Builder();
}
public void init() {
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@Override
public void setFieldStats(String fieldName, Map fieldStatistics) {
for (EncoderTuple t : getEncoders(this)) {
String name = t.getName();
Encoder encoder = t.getEncoder();
encoder.setFieldStats(name, fieldStatistics);
}
}
/**
* {@inheritDoc}
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
public void encodeIntoArray(Object input, int[] output) {
for (EncoderTuple t : getEncoders(this)) {
String name = t.getName();
Encoder encoder = t.getEncoder();
int offset = t.getOffset();
int[] tempArray = new int[encoder.getWidth()];
try {
Object o = getInputValue(input, name);
encoder.encodeIntoArray(o, tempArray);
}catch(Exception e) {
throw new IllegalStateException(e);
}
System.arraycopy(tempArray, 0, output, offset, tempArray.length);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
public int[] encodeField(String fieldName, Object value) {
for (EncoderTuple t : getEncoders(this)) {
String name = t.getName();
Encoder encoder = t.getEncoder();
if (name.equals(fieldName)) {
return encoder.encode(value);
}
}
return new int[]{};
}
@SuppressWarnings({ "rawtypes", "unchecked" })
public List encodeEachField(Object input) {
List encodings = new ArrayList();
for (EncoderTuple t : getEncoders(this)) {
String name = t.getName();
Encoder encoder = t.getEncoder();
encodings.add(encoder.encode(getInputValue(input, name)));
}
return encodings;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
public void addEncoder(String name, Encoder child) {
super.addEncoder(this, name, child, width);
for (Object d : child.getDescription()) {
Tuple dT = (Tuple) d;
description.add(new Tuple(dT.get(0), (int)dT.get(1) + getWidth()));
}
width += child.getWidth();
}
/**
* Configures this {@code MultiEncoder} using the specified settings map.
*
* @param fieldEncodings
* @return the assembled {@code MultiEncoder}
*/
public MultiEncoder addMultipleEncoders(Map> fieldEncodings) {
return MultiEncoderAssembler.assemble(this, fieldEncodings);
}
/**
* Convenience method to return the {@code Encoder} contained within this
* {@link MultiEncoder}, of a specific type.
* @param fmt the {@link FieldMetaType} specifying the type to return.
* @return the Encoder of the specified type or null if one isn't found.
*/
@SuppressWarnings("unchecked")
public > T getEncoderOfType(FieldMetaType fmt) {
Encoder> retVal = null;
for(Tuple t : getEncoders(this)) {
Encoder> enc = (Encoder>)t.get(1);
Set subTypes = enc.getDecoderOutputFieldTypes();
if(subTypes.contains(fmt)) {
retVal = enc; break;
}
}
return (T)retVal;
}
/**
* Open up for internal Network API use.
* Returns an {@link Encoder.Builder} which corresponds to the specified name.
* @param encoderName
* @return
*/
public Encoder.Builder,?> getBuilder(String encoderName) {
switch(encoderName) {
case "CategoryEncoder":
return CategoryEncoder.builder();
case "CoordinateEncoder":
return CoordinateEncoder.builder();
case "GeospatialCoordinateEncoder":
return GeospatialCoordinateEncoder.geobuilder();
case "LogEncoder":
return LogEncoder.builder();
case "PassThroughEncoder":
return PassThroughEncoder.builder();
case "ScalarEncoder":
return ScalarEncoder.builder();
case "AdaptiveScalarEncoder":
return AdaptiveScalarEncoder.builder();
case "SparsePassThroughEncoder":
return SparsePassThroughEncoder.sparseBuilder();
case "SDRCategoryEncoder":
return SDRCategoryEncoder.builder();
case "RandomDistributedScalarEncoder":
return RandomDistributedScalarEncoder.builder();
case "DateEncoder":
return DateEncoder.builder();
case "DeltaEncoder":
return DeltaEncoder.deltaBuilder();
case "SDRPassThroughEncoder" :
return SDRPassThroughEncoder.sptBuilder();
default:
throw new IllegalArgumentException("Invalid encoder: " + encoderName);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
public void setValue(Encoder.Builder builder, String param, Object value) {
switch(param) {
case "n":
builder.n(((Number)value).intValue());
break;
case "w":
builder.w(((Number)value).intValue());
break;
case "minVal":
builder.minVal(((Number)value).doubleValue());
break;
case "maxVal":
builder.maxVal(((Number)value).doubleValue());
break;
case "radius":
builder.radius(((Number)value).doubleValue());
break;
case "resolution":
builder.resolution(((Number)value).doubleValue());
break;
case "periodic":
builder.periodic((boolean) value);
break;
case "clipInput":
builder.clipInput((boolean) value);
break;
case "forced":
builder.forced((boolean) value);
break;
case "name":
builder.name((String) value);
break;
case "categoryList":
if(value instanceof String) {
String strVal = (String)value;
if(strVal.indexOf(CATEGORY_DELIMITER) == -1) {
throw new IllegalArgumentException("Category field not delimited with '" + CATEGORY_DELIMITER + "' character.");
}
value = Arrays.asList(strVal.split("[\\s]*\\" + CATEGORY_DELIMITER + "[\\s]*"));
}
if(builder instanceof CategoryEncoder.Builder) {
((CategoryEncoder.Builder) builder).categoryList((List) value);
}else{
((SDRCategoryEncoder.Builder) builder).categoryList((List) value);
}
break;
default:
throw new IllegalArgumentException("Invalid parameter: " + param);
}
}
@Override
public int getWidth() {
return width;
}
@Override
public int getN() {
return width;
}
@Override
public int getW() {
return width;
}
@Override
public String getName() {
if (name == null) return "";
else return name;
}
@Override
public boolean isDelta() {
return false;
}
@SuppressWarnings("rawtypes")
@Override
public void setLearning(boolean learningEnabled) {
for (EncoderTuple t : getEncoders(this)) {
Encoder encoder = t.getEncoder();
encoder.setLearningEnabled(learningEnabled);
}
}
@Override
public List getBucketValues(Class returnType) {
return null;
}
/**
* Returns a {@link EncoderBuilder} for constructing {@link MultiEncoder}s
*
* The base class architecture is put together in such a way where boilerplate
* initialization can be kept to a minimum for implementing subclasses, while avoiding
* the mistake-proneness of extremely long argument lists.
*
*/
public static class Builder extends Encoder.Builder {
private Builder() {}
@Override
public MultiEncoder build() {
//Must be instantiated so that super class can initialize
//boilerplate variables.
encoder = new MultiEncoder();
//Call super class here
super.build();
////////////////////////////////////////////////////////
// Implementing classes would do setting of specific //
// vars here together with any sanity checking //
////////////////////////////////////////////////////////
//Call init
((MultiEncoder)encoder).init();
return (MultiEncoder)encoder;
}
}
}