parquet.thrift.ParquetWriteProtocol Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package parquet.thrift;
import static parquet.Log.DEBUG;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TStruct;
import org.apache.thrift.protocol.TType;
import parquet.Log;
import parquet.io.ColumnIO;
import parquet.io.GroupColumnIO;
import parquet.io.MessageColumnIO;
import parquet.io.ParquetEncodingException;
import parquet.io.PrimitiveColumnIO;
import parquet.io.api.Binary;
import parquet.io.api.RecordConsumer;
import parquet.thrift.struct.ThriftField;
import parquet.thrift.struct.ThriftType;
import parquet.thrift.struct.ThriftType.EnumType;
import parquet.thrift.struct.ThriftType.EnumValue;
import parquet.thrift.struct.ThriftType.ListType;
import parquet.thrift.struct.ThriftType.MapType;
import parquet.thrift.struct.ThriftType.SetType;
import parquet.thrift.struct.ThriftType.StructType;
public class ParquetWriteProtocol extends ParquetProtocol {
interface Events {
public void start();
public void end();
}
abstract class FieldBaseWriteProtocol extends ParquetProtocol {
private final Events returnClause;
public FieldBaseWriteProtocol(Events returnClause) {
this.returnClause = returnClause;
}
void start() {
this.returnClause.start();
}
void end() {
this.returnClause.end();
}
}
class EnumWriteProtocol extends FieldBaseWriteProtocol {
private final EnumType type;
private PrimitiveColumnIO columnIO;
public EnumWriteProtocol(PrimitiveColumnIO columnIO, EnumType type, Events returnClause) {
super(returnClause);
this.columnIO = columnIO;
this.type = type;
}
@Override
public void writeI32(int i32) throws TException {
start();
EnumValue value = type.getEnumValueById(i32);
if (value == null) {
throw new ParquetEncodingException("Can not find enum value of index " + i32 + " for field:" + columnIO.toString());
}
recordConsumer.addBinary(Binary.fromString(value.getName()));
end();
}
}
class ListWriteProtocol extends FieldBaseWriteProtocol {
private ColumnIO listContent;
private TProtocol contentProtocol;
private int size;
public ListWriteProtocol(GroupColumnIO columnIO, ThriftField values, Events returnClause) {
super(returnClause);
this.listContent = columnIO.getChild(0);
this.contentProtocol = getProtocol(values, listContent, new Events() {
int consumedRecords = 0;
@Override
public void start() {
}
@Override
public void end() {
++ consumedRecords;
if (consumedRecords == size) {
currentProtocol = ListWriteProtocol.this;
consumedRecords = 0;
}
}
});
}
private void startListWrapper() {
start();
recordConsumer.startGroup();
if (size > 0) {
recordConsumer.startField(listContent.getType().getName(), 0);
currentProtocol = contentProtocol;
}
}
private void endListWrapper() {
if (size > 0) {
recordConsumer.endField(listContent.getType().getName(), 0);
}
recordConsumer.endGroup();
end();
}
@Override
public void writeListBegin(TList list) throws TException {
size = list.size;
startListWrapper();
}
@Override
public void writeListEnd() throws TException {
endListWrapper();
}
@Override
public void writeSetBegin(TSet set) throws TException {
size = set.size;
startListWrapper();
}
@Override
public void writeSetEnd() throws TException {
endListWrapper();
}
}
class MapWriteProtocol extends FieldBaseWriteProtocol {
private GroupColumnIO mapContent;
private ColumnIO key;
private ColumnIO value;
private TProtocol keyProtocol;
private TProtocol valueProtocol;
private int countToConsume;
public MapWriteProtocol(GroupColumnIO columnIO, MapType type, Events returnClause) {
super(returnClause);
this.mapContent = (GroupColumnIO)columnIO.getChild(0);
this.key = mapContent.getChild(0);
this.value = mapContent.getChild(1);
this.keyProtocol = getProtocol(type.getKey(), this.key, new Events() {
@Override
public void start() {
recordConsumer.startGroup();
recordConsumer.startField(key.getName(), key.getIndex());
}
@Override
public void end() {
recordConsumer.endField(key.getName(), key.getIndex());
currentProtocol = valueProtocol;
}
});
this.valueProtocol = getProtocol(type.getValue(), this.value, new Events() {
int consumed;
@Override
public void start() {
recordConsumer.startField(value.getName(), value.getIndex());
}
@Override
public void end() {
consumed ++;
recordConsumer.endField(value.getName(), value.getIndex());
recordConsumer.endGroup();
if (consumed == countToConsume) {
currentProtocol = MapWriteProtocol.this;
consumed = 0;
} else {
currentProtocol = keyProtocol;
}
}
});
}
@Override
public void writeMapBegin(TMap map) throws TException {
start();
recordConsumer.startGroup();
countToConsume = map.size;
if (countToConsume > 0) {
recordConsumer.startField(mapContent.getType().getName(), 0);
currentProtocol = keyProtocol;
}
}
@Override
public void writeMapEnd() throws TException {
if (countToConsume > 0) {
recordConsumer.endField(mapContent.getType().getName(), 0);
}
recordConsumer.endGroup();
end();
}
}
class PrimitiveWriteProtocol extends FieldBaseWriteProtocol {
public PrimitiveWriteProtocol(PrimitiveColumnIO columnIO, Events returnClause) {
super(returnClause);
}
@Override
public void writeBool(boolean b) throws TException {
start();
recordConsumer.addBoolean(b);
end();
}
@Override
public void writeByte(byte b) throws TException {
start();
recordConsumer.addInteger(b);
end();
}
@Override
public void writeI16(short i16) throws TException {
start();
recordConsumer.addInteger(i16);
end();
}
@Override
public void writeI32(int i32) throws TException {
start();
recordConsumer.addInteger(i32);
end();
}
@Override
public void writeI64(long i64) throws TException {
start();
recordConsumer.addLong(i64);
end();
}
@Override
public void writeDouble(double dub) throws TException {
start();
recordConsumer.addDouble(dub);
end();
}
@Override
public void writeString(String str) throws TException {
start();
writeStringToRecordConsumer(str);
end();
}
@Override
public void writeBinary(ByteBuffer buf) throws TException {
start();
writeBinaryToRecordConsumer(buf);
end();
}
}
class StructWriteProtocol extends FieldBaseWriteProtocol {
private final GroupColumnIO schema;
private final StructType thriftType;
private final TProtocol[] children;
private ColumnIO currentType;
private ColumnIO[] thriftFieldIdToParquetField;
public StructWriteProtocol(GroupColumnIO schema, StructType thriftType, Events returnClause) {
super(returnClause);
if (schema == null) {
throw new NullPointerException("schema");
}
this.thriftType = thriftType;
int maxFieldId = 0;
for (ThriftField field : thriftType.getChildren()) {
maxFieldId = Math.max(maxFieldId, field.getFieldId());
}
thriftFieldIdToParquetField = new ColumnIO[maxFieldId + 1];
for (int i = 0; i < thriftType.getChildren().size(); i++) {
thriftFieldIdToParquetField[thriftType.getChildren().get(i).getFieldId()] = schema.getChild(i);
}
for (ThriftField field : thriftType.getChildren()) {
}
this.schema = schema;
children = new TProtocol[thriftType.getChildren().size()];
for (int i = 0; i < children.length; i++) {
final ThriftField field = thriftType.getChildren().get(i);
final ColumnIO columnIO = schema.getChild(field.getName());
if (columnIO == null) {
throw new RuntimeException("Could not find " + field.getName() + " in " + schema);
}
try {
TProtocol p;
p = getProtocol(field, columnIO, new Events() {
@Override
public void start() {
}
@Override
public void end() {
currentProtocol = StructWriteProtocol.this;
}
});
children[i] = p;
} catch (RuntimeException e) {
throw new ParquetEncodingException("Could not create Protocol for " + field + " to " + columnIO, e);
}
}
}
@Override
public void writeStructBegin(TStruct struct) throws TException {
start();
recordConsumer.startGroup();
}
@Override
public void writeStructEnd() throws TException {
recordConsumer.endGroup();
end();
}
@Override
public void writeFieldBegin(TField field) throws TException {
if (field.type == TType.STOP) {
return;
}
try {
currentType = thriftFieldIdToParquetField[field.id];
if (currentType == null) {
throw new ParquetEncodingException("field " + field.id + " was not found in " + thriftType + " and " + schema.getType());
}
final int index = currentType.getIndex();
recordConsumer.startField(currentType.getName(), index);
currentProtocol = children[index];
} catch (ArrayIndexOutOfBoundsException e) {
throw new ParquetEncodingException("field " + field.id + " was not found in " + thriftType + " and " + schema.getType());
}
}
@Override
public void writeFieldStop() throws TException {
// duplicate with struct end
}
@Override
public void writeFieldEnd() throws TException {
recordConsumer.endField(currentType.getName(), currentType.getIndex());
}
}
class MessageWriteProtocol extends StructWriteProtocol {
public MessageWriteProtocol(MessageColumnIO schema, StructType thriftType) {
super(schema, thriftType, null);
}
@Override
public void writeStructBegin(TStruct struct) throws TException {
recordConsumer.startMessage();
}
@Override
public void writeStructEnd() throws TException {
recordConsumer.endMessage();
}
}
private static final Log LOG = Log.getLog(ParquetWriteProtocol.class);
private final RecordConsumer recordConsumer;
private TProtocol currentProtocol;
private String toString(TStruct struct) {
return "";
}
private String toString(TList list) {
return "";
}
private String toString(TMap map) {
return "";
}
public ParquetWriteProtocol(RecordConsumer recordConsumer, MessageColumnIO schema, StructType thriftType) {
this.currentProtocol = new MessageWriteProtocol(schema, thriftType);
this.recordConsumer = recordConsumer;
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeMessageBegin(org.apache.thrift.protocol.TMessage)
*/
@Override
public void writeMessageBegin(TMessage message) throws TException {
if (DEBUG) LOG.debug("writeMessageBegin("+message+")");
currentProtocol.writeMessageBegin(message);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeMessageEnd()
*/
@Override
public void writeMessageEnd() throws TException {
if (DEBUG) LOG.debug("writeMessageEnd()");
currentProtocol.writeMessageEnd();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeStructBegin(org.apache.thrift.protocol.TStruct)
*/
@Override
public void writeStructBegin(TStruct struct) throws TException {
if (DEBUG) LOG.debug("writeStructBegin("+toString(struct)+")");
currentProtocol.writeStructBegin(struct);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeStructEnd()
*/
@Override
public void writeStructEnd() throws TException {
if (DEBUG) LOG.debug("writeStructEnd()");
currentProtocol.writeStructEnd();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeFieldBegin(org.apache.thrift.protocol.TField)
*/
@Override
public void writeFieldBegin(TField field) throws TException {
if (DEBUG) LOG.debug("writeFieldBegin("+field+")");
currentProtocol.writeFieldBegin(field);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeFieldEnd()
*/
@Override
public void writeFieldEnd() throws TException {
if (DEBUG) LOG.debug("writeFieldEnd()");
currentProtocol.writeFieldEnd();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeFieldStop()
*/
@Override
public void writeFieldStop() throws TException {
if (DEBUG) LOG.debug("writeFieldStop()");
currentProtocol.writeFieldStop();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeMapBegin(org.apache.thrift.protocol.TMap)
*/
@Override
public void writeMapBegin(TMap map) throws TException {
if (DEBUG) LOG.debug("writeMapBegin("+toString(map)+")");
currentProtocol.writeMapBegin(map);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeMapEnd()
*/
@Override
public void writeMapEnd() throws TException {
if (DEBUG) LOG.debug("writeMapEnd()");
currentProtocol.writeMapEnd();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeListBegin(org.apache.thrift.protocol.TList)
*/
@Override
public void writeListBegin(TList list) throws TException {
if (DEBUG) LOG.debug("writeListBegin("+toString(list)+")");
currentProtocol.writeListBegin(list);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeListEnd()
*/
@Override
public void writeListEnd() throws TException {
if (DEBUG) LOG.debug("writeListEnd()");
currentProtocol.writeListEnd();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeSetBegin(org.apache.thrift.protocol.TSet)
*/
@Override
public void writeSetBegin(TSet set) throws TException {
if (DEBUG) LOG.debug("writeSetBegin("+set+")");
currentProtocol.writeSetBegin(set);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeSetEnd()
*/
@Override
public void writeSetEnd() throws TException {
if (DEBUG) LOG.debug("writeSetEnd()");
currentProtocol.writeSetEnd();
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeBool(boolean)
*/
@Override
public void writeBool(boolean b) throws TException {
if (DEBUG) LOG.debug("writeBool("+b+")");
currentProtocol.writeBool(b);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeByte(byte)
*/
@Override
public void writeByte(byte b) throws TException {
if (DEBUG) LOG.debug("writeByte("+b+")");
currentProtocol.writeByte(b);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeI16(short)
*/
@Override
public void writeI16(short i16) throws TException {
if (DEBUG) LOG.debug("writeI16("+i16+")");
currentProtocol.writeI16(i16);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeI32(int)
*/
@Override
public void writeI32(int i32) throws TException {
if (DEBUG) LOG.debug("writeI32("+i32+")");
currentProtocol.writeI32(i32);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeI64(long)
*/
@Override
public void writeI64(long i64) throws TException {
if (DEBUG) LOG.debug("writeI64("+i64+")");
currentProtocol.writeI64(i64);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeDouble(double)
*/
@Override
public void writeDouble(double dub) throws TException {
if (DEBUG) LOG.debug("writeDouble("+dub+")");
currentProtocol.writeDouble(dub);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeString(java.lang.String)
*/
@Override
public void writeString(String str) throws TException {
if (DEBUG) LOG.debug("writeString("+str+")");
currentProtocol.writeString(str);
}
/**
* {@inheritDoc}
* @see parquet.thrift.ParquetProtocol#writeBinary(java.nio.ByteBuffer)
*/
@Override
public void writeBinary(ByteBuffer buf) throws TException {
if (DEBUG) LOG.debug("writeBinary("+buf+")");
currentProtocol.writeBinary(buf);
}
private void writeBinaryToRecordConsumer(ByteBuffer buf) {
recordConsumer.addBinary(Binary.fromByteArray(buf.array(), buf.position(), buf.limit() - buf.position()));
}
private void writeStringToRecordConsumer(String str) {
recordConsumer.addBinary(Binary.fromString(str));
}
private TProtocol getProtocol(ThriftField field, ColumnIO columnIO, Events returnClause) {
TProtocol p;
final ThriftType type = field.getType();
switch (type.getType()) {
case STOP:
case VOID:
default:
throw new UnsupportedOperationException("can't convert type of " + field);
case BOOL:
case BYTE:
case DOUBLE:
case I16:
case I32:
case I64:
case STRING:
p = new PrimitiveWriteProtocol((PrimitiveColumnIO)columnIO, returnClause);
break;
case STRUCT:
p = new StructWriteProtocol((GroupColumnIO)columnIO, (StructType)type, returnClause);
break;
case MAP:
p = new MapWriteProtocol((GroupColumnIO)columnIO, (MapType)type, returnClause);
break;
case SET:
p = new ListWriteProtocol((GroupColumnIO)columnIO, ((SetType)type).getValues(), returnClause);
break;
case LIST:
p = new ListWriteProtocol((GroupColumnIO)columnIO, ((ListType)type).getValues(), returnClause);
break;
case ENUM:
p = new EnumWriteProtocol((PrimitiveColumnIO)columnIO, (EnumType)type, returnClause);
break;
}
return p;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy