All Downloads are FREE. Search and download functionalities are using the official Maven repository.

me.lyh.parquet.tensorflow.ExampleReadSupport Maven / Gradle / Ivy

The newest version!
package me.lyh.parquet.tensorflow;

import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.Preconditions;
import org.apache.parquet.hadoop.api.InitContext;
import org.apache.parquet.hadoop.api.ReadSupport;
import org.apache.parquet.io.api.GroupConverter;
import org.apache.parquet.io.api.RecordMaterializer;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.MessageTypeParser;
import org.apache.parquet.schema.Type;
import org.apache.parquet.schema.Types;
import org.tensorflow.proto.example.Example;

import java.util.*;
import java.util.stream.Collectors;

public class ExampleReadSupport extends ReadSupport {
  private Schema schema;
  private Set fields;

  public ExampleReadSupport() {}

  public ExampleReadSupport(Schema schema) {
    this.schema = schema;
  }

  public ExampleReadSupport(Collection fields) {
    this.fields = new HashSet<>(fields);
  }

  @Override
  public ReadContext init(InitContext context) {
    MessageType messageType;
    if (schema != null) {
      messageType = schema.toParquet();
    } else if (fields != null) {
      messageType = projectFileSchema(context, fields);
    } else {
      String schemaString = context.getConfiguration().get(ExampleParquetInputFormat.SCHEMA_KEY);
      String fieldsString = context.getConfiguration().get(ExampleParquetInputFormat.FIELDS_KEY);
      if (schemaString != null) {
        messageType = MessageTypeParser.parseMessageType(schemaString);
      } else if (fieldsString != null) {
        fields = Arrays.stream(fieldsString.split(",")).collect(Collectors.toSet());
        messageType = projectFileSchema(context, fields);
      } else {
        messageType = context.getFileSchema();
      }
    }

    return new ReadContext(messageType, Collections.emptyMap());
  }

  @Override
  public RecordMaterializer prepareForRead(Configuration configuration,
                                                    Map keyValueMetaData,
                                                    MessageType fileSchema,
                                                    ReadContext readContext) {
    return new RecordMaterializer() {
      private ExampleConverter exampleConverter =
          new ExampleConverter(Schema.fromParquet(readContext.getRequestedSchema()));

      @Override
      public Example getCurrentRecord() {
        return exampleConverter.get();
      }

      @Override
      public GroupConverter getRootConverter() {
        return exampleConverter;
      }
    };
  }

  private static MessageType projectFileSchema(InitContext context, Set fields) {
    MessageType fileSchema = context.getFileSchema();
    Set unmatched = new TreeSet<>(fields);

    Types.MessageTypeBuilder builder = Types.buildMessage();
    for (Type field : fileSchema.getFields()) {
      if (unmatched.contains(field.getName())) {
        builder.addField(field);
        unmatched.remove(field.getName());
      }
    }

    Preconditions.checkState(unmatched.isEmpty(), "Invalid fields: " + unmatched);
    return builder.named(fileSchema.getName());
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy