All Downloads are FREE. Search and download functionalities are using the official Maven repository.

parquet.thrift.ParquetWriteProtocol Maven / Gradle / Ivy

/* 
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package parquet.thrift;

import static parquet.Log.DEBUG;

import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

import org.apache.thrift.TException;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TStruct;
import org.apache.thrift.protocol.TType;

import parquet.Log;
import parquet.io.ColumnIO;
import parquet.io.GroupColumnIO;
import parquet.io.MessageColumnIO;
import parquet.io.ParquetEncodingException;
import parquet.io.PrimitiveColumnIO;
import parquet.io.api.Binary;
import parquet.io.api.RecordConsumer;
import parquet.thrift.struct.ThriftField;
import parquet.thrift.struct.ThriftType;
import parquet.thrift.struct.ThriftType.EnumType;
import parquet.thrift.struct.ThriftType.EnumValue;
import parquet.thrift.struct.ThriftType.ListType;
import parquet.thrift.struct.ThriftType.MapType;
import parquet.thrift.struct.ThriftType.SetType;
import parquet.thrift.struct.ThriftType.StructType;

public class ParquetWriteProtocol extends ParquetProtocol {

  interface Events {

    public void start();

    public void end();

  }

  abstract class FieldBaseWriteProtocol extends ParquetProtocol {
    private final Events returnClause;

    public FieldBaseWriteProtocol(Events returnClause) {
      this.returnClause = returnClause;
    }

    void start() {
      this.returnClause.start();
    }

    void end() {
      this.returnClause.end();
    }
  }

  class EnumWriteProtocol extends FieldBaseWriteProtocol {

    private final EnumType type;
    private PrimitiveColumnIO columnIO;
    public EnumWriteProtocol(PrimitiveColumnIO columnIO, EnumType type, Events returnClause) {
      super(returnClause);
      this.columnIO = columnIO;
      this.type = type;
    }

    @Override
    public void writeI32(int i32) throws TException {
      start();
      EnumValue value = type.getEnumValueById(i32);
      if (value == null) {
        throw new ParquetEncodingException("Can not find enum value of index " + i32 + " for field:" + columnIO.toString());
      }
      recordConsumer.addBinary(Binary.fromString(value.getName()));
      end();
    }

  }

  class ListWriteProtocol extends FieldBaseWriteProtocol {

    private ColumnIO listContent;
    private TProtocol contentProtocol;
    private int size;

    public ListWriteProtocol(GroupColumnIO columnIO, ThriftField values, Events returnClause) {
      super(returnClause);
      this.listContent = columnIO.getChild(0);
      this.contentProtocol = getProtocol(values, listContent, new Events() {
        int consumedRecords = 0;
        @Override
        public void start() {
        }

        @Override
        public void end() {
          ++ consumedRecords;
          if (consumedRecords == size) {
            currentProtocol = ListWriteProtocol.this;
            consumedRecords = 0;
          }
        }
      });
    }

    private void startListWrapper() {
      start();
      recordConsumer.startGroup();
      if (size > 0) {
        recordConsumer.startField(listContent.getType().getName(), 0);
        currentProtocol = contentProtocol;
      }
    }

    private void endListWrapper() {
      if (size > 0) {
        recordConsumer.endField(listContent.getType().getName(), 0);
      }
      recordConsumer.endGroup();
      end();
    }

    @Override
    public void writeListBegin(TList list) throws TException {
      size = list.size;
      startListWrapper();
    }

    @Override
    public void writeListEnd() throws TException {
      endListWrapper();
    }

    @Override
    public void writeSetBegin(TSet set) throws TException {
      size = set.size;
      startListWrapper();
    }

    @Override
    public void writeSetEnd() throws TException {
      endListWrapper();
    }

  }

  class MapWriteProtocol extends FieldBaseWriteProtocol {

    private GroupColumnIO mapContent;
    private ColumnIO key;
    private ColumnIO value;
    private TProtocol keyProtocol;
    private TProtocol valueProtocol;
    private int countToConsume;

    public MapWriteProtocol(GroupColumnIO columnIO, MapType type, Events returnClause) {
      super(returnClause);
      this.mapContent = (GroupColumnIO)columnIO.getChild(0);
      this.key = mapContent.getChild(0);
      this.value = mapContent.getChild(1);
      this.keyProtocol = getProtocol(type.getKey(), this.key, new Events() {
        @Override
        public void start() {
          recordConsumer.startGroup();
          recordConsumer.startField(key.getName(), key.getIndex());
        }
        @Override
        public void end() {
          recordConsumer.endField(key.getName(), key.getIndex());
          currentProtocol = valueProtocol;
        }
      });
      this.valueProtocol = getProtocol(type.getValue(), this.value, new Events() {
        int consumed;
        @Override
        public void start() {
          recordConsumer.startField(value.getName(), value.getIndex());
        }
        @Override
        public void end() {
          consumed ++;
          recordConsumer.endField(value.getName(), value.getIndex());
          recordConsumer.endGroup();
          if (consumed == countToConsume) {
            currentProtocol = MapWriteProtocol.this;
            consumed = 0;
          } else {
            currentProtocol = keyProtocol;
          }

        }
      });
    }

    @Override
    public void writeMapBegin(TMap map) throws TException {
      start();
      recordConsumer.startGroup();
      countToConsume = map.size;
      if (countToConsume > 0) {
        recordConsumer.startField(mapContent.getType().getName(), 0);
        currentProtocol = keyProtocol;
      }
    }

    @Override
    public void writeMapEnd() throws TException {
      if (countToConsume > 0) {
        recordConsumer.endField(mapContent.getType().getName(), 0);
      }
      recordConsumer.endGroup();
      end();
    }

  }

  class PrimitiveWriteProtocol extends FieldBaseWriteProtocol {

    public PrimitiveWriteProtocol(PrimitiveColumnIO columnIO, Events returnClause) {
      super(returnClause);
    }

    @Override
    public void writeBool(boolean b) throws TException {
      start();
      recordConsumer.addBoolean(b);
      end();
    }

    @Override
    public void writeByte(byte b) throws TException {
      start();
      recordConsumer.addInteger(b);
      end();
    }

    @Override
    public void writeI16(short i16) throws TException {
      start();
      recordConsumer.addInteger(i16);
      end();
    }

    @Override
    public void writeI32(int i32) throws TException {
      start();
      recordConsumer.addInteger(i32);
      end();
    }

    @Override
    public void writeI64(long i64) throws TException {
      start();
      recordConsumer.addLong(i64);
      end();
    }

    @Override
    public void writeDouble(double dub) throws TException {
      start();
      recordConsumer.addDouble(dub);
      end();
    }

    @Override
    public void writeString(String str) throws TException {
      start();
      writeStringToRecordConsumer(str);
      end();
    }

    @Override
    public void writeBinary(ByteBuffer buf) throws TException {
      start();
      writeBinaryToRecordConsumer(buf);
      end();
    }

  }

  class StructWriteProtocol extends FieldBaseWriteProtocol {

    private final GroupColumnIO schema;
    private final StructType thriftType;
    private final TProtocol[] children;
    private ColumnIO currentType;
    private ColumnIO[] thriftFieldIdToParquetField;

    public StructWriteProtocol(GroupColumnIO schema, StructType thriftType, Events returnClause) {
      super(returnClause);
      if (schema == null) {
        throw new NullPointerException("schema");
      }
      this.thriftType = thriftType;
      int maxFieldId = 0;
      for (ThriftField field : thriftType.getChildren()) {
        maxFieldId = Math.max(maxFieldId, field.getFieldId());
      }
      thriftFieldIdToParquetField = new ColumnIO[maxFieldId + 1];
      for (int i = 0; i < thriftType.getChildren().size(); i++) {
        thriftFieldIdToParquetField[thriftType.getChildren().get(i).getFieldId()] = schema.getChild(i);
      }
      for (ThriftField field : thriftType.getChildren()) {
      }
      this.schema = schema;
      children = new TProtocol[thriftType.getChildren().size()];
      for (int i = 0; i < children.length; i++) {
        final ThriftField field = thriftType.getChildren().get(i);
        final ColumnIO columnIO = schema.getChild(field.getName());
        if (columnIO == null) {
          throw new RuntimeException("Could not find " + field.getName() + " in " + schema);
        }
        try {
          TProtocol p;
          p = getProtocol(field, columnIO, new Events() {
            @Override
            public void start() {
            }

            @Override
            public void end() {
              currentProtocol = StructWriteProtocol.this;
            }
          });
          children[i] = p;
        } catch (RuntimeException e) {
          throw new ParquetEncodingException("Could not create Protocol for " + field + " to " + columnIO, e);
        }
      }
    }

    @Override
    public void writeStructBegin(TStruct struct) throws TException {
      start();
      recordConsumer.startGroup();
    }

    @Override
    public void writeStructEnd() throws TException {
      recordConsumer.endGroup();
      end();
    }

    @Override
    public void writeFieldBegin(TField field) throws TException {
      if (field.type == TType.STOP) {
        return;
      }
      try {
        currentType = thriftFieldIdToParquetField[field.id];
        if (currentType == null) {
          throw new ParquetEncodingException("field " + field.id + " was not found in " + thriftType + " and " + schema.getType());
        }
        final int index = currentType.getIndex();
        recordConsumer.startField(currentType.getName(), index);
        currentProtocol = children[index];
      } catch (ArrayIndexOutOfBoundsException e) {
        throw new ParquetEncodingException("field " + field.id + " was not found in " + thriftType + " and " + schema.getType());
      }
    }

    @Override
    public void writeFieldStop() throws TException {
      // duplicate with struct end
    }

    @Override
    public void writeFieldEnd() throws TException {
      recordConsumer.endField(currentType.getName(), currentType.getIndex());
    }
  }

  class MessageWriteProtocol extends StructWriteProtocol {

    public MessageWriteProtocol(MessageColumnIO schema, StructType thriftType) {
      super(schema, thriftType, null);
    }

    @Override
    public void writeStructBegin(TStruct struct) throws TException {
      recordConsumer.startMessage();
    }

    @Override
    public void writeStructEnd() throws TException {
      recordConsumer.endMessage();
    }

  }

  private static final Log LOG = Log.getLog(ParquetWriteProtocol.class);


  private final RecordConsumer recordConsumer;
  private TProtocol currentProtocol;

  private String toString(TStruct struct) {
    return "";
  }

  private String toString(TList list) {
    return "";
  }

  private String toString(TMap map) {
    return "";
  }

  public ParquetWriteProtocol(RecordConsumer recordConsumer, MessageColumnIO schema, StructType thriftType) {
    this.currentProtocol = new MessageWriteProtocol(schema, thriftType);
    this.recordConsumer = recordConsumer;
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeMessageBegin(org.apache.thrift.protocol.TMessage)
   */
  @Override
  public void writeMessageBegin(TMessage message) throws TException {
    if (DEBUG) LOG.debug("writeMessageBegin("+message+")");
    currentProtocol.writeMessageBegin(message);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeMessageEnd()
   */
  @Override
  public void writeMessageEnd() throws TException {
    if (DEBUG) LOG.debug("writeMessageEnd()");
    currentProtocol.writeMessageEnd();
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeStructBegin(org.apache.thrift.protocol.TStruct)
   */
  @Override
  public void writeStructBegin(TStruct struct) throws TException {
    if (DEBUG) LOG.debug("writeStructBegin("+toString(struct)+")");
    currentProtocol.writeStructBegin(struct);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeStructEnd()
   */
  @Override
  public void writeStructEnd() throws TException {
    if (DEBUG) LOG.debug("writeStructEnd()");
    currentProtocol.writeStructEnd();
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeFieldBegin(org.apache.thrift.protocol.TField)
   */
  @Override
  public void writeFieldBegin(TField field) throws TException {
    if (DEBUG) LOG.debug("writeFieldBegin("+field+")");
    currentProtocol.writeFieldBegin(field);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeFieldEnd()
   */
  @Override
  public void writeFieldEnd() throws TException {
    if (DEBUG) LOG.debug("writeFieldEnd()");
    currentProtocol.writeFieldEnd();
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeFieldStop()
   */
  @Override
  public void writeFieldStop() throws TException {
    if (DEBUG) LOG.debug("writeFieldStop()");
    currentProtocol.writeFieldStop();
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeMapBegin(org.apache.thrift.protocol.TMap)
   */
  @Override
  public void writeMapBegin(TMap map) throws TException {
    if (DEBUG) LOG.debug("writeMapBegin("+toString(map)+")");
    currentProtocol.writeMapBegin(map);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeMapEnd()
   */
  @Override
  public void writeMapEnd() throws TException {
    if (DEBUG) LOG.debug("writeMapEnd()");
    currentProtocol.writeMapEnd();
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeListBegin(org.apache.thrift.protocol.TList)
   */
  @Override
  public void writeListBegin(TList list) throws TException {
    if (DEBUG) LOG.debug("writeListBegin("+toString(list)+")");
    currentProtocol.writeListBegin(list);
  }


  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeListEnd()
   */
  @Override
  public void writeListEnd() throws TException {
    if (DEBUG) LOG.debug("writeListEnd()");
    currentProtocol.writeListEnd();
  }


  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeSetBegin(org.apache.thrift.protocol.TSet)
   */
  @Override
  public void writeSetBegin(TSet set) throws TException {
    if (DEBUG) LOG.debug("writeSetBegin("+set+")");
    currentProtocol.writeSetBegin(set);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeSetEnd()
   */
  @Override
  public void writeSetEnd() throws TException {
    if (DEBUG) LOG.debug("writeSetEnd()");
    currentProtocol.writeSetEnd();
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeBool(boolean)
   */
  @Override
  public void writeBool(boolean b) throws TException {
    if (DEBUG) LOG.debug("writeBool("+b+")");
    currentProtocol.writeBool(b);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeByte(byte)
   */
  @Override
  public void writeByte(byte b) throws TException {
    if (DEBUG) LOG.debug("writeByte("+b+")");
    currentProtocol.writeByte(b);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeI16(short)
   */
  @Override
  public void writeI16(short i16) throws TException {
    if (DEBUG) LOG.debug("writeI16("+i16+")");
    currentProtocol.writeI16(i16);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeI32(int)
   */
  @Override
  public void writeI32(int i32) throws TException {
    if (DEBUG) LOG.debug("writeI32("+i32+")");
    currentProtocol.writeI32(i32);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeI64(long)
   */
  @Override
  public void writeI64(long i64) throws TException {
    if (DEBUG) LOG.debug("writeI64("+i64+")");
    currentProtocol.writeI64(i64);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeDouble(double)
   */
  @Override
  public void writeDouble(double dub) throws TException {
    if (DEBUG) LOG.debug("writeDouble("+dub+")");
    currentProtocol.writeDouble(dub);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeString(java.lang.String)
   */
  @Override
  public void writeString(String str) throws TException {
    if (DEBUG) LOG.debug("writeString("+str+")");
    currentProtocol.writeString(str);
  }

  /**
   * {@inheritDoc}
   * @see parquet.thrift.ParquetProtocol#writeBinary(java.nio.ByteBuffer)
   */
  @Override
  public void writeBinary(ByteBuffer buf) throws TException {
    if (DEBUG) LOG.debug("writeBinary("+buf+")");
    currentProtocol.writeBinary(buf);
  }

  private void writeBinaryToRecordConsumer(ByteBuffer buf) {
    recordConsumer.addBinary(Binary.fromByteArray(buf.array(), buf.position(), buf.limit() - buf.position()));
  }

  private void writeStringToRecordConsumer(String str) {
    recordConsumer.addBinary(Binary.fromString(str));
  }

  private TProtocol getProtocol(ThriftField field, ColumnIO columnIO, Events returnClause) {
    TProtocol p;
    final ThriftType type = field.getType();
    switch (type.getType()) {
    case STOP:
    case VOID:
    default:
      throw new UnsupportedOperationException("can't convert type of " + field);
    case BOOL:
    case BYTE:
    case DOUBLE:
    case I16:
    case I32:
    case I64:
    case STRING:
      p = new PrimitiveWriteProtocol((PrimitiveColumnIO)columnIO, returnClause);
      break;
    case STRUCT:
      p = new StructWriteProtocol((GroupColumnIO)columnIO, (StructType)type, returnClause);
      break;
    case MAP:
      p = new MapWriteProtocol((GroupColumnIO)columnIO, (MapType)type, returnClause);
      break;
    case SET:
      p = new ListWriteProtocol((GroupColumnIO)columnIO, ((SetType)type).getValues(), returnClause);
      break;
    case LIST:
      p = new ListWriteProtocol((GroupColumnIO)columnIO, ((ListType)type).getValues(), returnClause);
      break;
    case ENUM:
      p = new EnumWriteProtocol((PrimitiveColumnIO)columnIO, (EnumType)type, returnClause);
      break;
    }
    return p;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy