com.simiacryptus.mindseye.lang.DoubleBufferSet Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Map;
import java.util.function.Supplier;
public abstract class DoubleBufferSet> extends ReferenceCountingBase {
static final Logger log = LoggerFactory.getLogger(DoubleBufferSet.class);
@Nonnull
protected final RefHashMap map = new RefHashMap<>();
public DoubleBufferSet() {
}
public DoubleBufferSet(@Nonnull final DoubleBufferSet toCopy) {
this(toCopy.getMap());
toCopy.freeRef();
}
public DoubleBufferSet(@Nonnull final RefMap collect) {
map.putAll(collect);
}
@Nonnull
public RefMap getMap() {
return RefCollections.unmodifiableMap(map.addRef());
}
@Nonnull
@SuppressWarnings("unchecked")
public DoubleBufferSet copy() {
return map(x -> {
try {
return (V) x.copy();
} finally {
x.freeRef();
}
});
}
@javax.annotation.Nullable
public V get(final K layer, final double[] ptr) {
final V delta = get(layer, () -> factory(layer, ptr));
assert delta.key.equals(layer);
assert delta.target == ptr;
return delta;
}
@javax.annotation.Nullable
public V get(final K layer) {
final V delta = map.get(layer);
if (null == delta) return delta;
assert delta.key.equals(layer);
return delta;
}
@javax.annotation.Nullable
public V get(final K layer, @Nonnull final Tensor tensor) {
V delta = get(layer, tensor.getData());
tensor.freeRef();
return delta;
}
@Nonnull
public DoubleBufferSet map(@Nonnull final RefFunction mapper) {
RefHashSet> entries = map.entrySet();
try {
RefStream> stream = entries.stream();
if (map.size() > 100) {
stream = stream.parallel();
}
final RefMap newMap = stream.collect(RefCollectors.toMap(e -> {
try {
return e.getKey();
} finally {
RefUtil.freeRef(e);
}
}, e -> {
try {
return mapper.apply(e.getValue());
} finally {
RefUtil.freeRef(e);
}
}));
return new Delegate(this.addRef(), newMap);
} finally {
entries.freeRef();
}
}
@Nonnull
public RefStream stream() {
RefHashSet values = map.values();
RefStream stream = values.stream().filter(v -> {
if (null != v) {
v.freeRef();
return true;
} else {
return false;
}
})
.distinct()
// .sorted(RefComparator.comparingInt(v -> {
// int hashCode = RefSystem.identityHashCode(v.target);
// v.freeRef();
// return hashCode;
// }))
;
values.freeRef();
return stream;
}
public void _free() {
super._free();
map.freeRef();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
DoubleBufferSet addRef() {
return (DoubleBufferSet) super.addRef();
}
public int size() {
return map.size();
}
public DoubleBufferSet allFinite(double defaultValue) {
return map((V doubleBuffer) -> {
V v = (V) doubleBuffer.map(d -> Double.isFinite(d) ? d : defaultValue);
doubleBuffer.freeRef();
return v;
});
}
public RefSet keySet() {
return map.keySet();
}
protected abstract V factory(final K layer, final double[] target);
@NotNull
private V get(@Nullable final K layer, @Nullable final Supplier factory) {
if (null == factory)
throw new IllegalArgumentException();
if (null == layer)
throw new IllegalArgumentException();
try {
synchronized (map) {
return map.computeIfAbsent(layer, l -> {
RefUtil.freeRef(l);
V delta = factory.get();
assert null != delta;
if (log.isDebugEnabled())
log.debug(RefString.format("Init key buffer for %s - %s params", l.getClass(), delta.target.length));
return delta;
});
}
} finally {
RefUtil.freeRef(factory);
}
}
protected static class Delegate> extends DoubleBufferSet {
@Nullable
private final DoubleBufferSet parent;
public Delegate(final DoubleBufferSet parent) {
this(parent, new RefHashMap<>());
}
public Delegate(@Nullable final DoubleBufferSet parent, @Nonnull final RefMap newMap) {
super(newMap);
this.parent = parent;
}
public @SuppressWarnings("unused")
void _free() {
if (null != parent)
parent.freeRef();
super._free();
}
@Nonnull
public @Override
@SuppressWarnings("unused")
Delegate addRef() {
return (Delegate) super.addRef();
}
@Override
protected T factory(final K layer, final double[] target) {
assert parent != null;
return parent.factory(layer, target);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy