Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
com.simiacryptus.mindseye.lang.Tensor Maven / Gradle / Ivy
/*
* Copyright (c) 2018 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import org.apache.commons.lang3.ArrayUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.awt.image.BufferedImage;
import java.io.Serializable;
import java.util.*;
import java.util.function.*;
import java.util.stream.*;
/**
* A multi-dimensional array of data. Represented internally as a single double[] array. This class is central to data
* handling in MindsEye, and may have some odd-looking or suprising optimizations.
*/
@SuppressWarnings("serial")
public final class Tensor extends ReferenceCountingBase implements Serializable {
/**
* The constant json_precision.
*/
@Nonnull
public static DataSerializer json_precision = SerialPrecision.Float;
/**
* The Dimensions.
*/
@Nullable
protected final int[] dimensions;
/**
* The Strides.
*/
@Nullable
protected final int[] strides;
/**
* The Data.
*/
@Nullable
protected volatile double[] data;
@Nullable
protected volatile UUID id;
/**
* Instantiates a new Tensor.
*/
private Tensor() {
super();
data = null;
strides = null;
dimensions = null;
}
/**
* Instantiates a new Tensor.
*
* @param ds the ds
*/
public Tensor(@Nonnull final double... ds) {
this(ds, ds.length);
}
/**
* Instantiates a new Tensor.
*
* @param data the data
* @param dims the dims
*/
public Tensor(@Nullable final double[] data, @Nonnull final int... dims) {
if (Tensor.length(dims) > Integer.MAX_VALUE) throw new IllegalArgumentException();
if (null != data && Tensor.length(dims) != data.length)
throw new IllegalArgumentException(Arrays.toString(dims) + " != " + data.length);
dimensions = (null == dims || 0 == dims.length) ? new int[]{} : Arrays.copyOf(dims, dims.length);
strides = Tensor.getSkips(dims);
//this.data = data;// Arrays.copyOf(data, data.length);
if (null != data) {
this.data = RecycleBin.DOUBLES.copyOf(data, data.length);
}
assert isValid();
//assert (null == data || Tensor.length(dims) == data.length);
}
private Tensor(int[] dims, @Nullable double[] data) {
this(dims, Tensor.getSkips(dims), data);
}
private Tensor(int[] dimensions, int[] strides, @Nullable double[] data) {
if (Tensor.length(dimensions) >= Integer.MAX_VALUE) throw new IllegalArgumentException();
assert null == data || data.length == Tensor.length(dimensions);
this.dimensions = dimensions;
this.strides = strides;
this.data = data;
assert isValid();
}
/**
* Instantiates a new Tensor.
*
* @param data the data
* @param dims the dims
*/
public Tensor(@Nullable final float[] data, @Nonnull final int... dims) {
if (Tensor.length(dims) >= Integer.MAX_VALUE) throw new IllegalArgumentException();
dimensions = Arrays.copyOf(dims, dims.length);
strides = Tensor.getSkips(dims);
if (null != data) {
this.data = RecycleBin.DOUBLES.obtain(data.length);// Arrays.copyOf(data, data.length);
Arrays.parallelSetAll(this.data, i -> {
final double v = data[i];
return Double.isFinite(v) ? v : 0;
});
assert Arrays.stream(this.data).allMatch(v -> Double.isFinite(v));
}
assert isValid();
//assert (null == data || Tensor.length(dims) == data.length);
}
/**
* Instantiates a new Tensor.
*
* @param dims the dims
*/
public Tensor(@Nonnull final int... dims) {
this((double[]) null, dims);
assert dims.length > 0;
}
/**
* From json tensor.
*
* @param json the json
* @param resources the resources
* @return the tensor
*/
@Nullable
public static Tensor fromJson(@Nullable final JsonElement json, @Nullable Map resources) {
if (null == json) return null;
if (json.isJsonArray()) {
final JsonArray array = json.getAsJsonArray();
final int size = array.size();
if (array.get(0).isJsonPrimitive()) {
final double[] doubles = IntStream.range(0, size).mapToObj(i -> {
return array.get(i);
}).mapToDouble(element -> {
return element.getAsDouble();
}).toArray();
@Nonnull Tensor tensor = new Tensor(doubles);
assert tensor.isValid();
return tensor;
} else {
final List elements = IntStream.range(0, size).mapToObj(i -> {
return array.get(i);
}).map(element -> {
return Tensor.fromJson(element, resources);
}).collect(Collectors.toList());
@Nonnull final int[] dimensions = elements.get(0).getDimensions();
if (!elements.stream().allMatch(t -> Arrays.equals(dimensions, t.getDimensions()))) {
throw new IllegalArgumentException();
}
@Nonnull final int[] newDdimensions = Arrays.copyOf(dimensions, dimensions.length + 1);
newDdimensions[dimensions.length] = size;
@Nonnull final Tensor tensor = new Tensor(newDdimensions);
@Nullable final double[] data = tensor.getData();
for (int i = 0; i < size; i++) {
@Nullable final double[] e = elements.get(i).getData();
System.arraycopy(e, 0, data, i * e.length, e.length);
}
for (@Nonnull Tensor t : elements) {
t.freeRef();
}
assert tensor.isValid();
return tensor;
}
} else if (json.isJsonObject()) {
JsonObject jsonObject = json.getAsJsonObject();
@Nonnull int[] dims = fromJsonArray(jsonObject.getAsJsonArray("length"));
@Nonnull Tensor tensor = new Tensor(dims);
SerialPrecision precision = SerialPrecision.valueOf(jsonObject.getAsJsonPrimitive("precision").getAsString());
JsonElement base64 = jsonObject.get("base64");
if (null == base64) {
if (null == resources) throw new IllegalArgumentException("No Data Resources");
CharSequence resourceId = jsonObject.getAsJsonPrimitive("resource").getAsString();
tensor.setBytes(resources.get(resourceId), precision);
} else {
tensor.setBytes(Base64.getDecoder().decode(base64.getAsString()), precision);
}
assert tensor.isValid();
JsonElement id = jsonObject.get("id");
if(null != id) {
tensor.setId(UUID.fromString(id.getAsString()));
}
return tensor;
} else {
@Nonnull Tensor tensor = new Tensor(json.getAsJsonPrimitive().getAsDouble());
assert tensor.isValid();
return tensor;
}
}
private static double bound8bit(final double value) {
final int max = 0xFF;
final int min = 0;
return value < min ? min : value > max ? max : value;
}
private static int bound8bit(final int value) {
final int max = 0xFF;
final int min = 0;
return value < min ? min : value > max ? max : value;
}
/**
* Dim l long.
*
* @param dims the dims
* @return the long
*/
public static int length(@Nonnull int... dims) {
long total = 1;
for (final int dim : dims) {
total *= dim;
}
return (int) total;
}
/**
* From rgb tensor.
*
* @param img the img
* @return the tensor
*/
@Nonnull
public static Tensor fromRGB(@Nonnull final BufferedImage img) {
final int width = img.getWidth();
final int height = img.getHeight();
@Nonnull final Tensor a = new Tensor(width, height, 3);
IntStream.range(0, width).parallel().forEach(x -> {
@Nonnull final int[] coords = {0, 0, 0};
IntStream.range(0, height).forEach(y -> {
coords[0] = x;
coords[1] = y;
coords[2] = 0;
a.set(coords, img.getRGB(x, y) & 0xFF);
coords[2] = 1;
a.set(coords, img.getRGB(x, y) >> 8 & 0xFF);
coords[2] = 2;
a.set(coords, img.getRGB(x, y) >> 16 & 0x0FF);
});
});
return a;
}
/**
* Get doubles double [ ].
*
* @param stream the stream
* @param dim the length
* @return the double [ ]
*/
public static double[] getDoubles(@Nonnull final DoubleStream stream, final int dim) {
final double[] doubles = RecycleBin.DOUBLES.obtain(dim);
stream.forEach(new DoubleConsumer() {
int j = 0;
@Override
public void accept(final double value) {
doubles[j++] = value;
}
});
return doubles;
}
@Nonnull
private static int[] getSkips(@Nonnull final int[] dims) {
@Nonnull final int[] skips = new int[dims.length];
for (int i = 0; i < skips.length; i++) {
if (i == 0) {
skips[0] = 1;
} else {
skips[i] = skips[i - 1] * dims[i - 1];
}
}
return skips;
}
/**
* Product tensor.
*
* @param left the left
* @param right the right
* @return the tensor
*/
@Nonnull
public static Tensor product(@Nonnull final Tensor left, @Nonnull final Tensor right) {
if (left.length() == 1 && right.length() != 1) return Tensor.product(right, left);
assert left.length() == right.length() || 1 == right.length();
@Nonnull final Tensor result = new Tensor(left.getDimensions());
@Nullable final double[] resultData = result.getData();
@Nullable final double[] leftData = left.getData();
@Nullable final double[] rightData = right.getData();
for (int i = 0; i < resultData.length; i++) {
final double l = leftData[i];
final double r = rightData[1 == rightData.length ? 0 : i];
resultData[i] = l * r;
}
return result;
}
/**
* To doubles double [ ].
*
* @param data the data
* @return the double [ ]
*/
public static double[] toDoubles(@Nonnull final float[] data) {
final double[] buffer = RecycleBin.DOUBLES.obtain(data.length);
for (int i = 0; i < data.length; i++) {
buffer[i] = data[i];
}
return buffer;
}
/**
* To floats float [ ].
*
* @param data the data
* @return the float [ ]
*/
@Nonnull
public static float[] toFloats(@Nonnull final double[] data) {
@Nonnull final float[] buffer = new float[data.length];
for (int i = 0; i < data.length; i++) {
buffer[i] = (float) data[i];
}
return buffer;
}
/**
* To json array json array.
*
* @param ints the ints
* @return the json array
*/
@Nonnull
public static JsonArray toJsonArray(@Nonnull int[] ints) {
@Nonnull JsonArray dim = new JsonArray();
for (int i = 0; i < ints.length; i++) {
dim.add(new JsonPrimitive(ints[i]));
}
return dim;
}
/**
* From json array int [ ].
*
* @param ints the ints
* @return the int [ ]
*/
@Nonnull
public static int[] fromJsonArray(@Nonnull JsonArray ints) {
@Nonnull int[] array = new int[ints.size()];
for (int i = 0; i < ints.size(); i++) {
array[i] = ints.get(i).getAsInt();
}
return array;
}
/**
* Reverse dimensions tensor.
*
* @param tensor the tensor
* @return the tensor
*/
@Nonnull
public static Tensor reverseDimensions(@Nonnull Tensor tensor) {
return tensor.rearrange(Tensor::reverse);
}
/**
* Permute int [ ].
*
* @param key the key
* @param data the data
* @param dimensions the dimensions
* @return the int [ ]
*/
@Nonnull
public static int[] permute(@Nonnull int[] key, int[] data, final int[] dimensions) {
@Nonnull int[] copy = new int[key.length];
for (int i = 0; i < key.length; i++) {
int k = key[i];
if (k == Integer.MAX_VALUE) {
copy[i] = dimensions[0] - data[0] - 1;
} else if (k < 0) {
copy[i] = dimensions[-k] - data[-k] - 1;
} else {
copy[i] = data[k];
}
}
return copy;
}
/**
* Reverse int [ ].
*
* @param dimensions the dimensions
* @return the int [ ]
*/
@Nonnull
public static int[] reverse(@Nonnull int[] dimensions) {
@Nonnull int[] copy = Arrays.copyOf(dimensions, dimensions.length);
ArrayUtils.reverse(copy);
return copy;
}
/**
* Pretty print string.
*
* @param doubles the doubles
* @return the string
*/
public static CharSequence prettyPrint(double[] doubles) {
@Nonnull Tensor t = new Tensor(doubles);
String prettyPrint = t.prettyPrint();
t.freeRef();
return prettyPrint;
}
/**
* Get pixel double [ ].
*
* @param tensor the tensor
* @param x the x
* @param y the y
* @param bands the bands
* @return the double [ ]
*/
public static double[] getPixel(final Tensor tensor, final int x, final int y, final int bands) {
return IntStream.range(0, bands).mapToDouble(band -> tensor.get(x, y, band)).toArray();
}
/**
* Reduce tensor.
*
* @return the tensor
*/
@Nonnull
public Tensor sumChannels() {
int[] dimensions = getDimensions();
Tensor self = this;
return new Tensor(dimensions[0], dimensions[1], 1).setByCoord(c -> {
int[] coords = c.getCoords();
return IntStream.range(0, dimensions[2]).mapToDouble(j -> self.get(coords[0], coords[1], j)).sum();
});
}
/**
* Gets pixel stream.
*
* @return the pixel stream
*/
@Nonnull
public Stream getPixelStream() {
int[] dimensions = getDimensions();
int width = dimensions[0];
int height = dimensions[1];
int bands = dimensions[2];
return IntStream.range(0, width).mapToObj(x -> x).parallel().flatMap(x -> {
return IntStream.range(0, height).mapToObj(y -> y).map(y -> {
return getPixel(this, x, y, bands);
});
});
}
/**
* Rescale rms tensor.
*
* @param rms the rms
* @return the tensor
*/
public Tensor rescaleRms(final double rms) {
return scale(rms / rms());
}
/**
* Normalize distribution tensor.
*
* @return the tensor
*/
public Tensor normalizeDistribution() {
double[] sortedValues = Arrays.stream(getData()).sorted().toArray();
Tensor result = map(v -> Math.abs(((double) Arrays.binarySearch(sortedValues, v)) / ((double) sortedValues.length)));
return result;
}
/**
* Reorder dimensions tensor.
*
* @param fn the fn
* @return the tensor
*/
@Nonnull
public Tensor rearrange(@Nonnull UnaryOperator fn) {
return rearrange(fn, fn.apply(getDimensions()));
}
/**
* Reorder dimensions tensor.
*
* @param fn the fn
* @param outputDims the output dims
* @return the tensor
*/
@Nonnull
public Tensor rearrange(@Nonnull UnaryOperator fn, int[] outputDims) {
@Nonnull Tensor result = new Tensor(outputDims);
coordStream(false).forEach(c -> {
int[] inCoords = c.getCoords();
int[] outCoords = fn.apply(inCoords);
result.set(outCoords, get(c));
});
return result;
}
/**
* Is valid boolean.
*
* @return the boolean
*/
public boolean isValid() {
return !isFinalized() && (null == this.data || this.data.length == Tensor.length(dimensions));
}
/**
* Accum.
*
* @param tensor the tensor
*/
public void addInPlace(@Nonnull final Tensor tensor) {
assert Arrays.equals(getDimensions(), tensor.getDimensions()) : Arrays.toString(getDimensions()) + " != " + Arrays.toString(tensor.getDimensions());
setParallelByIndex(c -> get(c) + tensor.get(c));
}
/**
* Add.
*
* @param coords the coords
* @param value the value
*/
public void add(@Nonnull final Coordinate coords, final double value) {
add(coords.getIndex(), value);
}
/**
* Add tensor.
*
* @param index the index
* @param value the value
* @return the tensor
*/
@Nonnull
public final Tensor add(final int index, final double value) {
getData()[index] += value;
return this;
}
/**
* Add.
*
* @param coords the coords
* @param value the value
*/
public void add(@Nonnull final int[] coords, final double value) {
add(index(coords), value);
}
/**
* Add right.
*
* @param right the right
* @return the right
*/
@Nullable
public Tensor add(@Nonnull final Tensor right) {
assert Arrays.equals(getDimensions(), right.getDimensions());
return mapCoords((c) -> get(c) + right.get(c));
}
/**
* Add and free tensor.
*
* @param right the right
* @return the tensor
*/
@Nullable
public Tensor addAndFree(@Nonnull final Tensor right) {
assertAlive();
right.assertAlive();
if (1 == currentRefCount()) {
addInPlace(right);
return this;
} else {
assert Arrays.equals(getDimensions(), right.getDimensions());
return mapCoordsAndFree((c) -> get(c) + right.get(c));
}
}
/**
* Coord stream stream.
*
* @param parallel the safe
* @return the stream
*/
@Nonnull
public Stream coordStream(boolean parallel) {
//ConcurrentHashSet distinctBuffer = new ConcurrentHashSet<>();
//assert distinctBuffer.add(coordinate.copy()) : String.format("Duplicate: %s in %s", coordinate, distinctBuffer);
return StreamSupport.stream(Spliterators.spliterator(new Iterator() {
int cnt = 0;
@Nonnull
Coordinate coordinate = new Coordinate();
@Nonnull
int[] val = new int[dimensions.length];
@Nonnull
int[] safeCopy = new int[dimensions.length];
@Override
public boolean hasNext() {
return cnt < length();
}
@Nonnull
@Override
public synchronized Coordinate next() {
if (0 < cnt) {
for (int i = 0; i < val.length; i++) {
if (++val[i] >= dimensions[i]) {
val[i] = 0;
} else {
break;
}
}
}
System.arraycopy(val, 0, safeCopy, 0, val.length);
coordinate.setIndex(cnt++);
coordinate.setCoords(safeCopy);
return parallel ? coordinate.copy() : coordinate;
}
}, length(), Spliterator.ORDERED), parallel);
}
/**
* Dim int.
*
* @return the int
*/
public int length() {
assertAlive();
if (null != data) {
return data.length;
} else {
return Tensor.length(dimensions);
}
}
/**
* Copy tensor.
*
* @return the tensor
*/
@Nonnull
public Tensor copy() {
assertAlive();
return new Tensor(RecycleBin.DOUBLES.copyOf(getData(), getData().length), Arrays.copyOf(dimensions, dimensions.length));
}
@Override
protected void _free() {
if (null != data) {
if (RecycleBin.DOUBLES.want(data.length)) {
RecycleBin.DOUBLES.recycle(data, data.length);
}
data = null;
}
}
@Override
public boolean equals(@Nullable final Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
@Nullable final Tensor other = (Tensor) obj;
if (0 == currentRefCount()) return false;
if (0 == other.currentRefCount()) return false;
if (!Arrays.equals(dimensions, other.dimensions)) {
return false;
}
return Arrays.equals(getData(), other.getData());
}
/**
* Get double.
*
* @param coords the coords
* @return the double
*/
public double get(@Nonnull final Coordinate coords) {
final double v = getData()[coords.getIndex()];
return v;
}
/**
* Get double.
*
* @param index the index
* @return the double
*/
public double get(final int index) {
return getData()[index];
}
/**
* Get double.
*
* @param c1 the c 1
* @param c2 the c 2
* @return the double
*/
public double get(final int c1, final int c2) {
return getData()[index(c1, c2)];
}
/**
* Get double.
*
* @param c1 the c 1
* @param c2 the c 2
* @param c3 the c 3
* @return the double
*/
public double get(final int c1, final int c2, final int c3) {
return getData()[index(c1, c2, c3)];
}
/**
* Get double.
*
* @param c1 the c 1
* @param c2 the c 2
* @param c3 the c 3
* @param c4 the c 4
* @param coords the coords
* @return the double
*/
public double get(final int c1, final int c2, final int c3, final int c4, final int... coords) {
return getData()[index(c1, c2, c3, c4, coords)];
}
/**
* Get.
*
* @param bufferArray the buffer array
*/
public void get(@Nonnull final double[] bufferArray) {
System.arraycopy(getData(), 0, bufferArray, 0, length());
}
/**
* Get double.
*
* @param coords the coords
* @return the double
*/
public double get(@Nonnull final int[] coords) {
return getData()[index(coords)];
}
/**
* Get data double [ ].
*
* @return the double [ ]
*/
@Nonnull
public double[] getData() {
assertAlive();
if (null == data) {
synchronized (this) {
if (null == data) {
final int length = Tensor.length(dimensions);
data = RecycleBin.DOUBLES.obtain(length);
assert null != data;
assert length == data.length;
}
}
}
assert isValid();
assert null != data;
return data;
}
/**
* Get dimensions int [ ].
*
* @return the int [ ]
*/
@Nonnull
public final int[] getDimensions() {
return Arrays.copyOf(dimensions, dimensions.length);
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(getData());
result = prime * result + Arrays.hashCode(dimensions);
return result;
}
/**
* Get data as floats float [ ].
*
* @return the float [ ]
*/
@Nonnull
public float[] getDataAsFloats() {
return Tensor.toFloats(getData());
}
/**
* Index int.
*
* @param c1 the c 1
* @return the int
*/
public int index(final int c1) {
int v = 0;
v += strides[0] * c1;
return v;
// return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();
}
/**
* Index int.
*
* @param c1 the c 1
* @param c2 the c 2
* @return the int
*/
public int index(final int c1, final int c2) {
int v = 0;
v += strides[0] * c1;
v += strides[1] * c2;
return v;
// return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();
}
/**
* Index int.
*
* @param c1 the c 1
* @param c2 the c 2
* @param c3 the c 3
* @return the int
*/
public int index(final int c1, final int c2, final int c3) {
int v = 0;
v += strides[0] * c1;
v += strides[1] * c2;
v += strides[2] * c3;
return v;
// return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();
}
/**
* Index int.
*
* @param coords the coords
* @return the int
*/
public int index(@Nonnull final Coordinate coords) {
return coords.getIndex();
}
/**
* Index int.
*
* @param c1 the c 1
* @param c2 the c 2
* @param c3 the c 3
* @param c4 the c 4
* @param coords the coords
* @return the int
*/
public int index(final int c1, final int c2, final int c3, final int c4, @Nullable final int... coords) {
int v = 0;
v += strides[0] * c1;
v += strides[1] * c2;
v += strides[2] * c3;
v += strides[3] * c4;
if (null != coords && 0 < coords.length) {
for (int i = 0; 4 + i < strides.length && i < coords.length; i++) {
v += strides[4 + i] * coords[4 + i];
}
}
return v;
// return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();
}
/**
* L 1 double.
*
* @return the double
*/
public double l1() {
return Arrays.stream(getData()).sum();
}
/**
* L 2 double.
*
* @return the double
*/
public double l2() {
return Math.sqrt(Arrays.stream(getData()).map(x -> x * x).sum());
}
/**
* Index int.
*
* @param coords the coords
* @return the int
*/
public int index(@Nonnull final int[] coords) {
int v = 0;
for (int i = 0; i < strides.length && i < coords.length; i++) {
v += strides[i] * coords[i];
}
return v;
// return IntStream.range(0, strides.length).mapCoords(i->strides[i]*coords[i]).sum();
}
/**
* Map tensor.
*
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor map(@Nonnull final DoubleUnaryOperator f) {
@Nullable final double[] data = getData();
Tensor tensor = new Tensor(dimensions);
@Nonnull final double[] cpy = tensor.getData();
IntStream.range(0, data.length).parallel().forEach(i -> cpy[i] = f.applyAsDouble(data[i]));
return tensor;
}
/**
* Map and free tensor.
*
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor mapAndFree(@Nonnull final DoubleUnaryOperator f) {
@Nullable final double[] data = getData();
@Nonnull final double[] cpy = new double[data.length];
for (int i = 0; i < data.length; i++) {
final double x = data[i];
// assert Double.isFinite(x);
final double v = f.applyAsDouble(x);
// assert Double.isFinite(v);
cpy[i] = v;
}
Tensor tensor = new Tensor(cpy, dimensions);
this.freeRef();
return tensor;
}
/**
* Map coords tensor.
*
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor mapCoords(@Nonnull final ToDoubleFunction f) {
return mapCoords(f, false);
}
/**
* Map coords and free tensor.
*
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor mapCoordsAndFree(@Nonnull final ToDoubleFunction f) {
return mapCoordsAndFree(f, false);
}
/**
* Map coords tensor.
*
* @param f the f
* @param parallel the parallel
* @return the tensor
*/
@Nullable
public Tensor mapCoords(@Nonnull final ToDoubleFunction f, boolean parallel) {
return new Tensor(Tensor.getDoubles(coordStream(parallel).mapToDouble(i -> f.applyAsDouble(i)), length()), dimensions);
}
/**
* Map coords and free tensor.
*
* @param f the f
* @param parallel the parallel
* @return the tensor
*/
@Nullable
public Tensor mapCoordsAndFree(@Nonnull final ToDoubleFunction f, boolean parallel) {
Tensor tensor = new Tensor(Tensor.getDoubles(coordStream(parallel).mapToDouble(i -> f.applyAsDouble(i)), length()), dimensions);
freeRef();
return tensor;
}
/**
* Map index tensor.
*
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor mapIndex(@Nonnull final TupleOperator f) {
return new Tensor(Tensor.getDoubles(IntStream.range(0, length()).mapToDouble(i -> f.eval(get(i), i)), length()), dimensions);
}
/**
* Mean double.
*
* @return the double
*/
public double mean() {
return sum() / length();
}
/**
* Map parallel tensor.
*
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor mapParallel(@Nonnull final DoubleUnaryOperator f) {
@Nullable final double[] data = getData();
return new Tensor(Tensor.getDoubles(IntStream.range(0, length()).mapToDouble(i -> f.applyAsDouble(data[i])), length()), dimensions);
}
/**
* Minus tensor.
*
* @param right the right
* @return the tensor
*/
@Nonnull
public Tensor minus(@Nonnull final Tensor right) {
if (!Arrays.equals(getDimensions(), right.getDimensions())) {
throw new IllegalArgumentException(Arrays.toString(getDimensions()) + " != " + Arrays.toString(right.getDimensions()));
}
@Nonnull final Tensor copy = new Tensor(getDimensions());
@Nullable final double[] thisData = getData();
@Nullable final double[] rightData = right.getData();
Arrays.parallelSetAll(copy.getData(), i -> thisData[i] - rightData[i]);
return copy;
}
/**
* Pretty printGroups string.
*
* @return the string
*/
public String prettyPrint() {
return toString(true);
}
/**
* Pretty print and free string.
*
* @return the string
*/
public String prettyPrintAndFree() {
String prettyPrint = prettyPrint();
freeRef();
return prettyPrint;
}
/**
* Multiply tensor.
*
* @param d the d
* @return the tensor
*/
@Nonnull
public Tensor multiply(final double d) {
@Nonnull final Tensor tensor = new Tensor(getDimensions());
@Nullable final double[] resultData = tensor.getData();
@Nullable final double[] thisData = getData();
for (int i = 0; i < thisData.length; i++) {
resultData[i] = d * thisData[i];
}
return tensor;
}
/**
* Rms double.
*
* @return the double
*/
public double rms() {
return Math.sqrt(sumSq() / length());
}
/**
* Reduce parallel tensor.
*
* @param right the right
* @param f the f
* @return the tensor
*/
@Nullable
public Tensor reduceParallel(@Nonnull final Tensor right, @Nonnull final DoubleBinaryOperator f) {
if (!Arrays.equals(right.getDimensions(), getDimensions())) {
throw new IllegalArgumentException(Arrays.toString(right.getDimensions()) + " != " + Arrays.toString(getDimensions()));
}
@Nullable final double[] dataL = getData();
@Nullable final double[] dataR = right.getData();
return new Tensor(Tensor.getDoubles(IntStream.range(0, length()).mapToDouble(i -> f.applyAsDouble(dataL[i], dataR[i])), length()), dimensions);
}
/**
* Round tensor.
*
* @param precision the precision
* @return the tensor
*/
@Nullable
public Tensor round(final int precision) {
if (precision > 8) return this;
if (precision < 1) throw new IllegalArgumentException();
return round(precision, 10);
}
/**
* Round tensor.
*
* @param precision the precision
* @param base the base
* @return the tensor
*/
@Nullable
public Tensor round(final int precision, final int base) {
return map(v -> {
final double units = Math.pow(base, Math.ceil(Math.log(v) / Math.log(base)) - precision);
return Math.round(v / units) * units;
});
}
/**
* Scale tensor.
*
* @param d the d
* @return the tensor
*/
@Nullable
public Tensor scale(final double d) {
return map(v -> v * d);
}
/**
* Scale tensor.
*
* @param d the d
* @return the tensor
*/
@Nonnull
public Tensor scaleInPlace(final double d) {
@Nullable final double[] data = getData();
for (int i = 0; i < data.length; i++) {
data[i] *= d;
}
return this;
}
/**
* Set.
*
* @param coords the coords
* @param value the value
*/
public void set(@Nonnull final Coordinate coords, final double value) {
if (Double.isFinite(value)) set(coords.getIndex(), value);
}
/**
* Set tensor.
*
* @param data the data
* @return the tensor
*/
@Nonnull
public Tensor set(final double[] data) {
for (int i = 0; i < getData().length; i++) {
getData()[i] = data[i];
}
return this;
}
/**
* Fill tensor.
*
* @param f the f
* @return the tensor
*/
@Nonnull
public Tensor set(@Nonnull final DoubleSupplier f) {
Arrays.setAll(getData(), i -> f.getAsDouble());
return this;
}
/**
* Set.
*
* @param coord1 the coord 1
* @param coord2 the coord 2
* @param value the value
*/
public void set(final int coord1, final int coord2, final double value) {
assert Double.isFinite(value);
set(index(coord1, coord2), value);
}
/**
* Set.
*
* @param coord1 the coord 1
* @param coord2 the coord 2
* @param coord3 the coord 3
* @param value the value
*/
public void set(final int coord1, final int coord2, final int coord3, final double value) {
assert Double.isFinite(value);
set(index(coord1, coord2, coord3), value);
}
/**
* Set.
*
* @param coord1 the coord 1
* @param coord2 the coord 2
* @param coord3 the coord 3
* @param coord4 the coord 4
* @param value the value
*/
public void set(final int coord1, final int coord2, final int coord3, final int coord4, final double value) {
assert Double.isFinite(value);
set(index(coord1, coord2, coord3, coord4), value);
}
/**
* Set tensor.
*
* @param index the index
* @param value the value
* @return the tensor
*/
@Nonnull
public Tensor set(final int index, final double value) {
// assert Double.isFinite(value);
getData()[index] = value;
return this;
}
/**
* Set.
*
* @param coords the coords
* @param value the value
*/
public void set(@Nonnull final int[] coords, final double value) {
assert Double.isFinite(value);
set(index(coords), value);
}
/**
* Set tensor.
*
* @param f the f
* @return the tensor
*/
@Nonnull
public Tensor set(@Nonnull final IntToDoubleFunction f) {
Arrays.parallelSetAll(getData(), f);
return this;
}
/**
* Set.
*
* @param right the right
* @return the tensor
*/
public Tensor set(@Nonnull final Tensor right) {
assertAlive();
assert length() == right.length();
@Nullable final double[] rightData = right.getData();
Arrays.parallelSetAll(getData(), i -> rightData[i]);
return this;
}
/**
* Sets all.
*
* @param v the v
* @return the all
*/
public Tensor setAll(final double v) {
@Nullable final double[] data = getData();
for (int i = 0; i < data.length; i++) {
data[i] = v;
}
return this;
}
/**
* Fill by coord tensor.
*
* @param f the f
* @return the tensor
*/
@Nonnull
public Tensor setByCoord(@Nonnull final ToDoubleFunction f) {
return setByCoord(f, true);
}
/**
* Fill by coord tensor.
*
* @param f the f
* @param parallel the parallel
* @return the tensor
*/
@Nonnull
public Tensor setByCoord(@Nonnull final ToDoubleFunction f, boolean parallel) {
coordStream(parallel).forEach(c -> set(c, f.applyAsDouble(c)));
return this;
}
/**
* Sum double.
*
* @return the double
*/
public double sum() {
double v = 0;
for (final double element : getData()) {
v += element;
}
// assert Double.isFinite(v);
return v;
}
/**
* Sum sq double.
*
* @return the double
*/
public double sumSq() {
double v = 0;
for (final double element : getData()) {
v += element * element;
}
// assert Double.isFinite(v);
return v;
}
/**
* Sets parallel by index.
*
* @param f the f
*/
public void setParallelByIndex(@Nonnull final IntToDoubleFunction f) {
IntStream.range(0, length()).parallel().forEach(c -> set(c, f.applyAsDouble(c)));
}
/**
* To gray png buffered png.
*
* @return the buffered png
*/
@Nonnull
public BufferedImage toGrayImage() {
return toGrayImage(0);
}
/**
* To gray png buffered png.
*
* @param band the band
* @return the buffered png
*/
@Nonnull
public BufferedImage toGrayImage(final int band) {
final int width = getDimensions()[0];
final int height = getDimensions()[1];
@Nonnull final BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
for (int x = 0; x < width; x++) {
for (int y = 0; y < height; y++) {
final double v = get(x, y, band);
image.getRaster().setSample(x, y, 0, v < 0 ? 0 : v > 255 ? 255 : v);
}
}
return image;
}
/**
* To png buffered png.
*
* @return the buffered png
*/
@Nonnull
public BufferedImage toImage() {
@Nonnull final int[] dims = getDimensions();
if (3 == dims.length) {
if (3 == dims[2]) {
return toRgbImage();
} else {
assert 1 == dims[2];
return toGrayImage();
}
} else {
assert 2 == dims.length;
return toGrayImage();
}
}
/**
* To images list.
*
* @return the list
*/
@Nonnull
public List toImages() {
@Nonnull final int[] dims = getDimensions();
if (3 == dims.length) {
if (3 == dims[2]) {
return Arrays.asList(toRgbImage());
} else if (0 == dims[2] % 3) {
@Nonnull final ArrayList list = new ArrayList<>();
for (int i = 0; i < dims[2]; i += 3) {
list.add(toRgbImage(i, i + 1, i + 2));
}
return list;
} else if (1 == dims[2]) {
return Arrays.asList(toGrayImage());
} else {
@Nonnull final ArrayList list = new ArrayList<>();
for (int i = 0; i < dims[2]; i++) {
list.add(toGrayImage(i));
}
return list;
}
} else {
assert 2 == dims.length : "order: " + dims.length;
return Arrays.asList(toGrayImage());
}
}
/**
* To json json element.
*
* @param resources the resources
* @param dataSerializer the data serializer
* @return the json element
*/
@Nonnull
public JsonElement toJson(@Nullable Map resources, @Nonnull DataSerializer dataSerializer) {
if (length() > 1024) {
@Nonnull JsonObject obj = new JsonObject();
@Nonnull int[] dimensions = getDimensions();
obj.add("length", toJsonArray(dimensions));
if(null != id) obj.addProperty("id", id.toString());
@Nonnull byte[] bytes = getBytes(dataSerializer);
obj.addProperty("precision", ((SerialPrecision) dataSerializer).name());
if (null != resources) {
@Nonnull String id = UUID.randomUUID().toString();
obj.addProperty("resource", id);
resources.put(id, bytes);
} else {
obj.addProperty("base64", Base64.getEncoder().encodeToString(bytes));
}
return obj;
} else {
return toJson(new int[]{});
}
}
/**
* Sets bytes.
*
* @param bytes the bytes
* @return the bytes
*/
@Nonnull
public Tensor setBytes(byte[] bytes) {
return setBytes(bytes, json_precision);
}
/**
* Get bytes byte [ ].
*
* @param precision the precision
* @return the byte [ ]
*/
@Nonnull
public byte[] getBytes(@Nonnull DataSerializer precision) {
return precision.toBytes(getData());
}
/**
* Sets bytes.
*
* @param bytes the bytes
* @param precision the precision
* @return the bytes
*/
@Nonnull
public Tensor setBytes(byte[] bytes, @Nonnull DataSerializer precision) {
precision.copy(bytes, getData());
return this;
}
@Nonnull
private JsonElement toJson(@Nonnull final int[] coords) {
if (coords.length == dimensions.length) {
final double d = get(coords);
return new JsonPrimitive(d);
} else {
@Nonnull final JsonArray jsonArray = new JsonArray();
IntStream.range(0, dimensions[dimensions.length - (coords.length + 1)]).mapToObj(i -> {
@Nonnull final int[] newCoord = new int[coords.length + 1];
System.arraycopy(coords, 0, newCoord, 1, coords.length);
newCoord[0] = i;
return toJson(newCoord);
}).forEach(l -> jsonArray.add(l));
return jsonArray;
}
}
/**
* To rgb png buffered png.
*
* @return the buffered png
*/
@Nonnull
public BufferedImage toRgbImage() {
return toRgbImage(0, 1, 2);
}
/**
* To rgb png buffered png.
*
* @param redBand the red band
* @param greenBand the green band
* @param blueBand the blue band
* @return the buffered png
*/
@Nonnull
public BufferedImage toRgbImage(final int redBand, final int greenBand, final int blueBand) {
assertAlive();
@Nonnull final int[] dims = getDimensions();
@Nonnull final BufferedImage img = new BufferedImage(dims[0], dims[1], BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < img.getWidth(); x++) {
for (int y = 0; y < img.getHeight(); y++) {
if (getDimensions()[2] == 1) {
final double value = this.get(x, y, 0);
img.setRGB(x, y, Tensor.bound8bit((int) value) * 0x010101);
} else {
final double red = Tensor.bound8bit(this.get(x, y, redBand));
final double green = Tensor.bound8bit(this.get(x, y, greenBand));
final double blue = Tensor.bound8bit(this.get(x, y, blueBand));
img.setRGB(x, y, (int) (red + ((int) green << 8) + ((int) blue << 16)));
}
}
}
return img;
}
/**
* To rgb png buffered png.
*
* @param redBand the red band
* @param greenBand the green band
* @param blueBand the blue band
* @param alphaMask the alphaList mask
* @return the buffered png
*/
@Nonnull
public BufferedImage toRgbImageAlphaMask(final int redBand, final int greenBand, final int blueBand, Tensor alphaMask) {
assert alphaMask.getDimensions()[0] == getDimensions()[0];
assert alphaMask.getDimensions()[1] == getDimensions()[1];
@Nonnull final int[] dims = getDimensions();
@Nonnull final BufferedImage img = new BufferedImage(dims[0], dims[1], BufferedImage.TYPE_INT_ARGB);
for (int x = 0; x < img.getWidth(); x++) {
for (int y = 0; y < img.getHeight(); y++) {
final double red = Tensor.bound8bit(this.get(x, y, redBand));
final double green = Tensor.bound8bit(this.get(x, y, greenBand));
final double blue = Tensor.bound8bit(this.get(x, y, blueBand));
final double alpha = Tensor.bound8bit(alphaMask.get(x, y, 0));
img.setRGB(x, y, (int) (red + ((int) green << 8) + ((int) blue << 16) + ((int) alpha << 24)));
}
}
return img;
}
@Nonnull
@Override
public String toString() {
return (null == data ? "0" : Integer.toHexString(System.identityHashCode(data))) + "@" + toString(false);
}
private String toString(final boolean prettyPrint, @Nonnull final int... coords) {
if (coords.length == dimensions.length) {
return Double.toString(get(coords));
} else {
List list = IntStream.range(0, dimensions[coords.length]).mapToObj(i -> {
@Nonnull final int[] newCoord = Arrays.copyOf(coords, coords.length + 1);
newCoord[coords.length] = i;
return toString(prettyPrint, newCoord);
}).limit(15).collect(Collectors.toList());
if (list.size() > 10) {
list = list.subList(0, 8);
list.add("...");
}
if (prettyPrint) {
if (coords.length < dimensions.length - 2) {
final CharSequence str = list.stream().limit(10)
.map(s -> "\t" + s.toString().replaceAll("\n", "\n\t"))
.reduce((a, b) -> a + ",\n" + b).orElse("");
return "[\n" + str + "\n]";
} else {
final CharSequence str = list.stream().reduce((a, b) -> a + ", " + b).orElse("");
return "[ " + str + " ]";
}
} else {
final CharSequence str = list.stream().reduce((a, b) -> a + "," + b).orElse("");
return "[ " + str + " ]";
}
}
}
/**
* Reverse dimensions tensor.
*
* @return the tensor
*/
@Nonnull
public Tensor reverseDimensions() {
return reverseDimensions(this);
}
/**
* Permute dimensions tensor.
*
* @param key the key
* @return the tensor
*/
@Nonnull
public Tensor permuteDimensions(int... key) {
assertAlive();
int[] inputDims = getDimensions();
int[] absKey = Arrays.stream(key).map(a -> a == Integer.MAX_VALUE ? 0 : Math.abs(a)).toArray();
int[] outputDims = permute(absKey, inputDims, inputDims);
return rearrange(in -> permute(key, in, inputDims), outputDims);
}
/**
* Permute dimensions and free tensor.
*
* @param key the key
* @return the tensor
*/
public Tensor permuteDimensionsAndFree(int... key) {
Tensor result = permuteDimensions(key);
this.freeRef();
return result;
}
/**
* Reshape cast tensor.
*
* @param dims the dims
* @return the tensor
*/
@Nullable
public Tensor reshapeCast(@Nonnull int... dims) {
if (0 == dims.length) throw new IllegalArgumentException();
if (length(dims) != length()) throw new IllegalArgumentException(Arrays.toString(dims) + " != " + length());
double[] data = getData();
return new Tensor(dims, null == data ? null : RecycleBin.DOUBLES.copyOf(data, data.length));
}
/**
* Reshape cast and free tensor.
*
* @param dims the dims
* @return the tensor
*/
@Nullable
public Tensor reshapeCastAndFree(@Nonnull int... dims) {
Tensor tensor = reshapeCast(dims);
freeRef();
return tensor;
}
/**
* For each.
*
* @param fn the fn
* @param parallel the parallel
*/
public void forEach(@Nonnull CoordOperator fn, boolean parallel) {
coordStream(parallel).forEach(c -> {
fn.eval(get(c), c);
});
}
/**
* Dot double.
*
* @param right the right
* @return the double
*/
public double dot(final Tensor right) {
double[] l = getData();
double[] r = right.getData();
double v = 0;
for (int i = 0; i < l.length; i++) {
v += l[i] * r[i];
}
return v;
}
/**
* Unit tensor.
*
* @return the tensor
*/
public Tensor unit() {
return scale(1.0 / Math.sqrt(sumSq()));
}
/**
* Select band tensor.
*
* @param band the band
* @return the tensor
*/
public Tensor selectBand(final int band) {
assert band >= 0;
int[] dimensions = getDimensions();
assert 3 == dimensions.length;
assert band < dimensions[2];
return new Tensor(dimensions[0], dimensions[1], 1).setByCoord(c -> {
int[] coords = c.getCoords();
return get(coords[0], coords[1], band);
});
}
/**
* To image and free buffered image.
*
* @return the buffered image
*/
public BufferedImage toImageAndFree() {
BufferedImage image = toImage();
freeRef();
return image;
}
/**
* Copy and free tensor.
*
* @return the tensor
*/
public Tensor copyAndFree() {
if (currentRefCount() == 1) return this;
Tensor copy = copy();
freeRef();
return copy;
}
// /**
// * Resize as img tensor.
// *
// * @param width the width
// * @param height the height
// * @return the tensor
// */
// public Tensor resizeAsImg(final int width, final int height) {
// if (getDimensions()[0] == width && getDimensions()[1] == height) {
// addRef();
// return this;
// }
// return Tensor.fromRGB(TestUtil.resize(toImage(), width, height));
// }
@Nullable
public UUID getId() {
if(id == null) {
synchronized (this) {
if(id == null) {
id = UUID.randomUUID();
}
}
}
return id;
}
public Tensor setId(@Nullable UUID id) {
this.id = id;
return this;
}
/**
* The interface Coord operator.
*/
public interface CoordOperator {
/**
* Eval double.
*
* @param value the value
* @param index the index
*/
void eval(double value, Coordinate index);
}
/**
* The interface Tuple operator.
*/
public interface TupleOperator {
/**
* Eval double.
*
* @param value the value
* @param index the index
* @return the double
*/
double eval(double value, int index);
}
}