All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.clickzetta.platform.arrow.ArrowSchemaConvert Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
package com.clickzetta.platform.arrow;

import cz.proto.*;
import cz.proto.ingestion.v2.IngestionV2;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;

import java.util.*;
import java.util.stream.Collectors;

public class ArrowSchemaConvert {

  private static final String FIELD_ID_KEY = "PARQUET:field_id";
  public static final String TYPE_KEY = "CzType";

  public static ArrowSchema convertToExternalSchema(IngestionV2.StreamSchema streamSchema, List fields) {
    Map fieldIdMap = new HashMap<>();
    List originalTypes = new ArrayList<>();
    List arrowFields = new ArrayList<>();
    Map fieldIndexMap = new HashMap<>();
    int index = 0;
    for (IngestionV2.DataField field : fields) {
      fieldIdMap.put(field.getType().getFieldId(), field);
      fieldIndexMap.put(field.getType().getFieldId(), index++);
      originalTypes.add(field.getType());
      arrowFields.add(toArrowField(field.getName(), field.getType(), "UTC"));
    }
    // get all key field index in all arrowFields List.
    Set keyIndex = new HashSet<>();
    if (streamSchema != null && streamSchema.hasPrimaryKeySpec()) {
      for (Integer fieldId : streamSchema.getPrimaryKeySpec().getFieldIdsList()) {
        IngestionV2.DataField keyDataField = fieldIdMap.get(fieldId);
        keyIndex.add(fieldIndexMap.get(keyDataField.getType().getFieldId()));
      }
    }
    if (streamSchema != null && streamSchema.hasDistSpec()) {
      for (Integer fieldId : streamSchema.getDistSpec().getFieldIdsList()) {
        IngestionV2.DataField distDataField = fieldIdMap.get(fieldId);
        keyIndex.add(fieldIndexMap.get(distDataField.getType().getFieldId()));
      }
    }
    if (streamSchema != null && streamSchema.hasPartitionSpec()) {
      for (Integer fieldId : streamSchema.getPartitionSpec().getSrcFieldIdsList()) {
        IngestionV2.DataField partitionDataField = fieldIdMap.get(fieldId);
        keyIndex.add(fieldIndexMap.get(partitionDataField.getType().getFieldId()));
      }
    }
    return new ArrowSchema(originalTypes, arrowFields, keyIndex);
  }

  private static Map extractKeyValueMetadata(DataType dataType) {
    Map metadata = new HashMap<>();
    if (dataType.getCategory() == DataTypeCategory.CHAR) {
      metadata.put(TYPE_KEY, "char(" + dataType.getCharTypeInfo().getLength() + ")");
    } else if (dataType.getCategory() == DataTypeCategory.VARCHAR) {
      metadata.put(TYPE_KEY, "varchar(" + dataType.getVarCharTypeInfo().getLength() + ")");
    } else if (dataType.getCategory() == DataTypeCategory.JSON) {
      metadata.put(TYPE_KEY, "json");
    }
    if (dataType.getFieldId() > 0) {
      metadata.put(FIELD_ID_KEY, String.valueOf(dataType.getFieldId()));
    }
    return metadata;
  }

  private static Field toArrowField(String name, DataType dataType, String timeZoneId) {
    Map metadata = extractKeyValueMetadata(dataType);
    switch (dataType.getCategoryValue()) {
      case DataTypeCategory.ARRAY_VALUE: {
        FieldType fieldType = new FieldType(dataType.getNullable(), ArrowType.List.INSTANCE, null, metadata);
        return new Field(name, fieldType,
            new ArrayList() {{
              add(toArrowField("element", dataType.getArrayTypeInfo().getElementType(), timeZoneId));
            }}
        );
      }
      case DataTypeCategory.STRUCT_VALUE: {
        FieldType fieldType = new FieldType(dataType.getNullable(), ArrowType.Struct.INSTANCE, null, metadata);
        return new Field(name, fieldType,
            new ArrayList<>(dataType.getStructTypeInfo().getFieldsList().stream()
                .map(field -> toArrowField(field.getName(), field.getType(), timeZoneId))
                .collect(Collectors.toList()))
        );
      }
      case DataTypeCategory.MAP_VALUE:
        FieldType mapType = new FieldType(dataType.getNullable(), new ArrowType.Map(false), null, metadata);
        // Note: Map Type struct can not be null, Struct Type key field can not be null
        DataType keyType = dataType.getMapTypeInfo().getKeyType();
        DataType valueType = dataType.getMapTypeInfo().getValueType();
        // TODO hack map key not nullable.
        keyType = keyType.toBuilder().setNullable(false).build();
        DataType mapStructType = DataType.newBuilder().setCategory(DataTypeCategory.STRUCT)
            .setStructTypeInfo(StructTypeInfo.newBuilder()
                .addFields(StructTypeInfo.Field.newBuilder().setName(MapVector.KEY_NAME).setType(keyType).build())
                .addFields(StructTypeInfo.Field.newBuilder().setName(MapVector.VALUE_NAME).setType(valueType).build())
                .build())
            .setNullable(false)
            .build();
        return new Field(name, mapType,
            new ArrayList() {{
              add(toArrowField(MapVector.DATA_VECTOR_NAME, mapStructType, timeZoneId));
            }}
        );
      default:
        FieldType fieldType = new FieldType(dataType.getNullable(), toArrowType(dataType, timeZoneId), null, metadata);
        return new Field(name, fieldType, null);
    }
  }

  public static ArrowType toArrowType(DataType dataType, String timeZoneId) {
    switch (dataType.getCategoryValue()) {
      case DataTypeCategory.INT8_VALUE:
        return new ArrowType.Int(8, true);
      case DataTypeCategory.INT16_VALUE:
        return new ArrowType.Int(8 * 2, true);
      case DataTypeCategory.INT32_VALUE:
        return new ArrowType.Int(8 * 4, true);
      case DataTypeCategory.INT64_VALUE:
        return new ArrowType.Int(8 * 8, true);
      case DataTypeCategory.FLOAT32_VALUE:
        return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
      case DataTypeCategory.FLOAT64_VALUE:
        return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
      case DataTypeCategory.DECIMAL_VALUE:
        if (dataType.hasDecimalTypeInfo()) {
          int precision = Integer.parseInt(String.valueOf(dataType.getDecimalTypeInfo().getPrecision()));
          int scale = Integer.parseInt(String.valueOf(dataType.getDecimalTypeInfo().getScale()));
          return new ArrowType.Decimal(precision, scale);
        } else {
          throw new UnsupportedOperationException("Decimal dataType not has decimalTypeInfo." + dataType);
        }
      case DataTypeCategory.BOOLEAN_VALUE:
        return ArrowType.Bool.INSTANCE;
      case DataTypeCategory.VARCHAR_VALUE:
      case DataTypeCategory.CHAR_VALUE:
      case DataTypeCategory.STRING_VALUE:
      case DataTypeCategory.JSON_VALUE:
        return ArrowType.Utf8.INSTANCE;
      case DataTypeCategory.BINARY_VALUE:
        return ArrowType.Binary.INSTANCE;
      case DataTypeCategory.DATE_VALUE:
        return new ArrowType.Date(DateUnit.DAY);
      case DataTypeCategory.TIMESTAMP_LTZ_VALUE:
        switch (dataType.getTimestampInfo().getTsUnit()) {
          case SECONDS:
            return new ArrowType.Timestamp(TimeUnit.SECOND, timeZoneId);
          case MILLISECONDS:
            return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZoneId);
          case MICROSECONDS:
            return new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId);
          case NANOSECONDS:
            return new ArrowType.Timestamp(TimeUnit.NANOSECOND, timeZoneId);
          default:
            throw new UnsupportedOperationException("not supported timestamp type: " + dataType);
        }
      default:
        throw new UnsupportedOperationException("not support dataType: " + dataType);
    }
  }

  public static DataTypeCategory fromArrowType(ArrowType arrowType) {
    DataTypeCategory dataTypeCategory = null;
    Class clz = arrowType.getClass();
    if (ArrowType.Int.class.equals(clz)) {
      int typeWidth = ((ArrowType.Int) arrowType).getBitWidth() / 8;
      switch (typeWidth) {
        case UInt1Vector.TYPE_WIDTH: {
          dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT8_VALUE);
          break;
        }
        case UInt2Vector.TYPE_WIDTH: {
          dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT16_VALUE);
          break;
        }
        case UInt4Vector.TYPE_WIDTH: {
          dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT32_VALUE);
          break;
        }
        case UInt8Vector.TYPE_WIDTH: {
          dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT64_VALUE);
          break;
        }
        default:
          throw new UnsupportedOperationException(
              String.format("not support type width: %s of int: %s", typeWidth, arrowType));
      }
    } else if (ArrowType.FloatingPoint.class.equals(clz)) {
      FloatingPointPrecision precision = ((ArrowType.FloatingPoint) arrowType).getPrecision();
      switch (precision) {
        case SINGLE: {
          dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.FLOAT32_VALUE);
          break;
        }
        case DOUBLE: {
          dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.FLOAT64_VALUE);
          break;
        }
        default:
          throw new UnsupportedOperationException(
              String.format("not support precision: %s of floatingPoint: %s", precision, arrowType));
      }
    } else if (ArrowType.Decimal.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.DECIMAL_VALUE);
    } else if (ArrowType.Bool.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.BOOLEAN_VALUE);
    } else if (ArrowType.Utf8.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.STRING_VALUE);
    } else if (ArrowType.Binary.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.BINARY_VALUE);
    } else if (ArrowType.Date.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.DATE_VALUE);
    } else if (ArrowType.Timestamp.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.TIMESTAMP_LTZ_VALUE);
    } else if (ArrowType.List.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.ARRAY_VALUE);
    } else if (ArrowType.Struct.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.STRUCT_VALUE);
    } else if (ArrowType.Map.class.equals(clz)) {
      dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.MAP_VALUE);
    } else {
      throw new UnsupportedOperationException("not support arrowType: " + arrowType);
    }
    return dataTypeCategory;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy