All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.parquet.tensorflow.TensorflowExampleSchemaConverter Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2023 Spotify AB
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.spotify.parquet.tensorflow;

import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*;

import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.parquet.schema.Type;
import org.apache.parquet.schema.Types;
import org.tensorflow.metadata.v0.Feature;
import org.tensorflow.metadata.v0.FeatureType;
import org.tensorflow.metadata.v0.FixedShape;
import org.tensorflow.metadata.v0.Schema;
import org.tensorflow.metadata.v0.ValueCount;
import org.tensorflow.metadata.v0.ValueCountList;

public class TensorflowExampleSchemaConverter {

  public TensorflowExampleSchemaConverter(Configuration conf) {}

  public MessageType convert(Schema tfSchema) {
    return new MessageType("example", convertFeatures(tfSchema.getFeatureList()));
  }

  private List convertFeatures(List features) {
    List types = new ArrayList();
    for (Feature feature : features) {
      if (feature.getType().equals(FeatureType.TYPE_UNKNOWN)) {
        continue;
      }
      types.add(convertFeature(feature));
    }
    return types;
  }

  private Type.Repetition repetitionShape(FixedShape shape) {
    if (shape.getDimCount() == 1 && shape.getDim(0).getSize() == 1) {
      return Type.Repetition.REQUIRED;
    } else {
      return Type.Repetition.REPEATED;
    }
  }

  private Type.Repetition repetitionValueCount(ValueCount valueCount) {
    long min = valueCount.getMin();
    long max = valueCount.getMax();
    if (min == 0 && max == 1) {
      return Type.Repetition.OPTIONAL;
    } else if (min == 1 && max == 1) {
      return Type.Repetition.REQUIRED;
    } else {
      return Type.Repetition.REPEATED;
    }
  }

  private Type.Repetition repetitionValueCounts(ValueCountList valueCounts) {
    if (valueCounts.getValueCountCount() == 1) {
      return repetitionValueCount(valueCounts.getValueCount(0));
    } else {
      return Type.Repetition.REPEATED;
    }
  }

  private Type convertFeature(Feature feature) {
    String name = feature.getName();
    Types.PrimitiveBuilder builder;

    Type.Repetition repetition;
    if (feature.hasShape()) {
      repetition = repetitionShape(feature.getShape());
    } else if (feature.hasValueCount()) {
      repetition = repetitionValueCount(feature.getValueCount());
    } else {
      repetition = repetitionValueCounts(feature.getValueCounts());
    }

    FeatureType type = feature.getType();
    if (type.equals(FeatureType.INT)) {
      builder = Types.primitive(INT64, repetition);
    } else if (type.equals(FeatureType.FLOAT)) {
      builder = Types.primitive(FLOAT, repetition);
    } else if (type.equals(FeatureType.BYTES)) {
      builder = Types.primitive(BINARY, repetition);
    } else {
      throw new UnsupportedOperationException("Cannot convert tensorflow type " + type);
    }
    return builder.named(name);
  }

  public Schema convert(MessageType parquetSchema) {
    return Schema.newBuilder().addAllFeature(convertFields(parquetSchema.getFields())).build();
  }

  private List convertFields(List parquetFields) {
    List features = new ArrayList<>();
    for (Type parquetType : parquetFields) {
      Feature feature = convertField(parquetType);
      features.add(feature);
    }

    return features;
  }

  private ValueCount convertRepetition(Type.Repetition repetition) {
    switch (repetition) {
      case REQUIRED:
        return ValueCount.newBuilder().setMin(1).setMax(1).build();
      case OPTIONAL:
        return ValueCount.newBuilder().setMin(0).setMax(1).build();
      default:
        return null;
    }
  }

  private Feature convertField(final Type parquetType) {
    if (!parquetType.isPrimitive()) {
      throw new UnsupportedOperationException("Only primitive fields are supported");
    } else {
      final String featureName = parquetType.getName();
      final PrimitiveType asPrimitive = parquetType.asPrimitiveType();
      final PrimitiveType.PrimitiveTypeName parquetPrimitiveTypeName =
          asPrimitive.getPrimitiveTypeName();
      final Feature feature =
          parquetPrimitiveTypeName.convert(
              new PrimitiveType.PrimitiveTypeNameConverter() {
                @Override
                public Feature convertINT64(PrimitiveType.PrimitiveTypeName primitiveTypeName) {
                  return Feature.newBuilder().setName(featureName).setType(FeatureType.INT).build();
                }

                @Override
                public Feature convertINT96(PrimitiveType.PrimitiveTypeName primitiveTypeName)
                    throws RuntimeException {
                  throw new UnsupportedOperationException(
                      "Unsupported primitive type: " + primitiveTypeName);
                }

                @Override
                public Feature convertFIXED_LEN_BYTE_ARRAY(
                    PrimitiveType.PrimitiveTypeName primitiveTypeName) throws RuntimeException {
                  throw new UnsupportedOperationException(
                      "Unsupported primitive type: " + primitiveTypeName);
                }

                @Override
                public Feature convertBOOLEAN(PrimitiveType.PrimitiveTypeName primitiveTypeName)
                    throws UnsupportedOperationException {
                  throw new UnsupportedOperationException(
                      "Unsupported primitive type: " + primitiveTypeName);
                }

                @Override
                public Feature convertFLOAT(PrimitiveType.PrimitiveTypeName primitiveTypeName) {
                  return Feature.newBuilder()
                      .setName(featureName)
                      .setType(FeatureType.FLOAT)
                      .build();
                }

                @Override
                public Feature convertDOUBLE(PrimitiveType.PrimitiveTypeName primitiveTypeName)
                    throws RuntimeException {
                  throw new UnsupportedOperationException(
                      "Unsupported primitive type: " + primitiveTypeName);
                }

                @Override
                public Feature convertINT32(PrimitiveType.PrimitiveTypeName primitiveTypeName)
                    throws RuntimeException {
                  throw new UnsupportedOperationException(
                      "Unsupported primitive type: " + primitiveTypeName);
                }

                @Override
                public Feature convertBINARY(PrimitiveType.PrimitiveTypeName primitiveTypeName) {
                  return Feature.newBuilder()
                      .setName(featureName)
                      .setType(FeatureType.BYTES)
                      .build();
                }
              });
      final ValueCount valueCount = convertRepetition(asPrimitive.getRepetition());
      return valueCount == null ? feature : feature.toBuilder().setValueCount(valueCount).build();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy