com.simiacryptus.mindseye.lang.StateSet Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang;
import com.simiacryptus.ref.lang.RecycleBin;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.wrappers.*;
import javax.annotation.Nonnull;
import java.util.Map;
import java.util.function.BiConsumer;
public class StateSet extends DoubleBufferSet> {
public StateSet() {
}
public StateSet(@Nonnull final DeltaSet toCopy) {
// assert toCopy.stream().allMatch(x -> {
// boolean temp_41_0001 = RefArrays.stream(x.getDelta()).allMatch(Double::isFinite);
// x.freeRef();
// return temp_41_0001;
// });
RefMap> temp_41_0018 = toCopy.getMap();
temp_41_0018.forEach((layer, layerDelta) -> {
State state = get(layer, layerDelta.target);
assert state != null;
state.backup();
RefUtil.freeRef(state);
layerDelta.freeRef();
});
temp_41_0018.freeRef();
toCopy.freeRef();
// assert stream().allMatch(x -> {
// boolean temp_41_0002 = Arrays.stream(x.getDelta()).allMatch(Double::isFinite);
// x.freeRef();
// return temp_41_0002;
// });
// assert stream().allMatch(x -> {
// boolean temp_41_0003 = x instanceof State;
// if (null != x)
// x.freeRef();
// return temp_41_0003;
// });
}
public StateSet(@Nonnull final DoubleBufferSet> toCopy) {
super(toCopy);
// assert stream().allMatch(x -> {
// boolean temp_41_0004 = x instanceof State;
// if (null != x)
// x.freeRef();
// return temp_41_0004;
// });
}
public StateSet(@Nonnull final RefMap> collect) {
super(collect);
}
public boolean isDifferent() {
return stream().parallel().anyMatch(x -> {
boolean temp_41_0005 = !x.areEqual();
x.freeRef();
return temp_41_0005;
});
}
@Nonnull
public static StateSet union(@Nonnull final DoubleBufferSet> left,
@Nonnull final DoubleBufferSet> right) {
RefHashSet>> temp_41_0020 = left.map.entrySet();
left.freeRef();
RefHashSet>> temp_41_0021 = right.map.entrySet();
right.freeRef();
final RefMap> collect = RefStream.concat(temp_41_0020.stream(), temp_41_0021.stream())
.collect(RefCollectors.groupingBy((final Map.Entry> entry) -> {
K key = entry.getKey();
RefUtil.freeRef(entry);
return key;
}, RefCollectors.mapping(
(final Map.Entry> entry) -> {
State value = entry.getValue();
RefUtil.freeRef(entry);
return value;
},
RefCollectors.collectingAndThen(
RefCollectors.reducing((@Nonnull final State a, @Nonnull final State b) -> {
assert a.target == b.target;
assert a.key.equals(b.key);
b.freeRef();
return a;
}),
optional -> RefUtil.get(optional)))));
temp_41_0021.freeRef();
temp_41_0020.freeRef();
return new StateSet(collect);
}
public boolean containsAll(RefMap> deltaMap) {
RefSet keySet = deltaMap.keySet();
RefMap> weightsMap = getMap();
try {
return keySet.stream().allMatch(x -> weightsMap.containsKey(x));
} finally {
weightsMap.freeRef();
keySet.freeRef();
deltaMap.freeRef();
}
}
@Nonnull
public StateSet add(@Nonnull final DeltaSet right) {
@Nonnull final DeltaSet deltas = new DeltaSet();
map.forEach(
RefUtil.wrapInterface((BiConsumer>) (@Nonnull final K layer, @Nonnull final State buffer) -> {
Delta temp_41_0022 = deltas.get(layer, buffer.target);
assert temp_41_0022 != null;
temp_41_0022.set(buffer.getDelta());
temp_41_0022.freeRef();
buffer.freeRef();
}, deltas.addRef()));
right.map.forEach(
RefUtil.wrapInterface((BiConsumer>) (@Nonnull final K layer, @Nonnull final Delta buffer) -> {
Delta temp_41_0023 = deltas.get(layer, buffer.target);
assert temp_41_0023 != null;
temp_41_0023.addInPlace(buffer.getDelta());
temp_41_0023.freeRef();
buffer.freeRef();
}, deltas.addRef()));
right.freeRef();
StateSet temp_41_0008 = deltas.asState();
deltas.freeRef();
return temp_41_0008;
}
@Nonnull
public DeltaSet asVector() {
@Nonnull final RefHashMap> newMap = new RefHashMap<>();
map.forEach(RefUtil.wrapInterface((BiConsumer super K, ? super State>) (layer, state) -> {
RefUtil.freeRef(newMap.put(RefUtil.addRef(layer),
new Delta(layer, state.target, RecycleBin.DOUBLES.copyOf(state.delta, state.delta.length))));
state.freeRef();
}, RefUtil.addRef(newMap)));
return new DeltaSet<>(newMap);
}
@Nonnull
@Override
public StateSet copy() {
return map(x -> {
State temp_41_0010 = x.copy();
x.freeRef();
return temp_41_0010;
});
}
// /**
// * Union evalInputDelta setByCoord.
// *
// * @param right the right
// * @return the evalInputDelta setByCoord
// */
// @Nonnull
// public DoubleBufferSet> union(@Nonnull final DoubleBufferSet> right) {
// return StateSet.union(this, right);
// }
public void restore() {
RefHashSet>> temp_41_0024 = map.entrySet();
RefStream>> stream = temp_41_0024.stream();
if (map.size() > 100) {
stream = stream.parallel();
}
stream.forEach(e -> {
State temp_41_0025 = e.getValue();
temp_41_0025.restore();
temp_41_0025.freeRef();
RefUtil.freeRef(e);
});
temp_41_0024.freeRef();
}
@Nonnull
@Override
public StateSet map(@Nonnull final RefFunction, State> mapper) {
RefHashSet>> temp_41_0026 = map.entrySet();
RefStream>> stream = temp_41_0026.stream();
if (map.size() > 100) {
stream = stream.parallel();
}
final RefMap> newMap = stream.collect(RefCollectors.toMap(e -> {
K temp_41_0011 = e.getKey();
RefUtil.freeRef(e);
return temp_41_0011;
}, e -> {
State temp_41_0012 = mapper.apply(e.getValue());
RefUtil.freeRef(e);
return temp_41_0012;
}));
temp_41_0026.freeRef();
return new StateSet<>(newMap);
}
@Nonnull
public StateSet subtract(@Nonnull final DeltaSet right) {
StateSet temp_41_0017 = this.add(right.scale(-1));
right.freeRef();
return temp_41_0017;
}
@Nonnull
public DeltaSet subtract(@Nonnull final StateSet right) {
@Nonnull
DeltaSet rvec = right.asVector();
right.freeRef();
@Nonnull
DeltaSet scale = rvec.scale(-1);
rvec.freeRef();
@Nonnull
StateSet add = this.add(scale);
DeltaSet temp_41_0014 = add.asVector();
add.freeRef();
return temp_41_0014;
}
public @SuppressWarnings("unused")
void _free() {
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
StateSet addRef() {
return (StateSet) super.addRef();
}
@Nonnull
@Override
protected State factory(@Nonnull final K layer, final double[] target) {
return new State(layer, target);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy