com.simiacryptus.mindseye.lang.DeltaSet Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import javax.annotation.Nonnull;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.ToDoubleFunction;
public class DeltaSet extends DoubleBufferSet> {
public DeltaSet() {
}
public DeltaSet(@Nonnull final DoubleBufferSet> toCopy) {
super(toCopy);
// assert stream().allMatch(x -> {
// try {
// return x instanceof Delta;
// } finally {
// if (null != x) x.freeRef();
// }
// });
}
public DeltaSet(@Nonnull final RefMap> collect) {
super(collect);
// assert stream().allMatch(x -> {
// boolean temp_37_0002 = x instanceof Delta;
// if (null != x)
// x.freeRef();
// return temp_37_0002;
// });
}
public double getMagnitude() {
RefHashSet>> temp_37_0011 = map.entrySet();
RefStream>> stream = temp_37_0011.stream();
if (100 < map.size()) {
stream = stream.parallel();
}
final double[] elementArray = stream.mapToDouble(entry -> {
final DoubleBuffer value = entry.getValue();
RefUtil.freeRef(entry);
double temp_37_0003 = value.deltaStatistics().sumSq();
value.freeRef();
return temp_37_0003;
}).toArray();
temp_37_0011.freeRef();
return Math.sqrt(RefArrays.stream(elementArray).sum());
}
public void accumulate(final double alpha) {
stream().forEach((Delta d) -> {
d.accumulate(alpha);
d.freeRef();
});
}
public DeltaSet add(@Nonnull final DeltaSet right) {
DeltaSet copy = this.copy();
copy.addInPlace(right);
return copy;
}
public void addInPlace(@Nonnull DeltaSet right) {
right.map.forEach((layer, buffer) -> {
Delta temp_37_0013 = get(layer, buffer.target);
assert temp_37_0013 != null;
temp_37_0013.addInPlace(buffer);
temp_37_0013.freeRef();
});
right.freeRef();
}
@Nonnull
public StateSet asState() {
@Nonnull final StateSet returnValue = new StateSet<>();
map.forEach(RefUtil.wrapInterface((BiConsumer super K, ? super Delta>) (layer, delta) -> {
delta.assertAlive();
State kState = returnValue.get(layer, delta.target);
assert kState != null;
kState.set(delta.delta);
kState.freeRef();
delta.freeRef();
}, returnValue.addRef()));
return returnValue;
}
@Nonnull
@Override
public DeltaSet copy() {
return new DeltaSet(this.map(x -> {
try {
return x.copy();
} finally {
x.freeRef();
}
}));
}
public double dot(@Nonnull final DoubleBufferSet> right) {
RefHashSet>> entries = map.entrySet();
RefStream>> stream = entries.stream();
if (100 < map.size()) {
stream = stream.parallel();
}
double temp_37_0010 = stream
.mapToDouble(RefUtil.wrapInterface((ToDoubleFunction super Map.Entry>>) entry -> {
final K key = entry.getKey();
final Delta value = entry.getValue();
RefUtil.freeRef(entry);
final Delta rValue = right.map.get(key);
if (null != rValue) {
double temp_37_0005 = value.dot(rValue);
if (null != value) value.freeRef();
return temp_37_0005;
} else {
if (null != value) value.freeRef();
return 0;
}
}, right)).sum();
entries.freeRef();
return temp_37_0010;
}
@Nonnull
@Override
public DeltaSet map(@Nonnull final RefFunction, Delta> mapper) {
return new DeltaSet<>(super.map(mapper));
}
@Nonnull
public DeltaSet scale(final double f) {
return map(x -> {
Delta temp_37_0007 = x.scale(f);
x.freeRef();
return temp_37_0007;
});
}
@Nonnull
public DeltaSet subtract(@Nonnull final DeltaSet right) {
DeltaSet scale = right.scale(-1);
right.freeRef();
return this.add(scale);
}
@Nonnull
public DeltaSet unit() {
return scale(1.0 / getMagnitude());
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
DeltaSet addRef() {
return (DeltaSet) super.addRef();
}
public DeltaSet allFinite(double defaultValue) {
return map((Delta delta) -> {
Delta map = delta.map(d -> Double.isFinite(d) ? d : defaultValue);
delta.freeRef();
return map;
});
}
@Nonnull
@Override
protected Delta factory(@Nonnull final K layer, final double[] target) {
return new Delta(layer, target);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy