All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.parquet.thrift.ParquetWriteProtocol Maven / Gradle / Ivy

There is a newer version: 1.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.parquet.thrift;

import java.nio.ByteBuffer;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.conf.HadoopParquetConfiguration;
import org.apache.parquet.conf.ParquetConfiguration;
import org.apache.parquet.io.ColumnIO;
import org.apache.parquet.io.GroupColumnIO;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.io.ParquetEncodingException;
import org.apache.parquet.io.PrimitiveColumnIO;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.thrift.struct.ThriftField;
import org.apache.parquet.thrift.struct.ThriftType;
import org.apache.parquet.thrift.struct.ThriftType.EnumType;
import org.apache.parquet.thrift.struct.ThriftType.EnumValue;
import org.apache.parquet.thrift.struct.ThriftType.ListType;
import org.apache.parquet.thrift.struct.ThriftType.MapType;
import org.apache.parquet.thrift.struct.ThriftType.SetType;
import org.apache.parquet.thrift.struct.ThriftType.StructType;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TStruct;
import org.apache.thrift.protocol.TType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParquetWriteProtocol extends ParquetProtocol {

  public static final String WRITE_THREE_LEVEL_LISTS = "parquet.thrift.write-three-level-lists";
  public static final boolean WRITE_THREE_LEVEL_LISTS_DEFAULT = false;

  private boolean writeThreeLevelList = WRITE_THREE_LEVEL_LISTS_DEFAULT;

  interface Events {

    public void start();

    public void end();
  }

  abstract class FieldBaseWriteProtocol extends ParquetProtocol {
    private final Events returnClause;

    public FieldBaseWriteProtocol(Events returnClause) {
      this.returnClause = returnClause;
    }

    void start() {
      this.returnClause.start();
    }

    void end() {
      this.returnClause.end();
    }
  }

  class EnumWriteProtocol extends FieldBaseWriteProtocol {

    private final EnumType type;
    private PrimitiveColumnIO columnIO;

    public EnumWriteProtocol(PrimitiveColumnIO columnIO, EnumType type, Events returnClause) {
      super(returnClause);
      this.columnIO = columnIO;
      this.type = type;
    }

    @Override
    public void writeI32(int i32) throws TException {
      start();
      EnumValue value = type.getEnumValueById(i32);
      if (value == null) {
        throw new ParquetEncodingException(
            "Can not find enum value of index " + i32 + " for field:" + columnIO.toString());
      }
      recordConsumer.addBinary(Binary.fromString(value.getName()));
      end();
    }
  }

  abstract class ListWriteProtocol extends FieldBaseWriteProtocol {
    protected int size;

    public ListWriteProtocol(Events returnClause) {
      super(returnClause);
    }

    protected abstract void startListWrapper();

    protected abstract void endListWrapper();

    @Override
    public void writeListBegin(TList list) throws TException {
      size = list.size;
      startListWrapper();
    }

    @Override
    public void writeListEnd() throws TException {
      endListWrapper();
    }

    @Override
    public void writeSetBegin(TSet set) throws TException {
      size = set.size;
      startListWrapper();
    }

    @Override
    public void writeSetEnd() throws TException {
      endListWrapper();
    }
  }

  class TwoLevelListWriteProtocol extends ListWriteProtocol {
    private ColumnIO listContent;
    private TProtocol contentProtocol;

    public TwoLevelListWriteProtocol(GroupColumnIO columnIO, ThriftField values, Events returnClause) {
      super(returnClause);
      this.listContent = columnIO.getChild(0);
      this.contentProtocol = getProtocol(values, listContent, new Events() {
        int consumedRecords = 0;

        @Override
        public void start() {}

        @Override
        public void end() {
          ++consumedRecords;
          if (consumedRecords == size) {
            currentProtocol = TwoLevelListWriteProtocol.this;
            consumedRecords = 0;
          }
        }
      });
    }

    @Override
    protected void startListWrapper() {
      start();
      recordConsumer.startGroup();
      if (size > 0) {
        recordConsumer.startField(listContent.getType().getName(), 0);
        currentProtocol = contentProtocol;
      }
    }

    @Override
    protected void endListWrapper() {
      if (size > 0) {
        recordConsumer.endField(listContent.getType().getName(), 0);
      }
      recordConsumer.endGroup();
      end();
    }
  }

  class ThreeLevelListWriteProtocol extends ListWriteProtocol {
    private GroupColumnIO listContent;
    private ColumnIO elementContent;
    private TProtocol contentProtocol;

    public ThreeLevelListWriteProtocol(GroupColumnIO columnIO, ThriftField values, Events returnClause) {
      super(returnClause);
      this.listContent = (GroupColumnIO) columnIO.getChild(0);
      this.elementContent = listContent.getChild(0);
      this.contentProtocol = getProtocol(values, elementContent, new Events() {
        int consumedRecords = 0;

        @Override
        public void start() {
          recordConsumer.startGroup();
          recordConsumer.startField("element", 0);
        }

        @Override
        public void end() {
          recordConsumer.endField("element", 0);
          recordConsumer.endGroup();
          ++consumedRecords;
          if (consumedRecords == size) {
            currentProtocol = ThreeLevelListWriteProtocol.this;
            consumedRecords = 0;
          } else {
            currentProtocol = contentProtocol;
          }
        }
      });
    }

    @Override
    protected void startListWrapper() {
      start();
      recordConsumer.startGroup();
      if (size > 0) {
        recordConsumer.startField("list", 0);
        currentProtocol = contentProtocol;
      }
    }

    @Override
    protected void endListWrapper() {
      if (size > 0) {
        recordConsumer.endField("list", 0);
      }
      recordConsumer.endGroup();
      end();
    }
  }

  class MapWriteProtocol extends FieldBaseWriteProtocol {

    private GroupColumnIO mapContent;
    private ColumnIO key;
    private ColumnIO value;
    private TProtocol keyProtocol;
    private TProtocol valueProtocol;
    private int countToConsume;

    public MapWriteProtocol(GroupColumnIO columnIO, MapType type, Events returnClause) {
      super(returnClause);
      this.mapContent = (GroupColumnIO) columnIO.getChild(0);
      this.key = mapContent.getChild(0);
      this.value = mapContent.getChild(1);
      this.keyProtocol = getProtocol(type.getKey(), this.key, new Events() {
        @Override
        public void start() {
          recordConsumer.startGroup();
          recordConsumer.startField(key.getName(), key.getIndex());
        }

        @Override
        public void end() {
          recordConsumer.endField(key.getName(), key.getIndex());
          currentProtocol = valueProtocol;
        }
      });
      this.valueProtocol = getProtocol(type.getValue(), this.value, new Events() {
        int consumed;

        @Override
        public void start() {
          recordConsumer.startField(value.getName(), value.getIndex());
        }

        @Override
        public void end() {
          consumed++;
          recordConsumer.endField(value.getName(), value.getIndex());
          recordConsumer.endGroup();
          if (consumed == countToConsume) {
            currentProtocol = MapWriteProtocol.this;
            consumed = 0;
          } else {
            currentProtocol = keyProtocol;
          }
        }
      });
    }

    @Override
    public void writeMapBegin(TMap map) throws TException {
      start();
      recordConsumer.startGroup();
      countToConsume = map.size;
      if (countToConsume > 0) {
        recordConsumer.startField(mapContent.getType().getName(), 0);
        currentProtocol = keyProtocol;
      }
    }

    @Override
    public void writeMapEnd() throws TException {
      if (countToConsume > 0) {
        recordConsumer.endField(mapContent.getType().getName(), 0);
      }
      recordConsumer.endGroup();
      end();
    }
  }

  class PrimitiveWriteProtocol extends FieldBaseWriteProtocol {

    public PrimitiveWriteProtocol(PrimitiveColumnIO columnIO, Events returnClause) {
      super(returnClause);
    }

    @Override
    public void writeBool(boolean b) throws TException {
      start();
      recordConsumer.addBoolean(b);
      end();
    }

    @Override
    public void writeByte(byte b) throws TException {
      start();
      recordConsumer.addInteger(b);
      end();
    }

    @Override
    public void writeI16(short i16) throws TException {
      start();
      recordConsumer.addInteger(i16);
      end();
    }

    @Override
    public void writeI32(int i32) throws TException {
      start();
      recordConsumer.addInteger(i32);
      end();
    }

    @Override
    public void writeI64(long i64) throws TException {
      start();
      recordConsumer.addLong(i64);
      end();
    }

    @Override
    public void writeDouble(double dub) throws TException {
      start();
      recordConsumer.addDouble(dub);
      end();
    }

    @Override
    public void writeString(String str) throws TException {
      start();
      writeStringToRecordConsumer(str);
      end();
    }

    @Override
    public void writeBinary(ByteBuffer buf) throws TException {
      start();
      writeBinaryToRecordConsumer(buf);
      end();
    }
  }

  class StructWriteProtocol extends FieldBaseWriteProtocol {

    private final GroupColumnIO schema;
    private final StructType thriftType;
    private final TProtocol[] children;
    private ColumnIO currentType;
    private ColumnIO[] thriftFieldIdToParquetField;

    public StructWriteProtocol(GroupColumnIO schema, StructType thriftType, Events returnClause) {
      super(returnClause);
      if (schema == null) {
        throw new NullPointerException("schema");
      }
      this.thriftType = thriftType;
      int maxFieldId = 0;
      for (ThriftField field : thriftType.getChildren()) {
        maxFieldId = Math.max(maxFieldId, field.getFieldId());
      }
      thriftFieldIdToParquetField = new ColumnIO[maxFieldId + 1];
      for (int i = 0; i < thriftType.getChildren().size(); i++) {
        thriftFieldIdToParquetField[thriftType.getChildren().get(i).getFieldId()] = schema.getChild(i);
      }
      for (ThriftField field : thriftType.getChildren()) {}
      this.schema = schema;
      children = new TProtocol[thriftType.getChildren().size()];
      for (int i = 0; i < children.length; i++) {
        final ThriftField field = thriftType.getChildren().get(i);
        final ColumnIO columnIO = schema.getChild(field.getName());
        if (columnIO == null) {
          throw new RuntimeException("Could not find " + field.getName() + " in " + schema);
        }
        try {
          TProtocol p;
          p = getProtocol(field, columnIO, new Events() {
            @Override
            public void start() {}

            @Override
            public void end() {
              currentProtocol = StructWriteProtocol.this;
            }
          });
          children[i] = p;
        } catch (RuntimeException e) {
          throw new ParquetEncodingException("Could not create Protocol for " + field + " to " + columnIO, e);
        }
      }
    }

    @Override
    public void writeStructBegin(TStruct struct) throws TException {
      start();
      recordConsumer.startGroup();
    }

    @Override
    public void writeStructEnd() throws TException {
      recordConsumer.endGroup();
      end();
    }

    @Override
    public void writeFieldBegin(TField field) throws TException {
      if (field.type == TType.STOP) {
        return;
      }
      try {
        currentType = thriftFieldIdToParquetField[field.id];
        if (currentType == null) {
          throw new ParquetEncodingException(
              "field " + field.id + " was not found in " + thriftType + " and " + schema.getType());
        }
        final int index = currentType.getIndex();
        recordConsumer.startField(currentType.getName(), index);
        currentProtocol = children[index];
      } catch (ArrayIndexOutOfBoundsException e) {
        throw new ParquetEncodingException(
            "field " + field.id + " was not found in " + thriftType + " and " + schema.getType());
      }
    }

    @Override
    public void writeFieldStop() throws TException {
      // duplicate with struct end
    }

    @Override
    public void writeFieldEnd() throws TException {
      recordConsumer.endField(currentType.getName(), currentType.getIndex());
    }
  }

  class MessageWriteProtocol extends StructWriteProtocol {

    public MessageWriteProtocol(MessageColumnIO schema, StructType thriftType) {
      super(schema, thriftType, null);
    }

    @Override
    public void writeStructBegin(TStruct struct) throws TException {
      recordConsumer.startMessage();
    }

    @Override
    public void writeStructEnd() throws TException {
      recordConsumer.endMessage();
    }
  }

  private static final Logger LOG = LoggerFactory.getLogger(ParquetWriteProtocol.class);

  private final RecordConsumer recordConsumer;
  private TProtocol currentProtocol;

  private String toString(TStruct struct) {
    return "";
  }

  private String toString(TList list) {
    return "";
  }

  private String toString(TMap map) {
    return "";
  }

  public ParquetWriteProtocol(
      Configuration configuration, RecordConsumer recordConsumer, MessageColumnIO schema, StructType thriftType) {
    this(new HadoopParquetConfiguration(configuration), recordConsumer, schema, thriftType);
  }

  public ParquetWriteProtocol(
      ParquetConfiguration configuration,
      RecordConsumer recordConsumer,
      MessageColumnIO schema,
      StructType thriftType) {
    this.recordConsumer = recordConsumer;
    if (configuration != null) {
      this.writeThreeLevelList =
          configuration.getBoolean(WRITE_THREE_LEVEL_LISTS, WRITE_THREE_LEVEL_LISTS_DEFAULT);
    }
    this.currentProtocol = new MessageWriteProtocol(schema, thriftType);
  }

  public ParquetWriteProtocol(RecordConsumer recordConsumer, MessageColumnIO schema, StructType thriftType) {
    this.currentProtocol = new MessageWriteProtocol(schema, thriftType);
    this.recordConsumer = recordConsumer;
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeMessageBegin(org.apache.thrift.protocol.TMessage)
   */
  @Override
  public void writeMessageBegin(TMessage message) throws TException {
    LOG.debug("writeMessageBegin({})", message);
    currentProtocol.writeMessageBegin(message);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeMessageEnd()
   */
  @Override
  public void writeMessageEnd() throws TException {
    LOG.debug("writeMessageEnd()");
    currentProtocol.writeMessageEnd();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeStructBegin(org.apache.thrift.protocol.TStruct)
   */
  @Override
  public void writeStructBegin(TStruct struct) throws TException {
    if (LOG.isDebugEnabled()) LOG.debug("writeStructBegin(" + toString(struct) + ")");
    currentProtocol.writeStructBegin(struct);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeStructEnd()
   */
  @Override
  public void writeStructEnd() throws TException {
    LOG.debug("writeStructEnd()");
    currentProtocol.writeStructEnd();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeFieldBegin(org.apache.thrift.protocol.TField)
   */
  @Override
  public void writeFieldBegin(TField field) throws TException {
    LOG.debug("writeFieldBegin({})", field);
    currentProtocol.writeFieldBegin(field);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeFieldEnd()
   */
  @Override
  public void writeFieldEnd() throws TException {
    LOG.debug("writeFieldEnd()");
    currentProtocol.writeFieldEnd();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeFieldStop()
   */
  @Override
  public void writeFieldStop() throws TException {
    LOG.debug("writeFieldStop()");
    currentProtocol.writeFieldStop();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeMapBegin(org.apache.thrift.protocol.TMap)
   */
  @Override
  public void writeMapBegin(TMap map) throws TException {
    if (LOG.isDebugEnabled()) LOG.debug("writeMapBegin(" + toString(map) + ")");
    currentProtocol.writeMapBegin(map);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeMapEnd()
   */
  @Override
  public void writeMapEnd() throws TException {
    LOG.debug("writeMapEnd()");
    currentProtocol.writeMapEnd();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeListBegin(org.apache.thrift.protocol.TList)
   */
  @Override
  public void writeListBegin(TList list) throws TException {
    if (LOG.isDebugEnabled()) LOG.debug("writeListBegin(" + toString(list) + ")");
    currentProtocol.writeListBegin(list);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeListEnd()
   */
  @Override
  public void writeListEnd() throws TException {
    LOG.debug("writeListEnd()");
    currentProtocol.writeListEnd();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeSetBegin(org.apache.thrift.protocol.TSet)
   */
  @Override
  public void writeSetBegin(TSet set) throws TException {
    LOG.debug("writeSetBegin({})", set);
    currentProtocol.writeSetBegin(set);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeSetEnd()
   */
  @Override
  public void writeSetEnd() throws TException {
    LOG.debug("writeSetEnd()");
    currentProtocol.writeSetEnd();
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeBool(boolean)
   */
  @Override
  public void writeBool(boolean b) throws TException {
    LOG.debug("writeBool({})", b);
    currentProtocol.writeBool(b);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeByte(byte)
   */
  @Override
  public void writeByte(byte b) throws TException {
    LOG.debug("writeByte({})", b);
    currentProtocol.writeByte(b);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeI16(short)
   */
  @Override
  public void writeI16(short i16) throws TException {
    LOG.debug("writeI16({})", i16);
    currentProtocol.writeI16(i16);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeI32(int)
   */
  @Override
  public void writeI32(int i32) throws TException {
    LOG.debug("writeI32({})", i32);
    currentProtocol.writeI32(i32);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeI64(long)
   */
  @Override
  public void writeI64(long i64) throws TException {
    LOG.debug("writeI64({})", i64);
    currentProtocol.writeI64(i64);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeDouble(double)
   */
  @Override
  public void writeDouble(double dub) throws TException {
    LOG.debug("writeDouble({})", dub);
    currentProtocol.writeDouble(dub);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeString(java.lang.String)
   */
  @Override
  public void writeString(String str) throws TException {
    LOG.debug("writeString({})", str);
    currentProtocol.writeString(str);
  }

  /**
   * {@inheritDoc}
   *
   * @see org.apache.parquet.thrift.ParquetProtocol#writeBinary(java.nio.ByteBuffer)
   */
  @Override
  public void writeBinary(ByteBuffer buf) throws TException {
    LOG.debug("writeBinary({})", buf);
    currentProtocol.writeBinary(buf);
  }

  private void writeBinaryToRecordConsumer(ByteBuffer buf) {
    recordConsumer.addBinary(Binary.fromReusedByteArray(buf.array(), buf.position(), buf.limit() - buf.position()));
  }

  private void writeStringToRecordConsumer(String str) {
    recordConsumer.addBinary(Binary.fromString(str));
  }

  private TProtocol getProtocol(ThriftField field, ColumnIO columnIO, Events returnClause) {
    TProtocol p;
    final ThriftType type = field.getType();
    switch (type.getType()) {
      case STOP:
      case VOID:
      default:
        throw new UnsupportedOperationException("can't convert type of " + field);
      case BOOL:
      case BYTE:
      case DOUBLE:
      case I16:
      case I32:
      case I64:
      case STRING:
      case UUID:
        p = new PrimitiveWriteProtocol((PrimitiveColumnIO) columnIO, returnClause);
        break;
      case STRUCT:
        p = new StructWriteProtocol((GroupColumnIO) columnIO, (StructType) type, returnClause);
        break;
      case MAP:
        p = new MapWriteProtocol((GroupColumnIO) columnIO, (MapType) type, returnClause);
        break;
      case SET:
        p = getListWriteProtocol((GroupColumnIO) columnIO, ((SetType) type).getValues(), returnClause);
        break;
      case LIST:
        p = getListWriteProtocol((GroupColumnIO) columnIO, ((ListType) type).getValues(), returnClause);
        break;
      case ENUM:
        p = new EnumWriteProtocol((PrimitiveColumnIO) columnIO, (EnumType) type, returnClause);
        break;
    }
    return p;
  }

  private TProtocol getListWriteProtocol(GroupColumnIO columnIO, ThriftField values, Events returnClause) {
    return writeThreeLevelList
        ? new ThreeLevelListWriteProtocol(columnIO, values, returnClause)
        : new TwoLevelListWriteProtocol(columnIO, values, returnClause);
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy