sririshindra commented on code in PR #48252:
URL: https://github.com/apache/spark/pull/48252#discussion_r1844241460


##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
       // TODO: we should only collect properties that have getter and setter. 
However, some tests
       //   pass in scala case class as java bean class which doesn't have 
getter and setter.
       val properties = getJavaBeanReadableProperties(c)
-      // add type variables from inheritance hierarchy of the class
-      val classTV = JavaTypeUtils.getTypeArguments(c, 
classOf[Object]).asScala.toMap ++
-        typeVariables
-      // Note that the fields are ordered by name.
-      val fields = properties.map { property =>
-        val readMethod = property.getReadMethod
-        val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet 
+ c, classTV)
-        // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
-        val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
-        EncoderField(
-          property.getName,
-          encoder,
-          encoder.nullable && !hasNonNull,
-          Metadata.empty,
-          Option(readMethod.getName),
-          Option(property.getWriteMethod).map(_.getName))
+
+      // if the properties is empty and this is not a top level enclosing 
class, then we
+      // should not consider class as bean, as otherwise it will be treated as 
empty schema
+      // and loose the data on deser.
+      if (properties.isEmpty && seenTypeSet.nonEmpty) {
+        findBestEncoder(Seq(c), seenTypeSet, typeVariables, None, 
serializableEncodersOnly = true)
+          .getOrElse(throw 
ExecutionErrors.cannotFindEncoderForTypeError(t.getTypeName))
+      } else {
+        // add type variables from inheritance hierarchy of the class
+        val parentClassesTypeMap =
+          JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap
+        val classTV = parentClassesTypeMap ++ typeVariables
+        // Note that the fields are ordered by name.
+        val fields = properties.map { property =>
+          val readMethod = property.getReadMethod
+          val methodReturnType = readMethod.getGenericReturnType
+          val encoder = encoderFor(methodReturnType, seenTypeSet + c, classTV)
+          // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
+          val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
+          EncoderField(
+            property.getName,
+            encoder,
+            encoder.nullable && !hasNonNull,
+            Metadata.empty,
+            Option(readMethod.getName),
+            Option(property.getWriteMethod).map(_.getName))
+        }
+        // implies it cannot be assumed a BeanClass.
+        // Check if its super class or interface could be represented by an 
Encoder
+
+        JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq)
       }
-      JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq)
 
     case _ =>
       throw ExecutionErrors.cannotFindEncoderForTypeError(t.toString)
   }
 
+  private def createUDTEncoderUsingAnnotation(c: Class[_]): UDTEncoder[Any] = {
+    val udt = c
+      .getAnnotation(classOf[SQLUserDefinedType])
+      .udt()
+      .getConstructor()
+      .newInstance()
+      .asInstanceOf[UserDefinedType[Any]]
+    val udtClass = 
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
+    UDTEncoder(udt, udtClass)
+  }
+
+  private def createUDTEncoderUsingRegistration(c: Class[_]): UDTEncoder[Any] 
= {
+    val udt = UDTRegistration
+      .getUDTFor(c.getName)
+      .get
+      .getConstructor()
+      .newInstance()
+      .asInstanceOf[UserDefinedType[Any]]
+    UDTEncoder(udt, udt.getClass)
+  }
+
   def getJavaBeanReadableProperties(beanClass: Class[_]): 
Array[PropertyDescriptor] = {
     val beanInfo = Introspector.getBeanInfo(beanClass)
     beanInfo.getPropertyDescriptors
       .filterNot(_.getName == "class")
       .filterNot(_.getName == "declaringClass")
       .filter(_.getReadMethod != null)
   }
+
+  private def findBestEncoder(
+      typesToCheck: Seq[Class[_]],
+      seenTypeSet: Set[Class[_]],
+      typeVariables: Map[TypeVariable[_], Type],
+      baseClass: Option[Class[_]],
+      serializableEncodersOnly: Boolean = false): Option[AgnosticEncoder[_]] =
+    if (serializableEncodersOnly) {
+      val isClientConnect = clientConnectFlag.get
+      assert(typesToCheck.size == 1)
+      typesToCheck
+        .flatMap(c => {
+          if (!isClientConnect && 
classOf[KryoSerializable].isAssignableFrom(c)) {

Review Comment:
   Can we add a comment explaining why isClientConnect being true disqualifies 
the type to be encodes using kryo Serializer. Is there a chance that will 
change in the future? If so, can we add a TODO statement here so that we can 
remove this condition if and when KryoSerization is avaialble with Spark 
Connect. 



##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
       // TODO: we should only collect properties that have getter and setter. 
However, some tests
       //   pass in scala case class as java bean class which doesn't have 
getter and setter.
       val properties = getJavaBeanReadableProperties(c)
-      // add type variables from inheritance hierarchy of the class
-      val classTV = JavaTypeUtils.getTypeArguments(c, 
classOf[Object]).asScala.toMap ++
-        typeVariables
-      // Note that the fields are ordered by name.
-      val fields = properties.map { property =>
-        val readMethod = property.getReadMethod
-        val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet 
+ c, classTV)
-        // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
-        val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
-        EncoderField(
-          property.getName,
-          encoder,
-          encoder.nullable && !hasNonNull,
-          Metadata.empty,
-          Option(readMethod.getName),
-          Option(property.getWriteMethod).map(_.getName))
+
+      // if the properties is empty and this is not a top level enclosing 
class, then we
+      // should not consider class as bean, as otherwise it will be treated as 
empty schema
+      // and loose the data on deser.
+      if (properties.isEmpty && seenTypeSet.nonEmpty) {
+        findBestEncoder(Seq(c), seenTypeSet, typeVariables, None, 
serializableEncodersOnly = true)

Review Comment:
   Could you please elobrate on why the serializableEncodersOnly flag is needed 
here? Maybe adding a comment for that makes it more clear?



##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2909,6 +3016,7 @@ object KryoData {
 
 /** Used to test Java encoder. */
 class JavaData(val a: Int) extends Serializable {
+  def this() = this(0)

Review Comment:
   It is not clear to me my why the method 'def this()' needs to be explicitly 
defnined. I don't see it being used anywhere. Can we remove it or Could you 
please explain why this is here? 



##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2802,6 +2821,79 @@ class DatasetSuite extends QueryTest
       }
     }
   }
+
+  test("SPARK-49789 Bean class encoding with generic type implementing 
Serializable") {
+    // just create encoder
+    val enc = Encoders.bean(classOf[MessageWrapper[_]])
+    val data = Seq("test1", "test2").map(str => {
+      val msg = new MessageWrapper[String]()
+      msg.setMessage(str)
+      msg
+    })
+    validateParamBeanDataset(classOf[MessageWrapper[String]],
+      data, mutable.Buffer(data: _*),
+      StructType(Seq(StructField("message", BinaryType, true)))
+    )
+  }
+
+  test("SPARK-49789 Bean class encoding with  generic type indirectly 
extending" +
+    " Serializable class") {
+    // just create encoder
+    Encoders.bean(classOf[BigDecimalMessageWrapper[_]])
+    val data = Seq(2d, 8d).map(doub => {
+      val bean = new BigDecimalMessageWrapper[DerivedBigDecimalExtender]()
+      bean.setMessage(new DerivedBigDecimalExtender(doub))
+      bean
+    })
+    validateParamBeanDataset(
+      classOf[BigDecimalMessageWrapper[DerivedBigDecimalExtender]],
+      data, mutable.Buffer(data: _*),
+      StructType(Seq(StructField("message", BinaryType, true))))
+  }
+
+  test("SPARK-49789. test bean class with  generictype bound of UDTType") {
+    // just create encoder
+    UDTRegistration.register(classOf[TestUDT].getName, 
classOf[TestUDTType].getName)
+    val enc = Encoders.bean(classOf[UDTBean[_]])
+    val baseData = Seq((1, "a"), (2, "b"))
+    val data = baseData.map(tup => {
+      val bean = new UDTBean[TestUDT]()
+      bean.setMessage(new TestUDTImplSub(tup._1, tup._2))
+      bean
+    })
+    val expectedData = baseData.map(tup => {
+      val bean = new UDTBean[TestUDT]()
+      bean.setMessage(new TestUDTImpl(tup._1, tup._2))
+      bean
+    })
+    validateParamBeanDataset(
+      classOf[UDTBean[TestUDT]],
+      data, mutable.Buffer(expectedData: _*),
+      StructType(Seq(StructField("message", new TestUDTType(), true))))
+  }
+
+  private def validateParamBeanDataset[T](
+                                           classToEncode: Class[T],
+                                           data: Seq[T],
+                                           expectedData: mutable.Buffer[T],
+                                           expectedSchema: StructType): Unit = 
{
+

Review Comment:
   Indentation seems to be off. Could you please fix that. 



##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala:
##########
@@ -137,8 +137,16 @@ class SparkSession private[sql] (
 
   /** @inheritdoc */
   def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame 
= {
-    val encoder = 
JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
-    createDataset(encoder, data.iterator().asScala).toDF()
+    JavaTypeInference.setSparkClientFlag()
+    val encoderTry = Try {
+      JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
+    }
+    JavaTypeInference.unsetSparkClientFlag()

Review Comment:
   Can we add a comment explaining exactly why we are setting and unsetting the 
Spark Client Flag here? 
   My understanding is that based on if we use a regular spark Session or a 
spark session created from Spark connect makes a difference in terms of what we 
can infer. But the exact reason why there is a difference between regular 
SparkSesson and SparkSession correspoding to connect and what it is not super 
clear to me. Can we please document that here or where the connectClient Field 
is initialized. 



##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2945,3 +3053,132 @@ case class SaveModeArrayCase(modes: Array[SaveMode])
 
 case class K1(a: Long)
 case class K2(a: Long, b: Long)
+
+class MessageWrapper[T <: java.io.Serializable] extends java.io.Serializable {
+  private var message: T = _
+
+  def getMessage: T = message
+
+  def setMessage(message: T): Unit = {
+    this.message = message
+  }
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: MessageWrapper[_] => m.message == this.message
+
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalMessageWrapper[T <: BigDecimalExtender] extends 
java.io.Serializable {
+  private var message: T = _
+
+  def getMessage: T = message
+
+  def setMessage(message: T): Unit = {
+    this.message = message
+  }
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: BigDecimalMessageWrapper[_] => m.message == this.message
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalExtender(doub: Double) extends java.math.BigDecimal(doub) {
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: BigDecimalExtender => 
super.equals(m.asInstanceOf[java.math.BigDecimal])
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = super.hashCode()
+}
+
+class DerivedBigDecimalExtender(doub: Double) extends BigDecimalExtender(doub) 
{
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: DerivedBigDecimalExtender => 
super.equals(m.asInstanceOf[BigDecimalExtender])
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = super.hashCode()
+}
+
+trait TestUDT extends Serializable {
+  def intField: Int
+
+  def stringField: String
+}
+
+class TestUDTImpl(var intF: Int, var stringF: String) extends TestUDT {
+  def this() = this(0, "")
+
+  override def intField: Int = intF
+
+  override def stringField: String = stringF
+
+  override def hashCode(): Int = intF.hashCode() + stringF.hashCode
+
+  override def equals(obj: Any): Boolean = obj match {
+    case b: TestUDT => b.intField == this.intField && b.stringField == 
this.stringField
+
+    case _ => false
+  }
+}
+
+class TestUDTImplSub(var iF: Int, var sF: String) extends TestUDTImpl(iF, sF) {
+  def this() = this(0, "")

Review Comment:
   Ditto



##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
       // TODO: we should only collect properties that have getter and setter. 
However, some tests
       //   pass in scala case class as java bean class which doesn't have 
getter and setter.
       val properties = getJavaBeanReadableProperties(c)
-      // add type variables from inheritance hierarchy of the class
-      val classTV = JavaTypeUtils.getTypeArguments(c, 
classOf[Object]).asScala.toMap ++
-        typeVariables
-      // Note that the fields are ordered by name.
-      val fields = properties.map { property =>
-        val readMethod = property.getReadMethod
-        val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet 
+ c, classTV)
-        // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
-        val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
-        EncoderField(
-          property.getName,
-          encoder,
-          encoder.nullable && !hasNonNull,
-          Metadata.empty,
-          Option(readMethod.getName),
-          Option(property.getWriteMethod).map(_.getName))
+
+      // if the properties is empty and this is not a top level enclosing 
class, then we
+      // should not consider class as bean, as otherwise it will be treated as 
empty schema
+      // and loose the data on deser.

Review Comment:
   nit: Can you rename deser as 'deserialization' for more clarity.
   
   Also, I did not quite understand why if 'this is  not a top lecvel enclosing 
class'  'it will be treated as empty schema'. What does that mean exactly? 
Could you please elobrate that.



##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2945,3 +3053,132 @@ case class SaveModeArrayCase(modes: Array[SaveMode])
 
 case class K1(a: Long)
 case class K2(a: Long, b: Long)
+
+class MessageWrapper[T <: java.io.Serializable] extends java.io.Serializable {
+  private var message: T = _
+
+  def getMessage: T = message
+
+  def setMessage(message: T): Unit = {
+    this.message = message
+  }
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: MessageWrapper[_] => m.message == this.message
+
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalMessageWrapper[T <: BigDecimalExtender] extends 
java.io.Serializable {
+  private var message: T = _
+
+  def getMessage: T = message
+
+  def setMessage(message: T): Unit = {
+    this.message = message
+  }
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: BigDecimalMessageWrapper[_] => m.message == this.message
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalExtender(doub: Double) extends java.math.BigDecimal(doub) {
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: BigDecimalExtender => 
super.equals(m.asInstanceOf[java.math.BigDecimal])
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = super.hashCode()
+}
+
+class DerivedBigDecimalExtender(doub: Double) extends BigDecimalExtender(doub) 
{
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case m: DerivedBigDecimalExtender => 
super.equals(m.asInstanceOf[BigDecimalExtender])
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = super.hashCode()
+}
+
+trait TestUDT extends Serializable {
+  def intField: Int
+
+  def stringField: String
+}
+
+class TestUDTImpl(var intF: Int, var stringF: String) extends TestUDT {
+  def this() = this(0, "")

Review Comment:
   Ditto



##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
       // TODO: we should only collect properties that have getter and setter. 
However, some tests
       //   pass in scala case class as java bean class which doesn't have 
getter and setter.
       val properties = getJavaBeanReadableProperties(c)
-      // add type variables from inheritance hierarchy of the class
-      val classTV = JavaTypeUtils.getTypeArguments(c, 
classOf[Object]).asScala.toMap ++
-        typeVariables
-      // Note that the fields are ordered by name.
-      val fields = properties.map { property =>
-        val readMethod = property.getReadMethod
-        val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet 
+ c, classTV)
-        // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
-        val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
-        EncoderField(
-          property.getName,
-          encoder,
-          encoder.nullable && !hasNonNull,
-          Metadata.empty,
-          Option(readMethod.getName),
-          Option(property.getWriteMethod).map(_.getName))
+
+      // if the properties is empty and this is not a top level enclosing 
class, then we
+      // should not consider class as bean, as otherwise it will be treated as 
empty schema
+      // and loose the data on deser.
+      if (properties.isEmpty && seenTypeSet.nonEmpty) {
+        findBestEncoder(Seq(c), seenTypeSet, typeVariables, None, 
serializableEncodersOnly = true)
+          .getOrElse(throw 
ExecutionErrors.cannotFindEncoderForTypeError(t.getTypeName))
+      } else {
+        // add type variables from inheritance hierarchy of the class
+        val parentClassesTypeMap =
+          JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap
+        val classTV = parentClassesTypeMap ++ typeVariables
+        // Note that the fields are ordered by name.
+        val fields = properties.map { property =>
+          val readMethod = property.getReadMethod
+          val methodReturnType = readMethod.getGenericReturnType
+          val encoder = encoderFor(methodReturnType, seenTypeSet + c, classTV)
+          // The existence of `javax.annotation.Nonnull`, means this field is 
not nullable.
+          val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
+          EncoderField(
+            property.getName,
+            encoder,
+            encoder.nullable && !hasNonNull,
+            Metadata.empty,
+            Option(readMethod.getName),
+            Option(property.getWriteMethod).map(_.getName))
+        }
+        // implies it cannot be assumed a BeanClass.
+        // Check if its super class or interface could be represented by an 
Encoder

Review Comment:
   These two comments are a bit confusing. 
   
   First commnet says "implies it cannot be assumed a BeanClass." . But then 
why is it initialized with JavaBeanEncoder?
   
   The second comment says "Check if its super class or interface could be 
represented by an Encoder". It is not clear to me where exactly this check is 
being done. Could you please elobrate it for me. Thanks. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to