All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.iceberg.parquet.ParquetValueWriters Maven / Gradle / Ivy

There is a newer version: 1.0.0.8
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.iceberg.parquet;

import java.lang.reflect.Array;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.FieldMetrics;
import org.apache.iceberg.FloatFieldMetrics;
import org.apache.iceberg.deletes.PositionDelete;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.util.DecimalUtil;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.ColumnWriteStore;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.Type;

public class ParquetValueWriters {
  private ParquetValueWriters() {
  }

  public static  ParquetValueWriter option(Type type,
                                                 int definitionLevel,
                                                 ParquetValueWriter writer) {
    if (type.isRepetition(Type.Repetition.OPTIONAL)) {
      return new OptionWriter<>(definitionLevel, writer);
    }

    return writer;
  }

  public static UnboxedWriter booleans(ColumnDescriptor desc) {
    return new UnboxedWriter<>(desc);
  }

  public static UnboxedWriter tinyints(ColumnDescriptor desc) {
    return new ByteWriter(desc);
  }

  public static UnboxedWriter shorts(ColumnDescriptor desc) {
    return new ShortWriter(desc);
  }

  public static UnboxedWriter ints(ColumnDescriptor desc) {
    return new UnboxedWriter<>(desc);
  }

  public static UnboxedWriter longs(ColumnDescriptor desc) {
    return new UnboxedWriter<>(desc);
  }

  public static UnboxedWriter floats(ColumnDescriptor desc) {
    return new FloatWriter(desc);
  }

  public static UnboxedWriter doubles(ColumnDescriptor desc) {
    return new DoubleWriter(desc);
  }

  public static PrimitiveWriter strings(ColumnDescriptor desc) {
    return new StringWriter(desc);
  }

  public static PrimitiveWriter decimalAsInteger(ColumnDescriptor desc,
                                                             int precision, int scale) {
    return new IntegerDecimalWriter(desc, precision, scale);
  }

  public static PrimitiveWriter decimalAsLong(ColumnDescriptor desc,
                                                          int precision, int scale) {
    return new LongDecimalWriter(desc, precision, scale);
  }

  public static PrimitiveWriter decimalAsFixed(ColumnDescriptor desc,
                                                           int precision, int scale) {
    return new FixedDecimalWriter(desc, precision, scale);
  }

  public static PrimitiveWriter byteBuffers(ColumnDescriptor desc) {
    return new BytesWriter(desc);
  }

  public static  CollectionWriter collections(int dl, int rl, ParquetValueWriter writer) {
    return new CollectionWriter<>(dl, rl, writer);
  }

  public static  MapWriter maps(int dl, int rl,
                                            ParquetValueWriter keyWriter,
                                            ParquetValueWriter valueWriter) {
    return new MapWriter<>(dl, rl, keyWriter, valueWriter);
  }

  public abstract static class PrimitiveWriter implements ParquetValueWriter {
    @SuppressWarnings("checkstyle:VisibilityModifier")
    protected final ColumnWriter column;
    private final List> children;

    protected PrimitiveWriter(ColumnDescriptor desc) {
      this.column = ColumnWriter.newWriter(desc);
      this.children = ImmutableList.of(column);
    }

    @Override
    public void write(int repetitionLevel, T value) {
      column.write(repetitionLevel, value);
    }

    @Override
    public List> columns() {
      return children;
    }

    @Override
    public void setColumnStore(ColumnWriteStore columnStore) {
      this.column.setColumnStore(columnStore);
    }
  }

  private static class UnboxedWriter extends PrimitiveWriter {
    private UnboxedWriter(ColumnDescriptor desc) {
      super(desc);
    }

    public void writeBoolean(int repetitionLevel, boolean value) {
      column.writeBoolean(repetitionLevel, value);
    }

    public void writeInteger(int repetitionLevel, int value) {
      column.writeInteger(repetitionLevel, value);
    }

    public void writeLong(int repetitionLevel, long  value) {
      column.writeLong(repetitionLevel, value);
    }

    public void writeFloat(int repetitionLevel, float value) {
      column.writeFloat(repetitionLevel, value);
    }

    public void writeDouble(int repetitionLevel, double value) {
      column.writeDouble(repetitionLevel, value);
    }
  }

  private static class FloatWriter extends UnboxedWriter {
    private final int id;
    private long nanCount;

    private FloatWriter(ColumnDescriptor desc) {
      super(desc);
      this.id = desc.getPrimitiveType().getId().intValue();
      this.nanCount = 0;
    }

    @Override
    public void write(int repetitionLevel, Float value) {
      writeFloat(repetitionLevel, value);
      if (Float.isNaN(value)) {
        nanCount++;
      }
    }

    @Override
    public Stream metrics() {
      return Stream.of(new FloatFieldMetrics(id, nanCount));
    }
  }

  private static class DoubleWriter extends UnboxedWriter {
    private final int id;
    private long nanCount;

    private DoubleWriter(ColumnDescriptor desc) {
      super(desc);
      this.id = desc.getPrimitiveType().getId().intValue();
      this.nanCount = 0;
    }

    @Override
    public void write(int repetitionLevel, Double value) {
      writeDouble(repetitionLevel, value);
      if (Double.isNaN(value)) {
        nanCount++;
      }
    }

    @Override
    public Stream metrics() {
      return Stream.of(new FloatFieldMetrics(id, nanCount));
    }
  }

  private static class ByteWriter extends UnboxedWriter {
    private ByteWriter(ColumnDescriptor desc) {
      super(desc);
    }

    @Override
    public void write(int repetitionLevel, Byte value) {
      writeInteger(repetitionLevel, value.intValue());
    }
  }

  private static class ShortWriter extends UnboxedWriter {
    private ShortWriter(ColumnDescriptor desc) {
      super(desc);
    }

    @Override
    public void write(int repetitionLevel, Short value) {
      writeInteger(repetitionLevel, value.intValue());
    }
  }

  private static class IntegerDecimalWriter extends PrimitiveWriter {
    private final int precision;
    private final int scale;

    private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int scale) {
      super(desc);
      this.precision = precision;
      this.scale = scale;
    }

    @Override
    public void write(int repetitionLevel, BigDecimal decimal) {
      Preconditions.checkArgument(decimal.scale() == scale,
          "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal);
      Preconditions.checkArgument(decimal.precision() <= precision,
          "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal);

      column.writeInteger(repetitionLevel, decimal.unscaledValue().intValue());
    }
  }

  private static class LongDecimalWriter extends PrimitiveWriter {
    private final int precision;
    private final int scale;

    private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) {
      super(desc);
      this.precision = precision;
      this.scale = scale;
    }

    @Override
    public void write(int repetitionLevel, BigDecimal decimal) {
      Preconditions.checkArgument(decimal.scale() == scale,
          "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal);
      Preconditions.checkArgument(decimal.precision() <= precision,
          "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal);

      column.writeLong(repetitionLevel, decimal.unscaledValue().longValue());
    }
  }

  private static class FixedDecimalWriter extends PrimitiveWriter {
    private final int precision;
    private final int scale;
    private final ThreadLocal bytes;

    private FixedDecimalWriter(ColumnDescriptor desc, int precision, int scale) {
      super(desc);
      this.precision = precision;
      this.scale = scale;
      this.bytes = ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]);
    }

    @Override
    public void write(int repetitionLevel, BigDecimal decimal) {
      byte[] binary = DecimalUtil.toReusedFixLengthBytes(precision, scale, decimal, bytes.get());
      column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(binary));
    }
  }

  private static class BytesWriter extends PrimitiveWriter {
    private BytesWriter(ColumnDescriptor desc) {
      super(desc);
    }

    @Override
    public void write(int repetitionLevel, ByteBuffer buffer) {
      column.writeBinary(repetitionLevel, Binary.fromReusedByteBuffer(buffer));
    }
  }

  private static class StringWriter extends PrimitiveWriter {
    private StringWriter(ColumnDescriptor desc) {
      super(desc);
    }

    @Override
    public void write(int repetitionLevel, CharSequence value) {
      if (value instanceof Utf8) {
        Utf8 utf8 = (Utf8) value;
        column.writeBinary(repetitionLevel,
            Binary.fromReusedByteArray(utf8.getBytes(), 0, utf8.getByteLength()));
      } else {
        column.writeBinary(repetitionLevel, Binary.fromString(value.toString()));
      }
    }
  }

  static class OptionWriter implements ParquetValueWriter {
    private final int definitionLevel;
    private final ParquetValueWriter writer;
    private final List> children;

    OptionWriter(int definitionLevel, ParquetValueWriter writer) {
      this.definitionLevel = definitionLevel;
      this.writer = writer;
      this.children = writer.columns();
    }

    @Override
    public void write(int repetitionLevel, T value) {
      if (value != null) {
        writer.write(repetitionLevel, value);

      } else {
        for (TripleWriter column : children) {
          column.writeNull(repetitionLevel, definitionLevel - 1);
        }
      }
    }

    @Override
    public List> columns() {
      return children;
    }

    @Override
    public void setColumnStore(ColumnWriteStore columnStore) {
      writer.setColumnStore(columnStore);
    }

    @Override
    public Stream metrics() {
      return writer.metrics();
    }
  }

  public abstract static class RepeatedWriter implements ParquetValueWriter {
    private final int definitionLevel;
    private final int repetitionLevel;
    private final ParquetValueWriter writer;
    private final List> children;

    protected RepeatedWriter(int definitionLevel, int repetitionLevel,
                             ParquetValueWriter writer) {
      this.definitionLevel = definitionLevel;
      this.repetitionLevel = repetitionLevel;
      this.writer = writer;
      this.children = writer.columns();
    }

    @Override
    public void write(int parentRepetition, L value) {
      Iterator elements = elements(value);

      if (!elements.hasNext()) {
        // write the empty list to each column
        // TODO: make sure this definition level is correct
        for (TripleWriter column : children) {
          column.writeNull(parentRepetition, definitionLevel - 1);
        }

      } else {
        boolean first = true;
        while (elements.hasNext()) {
          E element = elements.next();

          int rl = repetitionLevel;
          if (first) {
            rl = parentRepetition;
            first = false;
          }

          writer.write(rl, element);
        }
      }
    }

    @Override
    public List> columns() {
      return children;
    }

    @Override
    public void setColumnStore(ColumnWriteStore columnStore) {
      writer.setColumnStore(columnStore);
    }

    protected abstract Iterator elements(L value);

    @Override
    public Stream metrics() {
      return writer.metrics();
    }
  }

  private static class CollectionWriter extends RepeatedWriter, E> {
    private CollectionWriter(int definitionLevel, int repetitionLevel,
                             ParquetValueWriter writer) {
      super(definitionLevel, repetitionLevel, writer);
    }

    @Override
    protected Iterator elements(Collection list) {
      return list.iterator();
    }
  }

  public abstract static class RepeatedKeyValueWriter implements ParquetValueWriter {
    private final int definitionLevel;
    private final int repetitionLevel;
    private final ParquetValueWriter keyWriter;
    private final ParquetValueWriter valueWriter;
    private final List> children;

    protected RepeatedKeyValueWriter(int definitionLevel, int repetitionLevel,
                                     ParquetValueWriter keyWriter,
                                     ParquetValueWriter valueWriter) {
      this.definitionLevel = definitionLevel;
      this.repetitionLevel = repetitionLevel;
      this.keyWriter = keyWriter;
      this.valueWriter = valueWriter;
      this.children = ImmutableList.>builder()
          .addAll(keyWriter.columns())
          .addAll(valueWriter.columns())
          .build();
    }

    @Override
    public void write(int parentRepetition, M value) {
      Iterator> pairs = pairs(value);

      if (!pairs.hasNext()) {
        // write the empty map to each column
        for (TripleWriter column : children) {
          column.writeNull(parentRepetition, definitionLevel - 1);
        }

      } else {
        boolean first = true;
        while (pairs.hasNext()) {
          Map.Entry pair = pairs.next();

          int rl = repetitionLevel;
          if (first) {
            rl = parentRepetition;
            first = false;
          }

          keyWriter.write(rl, pair.getKey());
          valueWriter.write(rl, pair.getValue());
        }
      }
    }

    @Override
    public List> columns() {
      return children;
    }

    @Override
    public void setColumnStore(ColumnWriteStore columnStore) {
      keyWriter.setColumnStore(columnStore);
      valueWriter.setColumnStore(columnStore);
    }

    protected abstract Iterator> pairs(M value);

    @Override
    public Stream metrics() {
      return Stream.concat(keyWriter.metrics(), valueWriter.metrics());
    }
  }

  private static class MapWriter extends RepeatedKeyValueWriter, K, V> {
    private MapWriter(int definitionLevel, int repetitionLevel,
                      ParquetValueWriter keyWriter, ParquetValueWriter valueWriter) {
      super(definitionLevel, repetitionLevel, keyWriter, valueWriter);
    }

    @Override
    protected Iterator> pairs(Map map) {
      return map.entrySet().iterator();
    }
  }

  public abstract static class StructWriter implements ParquetValueWriter {
    private final ParquetValueWriter[] writers;
    private final List> children;

    @SuppressWarnings("unchecked")
    protected StructWriter(List> writers) {
      this.writers = (ParquetValueWriter[]) Array.newInstance(
          ParquetValueWriter.class, writers.size());

      ImmutableList.Builder> columnsBuilder = ImmutableList.builder();
      for (int i = 0; i < writers.size(); i += 1) {
        ParquetValueWriter writer = writers.get(i);
        this.writers[i] = (ParquetValueWriter) writer;
        columnsBuilder.addAll(writer.columns());
      }

      this.children = columnsBuilder.build();
    }

    @Override
    public void write(int repetitionLevel, S value) {
      for (int i = 0; i < writers.length; i += 1) {
        Object fieldValue = get(value, i);
        writers[i].write(repetitionLevel, fieldValue);
      }
    }

    @Override
    public List> columns() {
      return children;
    }

    @Override
    public void setColumnStore(ColumnWriteStore columnStore) {
      for (ParquetValueWriter writer : writers) {
        writer.setColumnStore(columnStore);
      }
    }

    protected abstract Object get(S struct, int index);

    @Override
    public Stream metrics() {
      return Arrays.stream(writers).flatMap(ParquetValueWriter::metrics);
    }
  }

  public static class PositionDeleteStructWriter extends StructWriter> {
    private final Function pathTransformFunc;

    public PositionDeleteStructWriter(StructWriter replacedWriter, Function pathTransformFunc) {
      super(Arrays.asList(replacedWriter.writers));
      this.pathTransformFunc = pathTransformFunc;
    }

    @Override
    protected Object get(PositionDelete delete, int index) {
      switch (index) {
        case 0:
          return pathTransformFunc.apply(delete.path());
        case 1:
          return delete.pos();
        case 2:
          return delete.row();
      }
      throw new IllegalArgumentException("Cannot get value for invalid index: " + index);
    }
  }
}