huaxingao commented on code in PR #49493: URL: https://github.com/apache/spark/pull/49493#discussion_r1988353194
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala: ########## @@ -132,52 +164,119 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute], originalRowIdValues: Seq[Expression]): Seq[Expression] = { val rowValues = assignments.map(_.value) - Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataAttrs ++ originalRowIdValues + val metadataValues = nullifyMetadataOnUpdate(metadataAttrs) + Seq(Literal(UPDATE_OPERATION)) ++ rowValues ++ metadataValues ++ originalRowIdValues + } + + protected def deltaReinsertOutput( + assignments: Seq[Assignment], + metadataAttrs: Seq[Attribute], + originalRowIdValues: Seq[Expression] = Seq.empty): Seq[Expression] = { + val rowValues = assignments.map(_.value) + val metadataValues = nullifyMetadataOnReinsert(metadataAttrs) + val extraNullValues = originalRowIdValues.map(e => Literal(null, e.dataType)) + Seq(Literal(REINSERT_OPERATION)) ++ rowValues ++ metadataValues ++ extraNullValues + } + + protected def addOperationColumn(operation: Int, plan: LogicalPlan): LogicalPlan = { + val operationType = Alias(Literal(operation, IntegerType), OPERATION_COLUMN)() + Project(operationType +: plan.output, plan) + } + + protected def buildReplaceDataProjections( + plan: LogicalPlan, + rowAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute]): ReplaceDataProjections = { + val outputs = extractOutputs(plan) + + val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION)) + val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION)) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) + } else { + None + } + + ReplaceDataProjections(rowProjection, metadataProjection) } protected def buildWriteDeltaProjections( plan: LogicalPlan, rowAttrs: Seq[Attribute], rowIdAttrs: Seq[Attribute], metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { + val outputs = extractOutputs(plan) val rowProjection = if (rowAttrs.nonEmpty) { - Some(newLazyProjection(plan, rowAttrs)) + val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW) + Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) } else { None } - val rowIdProjection = newLazyRowIdProjection(plan, rowIdAttrs) + val outputsWithRowId = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW_ID) + val rowIdProjection = newLazyRowIdProjection(plan, outputsWithRowId, rowIdAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { - Some(newLazyProjection(plan, metadataAttrs)) + val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None } WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection) } + private def extractOutputs(plan: LogicalPlan): Seq[Seq[Expression]] = { + plan match { + case p: Project => Seq(p.projectList) + case e: Expand => e.projections + case m: MergeRows => m.outputs + case _ => throw SparkException.internalError("Can't extract outputs from plan: " + plan) + } + } + + private def filterOutputs( + outputs: Seq[Seq[Expression]], + operations: Set[Int]): Seq[Seq[Expression]] = { + outputs.filter { + case Literal(operation: Integer, _) +: _ => operations.contains(operation) + case Alias(Literal(operation: Integer, _), _) +: _ => operations.contains(operation) + case other => throw SparkException.internalError("Can't determine operation: " + other) + } + } + private def newLazyProjection( plan: LogicalPlan, + outputs: Seq[Seq[Expression]], attrs: Seq[Attribute]): ProjectingInternalRow = { - val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name)) - val schema = DataTypeUtils.fromAttributes(attrs) - ProjectingInternalRow(schema, colOrdinals) + createProjectingInternalRow(outputs, colOrdinals, attrs) } // if there are assignment to row ID attributes, original values are projected as special columns // this method honors such special columns if present private def newLazyRowIdProjection( plan: LogicalPlan, + outputs: Seq[Seq[Expression]], rowIdAttrs: Seq[Attribute]): ProjectingInternalRow = { - val colOrdinals = rowIdAttrs.map { attr => val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name) if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name) } - val schema = DataTypeUtils.fromAttributes(rowIdAttrs) + createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs) + } + + private def createProjectingInternalRow( + outputs: Seq[Seq[Expression]], + colOrdinals: Seq[Int], + attrs: Seq[Attribute]): ProjectingInternalRow = { + val schema = StructType(attrs.zipWithIndex.map { case (attr, index) => + val nullable = outputs.exists(output => output(colOrdinals(index)).nullable) Review Comment: @aokolnychyi Do we only need this for metadata columns? For regular columns, shall we use `attr.nullable` instead? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org