All Downloads are FREE. Search and download functionalities are using the official Maven repository.

me.lyh.parquet.tensorflow.Schema Maven / Gradle / Ivy

The newest version!
package me.lyh.parquet.tensorflow;

import com.google.protobuf.ByteString;
import me.lyh.parquet.tensorflow.ExampleConverter.FeatureConverter;
import org.apache.parquet.Preconditions;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.*;
import org.tensorflow.proto.example.Feature;
import org.tensorflow.proto.example.Features;
import shaded.parquet.com.fasterxml.jackson.core.JsonProcessingException;
import shaded.parquet.com.fasterxml.jackson.databind.ObjectMapper;

import java.io.IOException;
import java.util.*;

public class Schema {
  public enum Type {
    INT64(PrimitiveType.PrimitiveTypeName.INT64) {
      @Override
      void write(String name, int index, Repetition repetition, RecordConsumer recordConsumer,
                        Feature feature) {
        List xs = feature.getInt64List().getValueList();
        repetition.checkSize(xs.size());
        if (xs.size() > 0) {
          recordConsumer.startField(name, index);
          xs.forEach(recordConsumer::addLong);
          recordConsumer.endField(name, index);
        }
      }

      @Override
      FeatureConverter newConverter(Repetition repetition) {
        return new ExampleConverter.Int64Converter(repetition);
      }
    },
    FLOAT(PrimitiveType.PrimitiveTypeName.FLOAT) {
      @Override
      void write(String name, int index, Repetition repetition, RecordConsumer recordConsumer,
                 Feature feature) {
        List xs = feature.getFloatList().getValueList();
        repetition.checkSize(xs.size());
        if (xs.size() > 0) {
          recordConsumer.startField(name, index);
          xs.forEach(recordConsumer::addFloat);
          recordConsumer.endField(name, index);
        }
      }

      @Override
      FeatureConverter newConverter(Repetition repetition) {
        return new ExampleConverter.FloatConverter(repetition);
      }
    },
    BYTES(PrimitiveType.PrimitiveTypeName.BINARY) {
      @Override
      void write(String name, int index, Repetition repetition, RecordConsumer recordConsumer,
                 Feature feature) {
        List xs = feature.getBytesList().getValueList();
        repetition.checkSize(xs.size());
        if (xs.size() > 0) {
          recordConsumer.startField(name, index);
          xs.stream()
              .map(b -> Binary.fromConstantByteArray(b.toByteArray()))
              .forEach(recordConsumer::addBinary);
          recordConsumer.endField(name, index);
        }
      }

      @Override
      FeatureConverter newConverter(Repetition repetition) {
        return new ExampleConverter.BytesConverter(repetition);
      }
    };

    private final PrimitiveType.PrimitiveTypeName parquet;
    Type(PrimitiveType.PrimitiveTypeName parquet) {
      this.parquet = parquet;
    }

    public static Type fromParquet(PrimitiveType.PrimitiveTypeName parquet) {
      switch (parquet) {
        case INT64: return INT64;
        case FLOAT: return FLOAT;
        case BINARY: return BYTES;
      }
      throw new IllegalArgumentException("Unsupported primitive type: " + parquet);
    }

    abstract void write(String name, int index, Repetition repetition,
                        RecordConsumer recordConsumer, Feature feature);
    abstract FeatureConverter newConverter(Repetition repetition);
  }

  public enum Repetition {
    REQUIRED(org.apache.parquet.schema.Type.Repetition.REQUIRED) {
      @Override
      void checkSize(int count) {
        Preconditions.checkState(count == 1, "Required field size != 1: " + count);
      }
    },
    OPTIONAL(org.apache.parquet.schema.Type.Repetition.OPTIONAL) {
      @Override
      void checkSize(int count) {
        Preconditions.checkState(count <= 1, "Required field size > 1: " + count);
      }
    },
    REPEATED(org.apache.parquet.schema.Type.Repetition.REPEATED) {
      @Override
      void checkSize(int count) {}
    };

    private final org.apache.parquet.schema.Type.Repetition parquet;
    Repetition(org.apache.parquet.schema.Type.Repetition parquet) {
      this.parquet = parquet;
    }

    public static Repetition fromParquet(org.apache.parquet.schema.Type.Repetition parquet) {
      switch (parquet) {
        case REQUIRED: return REQUIRED;
        case OPTIONAL: return OPTIONAL;
        case REPEATED: return REPEATED;
      }
      throw new IllegalStateException("This should never happen");
    }
    abstract void checkSize(int count);

  }

  public static class Field {
    private String name;
    private Type type;
    private Repetition repetition;

    private Field() {}

    private Field(String name, Type type, Repetition repetition) {
      this.name = name;
      this.type = type;
      this.repetition = repetition;
    }

    public String getName() {
      return name;
    }

    public Type getType() {
      return type;
    }

    public Repetition getRepetition() {
      return repetition;
    }

    public PrimitiveType toParquet() {
      Types.PrimitiveBuilder builder =
          Types.primitive(type.parquet, repetition.parquet);
      return type.parquet == PrimitiveType.PrimitiveTypeName.INT64
          ? builder.as(LogicalTypeAnnotation.intType(64, true)).named(name)
          : builder.named(name);
    }

    public static Field fromParquet(org.apache.parquet.schema.Type parquet) {
      Preconditions.checkArgument(parquet.isPrimitive(), "Only primitive fields are supported");
      return new Field(
          parquet.getName(),
          Type.fromParquet(parquet.asPrimitiveType().getPrimitiveTypeName()),
          Repetition.fromParquet(parquet.getRepetition()));
    }

    public void write(int index, RecordConsumer recordConsumer, Features features) {
      Feature feature = features.getFeatureOrDefault(name, Feature.getDefaultInstance());
      type.write(name, index, repetition, recordConsumer, feature);
    }

    public FeatureConverter newConverter() {
      return type.newConverter(repetition);
    }
  }

  ////////////////////////////////////////

  public static class Builder {
    private final Set names;
    private final List fields;

    private Builder() {
      names = new HashSet<>();
      fields = new ArrayList<>();
    }

    public Builder required(String name, Type type) {
      return addField(name, type, Repetition.REQUIRED);
    }

    public Builder optional(String name, Type type) {
      return addField(name, type, Repetition.OPTIONAL);
    }

    public Builder repeated(String name, Type type) {
      return addField(name, type, Repetition.REPEATED);
    }

    public Schema named(String name) {
      return new Schema(name, fields);
    }

    private Builder addField(String name, Type type, Repetition repetition) {
      Preconditions.checkArgument(!names.contains(name), "Duplicate field name %s", name);
      names.add(name);
      fields.add(new Field(name, type, repetition));
      return this;
    }
  }

  public static Builder newBuilder() {
    return new Builder();
  }

  ////////////////////////////////////////

  private String name;
  private List fields;

  private Schema() {}

  private Schema(String name, List fields) {
    this.name = name;
    this.fields = fields;
  }

  public String getName() {
    return name;
  }

  public List getFields() {
    return fields;
  }

  ////////////////////////////////////////

  public MessageType toParquet() {
    Types.MessageTypeBuilder builder = Types.buildMessage();
    for (Field field : fields) {
      builder.addField(field.toParquet());
    }
    return builder.named(name);
  }

  public static Schema fromParquet(MessageType schema) {
    Builder builder = Schema.newBuilder();
    for (org.apache.parquet.schema.Type parquet : schema.getFields()) {
      Field field = Field.fromParquet(parquet);
      builder.addField(field.name, field.type, field.repetition);
    }
    return builder.named(schema.getName());
  }

  ////////////////////////////////////////

  private static final ObjectMapper mapper = new ObjectMapper();

  public String toJson() {
    try {
      return mapper.writeValueAsString(this);
    } catch (JsonProcessingException e) {
      throw new RuntimeException(e);
    }
  }

  public static Schema fromJson(String json) throws IOException {
    return mapper.readValue(json, Schema.class);
  }

  ////////////////////////////////////////

  @Override
  public String toString() {
    return toParquet().toString();
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) return true;
    if (o == null || getClass() != o.getClass()) return false;
    Schema schema = (Schema) o;
    return this.toParquet().equals(schema.toParquet());
  }

  @Override
  public int hashCode() {
    return this.toParquet().hashCode();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy