com.simiacryptus.mindseye.opt.orient.LayerReweightingStrategy Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.opt.orient;
import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.util.ArrayUtil;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.UUID;
import java.util.function.BiConsumer;
public abstract class LayerReweightingStrategy extends OrientationStrategyBase {
@Nullable
public final OrientationStrategy inner;
public LayerReweightingStrategy(@Nullable final OrientationStrategy inner) {
OrientationStrategy temp_32_0001 = inner == null
? null
: inner.addRef();
this.inner = temp_32_0001 == null ? null : temp_32_0001.addRef();
if (null != temp_32_0001)
temp_32_0001.freeRef();
if (null != inner)
inner.freeRef();
}
@Nullable
public abstract Double getRegionPolicy(Layer layer);
@Override
public SimpleLineSearchCursor orient(@Nullable final Trainable subject, @Nullable final PointSample measurement,
final TrainingMonitor monitor) {
assert inner != null;
final SimpleLineSearchCursor orient = inner.orient(subject == null ? null : subject.addRef(), measurement, monitor);
assert orient.direction != null;
final DeltaSet direction = orient.direction.addRef();
RefMap> temp_32_0003 = direction.getMap();
temp_32_0003.forEach(RefUtil.wrapInterface(
(BiConsumer super UUID, ? super Delta>) (
uuid, buffer) -> {
if (null == buffer.getDelta()) {
buffer.freeRef();
return;
}
assert subject != null;
DAGNetwork dagNetwork = (DAGNetwork) subject.getLayer();
if (null != dagNetwork) {
RefMap temp_32_0004 = dagNetwork.getLayersById();
dagNetwork.freeRef();
Layer layer = temp_32_0004.get(uuid);
temp_32_0004.freeRef();
final Double weight = getRegionPolicy(layer);
if (null != weight && 0 < weight) {
final DoubleBuffer deltaBuffer = direction.get(uuid, buffer.target);
assert deltaBuffer != null;
@Nonnull final double[] adjusted = ArrayUtil.multiply(deltaBuffer.getDelta(), weight);
for (int i = 0; i < adjusted.length; i++) {
deltaBuffer.getDelta()[i] = adjusted[i];
}
deltaBuffer.freeRef();
}
}
buffer.freeRef();
}, subject, direction));
temp_32_0003.freeRef();
return orient;
}
@Override
public void _free() {
super._free();
if (null != inner)
inner.freeRef();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
LayerReweightingStrategy addRef() {
return (LayerReweightingStrategy) super.addRef();
}
public static class HashMapLayerReweightingStrategy extends LayerReweightingStrategy {
@Nonnull
private final RefHashMap map = new RefHashMap<>();
public HashMapLayerReweightingStrategy(final OrientationStrategy inner) {
super(inner);
}
@Nonnull
public RefMap getMap() {
return map.addRef();
}
@Nullable
@Override
public Double getRegionPolicy(final Layer layer) {
return map.get(layer);
}
@Override
public void reset() {
assert inner != null;
inner.reset();
}
public @SuppressWarnings("unused")
void _free() {
super._free();
map.freeRef();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
HashMapLayerReweightingStrategy addRef() {
return (HashMapLayerReweightingStrategy) super.addRef();
}
}
}