sririshindra commented on code in PR #48252: URL: https://github.com/apache/spark/pull/48252#discussion_r1844241460
########## sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala: ########## @@ -148,34 +163,180 @@ object JavaTypeInference { // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(c) - // add type variables from inheritance hierarchy of the class - val classTV = JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap ++ - typeVariables - // Note that the fields are ordered by name. - val fields = properties.map { property => - val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, classTV) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) - EncoderField( - property.getName, - encoder, - encoder.nullable && !hasNonNull, - Metadata.empty, - Option(readMethod.getName), - Option(property.getWriteMethod).map(_.getName)) + + // if the properties is empty and this is not a top level enclosing class, then we + // should not consider class as bean, as otherwise it will be treated as empty schema + // and loose the data on deser. + if (properties.isEmpty && seenTypeSet.nonEmpty) { + findBestEncoder(Seq(c), seenTypeSet, typeVariables, None, serializableEncodersOnly = true) + .getOrElse(throw ExecutionErrors.cannotFindEncoderForTypeError(t.getTypeName)) + } else { + // add type variables from inheritance hierarchy of the class + val parentClassesTypeMap = + JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap + val classTV = parentClassesTypeMap ++ typeVariables + // Note that the fields are ordered by name. + val fields = properties.map { property => + val readMethod = property.getReadMethod + val methodReturnType = readMethod.getGenericReturnType + val encoder = encoderFor(methodReturnType, seenTypeSet + c, classTV) + // The existence of `javax.annotation.Nonnull`, means this field is not nullable. + val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) + EncoderField( + property.getName, + encoder, + encoder.nullable && !hasNonNull, + Metadata.empty, + Option(readMethod.getName), + Option(property.getWriteMethod).map(_.getName)) + } + // implies it cannot be assumed a BeanClass. + // Check if its super class or interface could be represented by an Encoder + + JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq) } - JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq) case _ => throw ExecutionErrors.cannotFindEncoderForTypeError(t.toString) } + private def createUDTEncoderUsingAnnotation(c: Class[_]): UDTEncoder[Any] = { + val udt = c + .getAnnotation(classOf[SQLUserDefinedType]) + .udt() + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[Any]] + val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() + UDTEncoder(udt, udtClass) + } + + private def createUDTEncoderUsingRegistration(c: Class[_]): UDTEncoder[Any] = { + val udt = UDTRegistration + .getUDTFor(c.getName) + .get + .getConstructor() + .newInstance() + .asInstanceOf[UserDefinedType[Any]] + UDTEncoder(udt, udt.getClass) + } + def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { val beanInfo = Introspector.getBeanInfo(beanClass) beanInfo.getPropertyDescriptors .filterNot(_.getName == "class") .filterNot(_.getName == "declaringClass") .filter(_.getReadMethod != null) } + + private def findBestEncoder( + typesToCheck: Seq[Class[_]], + seenTypeSet: Set[Class[_]], + typeVariables: Map[TypeVariable[_], Type], + baseClass: Option[Class[_]], + serializableEncodersOnly: Boolean = false): Option[AgnosticEncoder[_]] = + if (serializableEncodersOnly) { + val isClientConnect = clientConnectFlag.get + assert(typesToCheck.size == 1) + typesToCheck + .flatMap(c => { + if (!isClientConnect && classOf[KryoSerializable].isAssignableFrom(c)) { Review Comment: Can we add a comment explaining why isClientConnect being true disqualifies the type to be encodes using kryo Serializer. Is there a chance that will change in the future? If so, can we add a TODO statement here so that we can remove this condition if and when KryoSerization is avaialble with Spark Connect. ########## sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala: ########## @@ -148,34 +163,180 @@ object JavaTypeInference { // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(c) - // add type variables from inheritance hierarchy of the class - val classTV = JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap ++ - typeVariables - // Note that the fields are ordered by name. - val fields = properties.map { property => - val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, classTV) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) - EncoderField( - property.getName, - encoder, - encoder.nullable && !hasNonNull, - Metadata.empty, - Option(readMethod.getName), - Option(property.getWriteMethod).map(_.getName)) + + // if the properties is empty and this is not a top level enclosing class, then we + // should not consider class as bean, as otherwise it will be treated as empty schema + // and loose the data on deser. + if (properties.isEmpty && seenTypeSet.nonEmpty) { + findBestEncoder(Seq(c), seenTypeSet, typeVariables, None, serializableEncodersOnly = true) Review Comment: Could you please elobrate on why the serializableEncodersOnly flag is needed here? Maybe adding a comment for that makes it more clear? ########## sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala: ########## @@ -2909,6 +3016,7 @@ object KryoData { /** Used to test Java encoder. */ class JavaData(val a: Int) extends Serializable { + def this() = this(0) Review Comment: It is not clear to me my why the method 'def this()' needs to be explicitly defnined. I don't see it being used anywhere. Can we remove it or Could you please explain why this is here? ########## sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala: ########## @@ -2802,6 +2821,79 @@ class DatasetSuite extends QueryTest } } } + + test("SPARK-49789 Bean class encoding with generic type implementing Serializable") { + // just create encoder + val enc = Encoders.bean(classOf[MessageWrapper[_]]) + val data = Seq("test1", "test2").map(str => { + val msg = new MessageWrapper[String]() + msg.setMessage(str) + msg + }) + validateParamBeanDataset(classOf[MessageWrapper[String]], + data, mutable.Buffer(data: _*), + StructType(Seq(StructField("message", BinaryType, true))) + ) + } + + test("SPARK-49789 Bean class encoding with generic type indirectly extending" + + " Serializable class") { + // just create encoder + Encoders.bean(classOf[BigDecimalMessageWrapper[_]]) + val data = Seq(2d, 8d).map(doub => { + val bean = new BigDecimalMessageWrapper[DerivedBigDecimalExtender]() + bean.setMessage(new DerivedBigDecimalExtender(doub)) + bean + }) + validateParamBeanDataset( + classOf[BigDecimalMessageWrapper[DerivedBigDecimalExtender]], + data, mutable.Buffer(data: _*), + StructType(Seq(StructField("message", BinaryType, true)))) + } + + test("SPARK-49789. test bean class with generictype bound of UDTType") { + // just create encoder + UDTRegistration.register(classOf[TestUDT].getName, classOf[TestUDTType].getName) + val enc = Encoders.bean(classOf[UDTBean[_]]) + val baseData = Seq((1, "a"), (2, "b")) + val data = baseData.map(tup => { + val bean = new UDTBean[TestUDT]() + bean.setMessage(new TestUDTImplSub(tup._1, tup._2)) + bean + }) + val expectedData = baseData.map(tup => { + val bean = new UDTBean[TestUDT]() + bean.setMessage(new TestUDTImpl(tup._1, tup._2)) + bean + }) + validateParamBeanDataset( + classOf[UDTBean[TestUDT]], + data, mutable.Buffer(expectedData: _*), + StructType(Seq(StructField("message", new TestUDTType(), true)))) + } + + private def validateParamBeanDataset[T]( + classToEncode: Class[T], + data: Seq[T], + expectedData: mutable.Buffer[T], + expectedSchema: StructType): Unit = { + Review Comment: Indentation seems to be off. Could you please fix that. ########## connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala: ########## @@ -137,8 +137,16 @@ class SparkSession private[sql] ( /** @inheritdoc */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { - val encoder = JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]]) - createDataset(encoder, data.iterator().asScala).toDF() + JavaTypeInference.setSparkClientFlag() + val encoderTry = Try { + JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]]) + } + JavaTypeInference.unsetSparkClientFlag() Review Comment: Can we add a comment explaining exactly why we are setting and unsetting the Spark Client Flag here? My understanding is that based on if we use a regular spark Session or a spark session created from Spark connect makes a difference in terms of what we can infer. But the exact reason why there is a difference between regular SparkSesson and SparkSession correspoding to connect and what it is not super clear to me. Can we please document that here or where the connectClient Field is initialized. ########## sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala: ########## @@ -2945,3 +3053,132 @@ case class SaveModeArrayCase(modes: Array[SaveMode]) case class K1(a: Long) case class K2(a: Long, b: Long) + +class MessageWrapper[T <: java.io.Serializable] extends java.io.Serializable { + private var message: T = _ + + def getMessage: T = message + + def setMessage(message: T): Unit = { + this.message = message + } + + override def equals(obj: Any): Boolean = { + obj match { + case m: MessageWrapper[_] => m.message == this.message + + case _ => false + } + } + + override def hashCode(): Int = this.message.hashCode() +} + +class BigDecimalMessageWrapper[T <: BigDecimalExtender] extends java.io.Serializable { + private var message: T = _ + + def getMessage: T = message + + def setMessage(message: T): Unit = { + this.message = message + } + + override def equals(obj: Any): Boolean = { + obj match { + case m: BigDecimalMessageWrapper[_] => m.message == this.message + case _ => false + } + } + + override def hashCode(): Int = this.message.hashCode() +} + +class BigDecimalExtender(doub: Double) extends java.math.BigDecimal(doub) { + override def equals(obj: Any): Boolean = { + obj match { + case m: BigDecimalExtender => super.equals(m.asInstanceOf[java.math.BigDecimal]) + case _ => false + } + } + + override def hashCode(): Int = super.hashCode() +} + +class DerivedBigDecimalExtender(doub: Double) extends BigDecimalExtender(doub) { + override def equals(obj: Any): Boolean = { + obj match { + case m: DerivedBigDecimalExtender => super.equals(m.asInstanceOf[BigDecimalExtender]) + case _ => false + } + } + + override def hashCode(): Int = super.hashCode() +} + +trait TestUDT extends Serializable { + def intField: Int + + def stringField: String +} + +class TestUDTImpl(var intF: Int, var stringF: String) extends TestUDT { + def this() = this(0, "") + + override def intField: Int = intF + + override def stringField: String = stringF + + override def hashCode(): Int = intF.hashCode() + stringF.hashCode + + override def equals(obj: Any): Boolean = obj match { + case b: TestUDT => b.intField == this.intField && b.stringField == this.stringField + + case _ => false + } +} + +class TestUDTImplSub(var iF: Int, var sF: String) extends TestUDTImpl(iF, sF) { + def this() = this(0, "") Review Comment: Ditto ########## sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala: ########## @@ -148,34 +163,180 @@ object JavaTypeInference { // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(c) - // add type variables from inheritance hierarchy of the class - val classTV = JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap ++ - typeVariables - // Note that the fields are ordered by name. - val fields = properties.map { property => - val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, classTV) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) - EncoderField( - property.getName, - encoder, - encoder.nullable && !hasNonNull, - Metadata.empty, - Option(readMethod.getName), - Option(property.getWriteMethod).map(_.getName)) + + // if the properties is empty and this is not a top level enclosing class, then we + // should not consider class as bean, as otherwise it will be treated as empty schema + // and loose the data on deser. Review Comment: nit: Can you rename deser as 'deserialization' for more clarity. Also, I did not quite understand why if 'this is not a top lecvel enclosing class' 'it will be treated as empty schema'. What does that mean exactly? Could you please elobrate that. ########## sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala: ########## @@ -2945,3 +3053,132 @@ case class SaveModeArrayCase(modes: Array[SaveMode]) case class K1(a: Long) case class K2(a: Long, b: Long) + +class MessageWrapper[T <: java.io.Serializable] extends java.io.Serializable { + private var message: T = _ + + def getMessage: T = message + + def setMessage(message: T): Unit = { + this.message = message + } + + override def equals(obj: Any): Boolean = { + obj match { + case m: MessageWrapper[_] => m.message == this.message + + case _ => false + } + } + + override def hashCode(): Int = this.message.hashCode() +} + +class BigDecimalMessageWrapper[T <: BigDecimalExtender] extends java.io.Serializable { + private var message: T = _ + + def getMessage: T = message + + def setMessage(message: T): Unit = { + this.message = message + } + + override def equals(obj: Any): Boolean = { + obj match { + case m: BigDecimalMessageWrapper[_] => m.message == this.message + case _ => false + } + } + + override def hashCode(): Int = this.message.hashCode() +} + +class BigDecimalExtender(doub: Double) extends java.math.BigDecimal(doub) { + override def equals(obj: Any): Boolean = { + obj match { + case m: BigDecimalExtender => super.equals(m.asInstanceOf[java.math.BigDecimal]) + case _ => false + } + } + + override def hashCode(): Int = super.hashCode() +} + +class DerivedBigDecimalExtender(doub: Double) extends BigDecimalExtender(doub) { + override def equals(obj: Any): Boolean = { + obj match { + case m: DerivedBigDecimalExtender => super.equals(m.asInstanceOf[BigDecimalExtender]) + case _ => false + } + } + + override def hashCode(): Int = super.hashCode() +} + +trait TestUDT extends Serializable { + def intField: Int + + def stringField: String +} + +class TestUDTImpl(var intF: Int, var stringF: String) extends TestUDT { + def this() = this(0, "") Review Comment: Ditto ########## sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala: ########## @@ -148,34 +163,180 @@ object JavaTypeInference { // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(c) - // add type variables from inheritance hierarchy of the class - val classTV = JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap ++ - typeVariables - // Note that the fields are ordered by name. - val fields = properties.map { property => - val readMethod = property.getReadMethod - val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet + c, classTV) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) - EncoderField( - property.getName, - encoder, - encoder.nullable && !hasNonNull, - Metadata.empty, - Option(readMethod.getName), - Option(property.getWriteMethod).map(_.getName)) + + // if the properties is empty and this is not a top level enclosing class, then we + // should not consider class as bean, as otherwise it will be treated as empty schema + // and loose the data on deser. + if (properties.isEmpty && seenTypeSet.nonEmpty) { + findBestEncoder(Seq(c), seenTypeSet, typeVariables, None, serializableEncodersOnly = true) + .getOrElse(throw ExecutionErrors.cannotFindEncoderForTypeError(t.getTypeName)) + } else { + // add type variables from inheritance hierarchy of the class + val parentClassesTypeMap = + JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap + val classTV = parentClassesTypeMap ++ typeVariables + // Note that the fields are ordered by name. + val fields = properties.map { property => + val readMethod = property.getReadMethod + val methodReturnType = readMethod.getGenericReturnType + val encoder = encoderFor(methodReturnType, seenTypeSet + c, classTV) + // The existence of `javax.annotation.Nonnull`, means this field is not nullable. + val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull]) + EncoderField( + property.getName, + encoder, + encoder.nullable && !hasNonNull, + Metadata.empty, + Option(readMethod.getName), + Option(property.getWriteMethod).map(_.getName)) + } + // implies it cannot be assumed a BeanClass. + // Check if its super class or interface could be represented by an Encoder Review Comment: These two comments are a bit confusing. First commnet says "implies it cannot be assumed a BeanClass." . But then why is it initialized with JavaBeanEncoder? The second comment says "Check if its super class or interface could be represented by an Encoder". It is not clear to me where exactly this check is being done. Could you please elobrate it for me. Thanks. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org