All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.table.sources.TableSourceUtil.scala Maven / Gradle / Ivy

There is a newer version: 1.5.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.sources

import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.expressions.{Cast, ResolvedFieldReference}
import org.apache.flink.table.types.{DataType, DataTypes, InternalType, RowType}

import com.google.common.collect.ImmutableList
import org.apache.calcite.plan.RelOptCluster
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.logical.LogicalValues
import org.apache.calcite.rex.{RexLiteral, RexNode}
import org.apache.calcite.tools.RelBuilder

import scala.collection.JavaConverters._

/** Util class for [[TableSource]]. */
object TableSourceUtil {

  /** Returns true if the [[TableSource]] has a rowtime attribute. */
  def hasRowtimeAttribute(tableSource: TableSource): Boolean =
    getRowtimeAttributes(tableSource).nonEmpty

  /** Returns true if the [[TableSource]] has a proctime attribute. */
  def hasProctimeAttribute(tableSource: TableSource): Boolean =
    getProctimeAttribute(tableSource).nonEmpty

  /**
    * Validates a TableSource.
    *
    * - checks that all fields of the schema can be resolved
    * - checks that resolved fields have the correct type
    * - checks that the time attributes are correctly configured.
    *
    * @param tableSource The [[TableSource]] for which the time attributes are checked.
    */
  def validateTableSource(tableSource: TableSource): Unit = {

    val schema = tableSource.getTableSchema
    val tableFieldNames = schema.getColumnNames
    val tableFieldTypes = schema.getTypes

    // get rowtime and proctime attributes
    val rowtimeAttributes = getRowtimeAttributes(tableSource)
    val proctimeAttribute = getProctimeAttribute(tableSource)

    // validate that schema fields can be resolved to a return type field of correct type
    var mappedFieldCnt = 0
    tableFieldTypes.zip(tableFieldNames).foreach {
      case (DataTypes.TIMESTAMP, name: String)
        if proctimeAttribute.contains(name) =>
        // OK, field was mapped to proctime attribute
      case (DataTypes.TIMESTAMP, name: String)
        if rowtimeAttributes.contains(name) =>
        // OK, field was mapped to rowtime attribute
      case (t: InternalType, name) =>
        // check if field is registered as time indicator
        if (getProctimeAttribute(tableSource).contains(name)) {
          throw new ValidationException(s"Processing time field '$name' has invalid type $t. " +
            s"Processing time attributes must be of type ${DataTypes.TIMESTAMP}.")
        }
        if (getRowtimeAttributes(tableSource).contains(name)) {
          throw new ValidationException(s"Rowtime field '$name' has invalid type $t. " +
            s"Rowtime attributes must be of type ${DataTypes.TIMESTAMP}.")
        }
        // check that field can be resolved in input type
        val (physicalName, _, tpe) = resolveInputField(name, tableSource)
        // validate that mapped fields are are same type
        if (tpe != t) {
          throw new ValidationException(s"Type $t of table field '$name' does not " +
            s"match with type $tpe of the field '$physicalName' of the TableSource return type.")
        }
        mappedFieldCnt += 1
    }
    // ensure that only one field is mapped to an atomic type
    if (!tableSource.getReturnType.toInternalType.isInstanceOf[RowType]
        && mappedFieldCnt > 1) {
      throw new ValidationException(
        s"More than one table field matched to atomic input type ${tableSource.getReturnType}.")
    }

    // validate rowtime attributes
    tableSource match {
      case r: DefinedRowtimeAttributes =>
        val descriptors = r.getRowtimeAttributeDescriptors
        if (descriptors.size() > 1) {
          throw new ValidationException(
            "Currently, only a single rowtime attribute is supported. " +
            s"Please remove all but one RowtimeAttributeDescriptor.")
        } else if (descriptors.size() == 1) {
          val descriptor = descriptors.get(0)
          val rowtimeAttribute = descriptor.getAttributeName
          val rowtimeIdx = schema.getColumnNames.indexOf(rowtimeAttribute)
          // ensure that field exists
          if (rowtimeIdx < 0) {
            throw new ValidationException(s"Found a RowtimeAttributeDescriptor for field " +
              s"'$rowtimeAttribute' but field '$rowtimeAttribute' does not exist in table.")
          }
          // ensure that field is of type TIMESTAMP
          if (schema.getFieldType(rowtimeIdx).get() != DataTypes.TIMESTAMP) {
            throw new ValidationException(s"Found a RowtimeAttributeDescriptor for field " +
              s"'$rowtimeAttribute' but field '$rowtimeAttribute' is not of type TIMESTAMP.")
          }
          // look up extractor input fields in return type
          val extractorInputFields = descriptor.getTimestampExtractor.getArgumentFields
          val physicalTypes = resolveInputFields(extractorInputFields, tableSource).map(_._3)
          // validate timestamp extractor
          descriptor.getTimestampExtractor.validateArgumentFields(physicalTypes)
        }
      case _ => // nothing to validate
    }

    // validate proctime attribute
    tableSource match {
      case p: DefinedProctimeAttribute if p.getProctimeAttribute != null =>
        val proctimeAttribute = p.getProctimeAttribute
        val proctimeIdx = schema.getColumnNames.indexOf(proctimeAttribute)
        // ensure that field exists
        if (proctimeIdx < 0) {
          throw new ValidationException(s"Found a ProctimeAttribute for field " +
            s"'$proctimeAttribute' but field '$proctimeAttribute' does not exist in table.")
        }
        // ensure that field is of type TIMESTAMP
        if (schema.getFieldType(proctimeIdx).get() != DataTypes.TIMESTAMP) {
          throw new ValidationException(s"Found a ProctimeAttribute for field " +
            s"'$proctimeAttribute' but field '$proctimeAttribute' is not of type TIMESTAMP.")
        }
      case _ => // nothing to validate
    }

    // ensure that proctime and rowtime attribute do not overlap
    if (proctimeAttribute.isDefined && rowtimeAttributes.contains(proctimeAttribute.get)) {
      throw new ValidationException(s"Field '${proctimeAttribute.get}' must not be " +
        s"processing time and rowtime attribute at the same time.")
    }
  }

  /**
    * Computes the indices that map the input type of the DataStream to the schema of the table.
    *
    * The mapping is based on the field names and fails if a table field cannot be
    * mapped to a field of the input type.
    *
    * @param tableSource The table source for which the table schema is mapped to the input type.
    * @param isStreamTable True if the mapping is computed for a streaming table, false otherwise.
    * @param selectedFields The indexes of the table schema fields for which a mapping is
    *                       computed. If None, a mapping for all fields is computed.
    * @return An index mapping from input type to table schema.
    */
  def computeIndexMapping(
      tableSource: TableSource,
      isStreamTable: Boolean,
      selectedFields: Option[Array[Int]]): Array[Int] = {
    val inputType = tableSource.getReturnType.toInternalType
    val tableSchema = tableSource.getTableSchema

    // get names of selected fields
    val tableFieldNames = if (selectedFields.isDefined) {
      val names = tableSchema.getColumnNames
      selectedFields.get.map(names(_))
    } else {
      tableSchema.getColumnNames
    }

    // get types of selected fields
    val tableFieldTypes = if (selectedFields.isDefined) {
      val types = tableSchema.getTypes
      selectedFields.get.map(types(_))
    } else {
      tableSchema.getTypes
    }

    // get rowtime and proctime attributes
    val rowtimeAttributes = getRowtimeAttributes(tableSource)
    val proctimeAttributes = getProctimeAttribute(tableSource)

    // compute mapping of selected fields and time attributes
    val mapping: Array[Int] = tableFieldTypes.zip(tableFieldNames).map {
      case (DataTypes.TIMESTAMP, name: String)
        if proctimeAttributes.contains(name) =>
        if (isStreamTable) {
          DataTypes.PROCTIME_STREAM_MARKER
        } else {
          DataTypes.PROCTIME_BATCH_MARKER
        }
      case (DataTypes.TIMESTAMP, name: String)
        if rowtimeAttributes.contains(name) =>
        if (isStreamTable) {
          DataTypes.ROWTIME_STREAM_MARKER
        } else {
          DataTypes.ROWTIME_BATCH_MARKER
        }
      case (t: InternalType, name) =>
        // check if field is registered as time indicator
        if (getProctimeAttribute(tableSource).contains(name)) {
          throw new ValidationException(s"Processing time field '$name' has invalid type $t. " +
            s"Processing time attributes must be of type ${DataTypes.TIMESTAMP}.")
        }
        if (getRowtimeAttributes(tableSource).contains(name)) {
          throw new ValidationException(s"Rowtime field '$name' has invalid type $t. " +
            s"Rowtime attributes must be of type ${DataTypes.TIMESTAMP}.")
        }

        val (physicalName, idx, tpe) = resolveInputField(name, tableSource)
        // validate that mapped fields are are same type
        if (tpe != t) {
          throw new ValidationException(s"Type $t of table field '$name' does not " +
            s"match with type $tpe of the field '$physicalName' of the TableSource return type.")
        }
        idx
    }

    // ensure that only one field is mapped to an atomic type
    if (!inputType.isInstanceOf[RowType] && mapping.count(_ >= 0) > 1) {
      throw new ValidationException(
        s"More than one table field matched to atomic input type $inputType.")
    }

    mapping
  }

  /**
    * Returns the Calcite schema of a [[TableSource]].
    *
    * @param tableSource The [[TableSource]] for which the Calcite schema is generated.
    * @param selectedFields The indices of all selected fields. None, if all fields are selected.
    * @param streaming Flag to determine whether the schema of a stream or batch table is created.
    * @param typeFactory The type factory to create the schema.
    * @return The Calcite schema for the selected fields of the given [[TableSource]].
    */
  def getRelDataType(
      tableSource: TableSource,
      selectedFields: Option[Array[Int]],
      streaming: Boolean,
      typeFactory: FlinkTypeFactory): RelDataType = {

    val fieldNames = tableSource.getTableSchema.getFieldNames
    var fieldTypes = tableSource.getTableSchema.getFieldTypes
    var fieldNullables = tableSource.getTableSchema.getFieldNullables

    if (streaming) {
      // adjust the type of time attributes for streaming tables
      val rowtimeAttributes = getRowtimeAttributes(tableSource)
      val proctimeAttributes = getProctimeAttribute(tableSource)

      // patch rowtime fields with time indicator type
      rowtimeAttributes.foreach { rowtimeField =>
        val idx = fieldNames.indexOf(rowtimeField)
        fieldTypes = fieldTypes.patch(idx, Seq(DataTypes.ROWTIME_INDICATOR), 1)
        fieldNullables = fieldNullables.patch(idx, Seq(false), 1)
      }
      // patch proctime field with time indicator type
      proctimeAttributes.foreach { proctimeField =>
        val idx = fieldNames.indexOf(proctimeField)
        fieldTypes = fieldTypes.patch(idx, Seq(DataTypes.PROCTIME_INDICATOR), 1)
        fieldNullables = fieldNullables.patch(idx, Seq(false), 1)
      }
    }

    val (selectedFieldNames, selectedFieldTypes, selectedFieldNullables) =
      if (selectedFields.isDefined) {
        // filter field names and types by selected fields
        (
            selectedFields.get.map(fieldNames(_)),
            selectedFields.get.map(fieldTypes(_)),
            selectedFields.get.map(fieldNullables(_)))
      } else {
        (fieldNames, fieldTypes,  fieldNullables)
      }
    typeFactory.buildRelDataType(selectedFieldNames, selectedFieldTypes, selectedFieldNullables)
  }

  /**
    * Returns the [[RowtimeAttributeDescriptor]] of a [[TableSource]].
    *
    * @param tableSource The [[TableSource]] for which the [[RowtimeAttributeDescriptor]] is
    *                    returned.
    * @param selectedFields The fields which are selected from the [[TableSource]].
    *                       If None, all fields are selected.
    * @return The [[RowtimeAttributeDescriptor]] of the [[TableSource]].
    */
  def getRowtimeAttributeDescriptor(
      tableSource: TableSource,
      selectedFields: Option[Array[Int]]): Option[RowtimeAttributeDescriptor] = {

    tableSource match {
      case r: DefinedRowtimeAttributes =>
        val descriptors = r.getRowtimeAttributeDescriptors
        if (descriptors.size() == 0) {
          None
        } else if (descriptors.size > 1) {
          throw new ValidationException("Table with has more than a single rowtime attribute..")
        } else {
          // exactly one rowtime attribute descriptor
          if (selectedFields.isEmpty) {
            // all fields are selected.
            Some(descriptors.get(0))
          } else {
            val descriptor = descriptors.get(0)
            // look up index of row time attribute in schema
            val fieldIdx = tableSource.getTableSchema.getColumnNames.indexOf(
              descriptor.getAttributeName)
            // is field among selected fields?
            if (selectedFields.get.contains(fieldIdx)) {
              Some(descriptor)
            } else {
              None
            }
          }
        }
      case _ => None
    }
  }

  /**
    * Obtains the [[RexNode]] expression to extract the rowtime timestamp for a [[TableSource]].
    *
    * @param tableSource The [[TableSource]] for which the expression is extracted.
    * @param selectedFields The selected fields of the [[TableSource]].
    *                       If None, all fields are selected.
    * @param cluster The [[RelOptCluster]] of the current optimization process.
    * @param relBuilder The [[RelBuilder]] to build the [[RexNode]].
    * @param resultType The result type of the timestamp expression.
    * @return The [[RexNode]] expression to extract the timestamp of the table source.
    */
  def getRowtimeExtractionExpression(
      tableSource: TableSource,
      selectedFields: Option[Array[Int]],
      cluster: RelOptCluster,
      relBuilder: RelBuilder,
      resultType: InternalType): Option[RexNode] = {

    val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]

    /**
      * Creates a RelNode with a schema that corresponds on the given fields
      * Fields for which no information is available, will have default values.
      */
    def createSchemaRelNode(fields: Array[(String, Int, InternalType)]): RelNode = {
      val maxIdx = fields.map(_._2).max
      val idxMap: Map[Int, (String, InternalType)] = Map(
        fields.map(f => f._2 -> (f._1, f._3)): _*)
      val (physicalFields, physicalTypes) = (0 to maxIdx)
        .map(i => idxMap.getOrElse(i, ("", DataTypes.BYTE))).unzip
      val physicalSchema: RelDataType = typeFactory.buildRelDataType(
        physicalFields,
        physicalTypes)
      LogicalValues.create(
        cluster,
        physicalSchema,
        ImmutableList.of().asInstanceOf[ImmutableList[ImmutableList[RexLiteral]]])
    }

    val rowtimeDesc = getRowtimeAttributeDescriptor(tableSource, selectedFields)
    rowtimeDesc.map { r =>
      val tsExtractor = r.getTimestampExtractor

      val fieldAccesses = if (tsExtractor.getArgumentFields.nonEmpty) {
        val resolvedFields = resolveInputFields(tsExtractor.getArgumentFields, tableSource)
        // push an empty values node with the physical schema on the relbuilder
        relBuilder.push(createSchemaRelNode(resolvedFields))
        // get extraction expression
        resolvedFields.map(f => ResolvedFieldReference(f._1, f._3))
      } else {
        new Array[ResolvedFieldReference](0)
      }

      val expression = tsExtractor.getExpression(fieldAccesses)
      // add cast to requested type and convert expression to RexNode
      val rexExpression = Cast(expression, resultType).toRexNode(relBuilder)
      relBuilder.clear()
      rexExpression
    }
  }

  /**
    * Returns the indexes of the physical fields that required to compute the given logical fields.
    *
    * @param tableSource The [[TableSource]] for which the physical indexes are computed.
    * @param logicalFieldIndexes The indexes of the accessed logical fields for which the physical
    *                            indexes are computed.
    * @return The indexes of the physical fields are accessed to forward and compute the logical
    *         fields.
    */
  def getPhysicalIndexes(
      tableSource: TableSource,
      logicalFieldIndexes: Array[Int]): Array[Int] = {

    // get the mapping from logical to physical positions.
    // stream / batch distinction not important here
    val fieldMapping = computeIndexMapping(tableSource, isStreamTable = true, None)

    logicalFieldIndexes
      // resolve logical indexes to physical indexes
      .map(fieldMapping(_))
      // resolve time indicator markers to physical indexes
      .flatMap {
      case DataTypes.PROCTIME_STREAM_MARKER =>
        // proctime field do not access a physical field
        Seq()
      case DataTypes.ROWTIME_STREAM_MARKER =>
        // rowtime field is computed.
        // get names of fields which are accessed by the expression to compute the rowtime field.
        val rowtimeAttributeDescriptor = getRowtimeAttributeDescriptor(tableSource, None)
        val accessedFields = if (rowtimeAttributeDescriptor.isDefined) {
          rowtimeAttributeDescriptor.get.getTimestampExtractor.getArgumentFields
        } else {
          throw new TableException("Computed field mapping includes a rowtime marker but the " +
            "TableSource does not provide a RowtimeAttributeDescriptor. " +
            "This is a bug and should be reported.")
        }
        // resolve field names to physical fields
        resolveInputFields(accessedFields, tableSource).map(_._2)
      case idx =>
        Seq(idx)
    }
  }

  /** Returns a list with all rowtime attribute names of the [[TableSource]]. */
  private def getRowtimeAttributes(tableSource: TableSource): Array[String] = {
    tableSource match {
      case r: DefinedRowtimeAttributes =>
        r.getRowtimeAttributeDescriptors.asScala.map(_.getAttributeName).toArray
      case _ =>
        Array()
    }
  }

  /** Returns the proctime attribute of the [[TableSource]] if it is defined. */
  private def getProctimeAttribute(tableSource: TableSource): Option[String] = {
    tableSource match {
      case p: DefinedProctimeAttribute if p.getProctimeAttribute != null =>
        Some(p.getProctimeAttribute)
      case _ =>
        None
    }
  }

  /**
    * Identifies for a field name of the logical schema, the corresponding physical field in the
    * return type of a [[TableSource]].
    *
    * @param fieldName The logical field to look up.
    * @param tableSource The table source in which to look for the field.
    * @return The name, index, and type information of the physical field.
    */
  private def resolveInputField(
      fieldName: String,
      tableSource: TableSource): (String, Int, InternalType) = {

    val returnType = tableSource.getReturnType

    /** Look up a field by name in a type information */
    def lookupField(fieldName: String, failMsg: String): (String, Int, InternalType) = {
      returnType.toInternalType match {

        case c: RowType =>
          // get and check field index
          val idx = c.getFieldIndex(fieldName)
          if (idx < 0) {
            throw new ValidationException(failMsg)
          }
          // return field name, index, and field type
          (fieldName, idx, c.getInternalTypeAt(idx).toInternalType)

        case t: InternalType =>
          // no composite type, we return the full atomic type as field
          (fieldName, 0, t)
      }
    }

    tableSource match {
      case d: DefinedFieldMapping if d.getFieldMapping != null =>
        // resolve field name in field mapping
        val resolvedFieldName = d.getFieldMapping.get(fieldName)
        if (resolvedFieldName == null) {
          throw new ValidationException(
            s"Field '$fieldName' could not be resolved by the field mapping.")
        }
        // look up resolved field in return type
        lookupField(
          resolvedFieldName,
          s"Table field '$fieldName' was resolved to TableSource return type field " +
            s"'$resolvedFieldName', but field '$resolvedFieldName' was not found in the return " +
            s"type $returnType of the TableSource. " +
            s"Please verify the field mapping of the TableSource.")
      case _ =>
        // look up field in return type
        lookupField(
          fieldName,
          s"Table field '$fieldName' was not found in the return type $returnType of the " +
            s"TableSource.")
    }
  }

  /**
    * Identifies the physical fields in the return type [[DataType]] of a [[TableSource]]
    * for a list of field names of the [[TableSource]]'s [[org.apache.flink.table.api.TableSchema]].
    *
    * @param fieldNames The field names to look up.
    * @param tableSource The table source in which to look for the field.
    * @return The name, index, and type information of the physical field.
    */
  private def resolveInputFields(
      fieldNames: Array[String],
      tableSource: TableSource): Array[(String, Int, InternalType)] = {
    fieldNames.map(resolveInputField(_, tableSource))
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy