com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-research Show documentation
Show all versions of mindseye-research Show documentation
Experimental Optimizers and Neural Networks
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.opt.orient;
import com.simiacryptus.lang.ref.ReferenceCountingBase;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.DoubleBuffer;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursor;
import com.simiacryptus.mindseye.opt.line.LineSearchCursorBase;
import com.simiacryptus.mindseye.opt.line.LineSearchPoint;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.mindseye.opt.region.TrustRegion;
import com.simiacryptus.util.ArrayUtil;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.LinkedList;
import java.util.List;
import java.util.UUID;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public abstract class TrustRegionStrategy extends OrientationStrategyBase {
public final OrientationStrategy extends SimpleLineSearchCursor> inner;
private final List history = new LinkedList<>();
private int maxHistory = 10;
public TrustRegionStrategy() {
this(new LBFGS());
}
protected TrustRegionStrategy(final OrientationStrategy extends SimpleLineSearchCursor> inner) {
this.inner = inner;
}
public static double dot(@Nonnull final List> a, @Nonnull final List> b) {
assert a.size() == b.size();
return IntStream.range(0, a.size()).mapToDouble(i -> a.get(i).dot(b.get(i))).sum();
}
@Override
protected void _free() {
history.forEach(ReferenceCountingBase::freeRef);
this.inner.freeRef();
}
public int getMaxHistory() {
return maxHistory;
}
@Nonnull
public TrustRegionStrategy setMaxHistory(final int maxHistory) {
this.maxHistory = maxHistory;
return this;
}
public abstract TrustRegion getRegionPolicy(Layer layer);
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin.addRef());
while (history.size() > maxHistory) {
history.remove(history.size() - 1).freeRef();
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new TrustRegionCursor(cursor, subject);
}
@Override
public void reset() {
inner.reset();
}
private class TrustRegionCursor extends LineSearchCursorBase {
private final SimpleLineSearchCursor cursor;
private final Trainable subject;
public TrustRegionCursor(SimpleLineSearchCursor cursor, Trainable subject) {
this.cursor = cursor;
this.subject = subject;
}
@Override
public PointSample afterStep(@Nonnull PointSample step) {
super.afterStep(step);
return cursor.afterStep(step);
}
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet position(final double alpha) {
//reset();
@Nonnull final DeltaSet adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
public Layer toLayer(UUID id) {
DAGNetwork layer = (DAGNetwork) subject.getLayer();
if (null == layer) return null;
return layer.getLayersById().get(id);
}
@Nonnull
public DeltaSet project(@Nonnull final DeltaSet deltaIn, final TrainingMonitor monitor) {
final DeltaSet originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((id, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta) return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(id, currentPosition).getDeltaAndFree();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(id, currentPosition).getDeltaAndFree();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(toLayer(id));
if (null != region) {
final Stream zz = history.stream().map((@Nonnull final PointSample pointSample) -> {
final DoubleBuffer d = pointSample.weights.getMap().get(id);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// monitor.log(String.format("%s: evalInputDelta = %s, projectedPosition = %s, proposedPosition = %s, currentPosition = %s, normalMagSq = %s", key,
// ArrayUtil.dot(evalInputDelta,evalInputDelta),
// ArrayUtil.dot(projectedPosition,projectedPosition),
// ArrayUtil.dot(proposedPosition,proposedPosition),
// ArrayUtil.dot(currentPosition,currentPosition),
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", key, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
adjustedPosVector.freeRef();
@Nonnull final PointSample sample = afterStep(subject.measure(monitor).setRate(alpha));
double dot = adjustedGradient.dot(sample.delta);
adjustedGradient.freeRef();
LineSearchPoint lineSearchPoint = new LineSearchPoint(sample, dot);
sample.freeRef();
return lineSearchPoint;
}
@Override
public void _free() {
cursor.freeRef();
}
}
}