com.clickzetta.platform.arrow.ArrowSchemaConvert Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of clickzetta-java Show documentation
Show all versions of clickzetta-java Show documentation
The java SDK for clickzetta's Lakehouse
package com.clickzetta.platform.arrow;
import cz.proto.*;
import cz.proto.ingestion.v2.IngestionV2;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import java.util.*;
import java.util.stream.Collectors;
public class ArrowSchemaConvert {
private static final String FIELD_ID_KEY = "PARQUET:field_id";
public static final String TYPE_KEY = "CzType";
public static ArrowSchema convertToExternalSchema(IngestionV2.StreamSchema streamSchema, List fields) {
Map fieldIdMap = new HashMap<>();
List originalTypes = new ArrayList<>();
List arrowFields = new ArrayList<>();
Map fieldIndexMap = new HashMap<>();
int index = 0;
for (IngestionV2.DataField field : fields) {
fieldIdMap.put(field.getType().getFieldId(), field);
fieldIndexMap.put(field.getType().getFieldId(), index++);
originalTypes.add(field.getType());
arrowFields.add(toArrowField(field.getName(), field.getType(), "UTC"));
}
// get all key field index in all arrowFields List.
Set keyIndex = new HashSet<>();
if (streamSchema != null && streamSchema.hasPrimaryKeySpec()) {
for (Integer fieldId : streamSchema.getPrimaryKeySpec().getFieldIdsList()) {
IngestionV2.DataField keyDataField = fieldIdMap.get(fieldId);
keyIndex.add(fieldIndexMap.get(keyDataField.getType().getFieldId()));
}
}
if (streamSchema != null && streamSchema.hasDistSpec()) {
for (Integer fieldId : streamSchema.getDistSpec().getFieldIdsList()) {
IngestionV2.DataField distDataField = fieldIdMap.get(fieldId);
keyIndex.add(fieldIndexMap.get(distDataField.getType().getFieldId()));
}
}
if (streamSchema != null && streamSchema.hasPartitionSpec()) {
for (Integer fieldId : streamSchema.getPartitionSpec().getSrcFieldIdsList()) {
IngestionV2.DataField partitionDataField = fieldIdMap.get(fieldId);
keyIndex.add(fieldIndexMap.get(partitionDataField.getType().getFieldId()));
}
}
return new ArrowSchema(originalTypes, arrowFields, keyIndex);
}
private static Map extractKeyValueMetadata(DataType dataType) {
Map metadata = new HashMap<>();
if (dataType.getCategory() == DataTypeCategory.CHAR) {
metadata.put(TYPE_KEY, "char(" + dataType.getCharTypeInfo().getLength() + ")");
} else if (dataType.getCategory() == DataTypeCategory.VARCHAR) {
metadata.put(TYPE_KEY, "varchar(" + dataType.getVarCharTypeInfo().getLength() + ")");
} else if (dataType.getCategory() == DataTypeCategory.JSON) {
metadata.put(TYPE_KEY, "json");
}
if (dataType.getFieldId() > 0) {
metadata.put(FIELD_ID_KEY, String.valueOf(dataType.getFieldId()));
}
return metadata;
}
private static Field toArrowField(String name, DataType dataType, String timeZoneId) {
Map metadata = extractKeyValueMetadata(dataType);
switch (dataType.getCategoryValue()) {
case DataTypeCategory.ARRAY_VALUE: {
FieldType fieldType = new FieldType(dataType.getNullable(), ArrowType.List.INSTANCE, null, metadata);
return new Field(name, fieldType,
new ArrayList() {{
add(toArrowField("element", dataType.getArrayTypeInfo().getElementType(), timeZoneId));
}}
);
}
case DataTypeCategory.STRUCT_VALUE: {
FieldType fieldType = new FieldType(dataType.getNullable(), ArrowType.Struct.INSTANCE, null, metadata);
return new Field(name, fieldType,
new ArrayList<>(dataType.getStructTypeInfo().getFieldsList().stream()
.map(field -> toArrowField(field.getName(), field.getType(), timeZoneId))
.collect(Collectors.toList()))
);
}
case DataTypeCategory.MAP_VALUE:
FieldType mapType = new FieldType(dataType.getNullable(), new ArrowType.Map(false), null, metadata);
// Note: Map Type struct can not be null, Struct Type key field can not be null
DataType keyType = dataType.getMapTypeInfo().getKeyType();
DataType valueType = dataType.getMapTypeInfo().getValueType();
// TODO hack map key not nullable.
keyType = keyType.toBuilder().setNullable(false).build();
DataType mapStructType = DataType.newBuilder().setCategory(DataTypeCategory.STRUCT)
.setStructTypeInfo(StructTypeInfo.newBuilder()
.addFields(StructTypeInfo.Field.newBuilder().setName(MapVector.KEY_NAME).setType(keyType).build())
.addFields(StructTypeInfo.Field.newBuilder().setName(MapVector.VALUE_NAME).setType(valueType).build())
.build())
.setNullable(false)
.build();
return new Field(name, mapType,
new ArrayList() {{
add(toArrowField(MapVector.DATA_VECTOR_NAME, mapStructType, timeZoneId));
}}
);
default:
FieldType fieldType = new FieldType(dataType.getNullable(), toArrowType(dataType, timeZoneId), null, metadata);
return new Field(name, fieldType, null);
}
}
public static ArrowType toArrowType(DataType dataType, String timeZoneId) {
switch (dataType.getCategoryValue()) {
case DataTypeCategory.INT8_VALUE:
return new ArrowType.Int(8, true);
case DataTypeCategory.INT16_VALUE:
return new ArrowType.Int(8 * 2, true);
case DataTypeCategory.INT32_VALUE:
return new ArrowType.Int(8 * 4, true);
case DataTypeCategory.INT64_VALUE:
return new ArrowType.Int(8 * 8, true);
case DataTypeCategory.FLOAT32_VALUE:
return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
case DataTypeCategory.FLOAT64_VALUE:
return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
case DataTypeCategory.DECIMAL_VALUE:
if (dataType.hasDecimalTypeInfo()) {
int precision = Integer.parseInt(String.valueOf(dataType.getDecimalTypeInfo().getPrecision()));
int scale = Integer.parseInt(String.valueOf(dataType.getDecimalTypeInfo().getScale()));
return new ArrowType.Decimal(precision, scale);
} else {
throw new UnsupportedOperationException("Decimal dataType not has decimalTypeInfo." + dataType);
}
case DataTypeCategory.BOOLEAN_VALUE:
return ArrowType.Bool.INSTANCE;
case DataTypeCategory.VARCHAR_VALUE:
case DataTypeCategory.CHAR_VALUE:
case DataTypeCategory.STRING_VALUE:
case DataTypeCategory.JSON_VALUE:
return ArrowType.Utf8.INSTANCE;
case DataTypeCategory.BINARY_VALUE:
return ArrowType.Binary.INSTANCE;
case DataTypeCategory.DATE_VALUE:
return new ArrowType.Date(DateUnit.DAY);
case DataTypeCategory.TIMESTAMP_LTZ_VALUE:
switch (dataType.getTimestampInfo().getTsUnit()) {
case SECONDS:
return new ArrowType.Timestamp(TimeUnit.SECOND, timeZoneId);
case MILLISECONDS:
return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZoneId);
case MICROSECONDS:
return new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId);
case NANOSECONDS:
return new ArrowType.Timestamp(TimeUnit.NANOSECOND, timeZoneId);
default:
throw new UnsupportedOperationException("not supported timestamp type: " + dataType);
}
default:
throw new UnsupportedOperationException("not support dataType: " + dataType);
}
}
public static DataTypeCategory fromArrowType(ArrowType arrowType) {
DataTypeCategory dataTypeCategory = null;
Class extends ArrowType> clz = arrowType.getClass();
if (ArrowType.Int.class.equals(clz)) {
int typeWidth = ((ArrowType.Int) arrowType).getBitWidth() / 8;
switch (typeWidth) {
case UInt1Vector.TYPE_WIDTH: {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT8_VALUE);
break;
}
case UInt2Vector.TYPE_WIDTH: {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT16_VALUE);
break;
}
case UInt4Vector.TYPE_WIDTH: {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT32_VALUE);
break;
}
case UInt8Vector.TYPE_WIDTH: {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.INT64_VALUE);
break;
}
default:
throw new UnsupportedOperationException(
String.format("not support type width: %s of int: %s", typeWidth, arrowType));
}
} else if (ArrowType.FloatingPoint.class.equals(clz)) {
FloatingPointPrecision precision = ((ArrowType.FloatingPoint) arrowType).getPrecision();
switch (precision) {
case SINGLE: {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.FLOAT32_VALUE);
break;
}
case DOUBLE: {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.FLOAT64_VALUE);
break;
}
default:
throw new UnsupportedOperationException(
String.format("not support precision: %s of floatingPoint: %s", precision, arrowType));
}
} else if (ArrowType.Decimal.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.DECIMAL_VALUE);
} else if (ArrowType.Bool.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.BOOLEAN_VALUE);
} else if (ArrowType.Utf8.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.STRING_VALUE);
} else if (ArrowType.Binary.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.BINARY_VALUE);
} else if (ArrowType.Date.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.DATE_VALUE);
} else if (ArrowType.Timestamp.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.TIMESTAMP_LTZ_VALUE);
} else if (ArrowType.List.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.ARRAY_VALUE);
} else if (ArrowType.Struct.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.STRUCT_VALUE);
} else if (ArrowType.Map.class.equals(clz)) {
dataTypeCategory = DataTypeCategory.forNumber(DataTypeCategory.MAP_VALUE);
} else {
throw new UnsupportedOperationException("not support arrowType: " + arrowType);
}
return dataTypeCategory;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy