com.simiacryptus.mindseye.eval.SampledArrayTrainable Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.eval;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.function.WeakCachedSupplier;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Random;
import java.util.function.Supplier;
public class SampledArrayTrainable extends TrainableWrapper
implements SampledTrainable, TrainableDataMask {
@Nonnull
private final RefList extends RefSupplier> trainingData;
private int minSamples = 0;
private long seed = Util.R.get().nextInt();
private int trainingSize;
public SampledArrayTrainable(@Nonnull final RefList extends RefSupplier> trainingData, final Layer network,
final int trainingSize) {
this(trainingData, network, trainingSize, trainingSize);
}
public SampledArrayTrainable(@Nonnull final RefList extends RefSupplier> trainingData, @Nullable final Layer network,
final int trainingSize, final int batchSize) {
super(new ArrayTrainable(null, network == null ? null : network.addRef(), batchSize));
if (null != network)
network.freeRef();
if (0 == trainingData.size()) {
trainingData.freeRef();
throw new IllegalArgumentException();
}
this.trainingData = trainingData;
this.trainingSize = trainingSize;
reseed(RefSystem.nanoTime());
}
public SampledArrayTrainable(@Nonnull final Tensor[][] trainingData, final Layer network, final int trainingSize) {
this(trainingData, network, trainingSize, trainingSize);
}
public SampledArrayTrainable(@Nonnull final Tensor[][] trainingData, @Nullable final Layer network, final int trainingSize,
final int batchSize) {
super(new ArrayTrainable(network == null ? null : network.addRef(), batchSize));
if (null != network)
network.freeRef();
RefUtil.freeRef(getInner());
if (0 == trainingData.length) {
RefUtil.freeRef(trainingData);
throw new IllegalArgumentException();
}
this.trainingData = RefArrays.stream(trainingData).map(obj -> {
return RefUtil.wrapInterface((RefSupplier) new WeakCachedSupplier(
() -> RefUtil.addRef(obj)), obj);
}).collect(RefCollectors.toList());
this.trainingSize = trainingSize;
reseed(RefSystem.nanoTime());
}
public int getMinSamples() {
return minSamples;
}
public void setMinSamples(int minSamples) {
this.minSamples = minSamples;
}
@Override
public int getTrainingSize() {
return Math.max(minSamples, Math.min(trainingData.size(), trainingSize));
}
@Override
public void setTrainingSize(final int trainingSize) {
this.trainingSize = trainingSize;
refreshSampledData();
}
private void setSeed(final int newValue) {
if (seed == newValue)
return;
seed = newValue;
refreshSampledData();
}
@Nonnull
@Override
public SampledCachedTrainable extends SampledTrainable> cached() {
return new SampledCachedTrainable<>(this.addRef());
}
@Override
public boolean reseed(final long seed) {
setSeed(Util.R.get().nextInt());
ArrayTrainable temp_00_0004 = getInner();
assert temp_00_0004 != null;
temp_00_0004.reseed(seed);
temp_00_0004.freeRef();
super.reseed(seed);
return true;
}
public @SuppressWarnings("unused")
void _free() {
super._free();
trainingData.freeRef();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
SampledArrayTrainable addRef() {
return (SampledArrayTrainable) super.addRef();
}
protected void refreshSampledData() {
assert 0 < trainingData.size();
Tensor[][] trainingData = null;
if (0 < getTrainingSize() && getTrainingSize() < this.trainingData.size() - 1) {
@Nonnull final Random random = new Random(seed);
if (null != trainingData) RefUtil.freeRef(trainingData);
trainingData = RefIntStream.generate(() -> random.nextInt(this.trainingData.size())).distinct()
.mapToObj(this.trainingData::get).filter(x -> {
if (x != null) {
Tensor[] tensors = x.get();
try {
if (tensors != null) {
return true;
}
} finally {
RefUtil.freeRef(tensors);
}
}
return false;
}).limit(getTrainingSize()).map(Supplier::get).toArray(Tensor[][]::new);
} else {
if (null != trainingData) RefUtil.freeRef(trainingData);
trainingData = this.trainingData.stream().filter(refSupplier -> {
if (refSupplier != null) {
Tensor[] tensors = refSupplier.get();
try {
if (tensors != null) {
return true;
}
} finally {
RefUtil.freeRef(tensors);
}
}
return false;
}).limit(getTrainingSize()).map(Supplier::get).toArray(Tensor[][]::new);
}
ArrayTrainable inner = getInner();
assert inner != null;
inner.setTrainingData(trainingData);
inner.freeRef();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy