nebojsa-db commented on code in PR #47331: URL: https://github.com/apache/spark/pull/47331#discussion_r1679332626
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala: ########## @@ -892,132 +892,108 @@ case class MapFromEntries(child: Expression) copy(child = newChild) } +// Sorts all MapType expressions based on the ordering of their keys. +// This is used when GROUP BY is done with a MapType (possibly nested) column. case class MapSort(base: Expression) - extends UnaryExpression with NullIntolerant with QueryErrorsBase { + extends UnaryExpression with NullIntolerant with QueryErrorsBase with CodegenFallback { - val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType - val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType + override lazy val canonicalized: Expression = base.canonicalized + + override lazy val deterministic: Boolean = base.deterministic override def child: Expression = base override def dataType: DataType = base.dataType - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case m: MapType if RowOrdering.isOrderable(m.keyType) => - TypeCheckResult.TypeCheckSuccess + def recursiveCheckDataTypes(dataType: DataType): TypeCheckResult = dataType match { + case a: ArrayType => recursiveCheckDataTypes(a.elementType) + case StructType(fields) => + fields.collect(sf => recursiveCheckDataTypes(sf.dataType)).filter(_.isFailure).headOption + .getOrElse(TypeCheckResult.TypeCheckSuccess) + case m: MapType if RowOrdering.isOrderable(m.keyType) => TypeCheckResult.TypeCheckSuccess case _: MapType => DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", messageParameters = Map( "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(base.dataType) + "dataType" -> toSQLType(dataType) ) ) - case _ => - DataTypeMismatch( + case _ => TypeCheckResult.TypeCheckSuccess + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (!dataType.existsRecursively(_.isInstanceOf[MapType])) { + return DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(0), "requiredType" -> toSQLType(MapType), "inputSql" -> toSQLExpr(base), "inputType" -> toSQLType(base.dataType)) ) - } - - override def nullSafeEval(array: Any): Any = { - // put keys and their respective values inside a tuple and sort them - // according to the key ordering. Extract the new sorted k/v pairs to form a sorted map - - val mapData = array.asInstanceOf[MapData] - val numElements = mapData.numElements() - val keys = mapData.keyArray() - val values = mapData.valueArray() - - val ordering = PhysicalDataType.ordering(keyType) - - val sortedMap = Array - .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any], - values.get(i, valueType).asInstanceOf[Any])) - .sortBy(_._1)(ordering) - - new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)), - new GenericArrayData(sortedMap.map(_._2))) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, b => sortCodegen(ctx, ev, b)) - } - - private def sortCodegen(ctx: CodegenContext, ev: ExprCode, - base: String): String = { - - val arrayBasedMapData = classOf[ArrayBasedMapData].getName - val genericArrayData = classOf[GenericArrayData].getName - - val numElements = ctx.freshName("numElements") - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - val sortArray = ctx.freshName("sortArray") - val i = ctx.freshName("i") - val o1 = ctx.freshName("o1") - val o1entry = ctx.freshName("o1entry") - val o2 = ctx.freshName("o2") - val o2entry = ctx.freshName("o2entry") - val c = ctx.freshName("c") - val newKeys = ctx.freshName("newKeys") - val newValues = ctx.freshName("newValues") - - val boxedKeyType = CodeGenerator.boxedType(keyType) - val boxedValueType = CodeGenerator.boxedType(valueType) - val javaKeyType = CodeGenerator.javaType(keyType) + } - val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType, $boxedValueType>" + if (dataType.existsRecursively(dt => + dt.isInstanceOf[MapType] && !RowOrdering.isOrderable(dt.asInstanceOf[MapType].keyType))) { + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> toSQLType(dataType) + ) + ) + } - val comp = if (CodeGenerator.isPrimitiveType(keyType)) { - val v1 = ctx.freshName("v1") - val v2 = ctx.freshName("v2") - s""" - |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value(); - |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value(); - |int $c = ${ctx.genComp(keyType, v1, v2)}; - """.stripMargin - } else { - s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)", s"(($javaKeyType) $o2)")};" + TypeCheckResult.TypeCheckSuccess + } + + // Evaluates the expression recursively by taking into + // account complex types and nesting + def nullSafeEvalRecursive(input: Any, dataType: DataType): Any = { + + dataType match { + // For ArrayType recursively call evaluate for + // all its children since MapType can be nested + // as array element + case ArrayType(elementType, _) => Review Comment: We are working here with input data, but ArrayTransform is expecting expression as an input, is this doable? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org