All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.parquet.avro.AvroWriteSupport Maven / Gradle / Ivy

The newest version!
/* 
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.parquet.avro;

import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.IndexedRecord;
import org.apache.avro.util.Utf8;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.hadoop.api.WriteSupport;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.parquet.Preconditions;

/**
 * Avro implementation of {@link WriteSupport} for generic, specific, and
 * reflect models. Use {@link AvroParquetWriter} or
 * {@link AvroParquetOutputFormat} rather than using this class directly.
 */
public class AvroWriteSupport extends WriteSupport {

  public static final String AVRO_DATA_SUPPLIER = "parquet.avro.write.data.supplier";

  public static void setAvroDataSupplier(
      Configuration configuration, Class suppClass) {
    configuration.set(AVRO_DATA_SUPPLIER, suppClass.getName());
  }

  static final String AVRO_SCHEMA = "parquet.avro.schema";
  private static final Schema MAP_KEY_SCHEMA = Schema.create(Schema.Type.STRING);

  public static final String WRITE_OLD_LIST_STRUCTURE =
      "parquet.avro.write-old-list-structure";
  static final boolean WRITE_OLD_LIST_STRUCTURE_DEFAULT = true;

  private static final String MAP_REPEATED_NAME = "key_value";
  private static final String MAP_KEY_NAME = "key";
  private static final String MAP_VALUE_NAME = "value";
  private static final String LIST_REPEATED_NAME = "list";
  private static final String OLD_LIST_REPEATED_NAME = "array";
  static final String LIST_ELEMENT_NAME = "element";

  private RecordConsumer recordConsumer;
  private MessageType rootSchema;
  private Schema rootAvroSchema;
  private GenericData model;
  private ListWriter listWriter;

  public AvroWriteSupport() {
  }

  /**
   * @deprecated use {@link AvroWriteSupport(MessageType, Schema, Configuration)}
   */
  @Deprecated
  public AvroWriteSupport(MessageType schema, Schema avroSchema) {
    this.rootSchema = schema;
    this.rootAvroSchema = avroSchema;
    this.model = null;
  }

  public AvroWriteSupport(MessageType schema, Schema avroSchema,
                          GenericData model) {
    this.rootSchema = schema;
    this.rootAvroSchema = avroSchema;
    this.model = model;
  }

  /**
   * @see org.apache.parquet.avro.AvroParquetOutputFormat#setSchema(org.apache.hadoop.mapreduce.Job, org.apache.avro.Schema)
   */
  public static void setSchema(Configuration configuration, Schema schema) {
    configuration.set(AVRO_SCHEMA, schema.toString());
  }

  @Override
  public WriteContext init(Configuration configuration) {
    if (rootAvroSchema == null) {
      this.rootAvroSchema = new Schema.Parser().parse(configuration.get(AVRO_SCHEMA));
      this.rootSchema = new AvroSchemaConverter().convert(rootAvroSchema);
    }

    if (model == null) {
      this.model = getDataModel(configuration);
    }

    boolean writeOldListStructure = configuration.getBoolean(
        WRITE_OLD_LIST_STRUCTURE, WRITE_OLD_LIST_STRUCTURE_DEFAULT);
    if (writeOldListStructure) {
      this.listWriter = new TwoLevelListWriter();
    } else {
      this.listWriter = new ThreeLevelListWriter();
    }

    Map extraMetaData = new HashMap();
    extraMetaData.put(AvroReadSupport.AVRO_SCHEMA_METADATA_KEY, rootAvroSchema.toString());
    return new WriteContext(rootSchema, extraMetaData);
  }

  @Override
  public void prepareForWrite(RecordConsumer recordConsumer) {
    this.recordConsumer = recordConsumer;
  }

  // overloaded version for backward compatibility
  @SuppressWarnings("unchecked")
  public void write(IndexedRecord record) {
    recordConsumer.startMessage();
    writeRecordFields(rootSchema, rootAvroSchema, record);
    recordConsumer.endMessage();
  }

  @Override
  public void write(T record) {
    recordConsumer.startMessage();
    writeRecordFields(rootSchema, rootAvroSchema, record);
    recordConsumer.endMessage();
  }

  private void writeRecord(GroupType schema, Schema avroSchema,
                           Object record) {
    recordConsumer.startGroup();
    writeRecordFields(schema, avroSchema, record);
    recordConsumer.endGroup();
  }

  private void writeRecordFields(GroupType schema, Schema avroSchema,
                                 Object record) {
    List fields = schema.getFields();
    List avroFields = avroSchema.getFields();
    int index = 0; // parquet ignores Avro nulls, so index may differ
    for (int avroIndex = 0; avroIndex < avroFields.size(); avroIndex++) {
      Schema.Field avroField = avroFields.get(avroIndex);
      if (avroField.schema().getType().equals(Schema.Type.NULL)) {
        continue;
      }
      Type fieldType = fields.get(index);
      Object value = model.getField(record, avroField.name(), avroIndex);
      if (value != null) {
        recordConsumer.startField(fieldType.getName(), index);
        writeValue(fieldType, avroField.schema(), value);
        recordConsumer.endField(fieldType.getName(), index);
      } else if (fieldType.isRepetition(Type.Repetition.REQUIRED)) {
        throw new RuntimeException("Null-value for required field: " + avroField.name());
      }
      index++;
    }
  }

  private  void writeMap(GroupType schema, Schema avroSchema,
                            Map map) {
    GroupType innerGroup = schema.getType(0).asGroupType();
    Type keyType = innerGroup.getType(0);
    Type valueType = innerGroup.getType(1);

    recordConsumer.startGroup(); // group wrapper (original type MAP)
    if (map.size() > 0) {
      recordConsumer.startField(MAP_REPEATED_NAME, 0);

      for (Map.Entry entry : map.entrySet()) {
        recordConsumer.startGroup(); // repeated group key_value, middle layer
        recordConsumer.startField(MAP_KEY_NAME, 0);
        writeValue(keyType, MAP_KEY_SCHEMA, entry.getKey());
        recordConsumer.endField(MAP_KEY_NAME, 0);
        V value = entry.getValue();
        if (value != null) {
          recordConsumer.startField(MAP_VALUE_NAME, 1);
          writeValue(valueType, avroSchema.getValueType(), value);
          recordConsumer.endField(MAP_VALUE_NAME, 1);
        } else if (!valueType.isRepetition(Type.Repetition.OPTIONAL)) {
          throw new RuntimeException("Null map value for " + avroSchema.getName());
        }
        recordConsumer.endGroup();
      }

      recordConsumer.endField(MAP_REPEATED_NAME, 0);
    }
    recordConsumer.endGroup();
  }

  private void writeUnion(GroupType parquetSchema, Schema avroSchema, 
                          Object value) {
    recordConsumer.startGroup();

    // ResolveUnion will tell us which of the union member types to
    // deserialise.
    int avroIndex = model.resolveUnion(avroSchema, value);

    // For parquet's schema we skip nulls
    GroupType parquetGroup = parquetSchema.asGroupType();
    int parquetIndex = avroIndex;
    for (int i = 0; i < avroIndex; i++) {
      if (avroSchema.getTypes().get(i).getType().equals(Schema.Type.NULL)) {
        parquetIndex--;
      }
    }

    // Sparsely populated method of encoding unions, each member has its own
    // set of columns.
    String memberName = "member" + parquetIndex;
    recordConsumer.startField(memberName, parquetIndex);
    writeValue(parquetGroup.getType(parquetIndex),
               avroSchema.getTypes().get(avroIndex), value);
    recordConsumer.endField(memberName, parquetIndex);

    recordConsumer.endGroup();
  }

  @SuppressWarnings("unchecked")
  private void writeValue(Type type, Schema avroSchema, Object value) {
    Schema nonNullAvroSchema = AvroSchemaConverter.getNonNull(avroSchema);
    Schema.Type avroType = nonNullAvroSchema.getType();
    if (avroType.equals(Schema.Type.BOOLEAN)) {
      recordConsumer.addBoolean((Boolean) value);
    } else if (avroType.equals(Schema.Type.INT)) {
      if (value instanceof Character) {
        recordConsumer.addInteger((Character) value);
      } else {
        recordConsumer.addInteger(((Number) value).intValue());
      }
    } else if (avroType.equals(Schema.Type.LONG)) {
      recordConsumer.addLong(((Number) value).longValue());
    } else if (avroType.equals(Schema.Type.FLOAT)) {
      recordConsumer.addFloat(((Number) value).floatValue());
    } else if (avroType.equals(Schema.Type.DOUBLE)) {
      recordConsumer.addDouble(((Number) value).doubleValue());
    } else if (avroType.equals(Schema.Type.BYTES)) {
      if (value instanceof byte[]) {
        recordConsumer.addBinary(Binary.fromReusedByteArray((byte[]) value));
      } else {
        recordConsumer.addBinary(Binary.fromReusedByteBuffer((ByteBuffer) value));
      }
    } else if (avroType.equals(Schema.Type.STRING)) {
      recordConsumer.addBinary(fromAvroString(value));
    } else if (avroType.equals(Schema.Type.RECORD)) {
      writeRecord(type.asGroupType(), nonNullAvroSchema, value);
    } else if (avroType.equals(Schema.Type.ENUM)) {
      recordConsumer.addBinary(Binary.fromString(value.toString()));
    } else if (avroType.equals(Schema.Type.ARRAY)) {
      listWriter.writeList(type.asGroupType(), nonNullAvroSchema, value);
    } else if (avroType.equals(Schema.Type.MAP)) {
      writeMap(type.asGroupType(), nonNullAvroSchema, (Map) value);
    } else if (avroType.equals(Schema.Type.UNION)) {
      writeUnion(type.asGroupType(), nonNullAvroSchema, value);
    } else if (avroType.equals(Schema.Type.FIXED)) {
      recordConsumer.addBinary(Binary.fromReusedByteArray(((GenericFixed) value).bytes()));
    }
  }

  private Binary fromAvroString(Object value) {
    if (value instanceof Utf8) {
      Utf8 utf8 = (Utf8) value;
      return Binary.fromReusedByteArray(utf8.getBytes(), 0, utf8.getByteLength());
    }
    return Binary.fromString(value.toString());
  }

  private static GenericData getDataModel(Configuration conf) {
    Class suppClass = conf.getClass(
        AVRO_DATA_SUPPLIER, SpecificDataSupplier.class, AvroDataSupplier.class);
    return ReflectionUtils.newInstance(suppClass, conf).get();
  }

  private abstract class ListWriter {

    protected abstract void writeCollection(
        GroupType type, Schema schema, Collection collection);

    protected abstract void writeObjectArray(
        GroupType type, Schema schema, Object[] array);

    protected abstract void startArray();

    protected abstract void endArray();

    public void writeList(GroupType schema, Schema avroSchema, Object value) {
      recordConsumer.startGroup(); // group wrapper (original type LIST)
      if (value instanceof Collection) {
        writeCollection(schema, avroSchema, (Collection) value);
      } else {
        Class arrayClass = value.getClass();
        Preconditions.checkArgument(arrayClass.isArray(),
            "Cannot write unless collection or array: " + arrayClass.getName());
        writeJavaArray(schema, avroSchema, arrayClass, value);
      }
      recordConsumer.endGroup();
    }

    public void writeJavaArray(GroupType schema, Schema avroSchema,
                               Class arrayClass, Object value) {
      Class elementClass = arrayClass.getComponentType();

      if (!elementClass.isPrimitive()) {
        writeObjectArray(schema, avroSchema, (Object[]) value);
        return;
      }

      switch (avroSchema.getElementType().getType()) {
        case BOOLEAN:
          Preconditions.checkArgument(elementClass == boolean.class,
              "Cannot write as boolean array: " + arrayClass.getName());
          writeBooleanArray((boolean[]) value);
          break;
        case INT:
          if (elementClass == byte.class) {
            writeByteArray((byte[]) value);
          } else if (elementClass == char.class) {
            writeCharArray((char[]) value);
          } else if (elementClass == short.class) {
            writeShortArray((short[]) value);
          } else if (elementClass == int.class) {
            writeIntArray((int[]) value);
          } else {
            throw new IllegalArgumentException(
                "Cannot write as an int array: " + arrayClass.getName());
          }
          break;
        case LONG:
          Preconditions.checkArgument(elementClass == long.class,
              "Cannot write as long array: " + arrayClass.getName());
          writeLongArray((long[]) value);
          break;
        case FLOAT:
          Preconditions.checkArgument(elementClass == float.class,
              "Cannot write as float array: " + arrayClass.getName());
          writeFloatArray((float[]) value);
          break;
        case DOUBLE:
          Preconditions.checkArgument(elementClass == double.class,
              "Cannot write as double array: " + arrayClass.getName());
          writeDoubleArray((double[]) value);
          break;
        default:
          throw new IllegalArgumentException("Cannot write " +
              avroSchema.getElementType() + " array: " + arrayClass.getName());
      }
    }

    protected void writeBooleanArray(boolean[] array) {
      if (array.length > 0) {
        startArray();
        for (boolean element : array) {
          recordConsumer.addBoolean(element);
        }
        endArray();
      }
    }

    protected void writeByteArray(byte[] array) {
      if (array.length > 0) {
        startArray();
        for (byte element : array) {
          recordConsumer.addInteger(element);
        }
        endArray();
      }
    }

    protected void writeShortArray(short[] array) {
      if (array.length > 0) {
        startArray();
        for (short element : array) {
          recordConsumer.addInteger(element);
        }
        endArray();
      }
    }

