com.twosigma.flint.arrow.ArrowUtils.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2018 TWO SIGMA OPEN SOURCE, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.twosigma.flint.arrow
import scala.collection.JavaConverters._
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.types.{ DateUnit, FloatingPointPrecision, TimeUnit }
import org.apache.arrow.vector.types.pojo.{ ArrowType, Field, FieldType, Schema }
import org.apache.spark.sql.types._
object ArrowUtils {
val rootAllocator = new RootAllocator(Long.MaxValue)
// todo: support more types.
def toArrowType(dt: DataType): ArrowType = dt match {
case BooleanType => ArrowType.Bool.INSTANCE
case ByteType => new ArrowType.Int(8, true)
case ShortType => new ArrowType.Int(8 * 2, true)
case IntegerType => new ArrowType.Int(8 * 4, true)
case LongType => new ArrowType.Int(8 * 8, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case StringType => ArrowType.Utf8.INSTANCE
case BinaryType => ArrowType.Binary.INSTANCE
// case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale)
case DateType => new ArrowType.Date(DateUnit.DAY)
case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}")
}
def fromArrowType(dt: ArrowType): DataType = dt match {
case ArrowType.Bool.INSTANCE => BooleanType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
case float: ArrowType.FloatingPoint if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType
case float: ArrowType.FloatingPoint if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
case ArrowType.Utf8.INSTANCE => StringType
case ArrowType.Binary.INSTANCE => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case _: ArrowType.Date => DateType
case _: ArrowType.Timestamp => TimestampType
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}
def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = {
dt match {
case ArrayType(elementType, containsNull) =>
val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava)
case StructType(fields) =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
new Field(name, fieldType,
fields.map { field =>
toArrowField(field.name, field.dataType, field.nullable)
}.toSeq.asJava)
case dataType =>
val fieldType = new FieldType(nullable, toArrowType(dataType), null)
new Field(name, fieldType, Seq.empty[Field].asJava)
}
}
def fromArrowField(field: Field): DataType = {
field.getType match {
case ArrowType.List.INSTANCE =>
val elementField = field.getChildren().get(0)
val elementType = fromArrowField(elementField)
ArrayType(elementType, containsNull = elementField.isNullable)
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
val dt = fromArrowField(child)
StructField(child.getName, dt, child.isNullable)
}
StructType(fields)
case arrowType => fromArrowType(arrowType)
}
}
def toArrowSchema(schema: StructType): Schema = {
new Schema(schema.map { field =>
toArrowField(field.name, field.dataType, field.nullable)
}.asJava)
}
def fromArrowSchema(schema: Schema): StructType = {
StructType(schema.getFields.asScala.map { field =>
val dt = fromArrowField(field)
StructField(field.getName, dt, field.isNullable)
})
}
def assertDataTypeEquals(sparkSchema: StructType, arrowSchema: Schema): Unit = {
// Arrow schema contains index information, need to drop those
val actualSchema: StructType = StructType(fromArrowSchema(arrowSchema))
require(
sparkSchema.size == actualSchema.size,
s"Expected schema doesn't match actual schema. actual: $actualSchema expected: $sparkSchema"
)
(actualSchema zip sparkSchema).foreach{
case (f1, f2) =>
require(
f1.dataType.equals(f2.dataType),
s"Expected schema doesn't match actual schema. \n " +
s"actual: $actualSchema \n expected: $sparkSchema \n " +
s"actual datatype: $f1 \n expected datatype: $f2 \n"
)
}
}
}