aokolnychyi commented on code in PR #50761: URL: https://github.com/apache/spark/pull/50761#discussion_r2070887598
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraint.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case v2Write: V2WriteCommand + if v2Write.table.resolved && v2Write.query.resolved && + !v2Write.query.isInstanceOf[Validate] && v2Write.outputResolved => + v2Write.table match { + case r: DataSourceV2Relation + if r.table.constraints() != null && r.table.constraints().nonEmpty => + val checks = r.table.constraints().collect { + case c: Check => c + } + val checkInvariants = checks.map { c => + val parsed = + catalogManager.v1SessionCatalog.parser.parseExpression(c.predicateSql()) + val columnExtractors = mutable.Map[String, Expression]() + buildColumnExtractors(parsed, columnExtractors) + CheckInvariant(parsed, columnExtractors.toSeq, c.name(), c.predicateSql()) + }.toSeq + v2Write.withNewQuery(Validate(checkInvariants, v2Write.query)) + case _ => + v2Write + } + } + + private def buildColumnExtractors( + expr: Expression, + columnExtractors: mutable.Map[String, Expression]): Unit = { + expr match { + case u: UnresolvedExtractValue => + // When extracting a value from a Map or Array type, we display only the specific extracted + // value rather than the entire Map or Array structure for clarity and readability. + columnExtractors(u.sql) = u + case u: UnresolvedAttribute => + columnExtractors(u.name) = u + Review Comment: Minor: Is the extra empty line intentional? ########## common/utils/src/main/resources/error/error-conditions.json: ########## @@ -544,6 +544,14 @@ ], "sqlState" : "56000" }, + "CHECK_CONSTRAINT_VIOLATION" : { + "message" : [ + "CHECK constraint <constraintName> <expression> violated by row with values:", + "<values>", + "" Review Comment: Do we need `""` here? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraint.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case v2Write: V2WriteCommand + if v2Write.table.resolved && v2Write.query.resolved && + !v2Write.query.isInstanceOf[Validate] && v2Write.outputResolved => + v2Write.table match { + case r: DataSourceV2Relation + if r.table.constraints() != null && r.table.constraints().nonEmpty => Review Comment: Minor: I usually prefer omitting `()` for such getters. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraint.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case v2Write: V2WriteCommand + if v2Write.table.resolved && v2Write.query.resolved && + !v2Write.query.isInstanceOf[Validate] && v2Write.outputResolved => + v2Write.table match { + case r: DataSourceV2Relation + if r.table.constraints() != null && r.table.constraints().nonEmpty => + val checks = r.table.constraints().collect { Review Comment: Do we reject others? What if we get back a PK that must be enforced? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraint.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case v2Write: V2WriteCommand + if v2Write.table.resolved && v2Write.query.resolved && + !v2Write.query.isInstanceOf[Validate] && v2Write.outputResolved => + v2Write.table match { + case r: DataSourceV2Relation + if r.table.constraints() != null && r.table.constraints().nonEmpty => + val checks = r.table.constraints().collect { + case c: Check => c + } + val checkInvariants = checks.map { c => + val parsed = Review Comment: Providing SQL is optional. The CHECK constraint can also supply an expression. If so, we have to first try to convert it to Catalyst using `V2ExpressionUtils`. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraint.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case v2Write: V2WriteCommand + if v2Write.table.resolved && v2Write.query.resolved && + !v2Write.query.isInstanceOf[Validate] && v2Write.outputResolved => + v2Write.table match { + case r: DataSourceV2Relation + if r.table.constraints() != null && r.table.constraints().nonEmpty => + val checks = r.table.constraints().collect { + case c: Check => c + } + val checkInvariants = checks.map { c => + val parsed = + catalogManager.v1SessionCatalog.parser.parseExpression(c.predicateSql()) + val columnExtractors = mutable.Map[String, Expression]() Review Comment: Why init outside of the method and make the method return nothing instead of returning the map from the method? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala: ########## @@ -1664,3 +1664,16 @@ case class Call( override protected def withNewChildInternal(newChild: LogicalPlan): Call = copy(procedure = newChild) } + +case class Validate( + conditions: Seq[CheckInvariant], + child: LogicalPlan) extends UnaryNode { + + assert(conditions.nonEmpty, "CheckData must have at least one condition") Review Comment: Should this be Validate or CheckData? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraint.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{CheckInvariant, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand, Validate} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.constraints.Check +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +class ResolveTableConstraint(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + case v2Write: V2WriteCommand Review Comment: Can we mare sure we have tests for DELETE, UPDATE, and MERGE too? There are two flavors: - `DeltaBasedXXXSuites` - `GroupBasedXXXSuites` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala: ########## @@ -259,3 +263,94 @@ case class ForeignKeyConstraint( copy(userProvidedCharacteristic = c) } } + +/** + * An expression that validates a specific invariant on a column, before writing into table. + * + * @param child The fully resolved expression to be evaluated to check the constraint. + * @param columnExtractors Extractors for each referenced column. Used to generate readable errors. + * @param constraintName The name of the constraint. + * @param predicateSql The SQL representation of the constraint. + */ +case class CheckInvariant( + child: Expression, + columnExtractors: Seq[(String, Expression)], + constraintName: String, + predicateSql: String) + extends Expression with NonSQLExpression { + + override def children: Seq[Expression] = child +: columnExtractors.map(_._2) + override def dataType: DataType = NullType + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result == false) { + val values = columnExtractors.map { + case (column, extractor) => column -> extractor.eval(input) + }.toMap + throw QueryExecutionErrors.checkViolation(constraintName, predicateSql, values) + } + null + } + + /** + * Generate the code to extract values for the columns referenced in a violated CHECK constraint. + * We build parallel lists of full column names and their extracted values in the row which + * violates the constraint, to be passed to the [[InvariantViolationException]] constructor + * in [[generateExpressionValidationCode()]]. + * + * Note that this code is a bit expensive, so it shouldn't be run until we already + * know the constraint has been violated. + */ + private def generateColumnValuesCode( Review Comment: Do we have tests with codegen on/off? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org