com.simiacryptus.mindseye.layers.LoggingWrapperLayer Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.layers;
import com.google.gson.JsonObject;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefArrays;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.UUID;
import java.util.function.IntFunction;
@SuppressWarnings("serial")
public final class LoggingWrapperLayer extends WrapperLayer {
static final Logger log = LoggerFactory.getLogger(LoggingWrapperLayer.class);
protected LoggingWrapperLayer(@Nonnull final JsonObject json, Map rs) {
super(json, rs);
}
public LoggingWrapperLayer(final Layer inner) {
super(inner);
}
@Nonnull
@SuppressWarnings("unused")
public static LoggingWrapperLayer fromJson(@Nonnull final JsonObject json, Map rs) {
return new LoggingWrapperLayer(json, rs);
}
@Nonnull
public static String getString(@Nonnull Tensor tensor) {
try {
return RefArrays.toString(tensor.getDimensions());
} finally {
tensor.freeRef();
}
}
@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
final LoggingWrapperLayer loggingWrapperLayer = this.addRef();
final Result[] wrappedInput = RefIntStream.range(0, inObj.length)
.mapToObj(RefUtil.wrapInterface((IntFunction) i -> {
final Result inputToWrap = inObj[i].addRef();
boolean alive = inputToWrap.isAlive();
TensorList data = inputToWrap.getData();
Result.Accumulator accumulator = new Accumulator2(inner.addRef(), i, inputToWrap.getAccumulator());
inputToWrap.freeRef();
return new Result(data, accumulator, alive);
}, loggingWrapperLayer.addRef(), RefUtil.addRef(inObj)))
.toArray(Result[]::new);
for (int i = 0; i < inObj.length; i++) {
final TensorList tensorList = inObj[i].getData();
@Nonnull final String formatted = RefUtil.get(
tensorList.stream().map(LoggingWrapperLayer::getString).reduce((a, b) -> a + "\n" + b)
);
tensorList.freeRef();
log.info(RefString.format("Input %s for key %s: \n\t%s", i, inner.getName(),
formatted.replaceAll("\n", "\n\t")));
}
RefUtil.freeRef(inObj);
@Nullable final Result output = inner.eval(wrappedInput);
{
assert output != null;
final TensorList tensorList = output.getData();
@Nonnull final String formatted = RefUtil.get(tensorList.stream().map(LoggingWrapperLayer::getString).reduce((a, b) -> a + "\n" + b));
tensorList.freeRef();
log.info(RefString.format("Output for key %s: \n\t%s", inner.getName(), formatted.replaceAll("\n", "\n\t")));
}
boolean alive = output.isAlive();
TensorList data = output.getData();
Result.Accumulator accumulator = new Accumulator(inner.addRef(), output.getAccumulator());
output.freeRef();
loggingWrapperLayer.freeRef();
return new Result(data, accumulator, alive);
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
LoggingWrapperLayer addRef() {
return (LoggingWrapperLayer) super.addRef();
}
private static class Accumulator extends Result.Accumulator {
private Layer inner;
private Result.Accumulator accumulator;
public Accumulator(Layer inner, Result.Accumulator accumulator) {
this.inner = inner;
this.accumulator = accumulator;
}
@Override
public void accept(@Nullable DeltaSet buffer, @Nonnull TensorList data) {
@Nonnull final String formatted = RefUtil.get(data.stream().map(tensor -> {
return getString(tensor);
}).reduce((a, b) -> a + "\n" + b));
log.info(RefString.format("Feedback Input for key %s: \n\t%s", inner.getName(),
formatted.replaceAll("\n", "\n\t")));
Result.Accumulator accumulator = this.accumulator;
try {
accumulator.accept(buffer, data);
} finally {
accumulator.freeRef();
}
}
public @SuppressWarnings("unused")
void _free() {
super._free();
accumulator.freeRef();
inner.freeRef();
}
}
private static class Accumulator2 extends Result.Accumulator {
private final int i;
private Layer inner;
private Result.Accumulator accumulator;
public Accumulator2(Layer inner, int i, Result.Accumulator accumulator) {
this.i = i;
this.inner = inner;
this.accumulator = accumulator;
}
@Override
public void accept(@Nullable DeltaSet buffer, @Nonnull TensorList data) {
@Nonnull final String formatted = RefUtil.get(data.stream().map(LoggingWrapperLayer::getString).reduce((a, b) -> a + "\n" + b));
log.info(RefString.format("Feedback Output %s for key %s: \n\t%s", i, inner.getName(),
formatted.replaceAll("\n", "\n\t")));
Result.Accumulator accumulator = this.accumulator;
try {
accumulator.accept(buffer, data);
} finally {
accumulator.freeRef();
}
}
public @SuppressWarnings("unused")
void _free() {
super._free();
accumulator.freeRef();
inner.freeRef();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy