This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new ef5f7c73a [SEDONA-669] Fix timestamp_nz for GeoParquet reader and 
writer (#1661)
ef5f7c73a is described below

commit ef5f7c73aab0a66869522568d81647ff8dbde446
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Oct 30 15:14:27 2024 +0800

    [SEDONA-669] Fix timestamp_nz for GeoParquet reader and writer (#1661)
    
    * Fix timestamp_nz for geoparquet format
    
    * Backport the fix to Spark 3.4
    
    * Overwrite the geoparquet output directory to avoid test failures when 
running with other tests
---
 .../datasources/parquet/GeoParquetFileFormat.scala |  3 ++
 .../parquet/GeoParquetRowConverter.scala           | 32 ++++++++++++++++++++++
 .../parquet/GeoParquetSchemaConverter.scala        | 12 ++++++++
 .../parquet/GeoParquetWriteSupport.scala           |  5 ++++
 .../org/apache/sedona/sql/geoparquetIOTests.scala  | 32 ++++++++++++++++++++--
 .../parquet/GeoParquetRowConverter.scala           | 32 ++++++++++++++++++++++
 .../parquet/GeoParquetSchemaConverter.scala        | 12 ++++++++
 .../parquet/GeoParquetWriteSupport.scala           |  5 ++++
 .../org/apache/sedona/sql/geoparquetIOTests.scala  | 29 ++++++++++++++++++++
 9 files changed, 159 insertions(+), 3 deletions(-)

diff --git 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
index 325a72098..cdb9834b8 100644
--- 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
+++ 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -202,6 +202,9 @@ class GeoParquetFileFormat(val spatialFilter: 
Option[GeoParquetSpatialFilter])
     hadoopConf.setBoolean(
       SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
       sparkSession.sessionState.conf.isParquetINT96AsTimestamp)
+    hadoopConf.setBoolean(
+      SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key,
+      sparkSession.sessionState.conf.parquetInferTimestampNTZEnabled)
 
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new 
SerializableConfiguration(hadoopConf))
diff --git 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
index 3e04a0a29..c50172874 100644
--- 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
+++ 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet
 
 import org.apache.parquet.column.Dictionary
 import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, 
PrimitiveConverter}
+import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
+import 
org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
 import org.apache.parquet.schema.OriginalType.LIST
 import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
 import org.apache.parquet.schema.{GroupType, OriginalType, Type}
@@ -312,6 +314,25 @@ private[parquet] class GeoParquetRowConverter(
           }
         }
 
+      case TimestampNTZType
+          if canReadAsTimestampNTZ(parquetType) &&
+            parquetType.getLogicalTypeAnnotation
+              .asInstanceOf[TimestampLogicalTypeAnnotation]
+              .getUnit == TimeUnit.MICROS =>
+        new ParquetPrimitiveConverter(updater)
+
+      case TimestampNTZType
+          if canReadAsTimestampNTZ(parquetType) &&
+            parquetType.getLogicalTypeAnnotation
+              .asInstanceOf[TimestampLogicalTypeAnnotation]
+              .getUnit == TimeUnit.MILLIS =>
+        new ParquetPrimitiveConverter(updater) {
+          override def addLong(value: Long): Unit = {
+            val micros = DateTimeUtils.millisToMicros(value)
+            updater.setLong(micros)
+          }
+        }
+
       case DateType =>
         new ParquetPrimitiveConverter(updater) {
           override def addInt(value: Int): Unit = {
@@ -379,6 +400,17 @@ private[parquet] class GeoParquetRowConverter(
     }
   }
 
+  // Only INT64 column with Timestamp logical annotation 
`isAdjustedToUTC=false`
+  // can be read as Spark's TimestampNTZ type. This is to avoid mistakes in 
reading the timestamp
+  // values.
+  private def canReadAsTimestampNTZ(parquetType: Type): Boolean =
+    schemaConverter.isTimestampNTZEnabled() &&
+      parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
+      
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation]
 &&
+      !parquetType.getLogicalTypeAnnotation
+        .asInstanceOf[TimestampLogicalTypeAnnotation]
+        .isAdjustedToUTC
+
   /**
    * Parquet converter for strings. A dictionary is used to minimize string 
decoding cost.
    */
diff --git 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
index eab20875a..10dd9e01d 100644
--- 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
+++ 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
@@ -42,6 +42,8 @@ import org.apache.spark.sql.types._
  *   Whether unannotated BINARY fields should be assumed to be Spark SQL 
[[StringType]] fields.
  * @param assumeInt96IsTimestamp
  *   Whether unannotated INT96 fields should be assumed to be Spark SQL 
[[TimestampType]] fields.
+ * @param inferTimestampNTZ
+ *   Whether TimestampNTZType type is enabled.
  * @param parameters
  *   Options for reading GeoParquet files.
  */
@@ -49,6 +51,7 @@ class GeoParquetToSparkSchemaConverter(
     keyValueMetaData: java.util.Map[String, String],
     assumeBinaryIsString: Boolean = 
SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
     assumeInt96IsTimestamp: Boolean = 
SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+    inferTimestampNTZ: Boolean = 
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
     parameters: Map[String, String]) {
 
   private val geoParquetMetaData: GeoParquetMetaData =
@@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter(
     keyValueMetaData = keyValueMetaData,
     assumeBinaryIsString = conf.isParquetBinaryAsString,
     assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
+    inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
     parameters = parameters)
 
   def this(
@@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter(
     keyValueMetaData = keyValueMetaData,
     assumeBinaryIsString = 
conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
     assumeInt96IsTimestamp = 
conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
+    inferTimestampNTZ = 
conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
     parameters = parameters)
 
+  /**
+   * Returns true if TIMESTAMP_NTZ type is enabled in this 
ParquetToSparkSchemaConverter.
+   */
+  def isTimestampNTZEnabled(): Boolean = {
+    inferTimestampNTZ
+  }
+
   /**
    * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL 
[[StructType]].
    */
diff --git 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
index 3a6a89773..9d6b36740 100644
--- 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
+++ 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
@@ -308,6 +308,11 @@ class GeoParquetWriteSupport extends 
WriteSupport[InternalRow] with Logging {
               recordConsumer.addLong(millis)
         }
 
+      case TimestampNTZType =>
+        // For TimestampNTZType column, Spark always output as INT64 with 
Timestamp annotation in
+        // MICROS time unit.
+        (row: SpecializedGetters, ordinal: Int) => 
recordConsumer.addLong(row.getLong(ordinal))
+
       case BinaryType =>
         (row: SpecializedGetters, ordinal: Int) =>
           
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)))
diff --git 
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala 
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
index ccfd560c8..f5bd8b486 100644
--- 
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
+++ 
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
@@ -32,15 +32,15 @@ import org.apache.spark.sql.functions.{col, expr}
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.sedona_sql.expressions.st_constructors.{ST_Point, 
ST_PolygonFromEnvelope}
 import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
-import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.sql.types.StructField
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType, 
TimestampNTZType}
 import org.json4s.jackson.parseJson
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.io.WKTReader
 import org.scalatest.BeforeAndAfterAll
 
 import java.io.File
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
 import java.util.Collections
 import java.util.concurrent.atomic.AtomicLong
 import scala.collection.JavaConverters._
@@ -732,6 +732,32 @@ class geoparquetIOTests extends TestBaseScala with 
BeforeAndAfterAll {
     }
   }
 
+  describe("Spark types tests") {
+    it("should support timestamp_ntz") {
+      // Write geoparquet files with a TimestampNTZ column
+      val schema = StructType(
+        Seq(
+          StructField("id", IntegerType, nullable = false),
+          StructField("timestamp_ntz", TimestampNTZType, nullable = false)))
+      val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
+      val data = Seq(
+        Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
+        Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)))
+      val df = sparkSession
+        .createDataFrame(sparkSession.sparkContext.parallelize(data), schema)
+        .withColumn("geom", expr("ST_Point(id, id)"))
+      
df.write.format("geoparquet").mode("overwrite").save(geoparquetoutputlocation)
+
+      // Read it back
+      val df2 =
+        
sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
+      assert(df2.schema.fields(1).dataType == TimestampNTZType)
+      val data1 = df.sort(col("id")).collect()
+      val data2 = df2.collect()
+      assert(data1 sameElements data2)
+    }
+  }
+
   def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => 
Unit): Unit = {
     val parquetFiles = new 
File(path).listFiles().filter(_.getName.endsWith(".parquet"))
     parquetFiles.foreach { filePath =>
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
index 07fc77e2c..44c65ab3e 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet
 
 import org.apache.parquet.column.Dictionary
 import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, 
PrimitiveConverter}
+import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
+import 
org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
 import org.apache.parquet.schema.OriginalType.LIST
 import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
 import org.apache.parquet.schema.{GroupType, OriginalType, Type}
@@ -315,6 +317,25 @@ private[parquet] class GeoParquetRowConverter(
           }
         }
 
+      case TimestampNTZType
+          if canReadAsTimestampNTZ(parquetType) &&
+            parquetType.getLogicalTypeAnnotation
+              .asInstanceOf[TimestampLogicalTypeAnnotation]
+              .getUnit == TimeUnit.MICROS =>
+        new ParquetPrimitiveConverter(updater)
+
+      case TimestampNTZType
+          if canReadAsTimestampNTZ(parquetType) &&
+            parquetType.getLogicalTypeAnnotation
+              .asInstanceOf[TimestampLogicalTypeAnnotation]
+              .getUnit == TimeUnit.MILLIS =>
+        new ParquetPrimitiveConverter(updater) {
+          override def addLong(value: Long): Unit = {
+            val micros = DateTimeUtils.millisToMicros(value)
+            updater.setLong(micros)
+          }
+        }
+
       case DateType =>
         new ParquetPrimitiveConverter(updater) {
           override def addInt(value: Int): Unit = {
@@ -382,6 +403,17 @@ private[parquet] class GeoParquetRowConverter(
     }
   }
 
+  // Only INT64 column with Timestamp logical annotation 
`isAdjustedToUTC=false`
+  // can be read as Spark's TimestampNTZ type. This is to avoid mistakes in 
reading the timestamp
+  // values.
+  private def canReadAsTimestampNTZ(parquetType: Type): Boolean =
+    schemaConverter.isTimestampNTZEnabled() &&
+      parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
+      
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation]
 &&
+      !parquetType.getLogicalTypeAnnotation
+        .asInstanceOf[TimestampLogicalTypeAnnotation]
+        .isAdjustedToUTC
+
   /**
    * Parquet converter for strings. A dictionary is used to minimize string 
decoding cost.
    */
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
index eab20875a..10dd9e01d 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
@@ -42,6 +42,8 @@ import org.apache.spark.sql.types._
  *   Whether unannotated BINARY fields should be assumed to be Spark SQL 
[[StringType]] fields.
  * @param assumeInt96IsTimestamp
  *   Whether unannotated INT96 fields should be assumed to be Spark SQL 
[[TimestampType]] fields.
+ * @param inferTimestampNTZ
+ *   Whether TimestampNTZType type is enabled.
  * @param parameters
  *   Options for reading GeoParquet files.
  */
@@ -49,6 +51,7 @@ class GeoParquetToSparkSchemaConverter(
     keyValueMetaData: java.util.Map[String, String],
     assumeBinaryIsString: Boolean = 
SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
     assumeInt96IsTimestamp: Boolean = 
SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+    inferTimestampNTZ: Boolean = 
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
     parameters: Map[String, String]) {
 
   private val geoParquetMetaData: GeoParquetMetaData =
@@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter(
     keyValueMetaData = keyValueMetaData,
     assumeBinaryIsString = conf.isParquetBinaryAsString,
     assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
+    inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
     parameters = parameters)
 
   def this(
@@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter(
     keyValueMetaData = keyValueMetaData,
     assumeBinaryIsString = 
conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
     assumeInt96IsTimestamp = 
conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
+    inferTimestampNTZ = 
conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
     parameters = parameters)
 
+  /**
+   * Returns true if TIMESTAMP_NTZ type is enabled in this 
ParquetToSparkSchemaConverter.
+   */
+  def isTimestampNTZEnabled(): Boolean = {
+    inferTimestampNTZ
+  }
+
   /**
    * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL 
[[StructType]].
    */
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
index fb5c92163..18f9f4f5c 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
@@ -309,6 +309,11 @@ class GeoParquetWriteSupport extends 
WriteSupport[InternalRow] with Logging {
               recordConsumer.addLong(millis)
         }
 
+      case TimestampNTZType =>
+        // For TimestampNTZType column, Spark always output as INT64 with 
Timestamp annotation in
+        // MICROS time unit.
+        (row: SpecializedGetters, ordinal: Int) => 
recordConsumer.addLong(row.getLong(ordinal))
+
       case BinaryType =>
         (row: SpecializedGetters, ordinal: Int) =>
           
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)))
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
index ccfd560c8..a6e74730a 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
@@ -35,6 +35,7 @@ import 
org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
 import org.apache.spark.sql.types.IntegerType
 import org.apache.spark.sql.types.StructField
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.TimestampNTZType
 import org.json4s.jackson.parseJson
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.io.WKTReader
@@ -43,6 +44,8 @@ import org.scalatest.BeforeAndAfterAll
 import java.io.File
 import java.util.Collections
 import java.util.concurrent.atomic.AtomicLong
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
 import scala.collection.JavaConverters._
 
 class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
@@ -732,6 +735,32 @@ class geoparquetIOTests extends TestBaseScala with 
BeforeAndAfterAll {
     }
   }
 
+  describe("Spark types tests") {
+    it("should support timestamp_ntz") {
+      // Write geoparquet files with a TimestampNTZ column
+      val schema = StructType(
+        Seq(
+          StructField("id", IntegerType, nullable = false),
+          StructField("timestamp_ntz", TimestampNTZType, nullable = false)))
+      val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
+      val data = Seq(
+        Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
+        Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)))
+      val df = sparkSession
+        .createDataFrame(sparkSession.sparkContext.parallelize(data), schema)
+        .withColumn("geom", expr("ST_Point(id, id)"))
+      
df.write.format("geoparquet").mode("overwrite").save(geoparquetoutputlocation)
+
+      // Read it back
+      val df2 =
+        
sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
+      assert(df2.schema.fields(1).dataType == TimestampNTZType)
+      val data1 = df.sort(col("id")).collect()
+      val data2 = df2.collect()
+      assert(data1 sameElements data2)
+    }
+  }
+
   def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue => 
Unit): Unit = {
     val parquetFiles = new 
File(path).listFiles().filter(_.getName.endsWith(".parquet"))
     parquetFiles.foreach { filePath =>

Reply via email to