com.simiacryptus.mindseye.lang.Tensor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mindseye-core Show documentation
Show all versions of mindseye-core Show documentation
Core Neural Networks Framework
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import com.simiacryptus.lang.SerializableFunction;
import com.simiacryptus.ref.lang.RecycleBin;
import com.simiacryptus.ref.lang.RefAware;
import com.simiacryptus.ref.lang.RefUtil;
import com.simiacryptus.ref.lang.ReferenceCountingBase;
import com.simiacryptus.ref.wrappers.*;
import com.simiacryptus.util.FastRandom;
import com.simiacryptus.util.data.DoubleStatistics;
import com.simiacryptus.util.data.ScalarStatistics;
import org.jetbrains.annotations.NotNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.awt.image.BufferedImage;
import java.io.Serializable;
import java.util.*;
import java.util.function.*;
import java.util.stream.DoubleStream;
@SuppressWarnings("serial")
public final class Tensor extends ReferenceCountingBase implements Serializable, ZipSerializable {
@Nonnull
public static final DataSerializer json_precision = SerialPrecision.Float;
@Nullable
protected final int[] dimensions;
@Nullable
protected final int[] strides;
@Nullable
protected volatile double[] data;
@Nullable
protected volatile UUID id;
private Tensor() {
super();
data = null;
strides = null;
dimensions = null;
}
public Tensor(@Nonnull final double... ds) {
this(ds, ds.length);
}
public Tensor(@Nullable final double[] data, @Nonnull final int... dims) {
if (null != data && 0 >= data.length)
throw new IllegalArgumentException();
Tensor.length(dims);
if (Tensor.length(dims) <= 0)
throw new IllegalArgumentException();
if (null != data && Tensor.length(dims) != data.length)
throw new IllegalArgumentException(RefArrays.toString(dims) + " != " + data.length);
dimensions = 0 == dims.length ? new int[]{} : RefArrays.copyOf(dims, dims.length);
strides = Tensor.getSkips(dims);
//this.data = data;// Arrays.copyOf(data, data.length);
if (null != data) {
this.data = RecycleBin.DOUBLES.copyOf(data, data.length);
}
assert isValid();
//assert (null == data || Tensor.length(dims) == data.length);
}
private Tensor(@Nonnull int[] dims, @Nullable double[] data) {
this(dims, Tensor.getSkips(dims), data);
}
private Tensor(@Nullable int[] dimensions,
@Nullable int[] strides, @Nullable double[] data) {
assert dimensions != null;
if (Tensor.length(dimensions) >= Integer.MAX_VALUE)
throw new IllegalArgumentException();
assert null == data || data.length == Tensor.length(dimensions);
this.dimensions = dimensions;
this.strides = strides;
this.data = data;
assert isValid();
}
public Tensor(@Nullable final float[] data, @Nonnull final int... dims) {
if (Tensor.length(dims) >= Integer.MAX_VALUE)
throw new IllegalArgumentException();
dimensions = RefArrays.copyOf(dims, dims.length);
strides = Tensor.getSkips(dims);
if (null != data) {
this.data = RecycleBin.DOUBLES.obtain(data.length);// Arrays.copyOf(data, data.length);
RefArrays.parallelSetAll(this.data, i -> {
final double v = data[i];
return Double.isFinite(v) ? v : 0;
});
assert RefArrays.stream(this.data).allMatch(Double::isFinite);
}
assert isValid();
//assert (null == data || Tensor.length(dims) == data.length);
}
public Tensor(@Nonnull final int... dims) {
this((double[]) null, dims);
assert dims.length > 0;
}
@Nonnull
public double[] getData() {
assertAlive();
if (null == data) {
synchronized (this) {
if (null == data) {
assert dimensions != null;
final int length = Tensor.length(dimensions);
data = RecycleBin.DOUBLES.obtain(length);
assert null != data;
assert length == data.length;
}
}
}
assert isValid();
assert null != data;
return data;
}
@Nonnull
public float[] getDataAsFloats() {
return Tensor.toFloats(getData());
}
@Nonnull
public final int[] getDimensions() {
assert dimensions != null;
return RefArrays.copyOf(dimensions, dimensions.length);
}
@NotNull
public DoubleStatistics getDoubleStatistics() {
return new DoubleStatistics().accept(getData());
}
@Nullable
public UUID getId() {
if (id == null) {
synchronized (this) {
if (id == null) {
id = UUID.randomUUID();
}
}
}
return id;
}
public void setId(@Nullable UUID id) {
this.id = id;
}
@Nonnull
public RefStream getPixelStream() {
int[] dimensions = getDimensions();
int width = dimensions[0];
int height = dimensions[1];
int bands = dimensions[2];
return RefIntStream.range(0, width).mapToObj(x -> x).parallel()
.flatMap((Function>) x1 -> {
return RefIntStream.range(0, height).mapToObj(y -> y).map(y -> {
return this.getPixel(x1, y, bands);
});
});
}
public ScalarStatistics getScalarStatistics() {
return new ScalarStatistics().add(getData());
}
public boolean isValid() {
assert dimensions != null;
return !isFreed() && (null == this.data || this.data.length == Tensor.length(dimensions));
}
public void setAll(double v) {
@Nullable final double[] data = getData();
for (int i = 0; i < data.length; i++) {
data[i] = v;
}
}
public void setByCoord(@RefAware @Nonnull ToDoubleFunction f) {
setByCoord(f, true);
}
public void setBytes(byte[] bytes) {
setBytes(bytes, json_precision);
}
public void setParallelByIndex(@Nonnull final IntToDoubleFunction f) {
RefIntStream.range(0, length()).parallel().forEach(c -> {
set(c, f.applyAsDouble(c));
});
}
@Nullable
@SuppressWarnings("unused")
public static Tensor fromJson(@Nullable final JsonElement json, @Nullable Map resources) {
if (null == json)
return null;
if (json.isJsonArray()) {
final JsonArray array = json.getAsJsonArray();
final int size = array.size();
if (array.get(0).isJsonPrimitive()) {
final double[] doubles = RefIntStream.range(0, size).mapToObj(array::get).mapToDouble(JsonElement::getAsDouble).toArray();
@Nonnull
Tensor tensor = new Tensor(doubles);
assert tensor.isValid();
return tensor;
} else {
final RefList elements = RefIntStream.range(0, size).mapToObj(array::get).map(element -> {
return Tensor.fromJson(element, resources);
}).collect(RefCollectors.toList());
Tensor temp_33_0010 = elements.get(0);
@Nonnull final int[] dimensions = temp_33_0010.getDimensions();
temp_33_0010.freeRef();
if (!elements.stream().allMatch(t -> {
boolean temp_33_0001 = RefArrays.equals(dimensions, t.getDimensions());
t.freeRef();
return temp_33_0001;
})) {
elements.freeRef();
throw new IllegalArgumentException();
}
@Nonnull final int[] newDdimensions = RefArrays.copyOf(dimensions, dimensions.length + 1);
newDdimensions[dimensions.length] = size;
@Nonnull final Tensor tensor = new Tensor(newDdimensions);
@Nullable final double[] data = tensor.getData();
for (int i = 0; i < size; i++) {
Tensor temp_33_0011 = elements.get(i);
@Nullable final double[] e = temp_33_0011.getData();
temp_33_0011.freeRef();
RefSystem.arraycopy(e, 0, data, i * e.length, e.length);
}
elements.freeRef();
assert tensor.isValid();
return tensor;
}
} else if (json.isJsonObject()) {
JsonObject jsonObject = json.getAsJsonObject();
@Nonnull
int[] dims = fromJsonArray(jsonObject.getAsJsonArray("length"));
@Nonnull
Tensor tensor = new Tensor(dims);
SerialPrecision precision = SerialPrecision.valueOf(jsonObject.getAsJsonPrimitive("precision").getAsString());
JsonElement base64 = jsonObject.get("base64");
if (null == base64) {
if (null == resources) {
tensor.freeRef();
throw new IllegalArgumentException("No Data Resources");
}
CharSequence resourceId = jsonObject.getAsJsonPrimitive("resource").getAsString();
tensor.setBytes(resources.get(resourceId), precision);
} else {
tensor.setBytes(Base64.getDecoder().decode(base64.getAsString()), precision);
}
assert tensor.isValid();
JsonElement id = jsonObject.get("id");
if (null != id) {
tensor.setId(UUID.fromString(id.getAsString()));
}
return tensor;
} else {
@Nonnull
Tensor tensor = new Tensor(json.getAsJsonPrimitive().getAsDouble());
assert tensor.isValid();
return tensor;
}
}
public static int length(@Nonnull int... dims) {
long total = 1;
if (null == dims) return 0;
if (0 == dims.length) return 0;
for (final int dim : dims) {
assert 0 <= dim : RefArrays.toString(dims);
total *= dim;
assert 0 <= total : RefArrays.toString(dims);
assert total < Integer.MAX_VALUE : RefArrays.toString(dims);
}
return (int) total;
}
@Nonnull
public static Tensor fromRGB(@Nonnull final BufferedImage img) {
final int width = img.getWidth();
final int height = img.getHeight();
@Nonnull final Tensor a = new Tensor(width, height, 3);
RefIntStream.range(0, width).parallel().forEach(x -> {
@Nonnull final int[] coords = {0, 0, 0};
RefIntStream.range(0, height).forEach(y -> {
coords[0] = x;
coords[1] = y;
coords[2] = 0;
a.set(coords, img.getRGB(x, y) & 0xFF);
coords[2] = 1;
a.set(coords, img.getRGB(x, y) >> 8 & 0xFF);
coords[2] = 2;
a.set(coords, img.getRGB(x, y) >> 16 & 0x0FF);
});
});
return a;
}
public static double[] getDoubles(@Nonnull final RefDoubleStream stream, final int dim) {
final double[] doubles = RecycleBin.DOUBLES.obtain(dim);
stream.forEach(new DoubleConsumer() {
int j = 0;
@Override
public void accept(final double value) {
doubles[j++] = value;
}
});
return doubles;
}
@Nonnull
public static Tensor product(@Nonnull final Tensor left, @Nonnull final Tensor right) {
if (left.length() == 1 && right.length() != 1) {
return Tensor.product(right, left);
}
assert left.length() == right.length() || 1 == right.length();
@Nonnull final Tensor result = new Tensor(left.getDimensions());
@Nullable final double[] resultData = result.getData();
@Nullable final double[] leftData = left.getData();
left.freeRef();
@Nullable final double[] rightData = right.getData();
right.freeRef();
for (int i = 0; i < resultData.length; i++) {
final double l = leftData[i];
final double r = rightData[1 == rightData.length ? 0 : i];
resultData[i] = l * r;
}
return result;
}
@Nonnull
public static float[] toFloats(@Nonnull final double[] data) {
return copy(data, new float[data.length]);
}
public static float[] copy(@Nonnull double[] src, float[] buffer) {
for (int i = 0; i < src.length; i++) {
buffer[i] = (float) src[i];
}
return buffer;
}
@Nonnull
public static JsonArray toJsonArray(@Nonnull int[] ints) {
@Nonnull
JsonArray dim = new JsonArray();
for (int i = 0; i < ints.length; i++) {
dim.add(new JsonPrimitive(ints[i]));
}
return dim;
}
@Nonnull
@SuppressWarnings("unused")
public static int[] fromJsonArray(@Nonnull JsonArray ints) {
@Nonnull
int[] array = new int[ints.size()];
for (int i = 0; i < ints.size(); i++) {
array[i] = ints.get(i).getAsInt();
}
return array;
}
@Nonnull
public static Tensor invertDimensions(@Nonnull Tensor tensor) {
Tensor temp_33_0005 = tensor.rearrange(Tensor::reverse);
tensor.freeRef();
return temp_33_0005;
}
@Nonnull
public static int[] permute(@Nonnull int[] key, int[] data, final int[] dimensions) {
@Nonnull
int[] copy = new int[key.length];
for (int i = 0; i < key.length; i++) {
int k = key[i];
if (k == Integer.MAX_VALUE) {
copy[i] = dimensions[0] - data[0] - 1;
} else if (k < 0) {
copy[i] = dimensions[-k] - data[-k] - 1;
} else {
copy[i] = data[k];
}
}
return copy;
}
@Nonnull
public static int[] reverse(@Nonnull int[] dimensions) {
return reverseInPlace(RefArrays.copyOf(dimensions, dimensions.length));
}
@Nullable
public static int[] reverseInPlace(@Nullable final int[] array) {
if (array == null) {
return null;
}
int i = 0;
int j = array.length - 1;
int tmp;
while (i < j) {
tmp = array[j];
array[j] = array[i];
array[i] = tmp;
j--;
i++;
}
return array;
}
public static CharSequence prettyPrint(double[] doubles) {
@Nonnull
Tensor t = new Tensor(doubles);
String temp_33_0002 = t.prettyPrint();
t.freeRef();
return temp_33_0002;
}
@Nonnull
public static SerializableFunction select(@Nonnull Coordinate... reducedCoords) {
return tensor -> {
Tensor reduced = new Tensor(reducedCoords.length);
final ToDoubleFunction f = RefUtil.wrapInterface(c2 -> tensor.get(reducedCoords[c2.getIndex()]),
tensor == null ? null : tensor.addRef());
reduced.setByCoord(f, false);
if (null != tensor)
tensor.freeRef();
return reduced;
};
}
@NotNull
public static Tensor add(Tensor a, Tensor b) {
if (1 == a.currentRefCount()) {
a.addInPlace(b);
return a;
} else {
try {
return a.add(b);
} finally {
a.freeRef();
}
}
}
private static double bound8bit(final double value) {
final int max = 0xFF;
final int min = 0;
return value < min ? min : value > max ? max : value;
}
private static int bound8bit(final int value) {
final int max = 0xFF;
final int min = 0;
return value < min ? min : value > max ? max : value;
}
@Nonnull
private static int[] getSkips(@Nonnull final int[] dims) {
@Nonnull final int[] skips = new int[dims.length];
for (int i = 0; i < skips.length; i++) {
if (i == 0) {
skips[0] = 1;
} else {
skips[i] = skips[i - 1] * dims[i - 1];
}
}
return skips;
}
public void fill(int fromIndex, int toIndex, double val) {
Arrays.fill(getData(), fromIndex, toIndex, val);
}
public double[] copyData() {
return Arrays.copyOf(getData(), length());
}
public RefDoubleStream doubleStream() {
return RefDoubleStream.of(getData()).track(addRef());
}
public double[] getPixel(int... coords) {
return getPixel(coords[0], coords[1], getDimensions()[2]);
}
public double[] getPixel(int x, int y, int bands) {
return RefIntStream.range(0, bands).mapToDouble(c -> get(x, y, c)).toArray();
}
@Nonnull
public Tensor rearrange(@Nonnull UnaryOperator fn) {
return rearrange(fn, fn.apply(getDimensions()));
}
@Nonnull
public Tensor rearrange(@Nonnull UnaryOperator fn, int[] outputDims) {
@Nonnull
Tensor result = new Tensor(outputDims);
coordStream(false).forEach(c -> {
int[] inCoords = c.getCoords();
int[] outCoords = fn.apply(inCoords);
result.set(outCoords, get(c));
});
return result;
}
public void addInPlace(@Nonnull final Tensor tensor) {
try {
assertAlive();
int[] tensorDimensions = tensor.getDimensions();
if (!Arrays.equals(getDimensions(), tensorDimensions)) {
throw new AssertionError(String.format("%s != %s", Arrays.toString(getDimensions()), Arrays.toString(tensorDimensions)));
}
double[] toAdd = tensor.getData();
double[] data = getData();
int length = length();
int shards = Math.max(1, Math.min(8, length / 64));
double shardSize = (double) length / shards;
RefDoubleStream.iterate(0, x -> x + shardSize).limit(shards).parallel().forEach(start -> {
int end = (int) Math.min(length, Math.floor(start + shardSize));
for (int i = (int) Math.floor(start); i < end; i++) {
data[i] += toAdd[i];
}
});
} finally {
if (tensor != null) tensor.freeRef();
}
}
public void add(@Nonnull final Coordinate coords, final double value) {
add(coords.getIndex(), value);
}
public final void add(final int index, final double value) {
getData()[index] += value;
}
public void add(@Nonnull final int[] coords, final double value) {
add(index(coords), value);
}
@Nonnull
public Tensor add(@Nonnull final Tensor right) {
try {
int[] dimensions = getDimensions();
assert RefArrays.equals(dimensions, right.getDimensions());
final double[] data = getData();
final double[] rightData = right.getData();
return new Tensor(dimensions,
RefIntStream.range(0, length()).mapToDouble(i -> rightData[i] + data[i]).toArray());
} finally {
right.freeRef();
}
}
@Nonnull
public RefStream coordStream(boolean parallel) {
//ConcurrentHashSet
© 2015 - 2025 Weber Informatics LLC | Privacy Policy