All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.hpccsystems.spark.HpccRDD Maven / Gradle / Ivy

/*******************************************************************************
 *     HPCC SYSTEMS software Copyright (C) 2018 HPCC Systems®.
 *
 *     Licensed under the Apache License, Version 2.0 (the "License");
 *     you may not use this file except in compliance with the License.
 *     You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License.
 *******************************************************************************/
package org.hpccsystems.spark;

import java.io.Serializable;
import java.util.Iterator;

import java.util.Arrays;

import org.apache.spark.Dependency;
import org.apache.spark.InterruptibleIterator;
import org.apache.spark.Partition;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.execution.python.EvaluatePython;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.LogManager;

import org.hpccsystems.dfs.client.DataPartition;
import org.hpccsystems.dfs.client.HpccRemoteFileReader;

import org.hpccsystems.commons.ecl.FieldDef;

import scala.collection.JavaConverters;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import net.razorvine.pickle.Unpickler;

/**
 * The implementation of the RDD
 *
 */
public class HpccRDD extends RDD implements Serializable
{
    private static final long          serialVersionUID = 1L;
    private static final Logger        log              = LogManager.getLogger(HpccRDD.class);
    private static final ClassTag CT_RECORD        = ClassTag$.MODULE$.apply(Row.class);

    public static int                  DEFAULT_CONNECTION_TIMEOUT = 120;

    private InternalPartition[]        parts;
    private FieldDef                   originalRecordDef = null;
    private FieldDef                   projectedRecordDef = null;
    private int                        connectionTimeout = DEFAULT_CONNECTION_TIMEOUT;
    private int                        recordLimit = -1;

    private static void registerPicklingFunctions()
    {
        EvaluatePython.registerPicklers();
        Unpickler.registerConstructor("pyspark.sql.types", "Row", new RowConstructor());
        Unpickler.registerConstructor("pyspark.sql.types", "_create_row", new RowConstructor());
    }

    private class InternalPartition implements Partition
    {
        private static final long serialVersionUID = 1L;

        public DataPartition      partition;

        public int hashCode()
        {
            return this.index();
        }

        public int index()
        {
            return partition.index();
        }
    }

    /**
     * @param sc
     * @param dataParts
     * @param originalRD 
    */
    public HpccRDD(SparkContext sc, DataPartition[] dataParts, FieldDef originalRD)
    {
        this(sc,dataParts,originalRD,originalRD);
    }
    
    /**
     * @param sc
     * @param dataParts
     * @param originalRD 
     * @param projectedRD 
    */
    public HpccRDD(SparkContext sc, DataPartition[] dataParts, FieldDef originalRD, FieldDef projectedRD)
    {
        this(sc,dataParts,originalRD,projectedRD,DEFAULT_CONNECTION_TIMEOUT,-1);
    }

    /**
     * @param sc
     * @param dataParts
     * @param originalRD 
     * @param projectedRD 
     * @param limit 
    */
    public HpccRDD(SparkContext sc, DataPartition[] dataParts, FieldDef originalRD, FieldDef projectedRD, int connectTimeout, int limit)
    {
        super(sc, new ArrayBuffer>(), CT_RECORD);
        this.parts = new InternalPartition[dataParts.length];
        for (int i = 0; i < dataParts.length; i++)
        {
            this.parts[i] = new InternalPartition();
            this.parts[i].partition = dataParts[i];
        }

        this.originalRecordDef = originalRD;
        this.projectedRecordDef = projectedRD; 
        this.connectionTimeout = connectTimeout;
        this.recordLimit = limit;
    }

    /**
     * Wrap this RDD as a JavaRDD so the Java API can be used.
     * @return a JavaRDD wrapper of the HpccRDD.
     */
    public JavaRDD asJavaRDD()
    {
        JavaRDD jRDD = new JavaRDD(this, CT_RECORD);
        return jRDD;
    }

    /**
     * Transform to an RDD of labeled points for MLLib supervised learning.
     * @param labelName the field name of the label datg
     * @param dimNames the field names for the dimensions
     * @throws IllegalArgumentException
     * @return
     */
    public RDD makeMLLibLabeledPoint(String labelName, String[] dimNames) throws IllegalArgumentException
    {
        StructType schema = null;
        try
        {
            schema = SparkSchemaTranslator.toSparkSchema(this.projectedRecordDef);
        }
        catch (Exception e)
        {
            throw new IllegalArgumentException(e.getMessage());
        }

        // Precompute indices for the requested fields
        // Throws illegal argument exception if field cannot be found
        int labelIndex = schema.fieldIndex(labelName);
        int[] dimIndices = new int[dimNames.length];
        for (int i = 0; i < dimIndices.length; i++)
        {
            dimIndices[i] = schema.fieldIndex(dimNames[i]);
        }

        // Map each row to a labeled point using the precomputed indices
        JavaRDD jRDD = this.asJavaRDD();
        return jRDD.map((row) ->
        {
            double label = row.getDouble(labelIndex);
            double[] dims = new double[dimIndices.length];
            for (int i = 0; i < dimIndices.length; i++)
            {
                dims[i] = row.getDouble(dimIndices[i]);
            }
            return new LabeledPoint(label, new DenseVector(dims));
        }).rdd();
    }

    /**
     * Transform to mllib.linalg.Vectors for ML Lib machine learning.
     * @param dimNames the field names for the dimensions
     * @throws IllegalArgumentException
     * @return
     */
    public RDD makeMLLibVector(String[] dimNames) throws IllegalArgumentException
    {
        StructType schema = null;
        try
        {
            schema = SparkSchemaTranslator.toSparkSchema(this.projectedRecordDef);
        }
        catch (Exception e)
        {
            throw new IllegalArgumentException(e.getMessage());
        }

        // Precompute indices for the requested fields
        // Throws illegal argument exception if field cannot be found
        int[] dimIndices = new int[dimNames.length];
        for (int i = 0; i < dimIndices.length; i++)
        {
            dimIndices[i] = schema.fieldIndex(dimNames[i]);
        }

        // Map each row to a vector using the precomputed indices
        JavaRDD jRDD = this.asJavaRDD();
        return jRDD.map((row) ->
        {
            double[] dims = new double[dimIndices.length];
            for (int i = 0; i < dimIndices.length; i++)
            {
                dims[i] = row.getDouble(dimIndices[i]);
            }
            return (Vector) new DenseVector(dims);
        }).rdd();
    }

    /* (non-Javadoc)
     * @see org.apache.spark.rdd.RDD#compute(org.apache.spark.Partition, org.apache.spark.TaskContext)
     */
    @Override
    public InterruptibleIterator compute(Partition p_arg, TaskContext ctx)
    {
        HpccRDD.registerPicklingFunctions();

        final InternalPartition this_part = (InternalPartition) p_arg;
        final FieldDef originalRD = this.originalRecordDef;
        final FieldDef projectedRD = this.projectedRecordDef;

        if (originalRD == null)
        {
            log.error("Original record defintion is null. Aborting.");
            return null;
        }
        
        if (projectedRD == null)
        {
            log.error("Projected record defintion is null. Aborting.");
            return null;
        }

        scala.collection.Iterator iter = null;
        try
        {
            final HpccRemoteFileReader fileReader = new HpccRemoteFileReader(this_part.partition, originalRD, new GenericRowRecordBuilder(projectedRD), connectionTimeout, recordLimit);
            ctx.addTaskCompletionListener(taskContext -> 
            {
                if (fileReader != null)
                {
                    try
                    {
                        fileReader.close();
                    }
                    catch(Exception e) {}
                }
            });

            iter = JavaConverters.asScalaIteratorConverter(fileReader).asScala();
        }
        catch (Exception e)
        {
            log.error("Failed to create remote file reader with error: " + e.getMessage());
            return null;
        }

        return new InterruptibleIterator(ctx, iter);
    }

    @Override
    public Seq getPreferredLocations(Partition split)
    {
        final InternalPartition part = (InternalPartition) split;
        return JavaConverters.asScalaBufferConverter(Arrays.asList(part.partition.getCopyLocations()[0])).asScala().seq();
    }

    /* (non-Javadoc)
     * @see org.apache.spark.rdd.RDD#getPartitions()
     */
    @Override
    public Partition[] getPartitions()
    {
        return parts;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy