
com.simiacryptus.mindseye.layers.ValueLayer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.lang.*;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* This key does not require any input, and produces a constant output. This constant can be tuned by optimization
* processes.
*/
@SuppressWarnings("serial")
public class ValueLayer extends LayerBase {
@Nullable
private Tensor[] data;
/**
* Instantiates a new Const nn key.
*
* @param json the json
* @param resources the resources
*/
protected ValueLayer(@Nonnull final JsonObject json, Map resources) {
super(json);
JsonArray values = json.getAsJsonArray("values");
data = IntStream.range(0, values.size()).mapToObj(i -> Tensor.fromJson(values.get(i), resources)).toArray(i -> new Tensor[i]);
}
/**
* Instantiates a new Const nn key.
*
* @param data the data
*/
public ValueLayer(final @Nonnull Tensor... data) {
super();
this.data = Arrays.copyOf(data, data.length);
Arrays.stream(this.data)
.map(x -> new RefWrapper(x)).distinct().map(x -> (Tensor) x.obj)
.forEach(ReferenceCountingBase::addRef);
this.frozen = true;
}
public static Layer wrap(Tensor tensor) {
ValueLayer valueLayer = new ValueLayer(tensor);
tensor.freeRef();
return valueLayer;
}
/**
* From json const nn key.
*
* @param json the json
* @param rs the rs
* @return the const nn key
*/
public static ValueLayer fromJson(@Nonnull final JsonObject json, Map rs) {
return new ValueLayer(json, rs);
}
@Nonnull
@Override
public Result eval(@Nonnull final Result... array) {
assert 0 == array.length;
ValueLayer.this.addRef();
Arrays.stream(data)
.map(x -> new RefWrapper(x)).distinct().map(x -> (Tensor) x.obj)
.forEach(ReferenceCountingBase::addRef);
return new Result(TensorArray.create(ValueLayer.this.data), (@Nonnull final DeltaSet buffer, @Nonnull final TensorList data) -> {
if (!isFrozen()) {
assertAlive();
assert (1 == ValueLayer.this.data.length || ValueLayer.this.data.length == data.length());
for (int i = 0; i < data.length(); i++) {
Tensor delta = data.get(i);
Tensor value = ValueLayer.this.data[i % ValueLayer.this.data.length];
buffer.get(value.getId(), value.getData()).addInPlace(delta.getData()).freeRef();
delta.freeRef();
}
}
data.freeRef();
}) {
@Override
protected void _free() {
Arrays.stream(ValueLayer.this.data)
.map(x -> new RefWrapper(x)).distinct().map(x -> (Tensor) x.obj)
.forEach(ReferenceCountingBase::freeRef);
ValueLayer.this.freeRef();
}
@Override
public boolean isAlive() {
return !ValueLayer.this.isFrozen();
}
};
}
@Override
protected void _free() {
Arrays.stream(ValueLayer.this.data)
.map(x -> new RefWrapper(x)).distinct().map(x -> (Tensor) x.obj)
.forEach(ReferenceCountingBase::freeRef);
}
/**
* Gets data.
*
* @return the data
*/
@Nullable
public Tensor[] getData() {
return data;
}
/**
* Sets data.
*
* @param data the data
*/
public void setData(final Tensor... data) {
Arrays.stream(data)
.map(x -> new RefWrapper(x)).distinct().map(x -> (Tensor) x.obj)
.forEach(ReferenceCountingBase::addRef);
if (null != this.data) Arrays.stream(this.data)
.map(x -> new RefWrapper(x)).distinct().map(x -> (Tensor) x.obj)
.forEach(ReferenceCountingBase::freeRef);
this.data = data;
}
@Nonnull
@Override
public JsonObject getJson(Map resources, @Nonnull DataSerializer dataSerializer) {
@Nonnull final JsonObject json = super.getJsonStub();
JsonArray values = new JsonArray();
Arrays.stream(data).map(datum -> datum.toJson(resources, dataSerializer)).forEach(values::add);
json.add("values", values);
return json;
}
@Nonnull
@Override
public List state() {
return Arrays.stream(data).map(x -> x.getData()).collect(Collectors.toList());
}
public static class RefWrapper {
public final T obj;
public RefWrapper(T obj) {
this.obj = obj;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RefWrapper> that = (RefWrapper>) o;
return obj == that.obj;
}
@Override
public int hashCode() {
return System.identityHashCode(obj);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy