com.simiacryptus.mindseye.eval.BasicTrainable Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.eval;
import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.ref.lang.RefIgnore;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.RefDoubleStream;
import com.simiacryptus.ref.wrappers.RefIntStream;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefString;
import org.jetbrains.annotations.NotNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.DoubleSummaryStatistics;
import java.util.UUID;
public class BasicTrainable extends ReferenceCountingBase implements DataTrainable, TrainableDataMask {
@Nullable
private final Layer network;
@Nullable
private RefList data;
@Nullable
private boolean[] mask = null;
private int verbosity = 0;
private int inputSize;
private Result[] inputProxies;
public BasicTrainable(@Nullable final Layer network) {
this.network = network;
data = null;
inputSize = 0;
}
@Nonnull
@Override
public Tensor[][] getData() {
assert data != null;
return data.toArray(new Tensor[][]{});
}
public synchronized void setData(@Nonnull final RefList data) {
if (null != this.data)
this.data.freeRef();
inputSize = getInputs(data);
this.data = data;
RefUtil.freeRef(this.inputProxies);
this.inputProxies = getInputProxies();
}
@Nonnull
protected Result[] getInputProxies() {
if (null == data) {
throw new IllegalArgumentException();
}
if (0 > data.size()) {
throw new IllegalArgumentException();
}
return RefIntStream.range(0, inputSize)
.mapToObj(inputIndex -> {
final Tensor[] tensors = select(inputIndex);
if (mask(inputIndex)) {
return new ConstantResult(new TensorArray(tensors));
} else {
return new MutableResult(tensors);
}
}).toArray(Result[]::new);
}
@Override
public Layer getLayer() {
return network == null ? null : network.addRef();
}
@Nullable
@Override
public boolean[] getMask() {
return mask;
}
@Override
public void setMask(final boolean... mask) {
this.mask = mask;
}
public void setVerbosity(int verbose) {
verbosity = verbose;
}
@RefIgnore
private static int getInputs(@NotNull @RefIgnore RefList data) {
if (null == data) return 0;
if (data.isEmpty()) return 0;
Tensor[] tensors = data.get(0);
int length = tensors.length;
RefUtil.freeRef(tensors);
return length;
}
@Override
public PointSample measure(@Nullable final TrainingMonitor monitor) {
assert data != null;
assert !data.isEmpty();
@Nonnull final TimedResult timedResult = TimedResult.time(() -> eval());
// log.info(String.format("Evaluated to %s evalInputDelta arrays", DeltaSet.apply.size()));
PointSample result = timedResult.getResult();
if (null != monitor && verbosity() > 1) {
monitor.log(RefString.format("Evaluated %s items in %.4fs (%s/%s)", data.size(), timedResult.timeNanos / 1e9,
result.getMean(), result.delta.getMagnitude()));
}
timedResult.freeRef();
return result;
}
public int verbosity() {
return verbosity;
}
public void _free() {
super._free();
if (null != inputProxies) {
RefUtil.freeRef(inputProxies);
inputProxies = null;
}
if (null != data) {
data.freeRef();
data = null;
}
inputSize = 0;
if (null != network) {
network.freeRef();
}
}
@Nonnull
public @Override
@SuppressWarnings("unused")
BasicTrainable addRef() {
return (BasicTrainable) super.addRef();
}
@Nonnull
protected PointSample eval() {
assert network != null;
final Result result = network.eval(RefUtil.addRef(this.inputProxies));
assert result != null;
final TensorList resultData = result.getData();
@Nonnull final DeltaSet deltaSet = new DeltaSet();
final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
RefDoubleStream doubleStream = x.doubleStream();
x.freeRef();
return doubleStream;
}).summaryStatistics();
final double sum = statistics.getSum();
result.accumulate(deltaSet.addRef());
result.freeRef();
StateSet stateSet = new StateSet<>(deltaSet.addRef());
resultData.freeRef();
//log.info(String.format("Evaluated to %s evalInputDelta buffers, %s mag", DeltaSet.getMap().size(), DeltaSet.getMagnitude()));
return normalize(new PointSample(deltaSet, stateSet, sum, 0.0, data.size()));
}
private boolean mask(int inputIndex) {
return null == mask || inputIndex >= mask.length || !mask[inputIndex];
}
@NotNull
private Tensor[] select(int inputIndex) {
return data.stream().map(batchData -> {
try {
return batchData[inputIndex].addRef();
} finally {
RefUtil.freeRef(batchData);
}
}).toArray(Tensor[]::new);
}
@NotNull
private PointSample normalize(PointSample pointSample) {
PointSample normalize = pointSample.normalize();
pointSample.freeRef();
return normalize;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy