All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.opt.orient.QQN Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.opt.orient;

import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursorBase;
import com.simiacryptus.mindseye.opt.line.LineSearchPoint;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.ref.wrappers.RefString;

import javax.annotation.Nonnull;
import java.util.UUID;

public class QQN extends OrientationStrategyBase {

  public static final String CURSOR_NAME = "QQN";
  private final LBFGS inner = new LBFGS();

  public int getMaxHistory() {
    return inner.getMaxHistory();
  }

  public void setMaxHistory(final int maxHistory) {
    inner.setMaxHistory(maxHistory);
  }

  public int getMinHistory() {
    return inner.getMinHistory();
  }

  public void setMinHistory(final int minHistory) {
    inner.setMinHistory(minHistory);
  }


  @Override
  public LineSearchCursor orient(@Nonnull final Trainable subject, @Nonnull final PointSample origin,
                                 @Nonnull final TrainingMonitor monitor) {
    inner.addToHistory(origin.addRef(), monitor);
    final SimpleLineSearchCursor lbfgsCursor = inner.orient(subject.addRef(),
        origin.addRef(), monitor);
    assert lbfgsCursor.direction != null;
    final DeltaSet lbfgs = lbfgsCursor.direction.addRef();
    @Nonnull final DeltaSet gd = origin.delta.scale(-1.0);
    origin.freeRef();
    final double lbfgsMag = lbfgs.getMagnitude();
    final double gdMag = gd.getMagnitude();
    if (Math.abs(lbfgsMag - gdMag) / (lbfgsMag + gdMag) > 1e-2) {
      @Nonnull final DeltaSet scaledGradient = gd.scale(lbfgsMag / gdMag);
      gd.freeRef();
      monitor.log(RefString.format("Returning Quadratic Cursor %s GD, %s QN", gdMag, lbfgsMag));
      try {
        return new LineSearchCursorBase() {

          {
            subject.addRef();
            scaledGradient.addRef();
            lbfgs.addRef();
            lbfgsCursor.addRef();
            inner.addRef();
          }

          @Nonnull
          @Override
          public CharSequence getDirectionType() {
            return CURSOR_NAME;
          }

          @Nonnull
          @Override
          public DeltaSet position(final double t) {
            if (!Double.isFinite(t))
              throw new IllegalArgumentException();
            DeltaSet temp_38_0007 = scaledGradient
                .scale(t - t * t);
            DeltaSet temp_38_0006 = temp_38_0007
                .add(lbfgs.scale(t * t));
            temp_38_0007.freeRef();
            return temp_38_0006;
          }

          @Override
          public void reset() {
            lbfgsCursor.reset();
          }

          @Nonnull
          @Override
          public LineSearchPoint step(final double t, @Nonnull final TrainingMonitor monitor) {
            if (!Double.isFinite(t))
              throw new IllegalArgumentException();
            reset();
            DeltaSet temp_38_0008 = position(t);
            temp_38_0008.accumulate(1);
            temp_38_0008.freeRef();
            PointSample temp_38_0009 = subject.measure(monitor);
            temp_38_0009.setRate(t);
            @Nonnull final PointSample sample = temp_38_0009.addRef();
            temp_38_0009.freeRef();
            //monitor.log(String.format("evalInputDelta buffers %d %d %d %d %d", sample.evalInputDelta.apply.size(), origin.evalInputDelta.apply.size(), lbfgs.apply.size(), gd.apply.size(), scaledGradient.apply.size()));
            inner.addToHistory(sample.addRef(), monitor);
            DeltaSet temp_38_0010 = scaledGradient.scale(1 - 2 * t);
            @Nonnull final DeltaSet tangent = temp_38_0010.add(lbfgs.scale(2 * t));
            temp_38_0010.freeRef();
            double dot = tangent.dot(sample.delta.addRef());
            tangent.freeRef();
            return new LineSearchPoint(sample, dot);
          }

          @Override
          public void _free() {
            super._free();
            subject.freeRef();
            scaledGradient.freeRef();
            lbfgs.freeRef();
            lbfgsCursor.freeRef();
            inner.freeRef();
          }
        };
      } finally {
        subject.freeRef();
        scaledGradient.freeRef();
        lbfgs.freeRef();
        lbfgsCursor.freeRef();
      }
    } else {
      lbfgs.freeRef();
      gd.freeRef();
      subject.freeRef();
      return lbfgsCursor;
    }
  }

  @Override
  public void reset() {
    inner.reset();
  }

  @Override
  public void _free() {
    super._free();
    inner.freeRef();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  QQN addRef() {
    return (QQN) super.addRef();
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy