Github user xccui commented on a diff in the pull request: https://github.com/apache/flink/pull/4894#discussion_r147574828 --- Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSourceUtil.scala --- @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources + +import java.sql.Timestamp + +import com.google.common.collect.ImmutableList +import org.apache.calcite.plan.RelOptCluster +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.logical.LogicalValues +import org.apache.calcite.rex.{RexLiteral, RexNode} +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.table.api.{TableException, Types, ValidationException} +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.expressions.{Cast, ResolvedFieldReference} +import org.apache.flink.table.typeutils.TimeIndicatorTypeInfo + +import scala.collection.JavaConverters._ + +/** Util class for [[TableSource]]. */ +object TableSourceUtil { + + /** Returns true if the [[TableSource]] has a rowtime attribute. */ + def hasRowtimeAttribute(tableSource: TableSource[_]): Boolean = + getRowtimeAttributes(tableSource).nonEmpty + + /** Returns true if the [[TableSource]] has a proctime attribute. */ + def hasProctimeAttribute(tableSource: TableSource[_]): Boolean = + getProctimeAttributes(tableSource).nonEmpty + + /** + * Validates a TableSource. + * + * - checks that all fields of the schema can be resolved + * - checks that resolved fields have the correct type + * - checks that the time attributes are correctly configured. + * + * @param tableSource The [[TableSource]] for which the time attributes are checked. + */ + def validateTableSource(tableSource: TableSource[_]): Unit = { + + val schema = tableSource.getTableSchema + val tableFieldNames = schema.getColumnNames + val tableFieldTypes = schema.getTypes + + // get rowtime and proctime attributes + val rowtimeAttributes = getRowtimeAttributes(tableSource) + val proctimeAttributes = getProctimeAttributes(tableSource) + + // validate that schema fields can be resolved to a return type field of correct type + var mappedFieldCnt = 0 + tableFieldTypes.zip(tableFieldNames).foreach { + case (t: SqlTimeTypeInfo[_], name: String) + if t.getTypeClass == classOf[Timestamp] && proctimeAttributes.contains(name) => + // OK, field was mapped to proctime attribute + case (t: SqlTimeTypeInfo[_], name: String) + if t.getTypeClass == classOf[Timestamp] && rowtimeAttributes.contains(name) => + // OK, field was mapped to rowtime attribute + case (t: TypeInformation[_], name) => + // check if field is registered as time indicator + if (getProctimeAttributes(tableSource).contains(name)) { + throw new ValidationException(s"Processing time field '$name' has invalid type $t. " + + s"Processing time attributes must be of type ${Types.SQL_TIMESTAMP}.") + } + if (getRowtimeAttributes(tableSource).contains(name)) { + throw new ValidationException(s"Rowtime field '$name' has invalid type $t. " + + s"Rowtime attributes must be of type ${Types.SQL_TIMESTAMP}.") + } + // check that field can be resolved in input type + val (physicalName, _, tpe) = resolveInputField(name, tableSource) + // validate that mapped fields are are same type + if (tpe != t) { + throw ValidationException(s"Type $t of table field '$name' does not " + + s"match with type $tpe of the field '$physicalName' of the TableSource return type.") + } + mappedFieldCnt += 1 + } + // ensure that only one field is mapped to an atomic type + if (tableSource.getReturnType.isInstanceOf[AtomicType[_]] && mappedFieldCnt > 1) { + throw ValidationException( + s"More than one table field matched to atomic input type ${tableSource.getReturnType}.") + } + + // validate rowtime attributes + tableSource match { + case r: DefinedRowtimeAttributes => + val descriptors = r.getRowtimeAttributeDescriptors + if (descriptors.size() > 1) { + throw ValidationException("Currently, only a single rowtime attribute is supported. " + + s"Please remove all but one RowtimeAttributeDescriptor.") + } else if (descriptors.size() == 1) { + val descriptor = descriptors.get(0) + val rowtimeAttribute = descriptor.getAttributeName + val rowtimeIdx = schema.getColumnNames.indexOf(rowtimeAttribute) + // ensure that field exists + if (rowtimeIdx < 0) { + throw ValidationException(s"Found a RowtimeAttributeDescriptor for field " + + s"'$rowtimeAttribute' but field '$rowtimeAttribute' does not exist in table.") + } + // ensure that field is of type TIMESTAMP + if (schema.getTypes(rowtimeIdx) != Types.SQL_TIMESTAMP) { + throw ValidationException(s"Found a RowtimeAttributeDescriptor for field " + + s"'$rowtimeAttribute' but field '$rowtimeAttribute' is not of type TIMESTAMP.") + } + // look up extractor input fields in return type + val extractorInputFields = descriptor.getTimestampExtractor.getArgumentFields + val physicalTypes = resolveInputFields(extractorInputFields, tableSource).map(_._3) + // validate timestamp extractor + descriptor.getTimestampExtractor.validateArgumentFields(physicalTypes) + } + case _ => // nothing to validate + } + + // validate proctime attribute + tableSource match { + case p: DefinedProctimeAttribute if p.getProctimeAttribute != null => + val proctimeAttribute = p.getProctimeAttribute + val proctimeIdx = schema.getColumnNames.indexOf(proctimeAttribute) + // ensure that field exists + if (proctimeIdx < 0) { + throw ValidationException(s"Found a RowtimeAttributeDescriptor for field " + + s"'$proctimeAttribute' but field '$proctimeAttribute' does not exist in table.") + } + // ensure that field is of type TIMESTAMP + if (schema.getTypes(proctimeIdx) != Types.SQL_TIMESTAMP) { + throw ValidationException(s"Found a RowtimeAttributeDescriptor for field " + + s"'$proctimeAttribute' but field '$proctimeAttribute' is not of type TIMESTAMP.") + } + case _ => // nothing to validate + } + + // ensure that proctime and rowtime attribute do not overlap + val overlap = getProctimeAttributes(tableSource).intersect(getRowtimeAttributes(tableSource)) + if (overlap.nonEmpty) { + throw new ValidationException(s"Fields ${overlap.mkString("[", ", ", "]")} must not be " + + s"processing time and rowtime attribute at the same time.") + } + } + + /** + * Computes the indices that map the input type of the DataStream to the schema of the table. + * + * The mapping is based on the field names and fails if a table field cannot be + * mapped to a field of the input type. + * + * @param tableSource The table source for which the table schema is mapped to the input type. + * @param isStreamTable True if the mapping is computed for a streaming table, false otherwise. + * @param selectedFields The indexes of the table schema fields for which a mapping is + * computed. If None, a mapping for all fields is computed. + * @return An index mapping from input type to table schema. + */ + def computeIndexMapping( + tableSource: TableSource[_], + isStreamTable: Boolean, + selectedFields: Option[Array[Int]]): Array[Int] = { + val inputType = tableSource.getReturnType + val tableSchema = tableSource.getTableSchema + + // get names of selected fields + val tableFieldNames = if (selectedFields.isDefined) { + val names = tableSchema.getColumnNames + selectedFields.get.map(names(_)) + } else { + tableSchema.getColumnNames + } + + // get types of selected fields + val tableFieldTypes = if (selectedFields.isDefined) { + val types = tableSchema.getTypes + selectedFields.get.map(types(_)) + } else { + tableSchema.getTypes + } + + // get rowtime and proctime attributes + val rowtimeAttributes = getRowtimeAttributes(tableSource) + val proctimeAttributes = getProctimeAttributes(tableSource) + + // compute mapping of selected fields and time attributes + val mapping: Array[Int] = tableFieldTypes.zip(tableFieldNames).map { + case (t: SqlTimeTypeInfo[_], name: String) + if t.getTypeClass == classOf[Timestamp] && proctimeAttributes.contains(name) => + if (isStreamTable) { + TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER + } else { + TimeIndicatorTypeInfo.PROCTIME_BATCH_MARKER + } + case (t: SqlTimeTypeInfo[_], name: String) + if t.getTypeClass == classOf[Timestamp] && rowtimeAttributes.contains(name) => + if (isStreamTable) { + TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER + } else { + TimeIndicatorTypeInfo.ROWTIME_BATCH_MARKER + } + case (t: TypeInformation[_], name) => + // check if field is registered as time indicator + if (getProctimeAttributes(tableSource).contains(name)) { + throw new ValidationException(s"Processing time field '$name' has invalid type $t. " + + s"Processing time attributes must be of type ${Types.SQL_TIMESTAMP}.") + } + if (getRowtimeAttributes(tableSource).contains(name)) { + throw new ValidationException(s"Rowtime field '$name' has invalid type $t. " + + s"Rowtime attributes must be of type ${Types.SQL_TIMESTAMP}.") + } + + val (physicalName, idx, tpe) = resolveInputField(name, tableSource) + // validate that mapped fields are are same type + if (tpe != t) { + throw ValidationException(s"Type $t of table field '$name' does not " + + s"match with type $tpe of the field '$physicalName' of the TableSource return type.") + } + idx + } + + // ensure that only one field is mapped to an atomic type + if (inputType.isInstanceOf[AtomicType[_]] && mapping.count(_ >= 0) > 1) { + throw ValidationException( + s"More than one table field matched to atomic input type $inputType.") + } + + mapping + } + + /** + * Returns the Calcite schema of a [[TableSource]]. + * + * @param tableSource The [[TableSource]] for which the Calcite schema is generated. + * @param selectedFields The indicies of all selected fields. None, if all fields are selected. + * @param streaming Flag to determine whether the schema of a stream or batch table is created. + * @param typeFactory The type factory to create the schema. + * @return The Calcite schema for the selected fields of the given [[TableSource]]. + */ + def getRelDataType( + tableSource: TableSource[_], + selectedFields: Option[Array[Int]], + streaming: Boolean, + typeFactory: FlinkTypeFactory): RelDataType = { + + val fieldNames = tableSource.getTableSchema.getColumnNames + var fieldTypes = tableSource.getTableSchema.getTypes + + if (streaming) { + // adjust the type of time attributes for streaming tables + val rowtimeAttributes = getRowtimeAttributes(tableSource) + val proctimeAttributes = getProctimeAttributes(tableSource) + + // patch rowtime fields with time indicator type + rowtimeAttributes.foreach { rowtimeField => + val idx = fieldNames.indexOf(rowtimeField) + fieldTypes = fieldTypes.patch(idx, Seq(TimeIndicatorTypeInfo.ROWTIME_INDICATOR), 1) + } + // patch proctime field with time indicator type + proctimeAttributes.foreach { proctimeField => + val idx = fieldNames.indexOf(proctimeField) + fieldTypes = fieldTypes.patch(idx, Seq(TimeIndicatorTypeInfo.PROCTIME_INDICATOR), 1) + } + } + + val (selectedFieldNames, selectedFieldTypes) = if (selectedFields.isDefined) { + // filter field names and types by selected fields + (selectedFields.get.map(fieldNames(_)), selectedFields.get.map(fieldTypes(_))) + } else { + (fieldNames, fieldTypes) + } + typeFactory.buildLogicalRowType(selectedFieldNames, selectedFieldTypes) + } + + /** + * Returns the [[RowtimeAttributeDescriptor]] of a [[TableSource]]. + * + * @param tableSource The [[TableSource]] for which the [[RowtimeAttributeDescriptor]] is + * returned. + * @param selectedFields The fields which are selected from the [[TableSource]]. + * If None, all fields are selected. + * @return The [[RowtimeAttributeDescriptor]] of the [[TableSource]]. + */ + def getRowtimeAttributeDescriptor( + tableSource: TableSource[_], + selectedFields: Option[Array[Int]]): Option[RowtimeAttributeDescriptor] = { + + tableSource match { + case r: DefinedRowtimeAttributes => + val descriptors = r.getRowtimeAttributeDescriptors + if (descriptors.size() == 0) { + None + } else if (descriptors.size > 1) { + throw ValidationException("Table with has more than a single rowtime attribute..") + } else { + // exactly one rowtime attribute descriptor + if (selectedFields.isEmpty) { + // all fields are selected. + Some(descriptors.get(0)) + } else { + val descriptor = descriptors.get(0) + // look up index of row time attribute in schema + val fieldIdx = tableSource.getTableSchema.getColumnNames.indexOf( + descriptor.getAttributeName) + // is field among selected fields? + if (selectedFields.get.contains(fieldIdx)) { + Some(descriptor) + } else { + None + } + } + } + case _ => None + } + } + + /** + * Obtains the [[RexNode]] expression to extract the rowtime timestamp for a [[TableSource]]. + * + * @param tableSource The [[TableSource]] for which the expression is extracted. + * @param selectedFields The selected fields of the [[TableSource]]. + * If None, all fields are selected. + * @param cluster The [[RelOptCluster]] of the current optimization process. + * @param relBuilder The [[RelBuilder]] to build the [[RexNode]]. + * @param resultType The result type of the timestamp expression. + * @return The [[RexNode]] expression to extract the timestamp of the table source. + */ + def getRowtimeExtractionExpression( + tableSource: TableSource[_], + selectedFields: Option[Array[Int]], + cluster: RelOptCluster, + relBuilder: RelBuilder, + resultType: TypeInformation[_]): Option[RexNode] = { + + val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] + + /** + * Creates a RelNode with a schema that corresponds on the given fields + * Fields for which no information is available, will have default values. + */ + def createSchemaRelNode(fields: Array[(String, Int, TypeInformation[_])]): RelNode = { + val maxIdx = fields.map(_._2).max + val idxMap: Map[Int, (String, TypeInformation[_])] = Map( + fields.map(f => f._2 -> (f._1, f._3)): _*) + val (physicalFields, physicalTypes) = (0 to maxIdx) + .map(i => idxMap.getOrElse(i, ("", Types.BYTE))).unzip + val physicalSchema: RelDataType = typeFactory.buildLogicalRowType( + physicalFields, + physicalTypes) + LogicalValues.create( + cluster, + physicalSchema, + ImmutableList.of().asInstanceOf[ImmutableList[ImmutableList[RexLiteral]]]) + } + + val rowtimeDesc = getRowtimeAttributeDescriptor(tableSource, selectedFields) + rowtimeDesc.map { r => + val tsExtractor = r.getTimestampExtractor + val resolvedFields = resolveInputFields(tsExtractor.getArgumentFields, tableSource) + + // push an empty values node with the physical schema on the relbuilder + relBuilder.push(createSchemaRelNode(resolvedFields)) + // get extraction expression + val fieldAccesses = resolvedFields.map(f => ResolvedFieldReference(f._1, f._3)) + val expression = tsExtractor.getExpression(fieldAccesses) + // add cast to requested type and convert expression to RexNode + val rexExpression = Cast(expression, resultType).toRexNode(relBuilder) + relBuilder.clear() + rexExpression + } + } + + /** + * Returns the indexes of the physical fields that required to compute the given logical fields. + * + * @param tableSource The [[TableSource]] for which the physical indexes are computed. + * @param logicalFieldIndexes The indexes of the accessed logical fields for which the physical + * indexes are computed. + * @return The indexes of the physical fields are accessed to forward and compute the logical + * fields. + */ + def getPhysicalIndexes( + tableSource: TableSource[_], + logicalFieldIndexes: Array[Int]): Array[Int] = { + + // get the mapping from logical to physical positions. + // stream / batch distinction not important here + val fieldMapping = computeIndexMapping(tableSource, isStreamTable = true, None) + + logicalFieldIndexes + // resolve logical indexes to physical indexes + .map(fieldMapping(_)) + // resolve time indicator markers to physical indexes + .flatMap { + case TimeIndicatorTypeInfo.PROCTIME_STREAM_MARKER => + // proctime field do not access a physical field + Seq() + case TimeIndicatorTypeInfo.ROWTIME_STREAM_MARKER => + // rowtime field is computed. + // get names of fields which are accessed by the expression to compute the rowtime field. + val rowtimeAttributeDescriptor = getRowtimeAttributeDescriptor(tableSource, None) + val accessedFields = if (rowtimeAttributeDescriptor.isDefined) { + rowtimeAttributeDescriptor.get.getTimestampExtractor.getArgumentFields + } else { + throw TableException("Computed field mapping includes a rowtime marker but the " + + "TableSource does not provide a RowtimeAttributeDescriptor. " + + "This is a bug and should be reported.") + } + // resolve field names to physical fields + resolveInputFields(accessedFields, tableSource).map(_._2) + case idx => + Seq(idx) + } + } + + /** Returns a list with all rowtime attribute names of the [[TableSource]]. */ + private def getRowtimeAttributes(tableSource: TableSource[_]): Array[String] = { + tableSource match { + case r: DefinedRowtimeAttributes => + r.getRowtimeAttributeDescriptors.asScala.map(_.getAttributeName).toArray + case _ => + Array() + } + } + + /** Returns a list with all proctime attribute names of the [[TableSource]]. */ + private def getProctimeAttributes(tableSource: TableSource[_]): Array[String] = { --- End diff -- Is it necessary to return an array for the proctime attribute?
---