
com.simiacryptus.mindseye.layers.meta.SumMetaLayer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye Show documentation
Show all versions of mindseye Show documentation
A neural network library for Java 8
/*
* Copyright (c) 2017 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers.meta;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.layers.DeltaSet;
import com.simiacryptus.mindseye.layers.NNLayer;
import com.simiacryptus.mindseye.layers.NNResult;
import com.simiacryptus.util.ml.Tensor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.stream.IntStream;
@SuppressWarnings("serial")
public class SumMetaLayer extends NNLayer {
private Tensor lastResult;
public JsonObject getJson() {
JsonObject json = super.getJsonStub();
json.add("lastResult", lastResult.getJson());
return json;
}
public static SumMetaLayer fromJson(JsonObject json) {
return new SumMetaLayer(json);
}
protected SumMetaLayer(JsonObject id) {
super(id);
lastResult = Tensor.fromJson(id.getAsJsonObject("lastResult"));
}
@SuppressWarnings("unused")
private static final Logger log = LoggerFactory.getLogger(SumMetaLayer.class);
public SumMetaLayer() {
}
@Override
public NNResult eval(final NNResult... inObj) {
NNResult input = inObj[0];
int itemCnt = input.data.length;
if(1
IntStream.range(0, itemCnt)
.mapToDouble(dataIndex -> input.data[dataIndex].get(c))
.sum());
return new NNResult(lastResult) {
@Override
public void accumulate(final DeltaSet buffer, final Tensor[] data) {
if (input.isAlive()) {
Tensor delta = data[0];
Tensor feedback[] = new Tensor[itemCnt];
Arrays.parallelSetAll(feedback, i -> new Tensor(delta.getDims()));
lastResult.mapParallel((rho, inputCoord) -> {
for (int inputItem = 0; inputItem < itemCnt; inputItem++) {
feedback[inputItem].add(inputCoord, delta.get(inputCoord));
}
return 0;
});
input.accumulate(buffer, feedback);
}
}
@Override
public boolean isAlive() {
return input.isAlive();
}
};
}
@Override
public List state() {
return Arrays.asList();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy