## Environment and packages

In [10]:
import os, sys
import findspark
import time
import faker

__References__:
1. https://www.geeksforgeeks.org/python-faker-library/

ensure that SPARK_HOME is set otherwise write the path to SPARK_HOME as `findspark.init("SPARK_HOME_PATH")`

In [4]:
findspark.init()

In [5]:
import pyspark
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [6]:
try:
    sparkSession.stop()
except:
    print("no SPARK session to stop")
finally:
    sparkSession = pyspark.sql.SparkSession.builder.appName("test").getOrCreate()

no SPARK session to stop


In [7]:
sparkSession

In [8]:
sys.version

'3.6.7 |Anaconda custom (64-bit)| (default, Oct 23 2018, 14:01:38) \n[GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)]'

## Creating Sample Data

In [24]:
from pyspark.sql import types as T
from pyspark.sql import functions as F

In [35]:
from faker import Faker 
fake = Faker() 

In [36]:
sparkSession.udf.register("latitude", lambda x: fake.latitude(), T.DecimalType(15,7))
sparkSession.udf.register("longitude", lambda x: fake.longitude(), T.DecimalType(15,7))

<function __main__.<lambda>(x)>

In [39]:
%%time

sparkSession.range(4).selectExpr("id", "latitude(id) as latitude", "longitude(id) as longitude").show()

+---+----------+-----------+
| id|  latitude|  longitude|
+---+----------+-----------+
|  0|57.0346335|114.0692670|
|  1|57.0346335|114.0692670|
|  2|57.0346335|114.0692670|
|  3|57.0346335|114.0692670|
+---+----------+-----------+

CPU times: user 3.8 ms, sys: 1.73 ms, total: 5.54 ms
Wall time: 278 ms


In [44]:
%%time

sparkSession.range(10000000).selectExpr("id", "latitude(id) as latitude", "longitude(id) as longitude") \
            .write.mode("overwrite").parquet("~/Downloads/testsourcedata/")

CPU times: user 5.51 ms, sys: 2.97 ms, total: 8.48 ms
Wall time: 48 s


## Creating the actual function

__Pandas UDF__

In [55]:
@F.pandas_udf(T.StringType(), F.PandasUDFType.SCALAR)
def generate_mgrs_series(lat_lon_str, level):

    import mgrs
    m = mgrs.MGRS()

    precision_level = 0
    levelval = level[0]

    if levelval == 1000:
        precision_level = 2
    if levelval == 100:
        precision_level = 3

    def convert(ll_str):
            lat, lon = ll_str.split('_')

            return m.toMGRS(lat, lon, 
               MGRSPrecision = precision_level)

    return lat_lon_str.apply(lambda x: convert(x))

__Regular UDF__:

In [104]:
def generate_mgrs_series_udf(lat_lon_str, level):

    import mgrs
    m = mgrs.MGRS()

    precision_level = 0
    levelval = level[0]

    if levelval == 1000:
        precision_level = 2
    if levelval == 100:
        precision_level = 3

    def convert(ll_str):
            lat, lon = ll_str.split('_')

            return m.toMGRS(lat, lon, 
               MGRSPrecision = precision_level)

    #return lat_lon_str.apply(lambda x: convert(x))
    return convert(lat_lon_str).decode()

In [107]:
sparkSession.udf.register("mgrs", generate_mgrs_series_udf, T.StringType())

<function __main__.generate_mgrs_series_udf(lat_lon_str, level)>

## Testing the function

__Pandas UDF__

In [72]:
sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .withColumn("mgrs", generate_mgrs_series(concat("latitude", lit('_'), "longitude"), F.lit(1000))) \
                .show(4, truncate=False)

+-------+-----------+------------+---------+
|id     |latitude   |longitude   |mgrs     |
+-------+-----------+------------+---------+
|1250000|57.0346335 |114.0692670 |50VLJ2225|
|1250001|-40.3437270|-80.6874540 |17GNR2634|
|1250002|-59.9482215|-119.8964430|11ELP3850|
|1250003|-11.3446500|-22.6893000 |27LUH1545|
+-------+-----------+------------+---------+
only showing top 4 rows



__Regular UDF__

In [109]:
sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .selectExpr("*", "mgrs(concat(latitude, '_', longitude), array(1000)) as mgrs") \
                .show(4, truncate=False)

+-------+-----------+------------+---------+
|id     |latitude   |longitude   |mgrs     |
+-------+-----------+------------+---------+
|1250000|57.0346335 |114.0692670 |50VLJ2225|
|1250001|-40.3437270|-80.6874540 |17GNR2634|
|1250002|-59.9482215|-119.8964430|11ELP3850|
|1250003|-11.3446500|-22.6893000 |27LUH1545|
+-------+-----------+------------+---------+
only showing top 4 rows



testing the query plan

In [120]:
sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .withColumn("mgrs", generate_mgrs_series(concat("latitude", lit('_'), "longitude"), F.lit(1000))) \
                .explain()

== Physical Plan ==
*(2) Project [id#956L, latitude#957, longitude#958, pythonUDF0#968 AS mgrs#963]
+- ArrowEvalPython [generate_mgrs_series(concat(cast(latitude#957 as string), _, cast(longitude#958 as string)), 1000)], [id#956L, latitude#957, longitude#958, pythonUDF0#968]
   +- *(1) FileScan parquet [id#956L,latitude#957,longitude#958] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/gouravsengupta/Downloads/~/Downloads/testsourcedata], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint,latitude:decimal(15,7),longitude:decimal(15,7)>


In [121]:
sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .selectExpr("*", "mgrs(concat(latitude, '_', longitude), array(1000)) as mgrs") \
                .explain()

== Physical Plan ==
*(2) Project [id#969L, latitude#970, longitude#971, pythonUDF0#981 AS mgrs#975]
+- BatchEvalPython [mgrs(concat(cast(latitude#970 as string), _, cast(longitude#971 as string)), [1000])], [id#969L, latitude#970, longitude#971, pythonUDF0#981]
   +- *(1) FileScan parquet [id#969L,latitude#970,longitude#971] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/gouravsengupta/Downloads/~/Downloads/testsourcedata], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint,latitude:decimal(15,7),longitude:decimal(15,7)>


testing out the count time

In [117]:
%%time 

sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .withColumn("mgrs", generate_mgrs_series(concat("latitude", lit('_'), "longitude"), F.lit(1000))) \
                .count()

CPU times: user 4.26 ms, sys: 1.97 ms, total: 6.22 ms
Wall time: 202 ms


10000000

In [118]:
%%time

sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .selectExpr("*", "mgrs(concat(latitude, '_', longitude), array(1000)) as mgrs") \
                .count()

CPU times: user 5.9 ms, sys: 3.57 ms, total: 9.46 ms
Wall time: 160 ms


10000000

testing out the write time

In [128]:
%%time 

sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .withColumn("mgrs", generate_mgrs_series(concat("latitude", lit('_'), "longitude"), F.lit(1000))) \
                .write.mode("overwrite").parquet("~/Downloads/testtargetdata_pandas_udf")

CPU times: user 6.21 ms, sys: 2.47 ms, total: 8.68 ms
Wall time: 23.3 s


In [127]:
%%time 

sparkSession.read.parquet("~/Downloads/testsourcedata/") \
                .selectExpr("*", "mgrs(concat(latitude, '_', longitude), array(1000)) as mgrs") \
                .write.mode("overwrite").parquet("~/Downloads/testtargetdata_udf")

CPU times: user 5.21 ms, sys: 2.69 ms, total: 7.9 ms
Wall time: 34.6 s