    protected void writeCharArray(char[] array) {
      if (array.length > 0) {
        startArray();
        for (char element : array) {
          recordConsumer.addInteger(element);
        }
        endArray();
      }
    }

    protected void writeIntArray(int[] array) {
      if (array.length > 0) {
        startArray();
        for (int element : array) {
          recordConsumer.addInteger(element);
        }
        endArray();
      }
    }

    protected void writeLongArray(long[] array) {
      if (array.length > 0) {
        startArray();
        for (long element : array) {
          recordConsumer.addLong(element);
        }
        endArray();
      }
    }

    protected void writeFloatArray(float[] array) {
      if (array.length > 0) {
        startArray();
        for (float element : array) {
          recordConsumer.addFloat(element);
        }
        endArray();
      }
    }

    protected void writeDoubleArray(double[] array) {
      if (array.length > 0) {
        startArray();
        for (double element : array) {
          recordConsumer.addDouble(element);
        }
        endArray();
      }
    }
  }

  /**
   * For backward-compatibility. This preserves how lists were written in 1.x.
   */
  private class TwoLevelListWriter extends ListWriter {
    @Override
    public void writeCollection(GroupType schema, Schema avroSchema,
                                Collection array) {
      if (array.size() > 0) {
        recordConsumer.startField(OLD_LIST_REPEATED_NAME, 0);
        for (Object elt : array) {
          writeValue(schema.getType(0), avroSchema.getElementType(), elt);
        }
        recordConsumer.endField(OLD_LIST_REPEATED_NAME, 0);
      }
    }

    @Override
    protected void writeObjectArray(GroupType type, Schema schema,
                                    Object[] array) {
      if (array.length > 0) {
        recordConsumer.startField(OLD_LIST_REPEATED_NAME, 0);
        for (Object element : array) {
          writeValue(type.getType(0), schema.getElementType(), element);
        }
        recordConsumer.endField(OLD_LIST_REPEATED_NAME, 0);
      }
    }

    @Override
    protected void startArray() {
      recordConsumer.startField(OLD_LIST_REPEATED_NAME, 0);
    }

    @Override
    protected void endArray() {
      recordConsumer.endField(OLD_LIST_REPEATED_NAME, 0);
    }
  }

  private class ThreeLevelListWriter extends ListWriter {
    @Override
    protected void writeCollection(GroupType type, Schema schema, Collection collection) {
      if (collection.size() > 0) {
        recordConsumer.startField(LIST_REPEATED_NAME, 0);
        GroupType repeatedType = type.getType(0).asGroupType();
        Type elementType = repeatedType.getType(0);
        for (Object element : collection) {
          recordConsumer.startGroup(); // repeated group array, middle layer
          if (element != null) {
            recordConsumer.startField(LIST_ELEMENT_NAME, 0);
            writeValue(elementType, schema.getElementType(), element);
            recordConsumer.endField(LIST_ELEMENT_NAME, 0);
          } else if (!elementType.isRepetition(Type.Repetition.OPTIONAL)) {
            throw new RuntimeException(
                "Null list element for " + schema.getName());
          }
          recordConsumer.endGroup();
        }
        recordConsumer.endField(LIST_REPEATED_NAME, 0);
      }
    }

    @Override
    protected void writeObjectArray(GroupType type, Schema schema,
                                    Object[] array) {
      if (array.length > 0) {
        recordConsumer.startField(LIST_REPEATED_NAME, 0);
        GroupType repeatedType = type.getType(0).asGroupType();
        Type elementType = repeatedType.getType(0);
        for (Object element : array) {
          recordConsumer.startGroup(); // repeated group array, middle layer
          if (element != null) {
            recordConsumer.startField(LIST_ELEMENT_NAME, 0);
            writeValue(elementType, schema.getElementType(), element);
            recordConsumer.endField(LIST_ELEMENT_NAME, 0);
          } else if (!elementType.isRepetition(Type.Repetition.OPTIONAL)) {
            throw new RuntimeException(
                "Null list element for " + schema.getName());
          }
          recordConsumer.endGroup();
        }
        recordConsumer.endField(LIST_REPEATED_NAME, 0);
      }
    }

    @Override
    protected void startArray() {
      recordConsumer.startField(LIST_REPEATED_NAME, 0);
      recordConsumer.startGroup(); // repeated group array, middle layer
      recordConsumer.startField(LIST_ELEMENT_NAME, 0);
    }

    @Override
    protected void endArray() {
      recordConsumer.endField(LIST_ELEMENT_NAME, 0);
      recordConsumer.endGroup();
      recordConsumer.endField(LIST_REPEATED_NAME, 0);
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy