All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.lang.DeltaSet Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.lang;

import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;

import javax.annotation.Nonnull;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.ToDoubleFunction;

public class DeltaSet extends DoubleBufferSet> {

  public DeltaSet() {
  }

  public DeltaSet(@Nonnull final DoubleBufferSet> toCopy) {
    super(toCopy);
//    assert stream().allMatch(x -> {
//      try {
//        return x instanceof Delta;
//      } finally {
//        if (null != x) x.freeRef();
//      }
//    });
  }

  public DeltaSet(@Nonnull final RefMap> collect) {
    super(collect);
//    assert stream().allMatch(x -> {
//      boolean temp_37_0002 = x instanceof Delta;
//      if (null != x)
//        x.freeRef();
//      return temp_37_0002;
//    });
  }

  public double getMagnitude() {
    RefHashSet>> temp_37_0011 = map.entrySet();
    RefStream>> stream = temp_37_0011.stream();
    if (100 < map.size()) {
      stream = stream.parallel();
    }
    final double[] elementArray = stream.mapToDouble(entry -> {
      final DoubleBuffer value = entry.getValue();
      RefUtil.freeRef(entry);
      double temp_37_0003 = value.deltaStatistics().sumSq();
      value.freeRef();
      return temp_37_0003;
    }).toArray();
    temp_37_0011.freeRef();
    return Math.sqrt(RefArrays.stream(elementArray).sum());
  }


  public void accumulate(final double alpha) {
    stream().forEach((Delta d) -> {
      d.accumulate(alpha);
      d.freeRef();
    });
  }

  public DeltaSet add(@Nonnull final DeltaSet right) {
    DeltaSet copy = this.copy();
    copy.addInPlace(right);
    return copy;
  }

  public void addInPlace(@Nonnull DeltaSet right) {
    right.map.forEach((layer, buffer) -> {
      Delta temp_37_0013 = get(layer, buffer.target);
      assert temp_37_0013 != null;
      temp_37_0013.addInPlace(buffer);
      temp_37_0013.freeRef();
    });
    right.freeRef();
  }

  @Nonnull
  public StateSet asState() {
    @Nonnull final StateSet returnValue = new StateSet<>();
    map.forEach(RefUtil.wrapInterface((BiConsumer>) (layer, delta) -> {
      delta.assertAlive();
      State kState = returnValue.get(layer, delta.target);
      assert kState != null;
      kState.set(delta.delta);
      kState.freeRef();
      delta.freeRef();
    }, returnValue.addRef()));
    return returnValue;
  }

  @Nonnull
  @Override
  public DeltaSet copy() {
    return new DeltaSet(this.map(x -> {
      try {
        return x.copy();
      } finally {
        x.freeRef();
      }
    }));
  }

  public double dot(@Nonnull final DoubleBufferSet> right) {
    RefHashSet>> entries = map.entrySet();
    RefStream>> stream = entries.stream();
    if (100 < map.size()) {
      stream = stream.parallel();
    }
    double temp_37_0010 = stream
        .mapToDouble(RefUtil.wrapInterface((ToDoubleFunction>>) entry -> {
          final K key = entry.getKey();
          final Delta value = entry.getValue();
          RefUtil.freeRef(entry);
          final Delta rValue = right.map.get(key);
          if (null != rValue) {
            double temp_37_0005 = value.dot(rValue);
            if (null != value) value.freeRef();
            return temp_37_0005;
          } else {
            if (null != value) value.freeRef();
            return 0;
          }
        }, right)).sum();
    entries.freeRef();
    return temp_37_0010;
  }

  @Nonnull
  @Override
  public DeltaSet map(@Nonnull final RefFunction, Delta> mapper) {
    return new DeltaSet<>(super.map(mapper));
  }

  @Nonnull
  public DeltaSet scale(final double f) {
    return map(x -> {
      Delta temp_37_0007 = x.scale(f);
      x.freeRef();
      return temp_37_0007;
    });
  }

  @Nonnull
  public DeltaSet subtract(@Nonnull final DeltaSet right) {
    DeltaSet scale = right.scale(-1);
    right.freeRef();
    return this.add(scale);
  }

  @Nonnull
  public DeltaSet unit() {
    return scale(1.0 / getMagnitude());
  }

  public @SuppressWarnings("unused")
  void _free() {
    super._free();
  }

  @Nonnull
  public @Override
  @SuppressWarnings("unused")
  DeltaSet addRef() {
    return (DeltaSet) super.addRef();
  }

  public DeltaSet allFinite(double defaultValue) {
    return map((Delta delta) -> {
      Delta map = delta.map(d -> Double.isFinite(d) ? d : defaultValue);
      delta.freeRef();
      return map;
    });
  }

  @Nonnull
  @Override
  protected Delta factory(@Nonnull final K layer, final double[] target) {
    return new Delta(layer, target);
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy