org.bytedeco.javacpp.helper.tensorflow Maven / Gradle / Ivy
The newest version!
/*
* Copyright (C) 2015-2018 Samuel Audet
*
* Licensed either under the Apache License, Version 2.0, or (at your option)
* under the terms of the GNU General Public License as published by
* the Free Software Foundation (subject to the "Classpath" exception),
* either version 2, or any later version (collectively, the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.gnu.org/licenses/
* http://www.gnu.org/software/classpath/license.html
*
* or as provided in the LICENSE.txt file that accompanied this code.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.bytedeco.javacpp.helper;
import java.nio.ByteBuffer;
import java.nio.Buffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.annotation.ByRef;
import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.annotation.Index;
import org.bytedeco.javacpp.annotation.Name;
import org.bytedeco.javacpp.annotation.Namespace;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexable;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
// required by javac to resolve circular dependencies
import org.bytedeco.javacpp.tensorflow.*;
import static org.bytedeco.javacpp.tensorflow.DT_BFLOAT16;
import static org.bytedeco.javacpp.tensorflow.DT_BOOL;
import static org.bytedeco.javacpp.tensorflow.DT_COMPLEX64;
import static org.bytedeco.javacpp.tensorflow.DT_DOUBLE;
import static org.bytedeco.javacpp.tensorflow.DT_FLOAT;
import static org.bytedeco.javacpp.tensorflow.DT_INT16;
import static org.bytedeco.javacpp.tensorflow.DT_INT32;
import static org.bytedeco.javacpp.tensorflow.DT_INT64;
import static org.bytedeco.javacpp.tensorflow.DT_INT8;
import static org.bytedeco.javacpp.tensorflow.DT_QINT32;
import static org.bytedeco.javacpp.tensorflow.DT_QINT8;
import static org.bytedeco.javacpp.tensorflow.DT_QUINT8;
import static org.bytedeco.javacpp.tensorflow.DT_STRING;
import static org.bytedeco.javacpp.tensorflow.DT_UINT8;
import static org.bytedeco.javacpp.tensorflow.NewSession;
import static org.bytedeco.javacpp.tensorflow.TF_AllocateTensor;
import static org.bytedeco.javacpp.tensorflow.TF_Buffer;
import static org.bytedeco.javacpp.tensorflow.TF_Graph;
import static org.bytedeco.javacpp.tensorflow.TF_ImportGraphDefOptions;
import static org.bytedeco.javacpp.tensorflow.TF_Session;
import static org.bytedeco.javacpp.tensorflow.TF_SessionOptions;
import static org.bytedeco.javacpp.tensorflow.TF_Status;
import static org.bytedeco.javacpp.tensorflow.TF_Tensor;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteBuffer;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteGraph;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteImportGraphDefOptions;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteSession;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteSessionOptions;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteStatus;
import static org.bytedeco.javacpp.tensorflow.TF_DeleteTensor;
import static org.bytedeco.javacpp.tensorflow.TF_NewBuffer;
import static org.bytedeco.javacpp.tensorflow.TF_NewGraph;
import static org.bytedeco.javacpp.tensorflow.TF_NewImportGraphDefOptions;
import static org.bytedeco.javacpp.tensorflow.TF_NewBufferFromString;
import static org.bytedeco.javacpp.tensorflow.TF_NewSession;
import static org.bytedeco.javacpp.tensorflow.TF_NewSessionOptions;
import static org.bytedeco.javacpp.tensorflow.TF_NewStatus;
import static org.bytedeco.javacpp.tensorflow.TF_NewTensor;
/**
*
* @author Samuel Audet
*/
public class tensorflow extends org.bytedeco.javacpp.presets.tensorflow {
public static abstract class AbstractTF_Status extends Pointer {
protected static class DeleteDeallocator extends TF_Status implements Pointer.Deallocator {
DeleteDeallocator(TF_Status s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteStatus(this); setNull(); }
}
public AbstractTF_Status(Pointer p) { super(p); }
/**
* Calls TF_NewStatus(), and registers a deallocator.
* @return TF_Status created. Do not call TF_DeleteStatus() on it.
*/
public static TF_Status newStatus() {
TF_Status s = TF_NewStatus();
if (s != null) {
s.deallocator(new DeleteDeallocator(s));
}
return s;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
public static abstract class AbstractTF_Buffer extends Pointer {
protected static class DeleteDeallocator extends TF_Buffer implements Pointer.Deallocator {
DeleteDeallocator(TF_Buffer s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteBuffer(this); setNull(); }
}
public AbstractTF_Buffer(Pointer p) { super(p); }
/**
* Calls TF_NewBuffer(), and registers a deallocator.
* @return TF_Buffer created. Do not call TF_DeleteBuffer() on it.
*/
public static TF_Buffer newBuffer() {
TF_Buffer b = TF_NewBuffer();
if (b != null) {
b.deallocator(new DeleteDeallocator(b));
}
return b;
}
/** Returns {@code newBufferFromString(new BytePointer(proto)). */
public static TF_Buffer newBufferFromString(byte[] proto) {
return newBufferFromString(new BytePointer(proto));
}
/**
* Calls TF_NewBufferFromString(), and registers a deallocator.
* @return TF_Buffer created. Do not call TF_DeleteBuffer() on it.
*/
public static TF_Buffer newBufferFromString(Pointer proto) {
TF_Buffer b = TF_NewBufferFromString(proto, proto.limit());
if (b != null) {
b.deallocator(new DeleteDeallocator(b));
}
return b;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
public static abstract class AbstractTF_Tensor extends Pointer {
protected static class DeleteDeallocator extends TF_Tensor implements Pointer.Deallocator {
DeleteDeallocator(TF_Tensor s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteTensor(this); setNull(); }
}
/** TensorFlow crashes if we don't pass it a deallocator, so... */
protected static Deallocator_Pointer_long_Pointer dummyDeallocator = new Deallocator_Pointer_long_Pointer() {
@Override public void call(Pointer data, long len, Pointer arg) { }
};
static {
PointerScope s = PointerScope.getInnerScope();
if (s != null) {
s.detach(dummyDeallocator);
}
}
/** A reference to prevent deallocation. */
protected Pointer pointer;
public AbstractTF_Tensor(Pointer p) { super(p); }
/**
* Calls TF_NewTensor(), and registers a deallocator.
* @return TF_Tensor created. Do not call TF_DeleteTensor() on it.
*/
public static TF_Tensor newTensor(int dtype, long[] dims, Pointer data) {
TF_Tensor t = TF_NewTensor(dtype, dims, dims.length, data, data.limit(), dummyDeallocator, null);
if (t != null) {
t.pointer = data;
t.deallocator(new DeleteDeallocator(t));
}
return t;
}
/**
* Calls TF_AllocateTensor(), and registers a deallocator.
* @return TF_Tensor created. Do not call TF_DeleteTensor() on it.
*/
public static TF_Tensor allocateTensor(int dtype, long[] dims, long length) {
TF_Tensor t = TF_AllocateTensor(dtype, dims, dims.length, length);
if (t != null) {
t.deallocator(new DeleteDeallocator(t));
}
return t;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
public static abstract class AbstractTF_SessionOptions extends Pointer {
protected static class DeleteDeallocator extends TF_SessionOptions implements Pointer.Deallocator {
DeleteDeallocator(TF_SessionOptions s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteSessionOptions(this); setNull(); }
}
public AbstractTF_SessionOptions(Pointer p) { super(p); }
/**
* Calls TF_NewSessionOptions(), and registers a deallocator.
* @return TF_SessionOptions created. Do not call TF_DeleteSessionOptions() on it.
*/
public static TF_SessionOptions newSessionOptions() {
TF_SessionOptions o = TF_NewSessionOptions();
if (o != null) {
o.deallocator(new DeleteDeallocator(o));
}
return o;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
public static abstract class AbstractTF_Graph extends Pointer {
protected static class DeleteDeallocator extends TF_Graph implements Pointer.Deallocator {
DeleteDeallocator(TF_Graph s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteGraph(this); setNull(); }
}
public AbstractTF_Graph(Pointer p) { super(p); }
/**
* Calls TF_NewGraph(), and registers a deallocator.
* @return TF_Graph created. Do not call TF_DeleteGraph() on it.
*/
public static TF_Graph newGraph() {
TF_Graph g = TF_NewGraph();
if (g != null) {
g.deallocator(new DeleteDeallocator(g));
}
return g;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
public static abstract class AbstractTF_ImportGraphDefOptions extends Pointer {
protected static class DeleteDeallocator extends TF_ImportGraphDefOptions implements Pointer.Deallocator {
DeleteDeallocator(TF_ImportGraphDefOptions s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteImportGraphDefOptions(this); setNull(); }
}
public AbstractTF_ImportGraphDefOptions(Pointer p) { super(p); }
/**
* Calls TF_NewImportGraphDefOptions(), and registers a deallocator.
* @return TF_ImportGraphDefOptions created. Do not call TF_DeleteImportGraphDefOptions() on it.
*/
public static TF_ImportGraphDefOptions newImportGraphDefOptions() {
TF_ImportGraphDefOptions o = TF_NewImportGraphDefOptions();
if (o != null) {
o.deallocator(new DeleteDeallocator(o));
}
return o;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
public static abstract class AbstractTF_Session extends Pointer {
protected static class DeleteDeallocator extends TF_Session implements Pointer.Deallocator {
DeleteDeallocator(TF_Session s) { super(s); }
@Override public void deallocate() { if (!isNull()) TF_DeleteSession(this, TF_Status.newStatus()); setNull(); }
}
/** References to prevent deallocation. */
protected TF_Graph graph;
protected TF_SessionOptions opts;
protected TF_Status status;
public AbstractTF_Session(Pointer p) { super(p); }
/**
* Calls TF_NewSession(), and registers a deallocator.
* @return TF_Session created. Do not call TF_DeleteSession() on it.
*/
public static TF_Session newSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status) {
TF_Session s = TF_NewSession(graph, opts, status);
if (s != null) {
s.graph = graph;
s.opts = opts;
s.status = status;
s.deallocator(new DeleteDeallocator(s));
}
return s;
}
/**
* Calls the deallocator, if registered, otherwise has no effect.
*/
public void delete() {
deallocate();
}
}
@Name("std::string") public static class StringArray extends Pointer {
static { Loader.load(); }
public StringArray(Pointer p) { super(p); }
public StringArray() { allocate(); }
private native void allocate();
public StringArray(StringArray p) { allocate(p); }
private native void allocate(@ByRef StringArray p);
public StringArray(BytePointer s, long count) { allocate(s, count); }
private native void allocate(@Cast("char*") BytePointer s, long count);
public StringArray(String s) { allocate(s); }
private native void allocate(String s);
public native @Name("operator=") @ByRef StringArray put(@ByRef StringArray str);
public native @Name("operator=") @ByRef StringArray put(String str);
@Override public StringArray position(long position) {
return (StringArray)super.position(position);
}
public native @Cast("size_t") long size();
public native void resize(@Cast("size_t") long n);
@Index public native @Cast("char") int get(@Cast("size_t") long pos);
public native StringArray put(@Cast("size_t") long pos, int c);
public native @Cast("const char*") BytePointer data();
@Override public String toString() {
long length = size();
byte[] bytes = new byte[length < Integer.MAX_VALUE ? (int)length : Integer.MAX_VALUE];
data().get(bytes);
return new String(bytes);
}
}
public static abstract class AbstractTensor extends Pointer implements Indexable {
static { Loader.load(); }
public AbstractTensor(Pointer p) { super(p); }
public static Tensor create(float[] data, TensorShape shape) { Tensor t = new Tensor(DT_FLOAT, shape); FloatBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(double[] data, TensorShape shape) { Tensor t = new Tensor(DT_DOUBLE, shape); DoubleBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(int[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT32, shape); IntBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(short[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT16, shape); ShortBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(byte[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT8, shape); ByteBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(long[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT64, shape); LongBuffer b = t.createBuffer(); b.put(data); return t; }
public static Tensor create(String[] data, TensorShape shape) {
Tensor t = new Tensor(DT_STRING, shape);
StringArray a = t.createStringArray();
for (int i = 0; i < a.capacity(); i++) {
a.position(i).put(data[i]);
}
return t;
}
public abstract int dtype();
public abstract int dims();
public abstract long dim_size(int d);
public abstract long NumElements();
public abstract long TotalBytes();
public abstract BytePointer tensor_data();
/** Returns {@code createBuffer(0)}. */
public B createBuffer() {
return (B)createBuffer(0);
}
/** Returns {@link #tensor_data()} wrapped in a {@link Buffer} of appropriate type starting at given index. */
public B createBuffer(long index) {
BytePointer ptr = tensor_data();
long size = TotalBytes();
switch (dtype()) {
case DT_COMPLEX64:
case DT_FLOAT: return (B)new FloatPointer(ptr).position(index).capacity(size/4).asBuffer();
case DT_DOUBLE: return (B)new DoublePointer(ptr).position(index).capacity(size/8).asBuffer();
case DT_QINT32:
case DT_INT32: return (B)new IntPointer(ptr).position(index).capacity(size/4).asBuffer();
case DT_BOOL:
case DT_QUINT8:
case DT_UINT8:
case DT_QINT8:
case DT_INT8: return (B)ptr.position(index).capacity(size).asBuffer();
case DT_BFLOAT16:
case DT_INT16: return (B)new ShortPointer(ptr).position(index).capacity(size/2).asBuffer();
case DT_INT64: return (B)new LongPointer(ptr).position(index).capacity(size/8).asBuffer();
case DT_STRING:
default: assert false;
}
return null;
}
/** Returns {@code createIndexer(true)}. */
public I createIndexer() {
return (I)createIndexer(true);
}
@Override public I createIndexer(boolean direct) {
BytePointer ptr = tensor_data();
int dims = dims();
long size = TotalBytes();
boolean complex = dtype() == DT_COMPLEX64;
boolean scalar = dims == 0;
dims = (complex ? 1 : 0) + (scalar ? 1 : dims);
long[] sizes = new long[dims];
long[] strides = new long[dims];
sizes[dims - 1] = complex ? 2 : (scalar ? 1 : dim_size(dims - 1));
strides[dims - 1] = 1;
for (int i = dims - 2; i >= 0; i--) {
sizes[i] = scalar ? 1 : dim_size(i);
strides[i] = sizes[i + 1] * strides[i + 1];
}
switch (dtype()) {
case DT_COMPLEX64:
case DT_FLOAT: return (I)FloatIndexer.create(new FloatPointer(ptr).capacity(size/4), sizes, strides, direct).indexable(this);
case DT_DOUBLE: return (I)DoubleIndexer.create(new DoublePointer(ptr).capacity(size/8), sizes, strides, direct).indexable(this);
case DT_QINT32:
case DT_INT32: return (I)IntIndexer.create(new IntPointer(ptr).capacity(size/4), sizes, strides, direct).indexable(this);
case DT_BOOL:
case DT_QUINT8:
case DT_UINT8: return (I)UByteIndexer.create(ptr.capacity(size), sizes, strides, direct).indexable(this);
case DT_QINT8:
case DT_INT8: return (I)ByteIndexer.create(ptr.capacity(size), sizes, strides, direct).indexable(this);
case DT_BFLOAT16: return (I)UShortIndexer.create(new ShortPointer(ptr).capacity(size/2), sizes, strides, direct).indexable(this);
case DT_INT16: return (I)ShortIndexer.create(new ShortPointer(ptr).capacity(size/2), sizes, strides, direct).indexable(this);
case DT_INT64: return (I)LongIndexer.create(new LongPointer(ptr).capacity(size/8), sizes, strides, direct).indexable(this);
case DT_STRING:
default: assert false;
}
return null;
}
/** Returns {@code new StringArray(tensor_data()).capacity(NumElements()).limit(NumElements())} when {@code dtype() == DT_STRING}. */
public StringArray createStringArray() {
if (dtype() != DT_STRING) {
return null;
}
long size = NumElements();
return new StringArray(tensor_data()).capacity(size).limit(size);
}
}
public static abstract class AbstractSession extends Pointer {
static { Loader.load(); }
SessionOptions options; // a reference to prevent deallocation
public AbstractSession(Pointer p) { super(p); }
/** Calls {@link org.bytedeco.javacpp.tensorflow#NewSession(SessionOptions)} and registers a deallocator. */
public AbstractSession(SessionOptions options) {
this.options = options;
if (NewSession(options, (Session)this).ok() && !isNull()) {
deallocator(new DeleteDeallocator((Session)this));
}
}
@Namespace public static native void delete(Session session);
protected static class DeleteDeallocator extends Session implements Pointer.Deallocator {
DeleteDeallocator(Session p) { super(p); }
@Override public void deallocate() { if (!isNull()) Session.delete(this); setNull(); }
}
}
}