All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.opt.IterativeTrainer Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.opt;

import com.simiacryptus.lang.TimedResult;
import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.IterativeStopException;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.layers.StochasticComponent;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch;
import com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchStrategy;
import com.simiacryptus.mindseye.opt.orient.LBFGS;
import com.simiacryptus.mindseye.opt.orient.OrientationStrategy;
import com.simiacryptus.util.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

/**
 * The basic type of training loop, which integrates a Trainable object apply an Orientation and Line Search strategy
 */
public class IterativeTrainer extends ReferenceCountingBase {
  private static final Logger log = LoggerFactory.getLogger(IterativeTrainer.class);

  private final Map lineSearchStrategyMap = new HashMap<>();
  private final Trainable subject;
  private AtomicInteger currentIteration = new AtomicInteger(0);
  private int iterationsPerSample = 100;
  private Function lineSearchFactory = (s) -> new ArmijoWolfeSearch();
  private int maxIterations = Integer.MAX_VALUE;
  private TrainingMonitor monitor = new TrainingMonitor();
  private OrientationStrategy orientation = new LBFGS();
  private double terminateThreshold;
  private Duration timeout;

  /**
   * Instantiates a new Iterative trainer.
   *
   * @param subject the subject
   */
  public IterativeTrainer(final Trainable subject) {
    this.subject = subject;
    this.subject.addRef();
    timeout = Duration.of(5, ChronoUnit.MINUTES);
    terminateThreshold = 0;
  }

  public static IterativeTrainer wrap(Trainable trainable) {
    IterativeTrainer trainer = new IterativeTrainer(trainable);
    trainable.freeRef();
    return trainer;
  }

  /**
   * Gets current iteration.
   *
   * @return the current iteration
   */
  public AtomicInteger getCurrentIteration() {
    return currentIteration;
  }

  /**
   * Sets current iteration.
   *
   * @param currentIteration the current iteration
   * @return the current iteration
   */
  @Nonnull
  public IterativeTrainer setCurrentIteration(final AtomicInteger currentIteration) {
    this.currentIteration = currentIteration;
    return this;
  }

  /**
   * Gets iterations per sample.
   *
   * @return the iterations per sample
   */
  public int getIterationsPerSample() {
    return iterationsPerSample;
  }

  /**
   * Sets iterations per sample.
   *
   * @param iterationsPerSample the iterations per sample
   * @return the iterations per sample
   */
  @Nonnull
  public IterativeTrainer setIterationsPerSample(final int iterationsPerSample) {
    this.iterationsPerSample = iterationsPerSample;
    return this;
  }

  /**
   * Gets line search factory.
   *
   * @return the line search factory
   */
  public Function getLineSearchFactory() {
    return lineSearchFactory;
  }

  /**
   * Sets line search factory.
   *
   * @param lineSearchFactory the line search factory
   * @return the line search factory
   */
  @Nonnull
  public IterativeTrainer setLineSearchFactory(final Function lineSearchFactory) {
    this.lineSearchFactory = lineSearchFactory;
    return this;
  }

  /**
   * Gets max iterations.
   *
   * @return the max iterations
   */
  public int getMaxIterations() {
    return maxIterations;
  }

  /**
   * Sets max iterations.
   *
   * @param maxIterations the max iterations
   * @return the max iterations
   */
  @Nonnull
  public IterativeTrainer setMaxIterations(final int maxIterations) {
    this.maxIterations = maxIterations;
    return this;
  }

  /**
   * Gets monitor.
   *
   * @return the monitor
   */
  public TrainingMonitor getMonitor() {
    return monitor;
  }

  /**
   * Sets monitor.
   *
   * @param monitor the monitor
   * @return the monitor
   */
  @Nonnull
  public IterativeTrainer setMonitor(final TrainingMonitor monitor) {
    this.monitor = monitor;
    return this;
  }

  /**
   * Gets orientation.
   *
   * @return the orientation
   */
  public OrientationStrategy getOrientation() {
    return orientation;
  }

  /**
   * Sets orientation.
   *
   * @param orientation the orientation
   * @return the orientation
   */
  @Nonnull
  public IterativeTrainer setOrientation(final OrientationStrategy orientation) {
    if (null != this.orientation) this.orientation.freeRef();
    this.orientation = orientation;
    return this;
  }

  /**
   * Gets terminate threshold.
   *
   * @return the terminate threshold
   */
  public double getTerminateThreshold() {
    return terminateThreshold;
  }

  /**
   * Sets terminate threshold.
   *
   * @param terminateThreshold the terminate threshold
   * @return the terminate threshold
   */
  @Nonnull
  public IterativeTrainer setTerminateThreshold(final double terminateThreshold) {
    this.terminateThreshold = terminateThreshold;
    return this;
  }

  /**
   * Gets timeout.
   *
   * @return the timeout
   */
  public Duration getTimeout() {
    return timeout;
  }

  /**
   * Sets timeout.
   *
   * @param timeout the timeout
   * @return the timeout
   */
  @Nonnull
  public IterativeTrainer setTimeout(final Duration timeout) {
    this.timeout = timeout;
    return this;
  }

  /**
   * Measure point sample.
   *
   * @return the point sample
   */
  @Nullable
  public PointSample measure() {
    @Nullable PointSample currentPoint = null;
    int retries = 0;
    do {
      if (null != currentPoint) {
        currentPoint.freeRef();
      }
      currentPoint = subject.measure(monitor);
    } while (!Double.isFinite(currentPoint.getMean()) && 10 < retries++);
    if (!Double.isFinite(currentPoint.getMean())) {
      currentPoint.freeRef();
      throw new IterativeStopException();
    }
    return currentPoint;
  }

  public void shuffle() {
    long seed = System.nanoTime();
    monitor.log(String.format("Reset training subject: " + seed));
    orientation.reset();
    subject.reseed(seed);
    if (subject.getLayer() instanceof DAGNetwork) {
      ((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
        if (layer instanceof StochasticComponent)
          ((StochasticComponent) layer).shuffle(seed);
      });
    }
  }

  /**
   * Run and free double.
   *
   * @return the double
   */
  public double runAndFree() {
    try {
      return run();
    } finally {
      freeRef();
    }
  }

  /**
   * Run double.
   *
   * @return the double
   */
  public double run() {
    long startTime = System.currentTimeMillis();
    final long timeoutMs = startTime + timeout.toMillis();
    long lastIterationTime = System.nanoTime();
    shuffle();
    @Nullable PointSample currentPoint = measure();
    try {
mainLoop:
      while (timeoutMs > System.currentTimeMillis()
          && terminateThreshold < currentPoint.getMean()
          && maxIterations > currentIteration.get()
          ) {
        shuffle();
        currentPoint.freeRef();
        currentPoint = null;
        currentPoint = measure();
        assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
subiterationLoop:
        for (int subiteration = 0; subiteration < iterationsPerSample || iterationsPerSample <= 0; subiteration++) {
          if (timeoutMs < System.currentTimeMillis()) {
            break mainLoop;
          }
          if (currentIteration.incrementAndGet() > maxIterations) {
            break mainLoop;
          }
          currentPoint.freeRef();
          currentPoint = null;
          currentPoint = measure();
          @Nullable final PointSample _currentPoint = currentPoint;
          @Nonnull final TimedResult timedOrientation = TimedResult.time(() -> orientation.orient(subject, _currentPoint, monitor));
          final LineSearchCursor direction = timedOrientation.result;
          final CharSequence directionType = direction.getDirectionType();
          @Nullable final PointSample previous = currentPoint;
          previous.addRef();
          try {
            @Nonnull final TimedResult timedLineSearch = TimedResult.time(() -> step(direction, directionType, previous));
            currentPoint.freeRef();
            currentPoint = null;
            currentPoint = timedLineSearch.result;
            final long now = System.nanoTime();
            final CharSequence perfString = String.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f",
                (now - lastIterationTime) / 1e9, timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9);
            lastIterationTime = now;
            monitor.log(String.format("Fitness changed from %s to %s", previous.getMean(), currentPoint.getMean()));
            if (previous.getMean() <= currentPoint.getMean()) {
              if (previous.getMean() < currentPoint.getMean()) {
                monitor.log(String.format("Resetting Iteration %s", perfString));
                currentPoint.freeRef();
                currentPoint = null;
                currentPoint = direction.step(0, monitor).point;
              } else {
                monitor.log(String.format("Static Iteration %s", perfString));
              }
              if (subject.reseed(System.nanoTime())) {
                monitor.log(String.format("Iteration %s failed, retrying. Error: %s",
                    currentIteration.get(), currentPoint.getMean()));
                monitor.log(String.format("Previous Error: %s -> %s",
                    previous.getRate(), previous.getMean()));
                break subiterationLoop;
              } else {
                monitor.log(String.format("Iteration %s failed, aborting. Error: %s",
                    currentIteration.get(), currentPoint.getMean()));
                monitor.log(String.format("Previous Error: %s -> %s",
                    previous.getRate(), previous.getMean()));
                break mainLoop;
              }
            } else {
              monitor.log(String.format("Iteration %s complete. Error: %s " + perfString,
                  currentIteration.get(), currentPoint.getMean()));
            }
            monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
          } finally {
            previous.freeRef();
            direction.freeRef();
          }
        }
      }
      if (subject.getLayer() instanceof DAGNetwork) {
        ((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
          if (layer instanceof StochasticComponent) ((StochasticComponent) layer).clearNoise();
        });
      }
      return null == currentPoint ? Double.NaN : currentPoint.getMean();
    } catch (Throwable e) {
      monitor.log(String.format("Error %s", Util.toString(e)));
      throw new RuntimeException(e);
    } finally {
      monitor.log(String.format("Final threshold in iteration %s: %s (> %s) after %.3fs (< %.3fs)",
          currentIteration.get(),
          null == currentPoint ? null : currentPoint.getMean(),
          terminateThreshold,
          (System.currentTimeMillis() - startTime) / 1000.0,
          timeout.toMillis() / 1000.0
      ));
      if (null != currentPoint) currentPoint.freeRef();
    }
  }

  /**
   * Sets timeout.
   *
   * @param number the number
   * @param units  the units
   * @return the timeout
   */
  @Nonnull
  public IterativeTrainer setTimeout(final int number, @Nonnull final TemporalUnit units) {
    timeout = Duration.of(number, units);
    return this;
  }

  /**
   * Sets timeout.
   *
   * @param number the number
   * @param units  the units
   * @return the timeout
   */
  @Nonnull
  public IterativeTrainer setTimeout(final int number, @Nonnull final TimeUnit units) {
    return setTimeout(number, Util.cvt(units));
  }

  /**
   * Step point sample.
   *
   * @param direction     the direction
   * @param directionType the direction type
   * @param previous      the previous
   * @return the point sample
   */
  public PointSample step(@Nonnull final LineSearchCursor direction, final CharSequence directionType, @Nonnull final PointSample previous) {
    PointSample currentPoint;
    LineSearchStrategy lineSearchStrategy;
    if (lineSearchStrategyMap.containsKey(directionType)) {
      lineSearchStrategy = lineSearchStrategyMap.get(directionType);
    } else {
      log.info(String.format("Constructing line search parameters: %s", directionType));
      lineSearchStrategy = lineSearchFactory.apply(direction.getDirectionType());
      lineSearchStrategyMap.put(directionType, lineSearchStrategy);
    }
    @Nonnull final FailsafeLineSearchCursor wrapped = new FailsafeLineSearchCursor(direction, previous, monitor);
    lineSearchStrategy.step(wrapped, monitor).freeRef();
    currentPoint = wrapped.getBest(monitor);
    wrapped.freeRef();
    return currentPoint;
  }

  @Override
  protected void _free() {
    this.subject.freeRef();
    this.orientation.freeRef();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy