All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.mindseye.lang.StateSet Maven / Gradle / Ivy

/*
 * Copyright (c) 2018 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.mindseye.lang;

import javax.annotation.Nonnull;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * A collection of State objects being staged for particular layers. Provides indexing capabilities to reference the
 * deltas based on physical references (to double[] objects) and based on logical referants (i.e. layers) Provides
 * collection-arithmetic operations appropriate to the State's 'point' geometric archtype.
 *
 * @param  the type parameter
 */
public class StateSet extends DoubleBufferSet> {

  /**
   * Instantiates a new State setByCoord.
   */
  public StateSet() {
  }

  /**
   * Instantiates a new State setByCoord as a copy of the target data buffers in the input setByCoord
   *
   * @param toCopy the to copy
   */
  public StateSet(@Nonnull final DeltaSet toCopy) {
    assert toCopy.stream().allMatch(x -> Arrays.stream(x.getDelta()).allMatch(Double::isFinite));
    toCopy.getMap().forEach((layer, layerDelta) -> {
      this.get(layer, layerDelta.target).backup().freeRef();
    });
    assert stream().allMatch(x -> Arrays.stream(x.getDelta()).allMatch(Double::isFinite));
    assert stream().allMatch(x -> x instanceof State);
  }

  /**
   * Instantiates a new State setByCoord.
   *
   * @param toCopy the to copy
   */
  public StateSet(@Nonnull final DoubleBufferSet> toCopy) {
    super(toCopy);
    assert stream().allMatch(x -> x instanceof State);
  }

  /**
   * Instantiates a new State setByCoord.
   *
   * @param collect the collect
   */
  public StateSet(@Nonnull final Map> collect) {
    super(collect);
  }

  /**
   * Union state setByCoord.
   *
   * @param    the type parameter
   * @param left  the left
   * @param right the right
   * @return the state setByCoord
   */
  public static  StateSet union(@Nonnull final DoubleBufferSet> left, @Nonnull final DoubleBufferSet> right) {
    final Map> collect = Stream.concat(
        left.map.entrySet().stream(),
        right.map.entrySet().stream()
    ).collect(Collectors.groupingBy((@Nonnull final Map.Entry> e1) -> e1.getKey(),
        Collectors.mapping((@Nonnull final Map.Entry> x) -> x.getValue(), Collectors.collectingAndThen(
            Collectors.reducing((@Nonnull final State a, @Nonnull final State b) -> {
              assert a.target == b.target;
              assert a.key.equals(b.key);
              return a;
            }), x -> x.get()))));
    return new StateSet(collect);
  }

  /**
   * Add evalInputDelta setByCoord.
   *
   * @param right the right
   * @return the evalInputDelta setByCoord
   */
  @Nonnull
  public StateSet add(@Nonnull final DeltaSet right) {
    @Nonnull final DeltaSet deltas = new DeltaSet();
    map.forEach(100, (@Nonnull final K layer, @Nonnull final State buffer) -> {
      deltas.get(layer, buffer.target).set(buffer.getDelta()).freeRef();
    });
    right.map.forEach(100, (@Nonnull final K layer, @Nonnull final Delta buffer) -> {
      deltas.get(layer, buffer.target).addInPlace(buffer.getDelta()).freeRef();
    });
    @Nonnull StateSet kStateSet = deltas.asState();
    deltas.freeRef();
    return kStateSet;
  }

  /**
   * As vector evalInputDelta setByCoord.
   *
   * @return the evalInputDelta setByCoord
   */
  @Nonnull
  public DeltaSet asVector() {
    @Nonnull final HashMap> newMap = new HashMap<>();
    map.forEach((layer, state) -> newMap.put(layer, new Delta(layer, state.target, RecycleBin.DOUBLES.copyOf(state.delta, state.delta.length))));
    @Nonnull DeltaSet deltaSet = new DeltaSet<>(newMap);
    newMap.values().forEach(v -> v.freeRef());
    return deltaSet;
  }

  @Nonnull
  @Override
  public StateSet copy() {
    return map(x -> x.copy());
  }

  /**
   * Backup copy state setBytes.
   *
   * @return the state setBytes
   */
  @Nonnull
  public StateSet backupCopy() {
    return map(l -> l.backupCopy());
  }

  /**
   * Backup state setBytes.
   *
   * @return the state setBytes
   */
  @Nonnull
  public StateSet backup() {
    Stream>> stream = map.entrySet().stream();
    if (map.size() > 100) {
      stream = stream.parallel();
    }
    stream.forEach(e -> e.getValue().backup());
    return this;
  }

  /**
   * Restore state setBytes.
   *
   * @return the state setBytes
   */
  @Nonnull
  public StateSet restore() {
    Stream>> stream = map.entrySet().stream();
    if (map.size() > 100) {
      stream = stream.parallel();
    }
    stream.forEach(e -> e.getValue().restore());
    return this;
  }

  @Nonnull
  @Override
  protected State factory(@Nonnull final K layer, final double[] target) {
    return new State(layer, target);
  }

  /**
   * Is different boolean.
   *
   * @return the boolean
   */
  public boolean isDifferent() {
    return stream().parallel().anyMatch(x -> !x.areEqual());
  }

  @Nonnull
  @Override
  public StateSet map(@Nonnull final Function, State> mapper) {
    Stream>> stream = map.entrySet().stream();
    if (map.size() > 100) {
      stream = stream.parallel();
    }
    final Map> newMap = stream.collect(Collectors.toMap(e -> e.getKey(), e -> mapper.apply(e.getValue())));
    @Nonnull StateSet kStateSet = new StateSet<>(newMap);
    newMap.values().forEach(x -> x.freeRef());
    return kStateSet;
  }

  /**
   * Subtract evalInputDelta setByCoord.
   *
   * @param right the right
   * @return the evalInputDelta setByCoord
   */
  @Nonnull
  public StateSet subtract(@Nonnull final DeltaSet right) {
    return this.add(right.scale(-1));
  }

  /**
   * Subtract evalInputDelta setByCoord.
   *
   * @param right the right
   * @return the evalInputDelta setByCoord
   */
  @Nonnull
  public DeltaSet subtract(@Nonnull final StateSet right) {
    @Nonnull DeltaSet rvec = right.asVector();
    @Nonnull DeltaSet scale = rvec.scale(-1);
    rvec.freeRef();
    @Nonnull StateSet add = this.add(scale);
    scale.freeRef();
    @Nonnull DeltaSet addVector = add.asVector();
    add.freeRef();
    return addVector;
  }


//  /**
//   * Union evalInputDelta setByCoord.
//   *
//   * @param right the right
//   * @return the evalInputDelta setByCoord
//   */
//  @Nonnull
//  public DoubleBufferSet> union(@Nonnull final DoubleBufferSet> right) {
//    return StateSet.union(this, right);
//  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy