All Downloads are FREE. Search and download functionalities are using the official Maven repository.

me.lyh.parquet.tensorflow.ExampleScanner Maven / Gradle / Ivy

The newest version!
package me.lyh.parquet.tensorflow;

import org.apache.parquet.Preconditions;
import org.tensorflow.proto.example.Example;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

public class ExampleScanner {
  private long total = 0L;
  private final String name;
  private final Set fields = new LinkedHashSet<>();
  private final Map nonZeroCounts = new HashMap<>();
  private final Map types = new HashMap<>();
  private final Map maxCounts = new HashMap<>();

  public ExampleScanner(String name) {
    this.name = name;
  }

  public ExampleScanner scan(Example example) {
    total++;

    example.getFeatures().getFeatureMap().forEach((name, feature) -> {
      fields.add(name);

      Schema.Type newType = null;
      int count = -1;
      switch (feature.getKindCase()) {
        case BYTES_LIST:
          newType = Schema.Type.BYTES;
          count = feature.getBytesList().getValueCount();
          break;
        case FLOAT_LIST:
          newType = Schema.Type.FLOAT;
          count = feature.getFloatList().getValueCount();
          break;
        case INT64_LIST:
          newType = Schema.Type.INT64;
          count = feature.getInt64List().getValueCount();
          break;
        case KIND_NOT_SET:
          count = 0;
          break;
      }
      Schema.Type type = types.get(name);
      if (type != null && newType != null) {
        Preconditions.checkArgument(
            type == newType,
            "Incompatible types for field %s: %s != %s",
            name, type, newType);
      }
      if (newType != null) {
        types.put(name, newType);
      }
      if (count != -1) {
        final int c = count;
        if (count >= 1) {
          nonZeroCounts.compute(name, (k, v) -> v == null ? 1 : v + 1);
        }
        maxCounts.compute(name, (k, v) -> v == null ? c : Math.max(v, c));
      }
    });
    return this;
  }

  public Schema getSchema() {
    Schema.Builder builder = Schema.newBuilder();
    for (String name : fields) {
      long nonZeroCount = nonZeroCounts.getOrDefault(name, 0L);

      Schema.Type type = types.get(name);
      Preconditions.checkNotNull(type, String.format("Field type for %s", name));

      int min = nonZeroCount < total ? 0 : 1;
      int max = maxCounts.get(name);

      if (min == 1 && max == 1) {
        builder = builder.required(name, type);
      } else if (min == 0 && max == 1) {
        builder = builder.optional(name, type);
      } else if (max > 1) {
        builder = builder.repeated(name, type);
      }
    }
    return builder.named(name);
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy