All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ompute.protobuf.0.2.5.source-code.ProtobufDeserializer.kt Maven / Gradle / Ivy

The newest version!
package org.ballistacompute.protobuf

import org.ballistacompute.datasource.CsvDataSource
import org.ballistacompute.datatypes.ArrowTypes
import org.ballistacompute.logical.*
import java.lang.RuntimeException

class ProtobufDeserializer {

    fun fromProto(node: LogicalPlanNode): LogicalPlan {
        return if (node.hasScan()) {
            val schema = fromProto(node.scan.schema)
            val ds = CsvDataSource(node.scan.path, schema,1024)
            Scan(node.scan.path, ds, node.scan.projectionList.asByteStringList().map { it.toString() })
        } else if (node.hasSelection()) {
            Selection(fromProto(node.input),
                    fromProto(node.selection.expr))
        } else if (node.hasProjection()) {
            Projection(fromProto(node.input),
                    node.projection.exprList.map { fromProto(it) })
        } else if (node.hasLimit()) {
            Limit(fromProto(node.input),
                    node.limit.limit)
        } else if (node.hasAggregate()) {
            val input = fromProto(node.input)
            val groupExpr = node.aggregate.groupExprList.map { fromProto(it) }
            val aggrExpr = node.aggregate.aggrExprList.map { fromProto(it) as AggregateExpr }
            Aggregate(input, groupExpr, aggrExpr)
        } else {
            throw RuntimeException("Failed to parse logical operator: $node")
        }
    }

    fun fromProto(node: LogicalExprNode): LogicalExpr {
        return if (node.hasBinaryExpr()) {
            val binaryNode = node.binaryExpr
            val ll = fromProto(binaryNode.l)
            val rr = fromProto(binaryNode.r)
            when (binaryNode.op) {
                "eq" -> Eq(ll, rr)
                "neq" -> Neq(ll, rr)
                "lt" -> Lt(ll, rr)
                "lteq" -> LtEq(ll, rr)
                "gt" -> Gt(ll, rr)
                "gteq" -> GtEq(ll, rr)
                "and" -> And(ll, rr)
                "or" -> Or(ll, rr)
                "add" -> Add(ll, rr)
                "subtract" -> Subtract(ll, rr)
                "multiply" -> Multiply(ll, rr)
                "divide" -> Divide(ll, rr)
                else -> throw RuntimeException("Failed to parse logical binary expression: $node")
            }
        } else if (node.hasColumnIndex) {
            ColumnIndex(node.columnIndex)
        } else if (node.hasColumnName) {
            col(node.columnName)
        } else if (node.hasLiteralString) {
            lit(node.literalString)
        } else if (node.hasLiteralLong) {
            lit(node.literalLong)
        } else if (node.hasLiteralDouble) {
            lit(node.literalDouble)
        } else if (node.hasAggregateExpr()) {
            val aggr = node.aggregateExpr
            val expr = fromProto(aggr.expr)
            return when (aggr.aggrFunction) {
                AggregateFunction.MIN -> Min(expr)
                AggregateFunction.MAX -> Max(expr)
                AggregateFunction.SUM -> Sum(expr)
                AggregateFunction.AVG -> Avg(expr)
                else -> throw RuntimeException("Failed to parse logical aggregate expression: ${aggr.aggrFunction}")
            }

        } else {
            throw RuntimeException("Failed to parse logical expression: $node")
        }
    }

    fun fromProto(schema: Schema): org.ballistacompute.datatypes.Schema {

        val arrowFields = schema.columnsList.map {

            //TODO add all types
            val dt = when (it.arrowTypeValue) {
                ArrowType.UTF8_VALUE -> ArrowTypes.StringType

                ArrowType.INT8_VALUE -> ArrowTypes.Int8Type
                ArrowType.INT16_VALUE -> ArrowTypes.Int16Type
                ArrowType.INT32_VALUE -> ArrowTypes.Int32Type
                ArrowType.INT64_VALUE -> ArrowTypes.Int64Type

                ArrowType.UINT8_VALUE -> ArrowTypes.UInt8Type
                ArrowType.UINT16_VALUE -> ArrowTypes.UInt16Type
                ArrowType.UINT32_VALUE -> ArrowTypes.UInt32Type
                ArrowType.UINT64_VALUE -> ArrowTypes.UInt64Type

                ArrowType.FLOAT_VALUE -> ArrowTypes.FloatType
                ArrowType.DOUBLE_VALUE -> ArrowTypes.DoubleType

                else -> throw IllegalStateException("Failed to parse Arrow data type enum from protobuf: ${it.arrowTypeValue}")
            }

            val fieldType = org.apache.arrow.vector.types.pojo.FieldType(true, dt, null)
            org.apache.arrow.vector.types.pojo.Field(it.name, fieldType, listOf())
        }

        return org.ballistacompute.datatypes.SchemaConverter.fromArrow(org.apache.arrow.vector.types.pojo.Schema(arrowFields))
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy