All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.lang.DeltaSet Maven / Gradle / Ivy

/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.lang;

import javax.annotation.Nonnull;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;

public class DeltaSet extends DoubleBufferSet> {

  public DeltaSet() {
  }

  public DeltaSet(@Nonnull final DoubleBufferSet> toCopy) {
    super(toCopy);
    assert stream().allMatch(x -> x instanceof Delta);
  }

  public DeltaSet(@Nonnull final Map> collect) {
    super(collect);
    assert stream().allMatch(x -> x instanceof Delta);
  }

  @Nonnull
  public DeltaSet accumulate(final double alpha) {
    stream().forEach(d -> d.accumulate(alpha));
    return this;
  }


  @Nonnull
  public DeltaSet add(@Nonnull final DeltaSet right) {
    return this.copy().addInPlace(right);
  }

  @Nonnull
  public DeltaSet addInPlace(@Nonnull final DeltaSet right) {
    right.map.forEach((layer, buffer) -> {
      get(layer, buffer.target).addInPlace(buffer).freeRef();
    });
    return this;
  }

  @Nonnull
  public StateSet asState() {
    @Nonnull final StateSet returnValue = new StateSet<>();
    map.forEach((layer, delta) -> {
      delta.assertAlive();
      State kState = returnValue.get(layer, delta.target);
      kState.set(delta.delta);
      kState.freeRef();
    });
    return returnValue;
  }

  @Nonnull
  @Override
  public DeltaSet copy() {
    return this.map(x -> x.copy());
  }

  public double dot(@Nonnull final DoubleBufferSet> right) {
    Stream>> stream = map.entrySet().stream();
    if (100 < map.size()) {
      stream = stream.parallel();
    }
    return stream.mapToDouble(entry -> {
      final K key = entry.getKey();
      final Delta value = entry.getValue();
      final Delta rValue = right.map.get(key);
      if (null != rValue) {
        return value.dot(rValue);
      } else {
        return 0;
      }
    }).sum();
  }

  @Nonnull
  @Override
  protected Delta factory(@Nonnull final K layer, final double[] target) {
    return new Delta(layer, target);
  }

  public double getMagnitude() {
    Stream>> stream = map.entrySet().stream();
    if (100 < map.size()) {
      stream = stream.parallel();
    }
    final double[] elementArray = stream.mapToDouble(entry -> {
      final DoubleBuffer value = entry.getValue();
      final double v = value.deltaStatistics().sumSq();
      return v;
    }).toArray();
    return Math.sqrt(Arrays.stream(elementArray).sum());
  }

  @Nonnull
  @Override
  public DeltaSet map(final Function, Delta> mapper) {
    @Nonnull DoubleBufferSet> map = super.map(mapper);
    @Nonnull DeltaSet kDeltaSet = new DeltaSet<>(map);
    map.freeRef();
    return kDeltaSet;
  }

  @Nonnull
  public DeltaSet scale(final double f) {
    return map(x -> x.scale(f));
  }

  @Nonnull
  public DeltaSet subtract(@Nonnull final DeltaSet right) {
    DeltaSet scale = right.scale(-1);
    DeltaSet add = this.add(scale);
    scale.freeRef();
    return add;
  }

  @Nonnull
  public DeltaSet unit() {
    return scale(1.0 / getMagnitude());
  }

  @Override
  public DeltaSet addRef() {
    return (DeltaSet) super.addRef();
  }

}





© 2015 - 2025 Weber Informatics LLC | Privacy Policy