All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.opt.orient;

import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.DoubleBuffer;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursorBase;
import com.simiacryptus.mindseye.opt.line.LineSearchPoint;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.mindseye.opt.region.TrustRegion;
import com.simiacryptus.util.ArrayUtil;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public abstract class TrustRegionStrategy extends OrientationStrategyBase {


  public final OrientationStrategy inner;
  private final List history = new LinkedList<>();
  private int maxHistory = 10;

  public TrustRegionStrategy() {
    this(new LBFGS());
  }

  protected TrustRegionStrategy(final OrientationStrategy inner) {
    this.inner = inner;
  }

  public static double dot(@Nonnull final List> a, @Nonnull final List> b) {
    assert a.size() == b.size();
    return IntStream.range(0, a.size()).mapToDouble(i -> a.get(i).dot(b.get(i))).sum();
  }

  @Override
  protected void _free() {
    history.forEach(ReferenceCountingBase::freeRef);
    this.inner.freeRef();
  }

  public int getMaxHistory() {
    return maxHistory;
  }

  @Nonnull
  public TrustRegionStrategy setMaxHistory(final int maxHistory) {
    this.maxHistory = maxHistory;
    return this;
  }

  public abstract TrustRegion getRegionPolicy(Layer layer);

  @Nonnull
  @Override
  public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
    history.add(0, origin.addRef());
    while (history.size() > maxHistory) {
      history.remove(history.size() - 1).freeRef();
    }
    final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
    return new TrustRegionCursor(cursor, subject);
  }

  @Override
  public void reset() {
    inner.reset();
  }

  private class TrustRegionCursor extends LineSearchCursorBase {
    private final SimpleLineSearchCursor cursor;
    private final Trainable subject;

    public TrustRegionCursor(SimpleLineSearchCursor cursor, Trainable subject) {
      this.cursor = cursor;
      this.subject = subject;
    }

    @Override
    public PointSample afterStep(@Nonnull PointSample step) {
      super.afterStep(step);
      return cursor.afterStep(step);
    }

    @Nonnull
    @Override
    public CharSequence getDirectionType() {
      return cursor.getDirectionType() + "+Trust";
    }

    @Nonnull
    @Override
    public DeltaSet position(final double alpha) {
      //reset();
      @Nonnull final DeltaSet adjustedPosVector = cursor.position(alpha);
      project(adjustedPosVector, new TrainingMonitor());
      return adjustedPosVector;
    }

    public Layer toLayer(UUID id) {
      DAGNetwork layer = (DAGNetwork) subject.getLayer();
      if (null == layer) return null;
      return layer.getLayersById().get(id);
    }

    @Nonnull
    public DeltaSet project(@Nonnull final DeltaSet deltaIn, final TrainingMonitor monitor) {
      final DeltaSet originalAlphaDerivative = cursor.direction;
      @Nonnull final DeltaSet newAlphaDerivative = originalAlphaDerivative.copy();
      deltaIn.getMap().forEach((id, buffer) -> {
        @Nullable final double[] delta = buffer.getDelta();
        if (null == delta) return;
        final double[] currentPosition = buffer.target;
        @Nullable final double[] originalAlphaD = originalAlphaDerivative.get(id, currentPosition).getDeltaAndFree();
        @Nullable final double[] newAlphaD = newAlphaDerivative.get(id, currentPosition).getDeltaAndFree();
        @Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
        final TrustRegion region = getRegionPolicy(toLayer(id));
        if (null != region) {
          final Stream zz = history.stream().map((@Nonnull final PointSample pointSample) -> {
            final DoubleBuffer d = pointSample.weights.getMap().get(id);
            @Nullable final double[] z = null == d ? null : d.getDelta();
            return z;
          });
          final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
          if (projectedPosition != proposedPosition) {
            for (int i = 0; i < projectedPosition.length; i++) {
              delta[i] = projectedPosition[i] - currentPosition[i];
            }
            @Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
            final double normalMagSq = ArrayUtil.dot(normal, normal);
//              monitor.log(String.format("%s: evalInputDelta = %s, projectedPosition = %s, proposedPosition = %s, currentPosition = %s, normalMagSq = %s", key,
//                ArrayUtil.dot(evalInputDelta,evalInputDelta),
//                ArrayUtil.dot(projectedPosition,projectedPosition),
//                ArrayUtil.dot(proposedPosition,proposedPosition),
//                ArrayUtil.dot(currentPosition,currentPosition),
//                normalMagSq));
            if (0 < normalMagSq) {
              final double a = ArrayUtil.dot(originalAlphaD, normal);
              if (a != -1) {
                @Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
                for (int i = 0; i < tangent.length; i++) {
                  newAlphaD[i] = tangent[i];
                }
//                  double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
//                  double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
//                  assert(newAlphaDerivSq <= originalAlphaDerivSq);
//                  assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
//                  monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", key, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
              }
            }


          }
        }
      });
      return newAlphaDerivative;
    }

    @Override
    public void reset() {
      cursor.reset();
    }

    @Nonnull
    @Override
    public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
      cursor.reset();
      @Nonnull final DeltaSet adjustedPosVector = cursor.position(alpha);
      @Nonnull final DeltaSet adjustedGradient = project(adjustedPosVector, monitor);
      adjustedPosVector.accumulate(1);
      adjustedPosVector.freeRef();
      @Nonnull final PointSample sample = afterStep(subject.measure(monitor).setRate(alpha));
      double dot = adjustedGradient.dot(sample.delta);
      adjustedGradient.freeRef();
      LineSearchPoint lineSearchPoint = new LineSearchPoint(sample, dot);
      sample.freeRef();
      return lineSearchPoint;
    }

    @Override
    public void _free() {
      cursor.freeRef();
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy