All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.opt.line.BisectionSearch Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.opt.line;

import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.ref.wrappers.RefString;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class BisectionSearch implements LineSearchStrategy {

  private double maxRate = 1e20;
  private double currentRate = 1.0;
  private double zeroTol = 1e-20;
  private double spanTol = 1e-3;

  public double getCurrentRate() {
    return currentRate;
  }

  @Nonnull
  public BisectionSearch setCurrentRate(final double currentRate) {
    this.currentRate = currentRate;
    return this;
  }

  public double getMaxRate() {
    return maxRate;
  }

  @Nonnull
  public BisectionSearch setMaxRate(double maxRate) {
    this.maxRate = maxRate;
    return this;
  }

  public double getSpanTol() {
    return spanTol;
  }

  @Nonnull
  public BisectionSearch setSpanTol(double spanTol) {
    this.spanTol = spanTol;
    return this;
  }

  public double getZeroTol() {
    return zeroTol;
  }

  @Nonnull
  public BisectionSearch setZeroTol(final double zeroTol) {
    this.zeroTol = zeroTol;
    return this;
  }

  @Override
  public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {

    double leftX = 0;
    final LineSearchPoint searchPoint = cursor.step(leftX, monitor);
    monitor.log(RefString.format("F(%s) = %s", leftX, searchPoint.addRef()));
    assert searchPoint != null;
    double leftValue = searchPoint.getPointSum();
    searchPoint.freeRef();
    double rightRight = getMaxRate();
    double rightX;
    double rightLineDeriv;
    double rightValue;
    double rightRightSoft = currentRate * 2;
    LineSearchPoint rightPoint = null;
    int loopCount = 0;
    while (true) {
      rightX = (leftX + Math.min(rightRight, rightRightSoft)) / 2;
      if (null != rightPoint) rightPoint.freeRef();
      rightPoint = cursor.step(rightX, monitor);
      monitor.log(RefString.format("F(%s)@%s = %s", rightX, loopCount, rightPoint.addRef()));
      assert rightPoint != null;
      rightLineDeriv = rightPoint.derivative;
      rightValue = rightPoint.getPointSum();
      if (loopCount++ > 100) {
        monitor.log(RefString.format("Loop overflow"));
        break;
      }
      if ((rightRight - leftX) * 2.0 / (leftX + rightRight) < spanTol) {
        monitor.log(RefString.format("Right limit is nonconvergent at %s/%s", leftX, rightRight));
        currentRate = leftX;
        rightPoint.freeRef();
        LineSearchPoint temp_49_0003 = cursor.step(leftX, monitor);
        assert temp_49_0003 != null;
        PointSample temp_49_0002 = temp_49_0003.getPoint();
        temp_49_0003.freeRef();
        cursor.freeRef();
        return temp_49_0002;
      }
      if (rightValue > leftValue) {
        rightRight = rightX;
        monitor.log(RefString.format("Right is at most %s", rightX));
      } else if (rightLineDeriv < 0) {
        rightRightSoft *= 2.0;
        leftValue = rightValue;
        leftX = rightX;
        monitor.log(RefString.format("Right is at least %s", rightX));
      } else {
        break;
      }
    }
    rightPoint.freeRef();
    monitor.log(RefString.format("Starting bisection search from %s to %s", leftX, rightX));
    LineSearchPoint temp_49_0004 = iterate(cursor, monitor, leftX, rightX);
    assert temp_49_0004 != null;
    PointSample temp_49_0001 = temp_49_0004.getPoint();
    temp_49_0004.freeRef();
    return temp_49_0001;
  }

  @Nullable
  public LineSearchPoint iterate(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor,
                                 double leftX, double rightX) {
    LineSearchPoint searchPoint = null;
    int loopCount = 0;
    try {
      while (true) {
        double thisX = (rightX + leftX) / 2;
        if (null != searchPoint) searchPoint.freeRef();
        searchPoint = cursor.step(thisX, monitor);
        monitor.log(RefString.format("F(%s) = %s", thisX, searchPoint.addRef()));
        if (loopCount++ > 1000) {
          return searchPoint.addRef();
        }
        assert searchPoint != null;
        if (searchPoint.derivative < -zeroTol) {
          if (leftX == thisX) {
            monitor.log(RefString.format("End (static left) at %s", thisX));
            currentRate = thisX;
            return searchPoint.addRef();
          }
          leftX = thisX;
        } else if (searchPoint.derivative > zeroTol) {
          if (rightX == thisX) {
            monitor.log(RefString.format("End (static right) at %s", thisX));
            currentRate = thisX;
            return searchPoint.addRef();
          }
          rightX = thisX;
        } else {
          monitor.log(RefString.format("End (at min) at %s", thisX));
          currentRate = thisX;
          return searchPoint.addRef();
        }
        if (Math.log10((rightX - leftX) * 2.0 / (leftX + rightX)) < -1) {
          monitor.log(RefString.format("End (narrow range) at %s to %s", rightX, leftX));
          currentRate = thisX;
          return searchPoint.addRef();
        }
      }
    } finally {
      if (null != searchPoint) searchPoint.freeRef();
      cursor.freeRef();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy