All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hudi.sync.common.util.Parquet2SparkSchemaUtils Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta2
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hudi.sync.common.util;

import org.apache.hudi.common.util.ValidationUtils;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;

import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;

/**
 * Convert the parquet schema to spark schema' json string.
 * This code is refer to org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter
 * in spark project.
 */
public class Parquet2SparkSchemaUtils {

  public static String convertToSparkSchemaJson(GroupType parquetSchema) {
    String fieldsJsonString = parquetSchema.getFields().stream().map(field -> {
      switch (field.getRepetition()) {
        case OPTIONAL:
          return "{\"name\":\"" + field.getName() + "\",\"type\":" + convertFieldType(field)
                  + ",\"nullable\":true,\"metadata\":{}}";
        case REQUIRED:
          return "{\"name\":\"" + field.getName() + "\",\"type\":" + convertFieldType(field)
                  + ",\"nullable\":false,\"metadata\":{}}";
        case REPEATED:
          String arrayType = arrayType(field, false);
          return "{\"name\":\"" + field.getName() + "\",\"type\":" + arrayType
                  + ",\"nullable\":false,\"metadata\":{}}";
        default:
          throw new UnsupportedOperationException("Unsupport convert " + field + " to spark sql type");
      }
    }).reduce((a, b) -> a + "," + b).orElse("");
    return "{\"type\":\"struct\",\"fields\":[" + fieldsJsonString + "]}";
  }

  private static String convertFieldType(Type field) {
    if (field instanceof PrimitiveType) {
      return "\"" + convertPrimitiveType((PrimitiveType) field) + "\"";
    } else {
      assert field instanceof GroupType;
      return convertGroupField((GroupType) field);
    }
  }

  private static String convertPrimitiveType(PrimitiveType field) {
    PrimitiveType.PrimitiveTypeName typeName = field.getPrimitiveTypeName();
    OriginalType originalType = field.getOriginalType();

    switch (typeName) {
      case BOOLEAN: return "boolean";
      case FLOAT: return "float";
      case DOUBLE: return "double";
      case INT32:
        if (originalType == null) {
          return "integer";
        }
        switch (originalType) {
          case INT_8: return "byte";
          case INT_16: return "short";
          case INT_32: return "integer";
          case DATE: return "date";
          case DECIMAL:
            return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
                    + field.getDecimalMetadata().getScale() + ")";
          default: throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
        }
      case INT64:
        if (originalType == null) {
          return "long";
        }
        switch (originalType)  {
          case INT_64: return "long";
          case DECIMAL:
            return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
                    + field.getDecimalMetadata().getScale() + ")";
          case TIMESTAMP_MICROS:
          case TIMESTAMP_MILLIS:
            return "timestamp";
          default:
            throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
        }
      case INT96: return "timestamp";

      case BINARY:
        if (originalType == null) {
          return "binary";
        }
        switch (originalType) {
          case UTF8:
          case  ENUM:
          case  JSON:
            return "string";
          case BSON: return "binary";
          case DECIMAL:
            return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
                    + field.getDecimalMetadata().getScale() + ")";
          default:
            throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
        }

      case FIXED_LEN_BYTE_ARRAY:
        switch (originalType) {
          case DECIMAL:
            return "decimal(" + field.getDecimalMetadata().getPrecision() + ","
                  + field.getDecimalMetadata().getScale() + ")";
          default:
            throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
        }
      default:
        throw new UnsupportedOperationException("Unsupport convert " + typeName + " to spark sql type");
    }
  }

  private static String convertGroupField(GroupType field) {
    if (field.getOriginalType() == null) {
      return convertToSparkSchemaJson(field);
    }
    switch (field.getOriginalType()) {
      case LIST:
        ValidationUtils.checkArgument(field.getFieldCount() == 1, "Illegal List type: " + field);
        Type repeatedType = field.getType(0);
        if (isElementType(repeatedType, field.getName())) {
          return arrayType(repeatedType, false);
        } else {
          Type elementType = repeatedType.asGroupType().getType(0);
          boolean optional = elementType.isRepetition(OPTIONAL);
          return arrayType(elementType, optional);
        }
      case MAP:
      case MAP_KEY_VALUE:
        GroupType keyValueType = field.getType(0).asGroupType();
        Type keyType = keyValueType.getType(0);
        Type valueType = keyValueType.getType(1);
        boolean valueOptional = valueType.isRepetition(OPTIONAL);
        return "{\"type\":\"map\", \"keyType\":" + convertFieldType(keyType)
                + ",\"valueType\":" + convertFieldType(valueType)
                + ",\"valueContainsNull\":" + valueOptional + "}";
      default:
        throw new UnsupportedOperationException("Unsupport convert " + field + " to spark sql type");
    }
  }

  private static String arrayType(Type elementType, boolean containsNull) {
    return "{\"type\":\"array\", \"elementType\":" + convertFieldType(elementType) + ",\"containsNull\":" + containsNull + "}";
  }

  private static boolean isElementType(Type repeatedType, String parentName) {
    return repeatedType.isPrimitive() || repeatedType.asGroupType().getFieldCount() > 1
      || repeatedType.getName().equals("array") || repeatedType.getName().equals(parentName + "_tuple");
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy