All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.eval.SampledArrayTrainable Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.eval;

import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.Tensor;
import com.simiacryptus.util.Util;
import com.simiacryptus.util.function.WeakCachedSupplier;

import javax.annotation.Nonnull;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * This type handles the data selection part of stochastic gradient descent training. Between each epoch, a "reset"
 * method is called to re-sample the training data and pass it to the heapCopy Trainable implementation.
 */
public class SampledArrayTrainable extends TrainableWrapper implements SampledTrainable, TrainableDataMask {

  private final List> trainingData;
  private int minSamples = 0;
  private long seed = Util.R.get().nextInt();
  private int trainingSize;

  /**
   * Instantiates a new Stochastic array trainable.
   *
   * @param trainingData the training data
   * @param network      the network
   * @param trainingSize the training size
   */
  public SampledArrayTrainable(@Nonnull final List> trainingData, final Layer network, final int trainingSize) {
    this(trainingData, network, trainingSize, trainingSize);
  }

  /**
   * Instantiates a new Stochastic array trainable.
   *
   * @param trainingData the training data
   * @param network      the network
   * @param trainingSize the training size
   * @param batchSize    the batch size
   */
  public SampledArrayTrainable(@Nonnull final List> trainingData, final Layer network, final int trainingSize, final int batchSize) {
    super(new ArrayTrainable(null, network, batchSize));
    if (0 == trainingData.size()) throw new IllegalArgumentException();
    this.trainingData = trainingData;
    this.trainingSize = trainingSize;
    reseed(System.nanoTime());
  }

  /**
   * Instantiates a new Stochastic array trainable.
   *
   * @param trainingData the training data
   * @param network      the network
   * @param trainingSize the training size
   */
  public SampledArrayTrainable(@Nonnull final Tensor[][] trainingData, final Layer network, final int trainingSize) {
    this(trainingData, network, trainingSize, trainingSize);
  }

  /**
   * Instantiates a new Stochastic array trainable.
   *
   * @param trainingData the training data
   * @param network      the network
   * @param trainingSize the training size
   * @param batchSize    the batch size
   */
  public SampledArrayTrainable(@Nonnull final Tensor[][] trainingData, final Layer network, final int trainingSize, final int batchSize) {
    super(new ArrayTrainable(network, batchSize));
    if (0 == trainingData.length) throw new IllegalArgumentException();
    this.trainingData = Arrays.stream(trainingData).map(obj -> new WeakCachedSupplier<>(() -> obj)).collect(Collectors.toList());
    this.trainingSize = trainingSize;
    reseed(System.nanoTime());
  }

  @Nonnull
  @Override
  public SampledCachedTrainable cached() {
    return new SampledCachedTrainable<>(this);
  }

  /**
   * Gets min samples.
   *
   * @return the min samples
   */
  public int getMinSamples() {
    return minSamples;
  }

  /**
   * Sets min samples.
   *
   * @param minSamples the min samples
   * @return the min samples
   */
  @Nonnull
  public SampledArrayTrainable setMinSamples(final int minSamples) {
    this.minSamples = minSamples;
    return this;
  }

  @Override
  public int getTrainingSize() {
    return Math.max(minSamples, Math.min(trainingData.size(), trainingSize));
  }

  /**
   * Refresh sampled data.
   */
  protected void refreshSampledData() {
    assert 0 < trainingData.size();
    Tensor[][] trainingData;
    if (0 < getTrainingSize() && getTrainingSize() < this.trainingData.size() - 1) {
      @Nonnull final Random random = new Random(seed);
      trainingData = IntStream.generate(() -> random.nextInt(this.trainingData.size()))
          .distinct()
          .mapToObj(i -> this.trainingData.get(i))
          .filter(x -> x != null && x.get() != null)
          .limit(getTrainingSize()).map(x -> x.get())
          .toArray(i -> new Tensor[i][]);
    } else {
      trainingData = this.trainingData.stream()
          .filter(x -> x != null && x.get() != null)
          .limit(getTrainingSize()).map(x -> x.get())
          .toArray(i -> new Tensor[i][]);
    }
    getInner().setTrainingData(trainingData);
  }

  @Override
  public boolean reseed(final long seed) {
    setSeed(Util.R.get().nextInt());
    getInner().reseed(seed);
    super.reseed(seed);
    return true;
  }

  private void setSeed(final int newValue) {
    if (seed == newValue) return;
    seed = newValue;
    refreshSampledData();
  }

  @Nonnull
  @Override
  public SampledTrainable setTrainingSize(final int trainingSize) {
    this.trainingSize = trainingSize;
    refreshSampledData();
    return this;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy