
org.datavec.python.PythonObject Maven / Gradle / Ivy
The newest version!
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.*;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.numpy.global.numpy.*;
/**
* Swift like python wrapper for J
*
* @author Fariz Rahman
*/
@Slf4j
public class PythonObject {
private PyObject nativePythonObject;
static {
new PythonExecutioner();
}
private static Map _getNDArraySerializer() {
Map ndarraySerializer = new HashMap<>();
PythonObject lambda = Python.eval(
"lambda x: " +
"{'address':" +
"x.__array_interface__['data'][0]," +
"'shape':x.shape,'strides':x.strides," +
"'dtype': str(x.dtype),'_is_numpy_array': True}" +
" if str(type(x))== \"\" else x");
ndarraySerializer.put("default",
lambda);
return ndarraySerializer;
}
public PythonObject(PyObject pyObject) {
nativePythonObject = pyObject;
}
public PythonObject(INDArray npArray) {
this(new NumpyArray(npArray));
}
public PythonObject(BytePointer bp){
long address = bp.address();
long size = bp.capacity();
NumpyArray npArr = NumpyArray.builder().address(address).shape(new long[]{size}).strides(new long[]{1}).dtype(DataType.INT8).build();
nativePythonObject = Python.memoryview(new PythonObject(npArr)).nativePythonObject;
}
public PythonObject(NumpyArray npArray) {
int numpyType;
INDArray indArray = npArray.getNd4jArray();
DataType dataType = indArray.dataType();
switch (dataType) {
case DOUBLE:
numpyType = NPY_DOUBLE;
break;
case FLOAT:
case BFLOAT16:
numpyType = NPY_FLOAT;
break;
case SHORT:
numpyType = NPY_SHORT;
break;
case INT:
numpyType = NPY_INT;
break;
case LONG:
numpyType = NPY_INT64;
break;
case UINT16:
numpyType = NPY_USHORT;
break;
case UINT32:
numpyType = NPY_UINT;
break;
case UINT64:
numpyType = NPY_UINT64;
break;
case BOOL:
numpyType = NPY_BOOL;
break;
case BYTE:
numpyType = NPY_BYTE;
break;
case UBYTE:
numpyType = NPY_UBYTE;
break;
case HALF:
numpyType = NPY_HALF;
break;
default:
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
}
long[] shape = indArray.shape();
INDArray inputArray = indArray;
if(dataType == DataType.BFLOAT16) {
log.warn("\n\nThe given nd4j array \n\n{}\n\n is of BFLOAT16 datatype. " +
"Casting a copy of it to FLOAT and creating the respective numpy array from it.\n", indArray);
inputArray = indArray.castTo(DataType.FLOAT);
}
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
if(inputArray.data() instanceof BaseDataBuffer){
((BaseDataBuffer)inputArray.data()).syncToPrimary();
}
nativePythonObject = PyArray_New(PyArray_Type(), shape.length, new SizeTPointer(shape),
numpyType, null,
inputArray.data().addressPointer(),
0, NPY_ARRAY_CARRAY, null);
}
/*---primitve constructors---*/
public PyObject getNativePythonObject() {
return nativePythonObject;
}
public PythonObject(String data) {
nativePythonObject = PyUnicode_FromString(data);
}
public PythonObject(int data) {
nativePythonObject = PyLong_FromLong((long) data);
}
public PythonObject(long data) {
nativePythonObject = PyLong_FromLong(data);
}
public PythonObject(double data) {
nativePythonObject = PyFloat_FromDouble(data);
}
public PythonObject(boolean data) {
nativePythonObject = PyBool_FromLong(data ? 1 : 0);
}
private static PythonObject j2pyObject(Object item) {
if (item instanceof PythonObject) {
return (PythonObject) item;
} else if (item instanceof PyObject) {
return new PythonObject((PyObject) item);
} else if (item instanceof INDArray) {
return new PythonObject((INDArray) item);
} else if (item instanceof NumpyArray) {
return new PythonObject((NumpyArray) item);
} else if (item instanceof List) {
return new PythonObject((List) item);
} else if (item instanceof Object[]) {
return new PythonObject((Object[]) item);
} else if (item instanceof Map) {
return new PythonObject((Map) item);
} else if (item instanceof String) {
return new PythonObject((String) item);
} else if (item instanceof Double) {
return new PythonObject((Double) item);
} else if (item instanceof Float) {
return new PythonObject((Float) item);
} else if (item instanceof Long) {
return new PythonObject((Long) item);
} else if (item instanceof Integer) {
return new PythonObject((Integer) item);
} else if (item instanceof Boolean) {
return new PythonObject((Boolean) item);
} else if (item instanceof Pointer){
return new PythonObject(new BytePointer((Pointer)item));
} else {
throw new RuntimeException("Unsupported item in list: " + item);
}
}
public PythonObject(Object[] data) {
PyObject pyList = PyList_New((long) data.length);
for (int i = 0; i < data.length; i++) {
PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject);
}
nativePythonObject = pyList;
}
public PythonObject(List data) {
PyObject pyList = PyList_New((long) data.size());
for (int i = 0; i < data.size(); i++) {
PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject);
}
nativePythonObject = pyList;
}
public PythonObject(Map data) {
PyObject pyDict = PyDict_New();
for (Object k : data.keySet()) {
PythonObject pyKey;
if (k instanceof PythonObject) {
pyKey = (PythonObject) k;
} else if (k instanceof String) {
pyKey = new PythonObject((String) k);
} else if (k instanceof Double) {
pyKey = new PythonObject((Double) k);
} else if (k instanceof Float) {
pyKey = new PythonObject((Float) k);
} else if (k instanceof Long) {
pyKey = new PythonObject((Long) k);
} else if (k instanceof Integer) {
pyKey = new PythonObject((Integer) k);
} else if (k instanceof Boolean) {
pyKey = new PythonObject((Boolean) k);
} else {
throw new RuntimeException("Unsupported key in map: " + k.getClass());
}
Object v = data.get(k);
PythonObject pyVal;
if (v instanceof PythonObject) {
pyVal = (PythonObject) v;
} else if (v instanceof PyObject) {
pyVal = new PythonObject((PyObject) v);
} else if (v instanceof INDArray) {
pyVal = new PythonObject((INDArray) v);
} else if (v instanceof NumpyArray) {
pyVal = new PythonObject((NumpyArray) v);
} else if (v instanceof Map) {
pyVal = new PythonObject((Map) v);
} else if (v instanceof List) {
pyVal = new PythonObject((List) v);
} else if (v instanceof String) {
pyVal = new PythonObject((String) v);
} else if (v instanceof Double) {
pyVal = new PythonObject((Double) v);
} else if (v instanceof Float) {
pyVal = new PythonObject((Float) v);
} else if (v instanceof Long) {
pyVal = new PythonObject((Long) v);
} else if (v instanceof Integer) {
pyVal = new PythonObject((Integer) v);
} else if (v instanceof Boolean) {
pyVal = new PythonObject((Boolean) v);
} else {
throw new RuntimeException("Unsupported value in map: " + k.getClass());
}
PyDict_SetItem(pyDict, pyKey.nativePythonObject, pyVal.nativePythonObject);
}
nativePythonObject = pyDict;
}
/*------*/
private static String pyObjectToString(PyObject pyObject) {
PyObject repr = PyObject_Str(pyObject);
PyObject str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~");
String jstr = PyBytes_AsString(str).getString();
Py_DecRef(repr);
Py_DecRef(str);
return jstr;
}
public String toString() {
return pyObjectToString(nativePythonObject);
}
public double toDouble() {
return PyFloat_AsDouble(nativePythonObject);
}
public float toFloat() {
return (float) PyFloat_AsDouble(nativePythonObject);
}
public int toInt() {
return (int) PyLong_AsLong(nativePythonObject);
}
public long toLong() {
return PyLong_AsLong(nativePythonObject);
}
public boolean toBoolean() {
if (isNone()) return false;
return toInt() != 0;
}
public NumpyArray toNumpy() throws PythonException{
PyObject np = PyImport_ImportModule("numpy");
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
if (PyObject_IsInstance(nativePythonObject, ndarray) != 1){
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
}
Py_DecRef(ndarray);
Py_DecRef(np);
Pointer objPtr = new Pointer(nativePythonObject);
PyArrayObject npArr = new PyArrayObject(objPtr);
Pointer ptr = PyArray_DATA(npArr);
long[] shape = new long[PyArray_NDIM(npArr)];
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
if (shapePtr != null)
shapePtr.get(shape, 0, shape.length);
long[] strides = new long[shape.length];
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
if (stridesPtr != null)
stridesPtr.get(strides, 0, strides.length);
int npdtype = PyArray_TYPE(npArr);
DataType dtype;
switch (npdtype){
case NPY_DOUBLE:
dtype = DataType.DOUBLE; break;
case NPY_FLOAT:
dtype = DataType.FLOAT; break;
case NPY_SHORT:
dtype = DataType.SHORT; break;
case NPY_INT:
dtype = DataType.INT32; break;
case NPY_LONG:
dtype = DataType.LONG; break;
case NPY_UINT:
dtype = DataType.UINT32; break;
case NPY_BYTE:
dtype = DataType.INT8; break;
case NPY_UBYTE:
dtype = DataType.UINT8; break;
case NPY_BOOL:
dtype = DataType.BOOL; break;
case NPY_HALF:
dtype = DataType.FLOAT16; break;
case NPY_LONGLONG:
dtype = DataType.INT64; break;
case NPY_USHORT:
dtype = DataType.UINT16; break;
case NPY_ULONG:
case NPY_ULONGLONG:
dtype = DataType.UINT64; break;
default:
throw new PythonException("Unsupported array data type: " + npdtype);
}
return new NumpyArray(ptr.address(), shape, strides, dtype);
}
public PythonObject attr(String attr) {
return new PythonObject(PyObject_GetAttrString(nativePythonObject, attr));
}
public PythonObject call(Object... args) {
if (args.length > 0 && args[args.length - 1] instanceof Map) {
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy