All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.python.PythonCondition Maven / Gradle / Ivy

The newest version!
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.datavec.python;

import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import java.util.List;

import static org.datavec.python.PythonUtils.schemaToPythonVariables;
import static org.nd4j.common.base.Preconditions.checkNotNull;
import static org.nd4j.common.base.Preconditions.checkState;

/**
 * Lets a condition be defined as a python method f that takes no arguments
 * and returns a boolean indicating whether or not to filter a row.
 * The values of all columns in current row are available as global variables to f.
 *
 * @author Fariz Rahman
 */
public class PythonCondition implements Condition {

    private Schema inputSchema;
    private PythonVariables pyInputs;
    private PythonTransform pythonTransform;
    private String code;


    public PythonCondition(String pythonCode) {
        checkNotNull("Python code must not be null!", pythonCode);
        checkState(!pythonCode.isEmpty(), "Python code must not be empty!");
        code = pythonCode;
    }


    @Override
    public void setInputSchema(Schema inputSchema) {
        this.inputSchema = inputSchema;
        try {
            pyInputs = schemaToPythonVariables(inputSchema);
            PythonVariables pyOuts = new PythonVariables();
            pyOuts.addInt("out");
            pythonTransform = PythonTransform.builder()
                    .code(code + "\n\nout=f()\nout=0 if out is None else int(out)")
                    .inputs(pyInputs)
                    .outputs(pyOuts)
                    .build();

        } catch (Exception e) {
            throw new RuntimeException(e);
        }


    }

    @Override
    public Schema getInputSchema() {
        return inputSchema;
    }

    @Override
    public String[] outputColumnNames() {
        String[] columnNames = new String[inputSchema.numColumns()];
        inputSchema.getColumnNames().toArray(columnNames);
        return columnNames;
    }

    @Override
    public String outputColumnName() {
        return outputColumnNames()[0];
    }

    @Override
    public String[] columnNames() {
        return outputColumnNames();
    }

    @Override
    public String columnName() {
        return outputColumnName();
    }

    @Override
    public Schema transform(Schema inputSchema) {
        return inputSchema;
    }

    @Override
    public boolean condition(List list) {
        PythonVariables inputs = getPyInputsFromWritables(list);
        try {
            pythonTransform.getPythonJob().exec(inputs, pythonTransform.getOutputs());
            boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
            return ret;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public boolean condition(Object input) {
        return condition(input);
    }

    @Override
    public boolean conditionSequence(List> list) {
        throw new UnsupportedOperationException("not supported");
    }


    @Override
    public boolean conditionSequence(Object input) {
        throw new UnsupportedOperationException("not supported");
    }

    private PythonVariables getPyInputsFromWritables(List writables) {
        PythonVariables ret = new PythonVariables();

        for (int i = 0; i < inputSchema.numColumns(); i++) {
            String name = inputSchema.getName(i);
            Writable w = writables.get(i);
            PythonType pyType = pyInputs.getType(inputSchema.getName(i));
            switch (pyType.getName()) {
                case INT:
                    if (w instanceof LongWritable) {
                        ret.addInt(name, ((LongWritable) w).get());
                    } else {
                        ret.addInt(name, ((IntWritable) w).get());
                    }

                    break;
                case FLOAT:
                    ret.addFloat(name, ((DoubleWritable) w).get());
                    break;
                case STR:
                    ret.addStr(name, w.toString());
                    break;
                case NDARRAY:
                    ret.addNDArray(name, ((NDArrayWritable) w).get());
                    break;
            }
        }

        return ret;
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy