All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.python.PythonTransform Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.datavec.python;

import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static org.datavec.python.PythonUtils.schemaToPythonVariables;

/**
 * Row-wise Transform that applies arbitrary python code on each row
 *
 * @author Fariz Rahman
 */

@NoArgsConstructor
@Data
public class PythonTransform implements Transform {

    private String code;
    private PythonVariables inputs;
    private PythonVariables outputs;
    private String name = UUID.randomUUID().toString();
    private Schema inputSchema;
    private Schema outputSchema;
    private String outputDict;
    private boolean returnAllVariables;
    private boolean setupAndRun = false;
    private PythonJob pythonJob;


    @Builder
    public PythonTransform(String code,
                           PythonVariables inputs,
                           PythonVariables outputs,
                           String name,
                           Schema inputSchema,
                           Schema outputSchema,
                           String outputDict,
                           boolean returnAllInputs,
                           boolean setupAndRun) {
        Preconditions.checkNotNull(code, "No code found to run!");
        this.code = code;
        this.returnAllVariables = returnAllInputs;
        this.setupAndRun = setupAndRun;
        if (inputs != null)
            this.inputs = inputs;
        if (outputs != null)
            this.outputs = outputs;
        if (name != null)
            this.name = name;
        if (outputDict != null) {
            this.outputDict = outputDict;
            this.outputs = new PythonVariables();
            this.outputs.addDict(outputDict);
        }

        try {
            if (inputSchema != null) {
                this.inputSchema = inputSchema;
                if (inputs == null || inputs.isEmpty()) {
                    this.inputs = schemaToPythonVariables(inputSchema);
                }
            }

            if (outputSchema != null) {
                this.outputSchema = outputSchema;
                if (outputs == null || outputs.isEmpty()) {
                    this.outputs = schemaToPythonVariables(outputSchema);
                }
            }
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
        try{
            pythonJob = PythonJob.builder()
                    .name("a" + UUID.randomUUID().toString().replace("-", "_"))
                    .code(code)
                    .setupRunMode(setupAndRun)
                    .build();
        }
        catch(Exception e){
            throw new IllegalStateException("Error creating python job: " + e);
        }

    }


    @Override
    public void setInputSchema(Schema inputSchema) {
        Preconditions.checkNotNull(inputSchema, "No input schema found!");
        this.inputSchema = inputSchema;
        try {
            inputs = schemaToPythonVariables(inputSchema);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (outputSchema == null && outputDict == null) {
            outputSchema = inputSchema;
        }

    }

    @Override
    public Schema getInputSchema() {
        return inputSchema;
    }

    @Override
    public List> mapSequence(List> sequence) {
        List> out = new ArrayList<>();
        for (List l : sequence) {
            out.add(map(l));
        }
        return out;
    }

    @Override
    public Object map(Object input) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Object mapSequence(Object sequence) {
        throw new UnsupportedOperationException("Not yet implemented");
    }


    @Override
    public List map(List writables) {
        PythonVariables pyInputs = getPyInputsFromWritables(writables);
        Preconditions.checkNotNull(pyInputs, "Inputs must not be null!");
        try {
            if (returnAllVariables) {
                return getWritablesFromPyOutputs(pythonJob.execAndReturnAllVariables(pyInputs));
            }

            if (outputDict != null) {
                pythonJob.exec(pyInputs, outputs);
                PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
                return getWritablesFromPyOutputs(out);
            } else {
                pythonJob.exec(pyInputs, outputs);

                return getWritablesFromPyOutputs(outputs);
            }

        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public String[] outputColumnNames() {
        return outputs.getVariables();
    }

    @Override
    public String outputColumnName() {
        return outputColumnNames()[0];
    }

    @Override
    public String[] columnNames() {
        return outputs.getVariables();
    }

    @Override
    public String columnName() {
        return columnNames()[0];
    }

    public Schema transform(Schema inputSchema) {
        return outputSchema;
    }


    private PythonVariables getPyInputsFromWritables(List writables) {
        PythonVariables ret = new PythonVariables();

        for (String name : inputs.getVariables()) {
            int colIdx = inputSchema.getIndexOfColumn(name);
            Writable w = writables.get(colIdx);
            PythonType pyType = inputs.getType(name);
            switch (pyType.getName()) {
                case INT:
                    if (w instanceof LongWritable) {
                        ret.addInt(name, ((LongWritable) w).get());
                    } else {
                        ret.addInt(name, ((IntWritable) w).get());
                    }
                    break;
                case FLOAT:
                    if (w instanceof DoubleWritable) {
                        ret.addFloat(name, ((DoubleWritable) w).get());
                    } else {
                        ret.addFloat(name, ((FloatWritable) w).get());
                    }
                    break;
                case STR:
                    ret.addStr(name, w.toString());
                    break;
                case NDARRAY:
                    ret.addNDArray(name, ((NDArrayWritable) w).get());
                    break;
                case BOOL:
                    ret.addBool(name, ((BooleanWritable) w).get());
                    break;
                default:
                    throw new RuntimeException("Unsupported input type:" + pyType);
            }

        }
        return ret;
    }

    private List getWritablesFromPyOutputs(PythonVariables pyOuts) {
        List out = new ArrayList<>();
        String[] varNames;
        varNames = pyOuts.getVariables();
        Schema.Builder schemaBuilder = new Schema.Builder();
        for (int i = 0; i < varNames.length; i++) {
            String name = varNames[i];
            PythonType pyType = pyOuts.getType(name);
            switch (pyType.getName()) {
                case INT:
                    schemaBuilder.addColumnLong(name);
                    break;
                case FLOAT:
                    schemaBuilder.addColumnDouble(name);
                    break;
                case STR:
                case DICT:
                case LIST:
                    schemaBuilder.addColumnString(name);
                    break;
                case NDARRAY:
                    INDArray arr = pyOuts.getNDArrayValue(name);
                    schemaBuilder.addColumnNDArray(name, arr.shape());
                    break;
                case BOOL:
                    schemaBuilder.addColumnBoolean(name);
                    break;
                default:
                    throw new IllegalStateException("Unable to support type " + pyType.getName());
            }
        }
        this.outputSchema = schemaBuilder.build();


        for (int i = 0; i < varNames.length; i++) {
            String name = varNames[i];
            PythonType pyType = pyOuts.getType(name);

            switch (pyType.getName()) {
                case INT:
                    out.add(new LongWritable(pyOuts.getIntValue(name)));
                    break;
                case FLOAT:
                    out.add(new DoubleWritable(pyOuts.getFloatValue(name)));
                    break;
                case STR:
                    out.add(new Text(pyOuts.getStrValue(name)));
                    break;
                case NDARRAY:
                    INDArray arr = pyOuts.getNDArrayValue(name);
                    out.add(new NDArrayWritable(arr));
                    break;
                case DICT:
                    Map dictValue = pyOuts.getDictValue(name);
                    Map noNullValues = new java.util.HashMap<>();
                    for (Map.Entry entry : dictValue.entrySet()) {
                        if (entry.getValue() != org.json.JSONObject.NULL) {
                            noNullValues.put(entry.getKey(), entry.getValue());
                        }
                    }

                    try {
                        out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(noNullValues)));
                    } catch (JsonProcessingException e) {
                        throw new IllegalStateException("Unable to serialize dictionary " + name + " to json!");
                    }
                    break;
                case LIST:
                    Object[] listValue = pyOuts.getListValue(name).toArray();
                    try {
                        out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
                    } catch (JsonProcessingException e) {
                        throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
                    }
                    break;
                case BOOL:
                    out.add(new BooleanWritable(pyOuts.getBooleanValue(name)));
                    break;
                default:
                    throw new IllegalStateException("Unable to support type " + pyType.getName());
            }
        }
        return out;
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy