All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.opt.line.QuadraticSearch Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.opt.line;

import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.lang.IterativeStopException;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.orient.DescribeOrientationWrapper;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/**
 * This exact line search method uses a linear interpolation of the derivative to find the extrema, where dx/dy = 0.
 * Bracketing conditions are established apply logic that largely ignores derivatives, due to heuristic observations.
 */
public class QuadraticSearch implements LineSearchStrategy {

  private final double initialDerivFactor = 0.95;
  private double absoluteTolerance = 1e-12;
  private double currentRate = 0.0;
  private double minRate = 1e-10;
  private double maxRate = 1e10;
  private double relativeTolerance = 1e-2;
  private double stepSize = 1.0;

  /**
   * Step point sample.
   *
   * @param cursor  the cursor
   * @param monitor the monitor
   * @return the point sample
   */
  public PointSample _step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
    double thisX = 0;
    LineSearchPoint thisPoint = cursor.step(thisX, monitor);
    final LineSearchPoint initialPoint = thisPoint.addRef();
    double leftX = thisX;
    LineSearchPoint leftPoint = thisPoint.addRef();
    monitor.log(String.format("F(%s) = %s", leftX, leftPoint));
    if (0 == leftPoint.derivative) {
      initialPoint.freeRef();
      thisPoint.freeRef();
      PointSample point = leftPoint.point;
      point.addRef();
      leftPoint.freeRef();
      return point;
    }

    @Nonnull final LocateInitialRightPoint locateInitialRightPoint = new LocateInitialRightPoint(cursor, monitor, leftPoint).apply();
    @Nonnull LineSearchPoint rightPoint = locateInitialRightPoint.getRightPoint();
    rightPoint.addRef();
    double rightX = locateInitialRightPoint.getRightX();

    try {
      int loops = 0;
      while (true) {
        final double a = (rightPoint.derivative - leftPoint.derivative) / (rightX - leftX);
        final double b = rightPoint.derivative - a * rightX;
        thisX = -b / a;
        final boolean isBracketed = Math.signum(leftPoint.derivative) != Math.signum(rightPoint.derivative);
        if (!Double.isFinite(thisX) || isBracketed && (leftX > thisX || rightX < thisX)) {
          thisX = (rightX + leftX) / 2;
        }
        if (!isBracketed && thisX < 0) {
          thisX = rightX * 2;
        }
        if (thisX < getMinRate()) thisX = getMinRate();
        if (thisX > getMaxRate()) thisX = getMaxRate();
        if (isSame(leftX, thisX, 1.0)) {
          monitor.log(String.format("Converged to left"));
          return filter(cursor, leftPoint.point, monitor);
        } else if (isSame(thisX, rightX, 1.0)) {
          monitor.log(String.format("Converged to right"));
          return filter(cursor, rightPoint.point, monitor);
        }
        thisPoint.freeRef();
        thisPoint = null;
        thisPoint = cursor.step(thisX, monitor);
        if (isSame(cursor, monitor, leftPoint, thisPoint)) {
          monitor.log(String.format("%s ~= %s", leftX, thisX));
          return filter(cursor, leftPoint.point, monitor);
        }
        if (isSame(cursor, monitor, thisPoint, rightPoint)) {
          monitor.log(String.format("%s ~= %s", thisX, rightX));
          return filter(cursor, rightPoint.point, monitor);
        }
        thisPoint.freeRef();
        thisPoint = null;
        thisPoint = cursor.step(thisX, monitor);
        boolean isLeft;
        if (!isBracketed) {
          isLeft = Math.abs(rightPoint.point.rate - thisPoint.point.rate) > Math.abs(leftPoint.point.rate - thisPoint.point.rate);
        } else {
          isLeft = thisPoint.derivative < 0;
        }
        //monitor.log(String.format("isLeft=%s; isBracketed=%s; leftPoint=%s; rightPoint=%s", isLeft, isBracketed, leftPoint, rightPoint));
        monitor.log(String.format("F(%s) = %s, evalInputDelta = %s", thisX, thisPoint, thisPoint.point.getMean() - initialPoint.point.getMean()));
        if (loops++ > 10) {
          monitor.log(String.format("Loops = %s", loops));
          PointSample filter = filter(cursor, thisPoint.point, monitor);
          return filter;
        }
        if (isSame(cursor, monitor, leftPoint, rightPoint)) {
          monitor.log(String.format("%s ~= %s", leftX, rightX));
          PointSample filter = filter(cursor, thisPoint.point, monitor);
          return filter;
        }
        if (isLeft) {
          if (thisPoint.point.getMean() > leftPoint.point.getMean()) {
            monitor.log(String.format("%s > %s", thisPoint.point.getMean(), leftPoint.point.getMean()));
            return filter(cursor, leftPoint.point, monitor);
          }
          if (!isBracketed && leftPoint.point.getMean() < rightPoint.point.getMean()) {
            rightX = leftX;
            if (null != rightPoint) rightPoint.freeRef();
            rightPoint = leftPoint;
            rightPoint.addRef();
          }
          if (null != leftPoint) leftPoint.freeRef();
          leftPoint = thisPoint.addRef();
          leftX = thisX;
          monitor.log(String.format("Left bracket at %s", thisX));
        } else {
          if (thisPoint.point.getMean() > rightPoint.point.getMean()) {
            monitor.log(String.format("%s > %s", thisPoint.point.getMean(), rightPoint.point.getMean()));
            return filter(cursor, rightPoint.point, monitor);
          }
          if (!isBracketed && rightPoint.point.getMean() < leftPoint.point.getMean()) {
            leftX = rightX;
            if (null != leftPoint) leftPoint.freeRef();
            leftPoint = rightPoint;
            leftPoint.addRef();
          }
          rightX = thisX;
          if (null != rightPoint) rightPoint.freeRef();
          rightPoint = thisPoint.addRef();
          monitor.log(String.format("Right bracket at %s", thisX));
        }
      }
    } finally {
      if (null != leftPoint) leftPoint.freeRef();
      if (null != rightPoint) rightPoint.freeRef();
      if (null != thisPoint) thisPoint.freeRef();
      if (null != initialPoint) initialPoint.freeRef();
      if (null != locateInitialRightPoint) locateInitialRightPoint.freeRef();
    }
  }

  private String diagnose(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor, @Nonnull final LineSearchPoint a, @Nonnull final LineSearchPoint b) {
    final LineSearchPoint verifyA = cursor.step(a.point.rate, monitor);
    final boolean validA = isSame(a.point.getMean(), verifyA.point.getMean(), 1.0);
    monitor.log(String.format("Verify %s: %s (%s)", a.point.rate, verifyA.point.getMean(), validA));
    verifyA.freeRef();
    if (!validA) {
      DescribeOrientationWrapper.render(a.point.weights, a.point.delta);
      return "Non-Reproducable Point Found: " + a.point.rate;
    }
    final LineSearchPoint verifyB = cursor.step(b.point.rate, monitor);
    final boolean validB = isSame(b.point.getMean(), verifyB.point.getMean(), 1.0);
    monitor.log(String.format("Verify %s: %s (%s)", b.point.rate, verifyB.point.getMean(), validB));
    verifyB.freeRef();
    if (!validA && !validB) return "Non-Reproducable Function Found";
    if (validA && validB) return "Function Discontinuity Found";
    if (!validA) {
      return "Non-Reproducable Point Found: " + a.point.rate;
    }
    if (!validB) {
      return "Non-Reproducable Point Found: " + b.point.rate;
    }
    return "";
  }

  private PointSample filter(@Nonnull final LineSearchCursor cursor, @Nonnull final PointSample point, final TrainingMonitor monitor) {
    if (stepSize == 1.0) {
      point.addRef();
      return point;
    } else {
      LineSearchPoint step = cursor.step(point.rate * stepSize, monitor);
      PointSample point1 = step.point;
      point1.addRef();
      step.freeRef();
      return point1;
    }
  }

  /**
   * Gets absolute tolerance.
   *
   * @return the absolute tolerance
   */
  public double getAbsoluteTolerance() {
    return absoluteTolerance;
  }

  /**
   * Sets absolute tolerance.
   *
   * @param absoluteTolerance the absolute tolerance
   * @return the absolute tolerance
   */
  @Nonnull
  public QuadraticSearch setAbsoluteTolerance(final double absoluteTolerance) {
    this.absoluteTolerance = absoluteTolerance;
    return this;
  }

  /**
   * Gets current rate.
   *
   * @return the current rate
   */
  public double getCurrentRate() {
    return currentRate;
  }

  /**
   * Sets current rate.
   *
   * @param currentRate the current rate
   * @return the current rate
   */
  @Nonnull
  public QuadraticSearch setCurrentRate(final double currentRate) {
    this.currentRate = currentRate;
    return this;
  }

  /**
   * Gets min rate.
   *
   * @return the min rate
   */
  public double getMinRate() {
    return minRate;
  }

  /**
   * Sets min rate.
   *
   * @param minRate the min rate
   */
  public QuadraticSearch setMinRate(final double minRate) {
    this.minRate = minRate;
    return this;
  }

  /**
   * Gets relative tolerance.
   *
   * @return the relative tolerance
   */
  public double getRelativeTolerance() {
    return relativeTolerance;
  }

  /**
   * Sets relative tolerance.
   *
   * @param relativeTolerance the relative tolerance
   * @return the relative tolerance
   */
  @Nonnull
  public QuadraticSearch setRelativeTolerance(final double relativeTolerance) {
    this.relativeTolerance = relativeTolerance;
    return this;
  }

  /**
   * Gets runStep size.
   *
   * @return the runStep size
   */
  public double getStepSize() {
    return stepSize;
  }

  /**
   * Sets runStep size.
   *
   * @param stepSize the runStep size
   * @return the runStep size
   */
  @Nonnull
  public QuadraticSearch setStepSize(final double stepSize) {
    this.stepSize = stepSize;
    return this;
  }

  /**
   * Is same boolean.
   *
   * @param a     the a
   * @param b     the b
   * @param slack the slack
   * @return the boolean
   */
  protected boolean isSame(final double a, final double b, final double slack) {
    final double diff = Math.abs(a - b) / slack;
    final double scale = Math.max(Math.abs(a), Math.abs(b));
    return diff < absoluteTolerance || diff < scale * relativeTolerance;
  }

  /**
   * Is same boolean.
   *
   * @param cursor  the cursor
   * @param monitor the monitor
   * @param a       the a
   * @param b       the b
   * @return the boolean
   */
  protected boolean isSame(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor, @Nonnull final LineSearchPoint a, @Nonnull final LineSearchPoint b) {
    if (isSame(a.point.rate, b.point.rate, 1.0)) {
      if (!isSame(a.point.getMean(), b.point.getMean(), 10.0)) {
        @Nonnull final String diagnose = diagnose(cursor, monitor, a, b);
        monitor.log(diagnose);
        throw new IterativeStopException(diagnose);
      }
      return true;
    } else {
      return false;
    }
  }

  @Override
  public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
    if (currentRate < getMinRate()) {
      currentRate = getMinRate();
    }
    if (currentRate > getMaxRate()) {
      currentRate = getMaxRate();
    }
    final PointSample pointSample = _step(cursor, monitor);
    setCurrentRate(pointSample.rate);
    return pointSample;
  }

  public double getMaxRate() {
    return maxRate;
  }

  public QuadraticSearch setMaxRate(double maxRate) {
    this.maxRate = maxRate;
    return this;
  }

  private class LocateInitialRightPoint extends ReferenceCountingBase {
    @Nonnull
    private final LineSearchCursor cursor;
    @Nonnull
    private final LineSearchPoint initialPoint;
    @Nonnull
    private final TrainingMonitor monitor;
    private LineSearchPoint thisPoint;
    private double thisX;

    /**
     * Instantiates a new Locate initial right point.
     *
     * @param cursor    the cursor
     * @param monitor   the monitor
     * @param leftPoint the left point
     */
    public LocateInitialRightPoint(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor, @Nonnull final LineSearchPoint leftPoint) {
      this.cursor = cursor;
      this.monitor = monitor;
      initialPoint = leftPoint;
      thisX = getCurrentRate() > 0 ? getCurrentRate() : Math.abs(leftPoint.point.getMean() * 1e-4 / leftPoint.derivative);
      thisPoint = cursor.step(thisX, monitor);
      monitor.log(String.format("F(%s) = %s, evalInputDelta = %s", thisX, thisPoint, thisPoint.point.getMean() - initialPoint.point.getMean()));
      this.cursor.addRef();
      this.initialPoint.addRef();
    }

    /**
     * Apply locate initial right point.
     *
     * @return the locate initial right point
     */
    @Nonnull
    public LocateInitialRightPoint apply() {
      assertAlive();
      @Nullable LineSearchPoint lastPoint = null;
      try {
        int loops = 0;
        while (true) {
          if (null != lastPoint) lastPoint.freeRef();
          lastPoint = thisPoint;
          lastPoint.addRef();
          if (isSame(cursor, monitor, initialPoint, thisPoint)) {
            monitor.log(String.format("%s ~= %s", initialPoint.point.rate, thisX));
            return this;
          } else if (thisPoint.point.getMean() > initialPoint.point.getMean() && thisX > minRate) {
            thisX = thisX / 13;
          } else if (thisPoint.derivative < initialDerivFactor * thisPoint.derivative && thisX < maxRate) {
            thisX = thisX * 7;
          } else {
            monitor.log(String.format("%s <= %s", thisPoint.point.getMean(), initialPoint.point.getMean()));
            return this;
          }

          if (null != thisPoint) thisPoint.freeRef();
          thisPoint = cursor.step(thisX, monitor);
          if (isSame(cursor, monitor, lastPoint, thisPoint)) {
            monitor.log(String.format("%s ~= %s", lastPoint.point.rate, thisX));
            return this;
          }
          monitor.log(String.format("F(%s) = %s, evalInputDelta = %s", thisX, thisPoint, thisPoint.point.getMean() - initialPoint.point.getMean()));
          if (loops++ > 50) {
            monitor.log(String.format("Loops = %s", loops));
            return this;
          }
        }
      } finally {
        if (null != lastPoint) lastPoint.freeRef();
      }
    }

    /**
     * Gets right point.
     *
     * @return the right point
     */
    public LineSearchPoint getRightPoint() {
      return thisPoint;
    }

    /**
     * Gets right x.
     *
     * @return the right x
     */
    public double getRightX() {
      return thisX;
    }

    @Override
    protected void _free() {
      this.thisPoint.freeRef();
      this.cursor.freeRef();
      this.initialPoint.freeRef();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy