All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.opt.orient.LayerReweightingStrategy Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.opt.orient;

import com.simiacryptus.mindseye.eval.Trainable;
import com.simiacryptus.mindseye.lang.*;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.RefHashMap;
import com.simiacryptus.ref.wrappers.RefMap;
import com.simiacryptus.util.ArrayUtil;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.UUID;
import java.util.function.BiConsumer;

public abstract class LayerReweightingStrategy extends OrientationStrategyBase {

  @Nullable
  public final OrientationStrategy inner;

  public LayerReweightingStrategy(@Nullable final OrientationStrategy inner) {
    OrientationStrategy temp_32_0001 = inner == null
        ? null
        : inner.addRef();
    this.inner = temp_32_0001 == null ? null : temp_32_0001.addRef();
    if (null != temp_32_0001)
      temp_32_0001.freeRef();
    if (null != inner)
      inner.freeRef();
  }

  @Nullable
  public abstract Double getRegionPolicy(Layer layer);

  @Override
  public SimpleLineSearchCursor orient(@Nullable final Trainable subject, @Nullable final PointSample measurement,
                                       final TrainingMonitor monitor) {
    assert inner != null;
    final SimpleLineSearchCursor orient = inner.orient(subject == null ? null : subject.addRef(), measurement, monitor);
    assert orient.direction != null;
    final DeltaSet direction = orient.direction.addRef();
    RefMap> temp_32_0003 = direction.getMap();
    temp_32_0003.forEach(RefUtil.wrapInterface(
        (BiConsumer>) (
            uuid, buffer) -> {
          if (null == buffer.getDelta()) {
            buffer.freeRef();
            return;
          }
          assert subject != null;
          DAGNetwork dagNetwork = (DAGNetwork) subject.getLayer();
          if (null != dagNetwork) {
            RefMap temp_32_0004 = dagNetwork.getLayersById();
            dagNetwork.freeRef();
            Layer layer = temp_32_0004.get(uuid);
            temp_32_0004.freeRef();
            final Double weight = getRegionPolicy(layer);
            if (null != weight && 0 < weight) {
              final DoubleBuffer deltaBuffer = direction.get(uuid, buffer.target);
              assert deltaBuffer != null;
              @Nonnull final double[] adjusted = ArrayUtil.multiply(deltaBuffer.getDelta(), weight);
              for (int i = 0; i < adjusted.length; i++) {
                deltaBuffer.getDelta()[i] = adjusted[i];
              }
              deltaBuffer.freeRef();
            }
          }
          buffer.freeRef();
        }, subject, direction));
    temp_32_0003.freeRef();
    return orient;
  }

  @Override
  public void _free() {
    super._free();
    if (null != inner)
      inner.freeRef();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  LayerReweightingStrategy addRef() {
    return (LayerReweightingStrategy) super.addRef();
  }

  public static class HashMapLayerReweightingStrategy extends LayerReweightingStrategy {

    @Nonnull
    private final RefHashMap map = new RefHashMap<>();

    public HashMapLayerReweightingStrategy(final OrientationStrategy inner) {
      super(inner);
    }

    @Nonnull
    public RefMap getMap() {
      return map.addRef();
    }

    @Nullable
    @Override
    public Double getRegionPolicy(final Layer layer) {
      return map.get(layer);
    }

    @Override
    public void reset() {
      assert inner != null;
      inner.reset();
    }

    public @SuppressWarnings("unused")
    void _free() {
      super._free();
      map.freeRef();
    }

    @Nonnull
    public @Override
    @SuppressWarnings("unused")
    HashMapLayerReweightingStrategy addRef() {
      return (HashMapLayerReweightingStrategy) super.addRef();
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy