All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.eval.L12Normalizer Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.eval;

import com.google.common.util.concurrent.AtomicDouble;
import com.simiacryptus.mindseye.lang.Delta;
import com.simiacryptus.mindseye.lang.DeltaSet;
import com.simiacryptus.mindseye.lang.Layer;
import com.simiacryptus.mindseye.lang.PointSample;
import com.simiacryptus.mindseye.network.DAGNetwork;
import com.simiacryptus.mindseye.opt.TrainingMonitor;
import com.simiacryptus.ref.wrappers.RefCollection;
import com.simiacryptus.ref.wrappers.RefCollectors;
import com.simiacryptus.ref.wrappers.RefList;
import com.simiacryptus.ref.wrappers.RefMap;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.UUID;

public abstract class L12Normalizer extends TrainableBase {
  @Nullable
  public final Trainable inner;
  private final boolean hideAdj = false;

  public L12Normalizer(@Nullable final Trainable inner) {
    Trainable temp_01_0001 = inner == null ? null : inner.addRef();
    this.inner = temp_01_0001 == null ? null : temp_01_0001.addRef();
    if (null != temp_01_0001)
      temp_01_0001.freeRef();
    if (null != inner)
      inner.freeRef();
  }

  @javax.annotation.Nullable
  public Layer toLayer(UUID id) {
    assert inner != null;
    DAGNetwork layer = (DAGNetwork) inner.getLayer();
    if (null == layer) return null;
    RefMap layersById = layer.getLayersById();
    Layer temp_01_0004 = layersById.get(id);
    layersById.freeRef();
    layer.freeRef();
    return temp_01_0004;
  }

  public RefCollection getLayers(@Nonnull final RefCollection layers) {
    RefList temp_01_0003 = layers.stream().map(this::toLayer)
        //.filter(layer -> layer instanceof FullyConnectedLayer)
        .collect(RefCollectors.toList());
    layers.freeRef();
    return temp_01_0003;
  }

  @Nonnull
  @Override
  public PointSample measure(final TrainingMonitor monitor) {
    assert inner != null;
    final PointSample innerMeasure = inner.measure(monitor);
    @Nonnull final DeltaSet normalizationVector = new DeltaSet();
    AtomicDouble valueAdj = new AtomicDouble(0);
    RefCollection layers = getLayers(innerMeasure.delta.keySet());
    layers.forEach(layer -> {
      Delta temp_01_0008 = innerMeasure.delta.get(layer.getId());
      assert temp_01_0008 != null;
      final double[] weights = temp_01_0008.target;
      temp_01_0008.freeRef();
      Delta temp_01_0009 = normalizationVector.get(layer.getId(), weights);
      assert temp_01_0009 != null;
      @Nullable final double[] gradientAdj = temp_01_0009.getDelta();
      temp_01_0009.freeRef();
      final double factor_L1 = getL1(layer.addRef());
      final double factor_L2 = getL2(layer);
      assert null != gradientAdj;
      for (int i = 0; i < gradientAdj.length; i++) {
        final double sign = weights[i] < 0 ? -1.0 : 1.0;
        gradientAdj[i] += factor_L1 * sign + 2 * factor_L2 * weights[i];
        valueAdj.addAndGet((factor_L1 * sign + factor_L2 * weights[i]) * weights[i]);
      }
    });

    layers.freeRef();
    final DeltaSet deltaSet = innerMeasure.delta.add(normalizationVector);
    final PointSample pointSample = new PointSample(deltaSet.addRef(),
        innerMeasure.weights.addRef(),
        innerMeasure.sum + (hideAdj ? 0 : valueAdj.get()),
        innerMeasure.rate,
        innerMeasure.count);
    deltaSet.freeRef();
    innerMeasure.freeRef();
    PointSample temp_01_0002 = pointSample.normalize();
    pointSample.freeRef();
    return temp_01_0002;
  }

  @Override
  public boolean reseed(final long seed) {
    assert inner != null;
    return inner.reseed(seed);
  }

  public void _free() {
    super._free();
    if (null != inner)
      inner.freeRef();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  L12Normalizer addRef() {
    return (L12Normalizer) super.addRef();
  }

  protected abstract double getL1(Layer layer);

  protected abstract double getL2(Layer layer);

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy