All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.deephaven.engine.util.PyCallableWrapperJpyImpl Maven / Gradle / Ivy

There is a newer version: 0.37.1
Show newest version
package io.deephaven.engine.util;

import io.deephaven.engine.table.impl.select.python.ArgumentsChunked;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import org.jpy.PyModule;
import org.jpy.PyObject;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs
 * method, so that we can call it using __call__ without all of the JPy nastiness.
 */
public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
    private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class);

    private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType();
    private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType();

    private static final PyModule dh_udf_module = PyModule.importModule("deephaven._udf");

    private static final Map> numpyType2JavaClass = new HashMap<>();

    static {
        numpyType2JavaClass.put('b', byte.class);
        numpyType2JavaClass.put('h', short.class);
        numpyType2JavaClass.put('H', char.class);
        numpyType2JavaClass.put('i', int.class);
        numpyType2JavaClass.put('l', long.class);
        numpyType2JavaClass.put('f', float.class);
        numpyType2JavaClass.put('d', double.class);
        numpyType2JavaClass.put('?', boolean.class);
        numpyType2JavaClass.put('U', String.class);
        numpyType2JavaClass.put('M', Instant.class);
        numpyType2JavaClass.put('O', Object.class);
    }

    /**
     * Ensure that the class initializer runs.
     */
    public static void init() {}

    // TODO: support for vectorizing functions that return arrays
    // https://github.com/deephaven/deephaven-core/issues/4649
    private static final Set> vectorizableReturnTypes = Set.of(
            boolean.class, boolean[].class,
            Boolean.class, Boolean[].class,
            byte.class, byte[].class,
            short.class, short[].class,
            char.class, char[].class,
            int.class, int[].class,
            long.class, long[].class,
            float.class, float[].class,
            double.class, double[].class,
            String.class, String[].class,
            Instant.class, Instant[].class,
            PyObject.class, PyObject[].class,
            Object.class, Object[].class);

    @Override
    public boolean isVectorizableReturnType() {
        parseSignature();
        return vectorizableReturnTypes.contains(returnType);
    }

    private final PyObject pyCallable;

    private String signature = null;
    private List> paramTypes;
    private Class returnType;
    private boolean vectorizable = false;
    private boolean vectorized = false;
    private Collection chunkArguments;
    private boolean numbaVectorized;
    private PyObject unwrapped;
    private PyObject pyUdfDecoratedCallable;

    public PyCallableWrapperJpyImpl(PyObject pyCallable) {
        this.pyCallable = pyCallable;
    }

    @Override
    public PyObject getAttribute(String name) {
        return this.pyCallable.getAttribute(name);
    }

    @Override
    public  T getAttribute(String name, Class valueType) {
        return this.pyCallable.getAttribute(name, valueType);
    }

    public ArgumentsChunked buildArgumentsChunked(List columnNames) {
        for (ChunkArgument arg : chunkArguments) {
            if (arg instanceof ColumnChunkArgument) {
                String columnName = ((ColumnChunkArgument) arg).getColumnName();
                int chunkSourceIndex = columnNames.indexOf(columnName);
                if (chunkSourceIndex < 0) {
                    throw new IllegalArgumentException("Column source not found: " + columnName);
                }
                ((ColumnChunkArgument) arg).setSourceChunkIndex(chunkSourceIndex);
            }
        }
        return new ArgumentsChunked(chunkArguments, returnType, numbaVectorized);
    }

    /**
     * This assumes that the Python interpreter won't be re-initialized during a session, if this turns out to be a
     * false assumption, then we'll need to make this initialization code 'python restart' proof.
     */
    private static PyObject getNumbaVectorizedFuncType() {
        try {
            return PyModule.importModule("numba.np.ufunc.dufunc").getAttribute("DUFunc");
        } catch (Exception e) {
            if (log.isDebugEnabled()) {
                log.debug("Numba isn't installed in the Python environment.");
            }
            return null;
        }
    }

    private static PyObject getNumbaGUVectorizedFuncType() {
        try {
            return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc");
        } catch (Exception e) {
            if (log.isDebugEnabled()) {
                log.debug("Numba isn't installed in the Python environment.");
            }
            return null;
        }
    }

    private void prepareSignature() {
        boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE);
        boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE);
        if (isNumbaGUVectorized || isNumbaVectorized) {
            List params = pyCallable.getAttribute("types").asList();
            if (params.isEmpty()) {
                throw new IllegalArgumentException(
                        "numba vectorized/guvectorized function must have an explicit signature: " + pyCallable);
            }
            // numba allows a vectorized function to have multiple signatures
            if (params.size() > 1) {
                throw new UnsupportedOperationException(
                        pyCallable
                                + " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions");
            }
            unwrapped = pyCallable;
            // since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized
            numbaVectorized = isNumbaVectorized;
            vectorized = isNumbaVectorized;
        } else if (pyCallable.hasAttribute("dh_vectorized")) {
            unwrapped = pyCallable.getAttribute("callable");
            numbaVectorized = false;
            vectorized = true;
        } else {
            unwrapped = pyCallable;
            numbaVectorized = false;
            vectorized = false;
        }
        pyUdfDecoratedCallable = dh_udf_module.call("_py_udf", unwrapped);
        signature = pyUdfDecoratedCallable.getAttribute("signature").toString();
    }

    @Override
    public void parseSignature() {
        if (signature != null) {
            return;
        }

        prepareSignature();

        // the 'types' field of a vectorized function follows the pattern of '[ilhfdb?O]*->[ilhfdb?O]',
        // eg. [ll->d] defines two int64 (long) arguments and a double return type.
        if (signature == null || signature.isEmpty()) {
            throw new IllegalStateException("Signature should always be available.");
        }

        List> paramTypes = new ArrayList<>();
        for (char numpyTypeChar : signature.toCharArray()) {
            if (numpyTypeChar != '-') {
                Class paramType = numpyType2JavaClass.get(numpyTypeChar);
                if (paramType == null) {
                    throw new IllegalStateException(
                            "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: "
                                    + numpyTypeChar + " of " + signature);
                }
                paramTypes.add(paramType);
            } else {
                break;
            }
        }

        this.paramTypes = paramTypes;

        returnType = pyUdfDecoratedCallable.getAttribute("return_type", null);
        if (returnType == null) {
            throw new IllegalStateException(
                    "Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type");
        }

        if (returnType == boolean.class) {
            this.returnType = Boolean.class;
        }
    }

    // In vectorized mode, we want to call the vectorized function directly.
    public PyObject vectorizedCallable() {
        if (numbaVectorized || vectorized) {
            return pyCallable;
        } else {
            return dh_udf_module.call("_dh_vectorize", unwrapped);
        }
    }

    // In non-vectorized mode, we want to call the udf decorated function or the original function.
    @Override
    public Object call(Object... args) {
        PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable;
        return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args));
    }

    @Override
    public List> getParamTypes() {
        return paramTypes;
    }

    @Override
    public boolean isVectorized() {
        return vectorized;
    }

    @Override
    public boolean isVectorizable() {
        return vectorizable;
    }

    @Override
    public void setVectorizable(boolean vectorizable) {
        this.vectorizable = vectorizable;
    }

    @Override
    public void initializeChunkArguments() {
        this.chunkArguments = new ArrayList<>();
    }

    @Override
    public void addChunkArgument(ChunkArgument chunkArgument) {
        this.chunkArguments.add(chunkArgument);
    }

    @Override
    public Class getReturnType() {
        return returnType;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy