Hi Team, I’m facing an issue with the execution order in the PySpark code snippet below. I’m not certain whether it’s caused by lazy evaluation, Spark plan optimization, or something else.
*Issue:* For the same data and scenario, during some runs, one of the final views is not returning any data. This appears to be due to changes in the execution order, which in turn affects the final result. In the final steps, we have three different DataFrames derived from the same base DataFrame, and I’m not sure if this could be the cause. I tried using the persist option to hold intermediate results and avoid potential lazy evaluation issues, but the problem still persists. Could you please review this issue and suggest a solution to ensure consistent execution order and results? *Note:* Please let me know if you need any clarification or additional information on this. Code: source_df = spark.sql(f""" SELECT * FROM {catalog_name}.{db_name}.{table_name}""") # Sample source query data_df = source_df.persist() # COMMAND ---------- type2_columns = [ ] data_df = updateRowHashValues(data_df, type2_columns, primary_key) # COMMAND ---------- target_df = spark.table(f"{catalog_name}.{db_name}.{table_name}" ).filter(col("IsCurrent") == True) target_col_list = target_df.columns source_with_target_df = data_df.alias("src").join(target_df.alias("tgt"), on ="PatientId", how="left") joined_df = source_with_target_df.persist() filtered_joined_df = joined_df.filter(col("tgt.DimTypeIIKey" ).isNull()).select([col("src." + c).alias(c) for c in data_df .columns]).drop("SourceTimeStamp") new_records_df = filtered_joined_df.filter((lower(col("ModifiedBy" )).contains("migration")) | (lower(col("ModifiedBy")).contains("migrated")) | (lower(col("CreatedBy")).contains("migration")) | (lower(col("CreatedBy" )).contains("migrated"))) new_records_source_df = filtered_joined_df.alias("jdf").join(new_records_df .alias("nrd"),col("jdf.PatientId") == col("nrd.PatientId"),how="left_anti" ).select([col("jdf." + c).alias(c) for c in filtered_joined_df .columns]).drop('SourceTimeStamp') # COMMAND ---------- if is_check_banfield_data and (not new_records_df.isEmpty()): #This is_check_banfield_data may get change based on the environment patient_df = spark.table(f"{banfield_catalog_name}.bfdw.patient" ).selectExpr("patientid as patient_patientid", "createdate as patient_createdate", "changedate as patient_changedate", "fw_modifiedts as patient_fw_modifiedts").withColumn('patient_patientid', upper( 'patient_patientid')) banfield_patient_df = patient_df.persist() window_spec = Window.partitionBy("patient_patientid").orderBy(col( "patient_changedate").desc(),col("patient_fw_modifiedts").desc()) banfield_patient_df = banfield_patient_df.withColumn("row_num", row_number().over(window_spec)) banfield_patient_df = banfield_patient_df.filter(col("row_num") == 1 ).drop("row_num") new_records_df = new_records_df.alias("new").join(banfield_patient_df .alias("pat"),col("new.PatientId") == col("pat.patient_patientid"),how= "left") new_records_df = new_records_df.withColumn("BusinessEffectiveStartDate" ,coalesce(col("pat.patient_createdate"), col("BusinessEffectiveStartDate" ))).select("new.*", "BusinessEffectiveStartDate") incoming_new_df = new_records_df.persist() else: is_check_banfield_data = False # COMMAND ---------- type2_changes = joined_df.filter((col("src.RowHashType2") != col( "tgt.RowHashType2")) & col("tgt.DimTypeIIKey").isNotNull()).select("src.*") type1_changes = joined_df.filter((col("src.RowHashType2") == col( "tgt.RowHashType2")) & (col("src.RowHashType1") != col("tgt.RowHashType1" ))).select("src.*") expired_records = joined_df.filter((col("src.RowHashType2") != col( "tgt.RowHashType2")) & col("tgt.DimTypeIIKey").isNotNull() ).select(col( "tgt.*"), col("src.BusinessEffectiveStartDate").alias( "NewBusinessEffectiveEndDate")).withColumn("BusinessEffectiveEndDate", col( "NewBusinessEffectiveEndDate")).withColumn("IsCurrent", lit(False)).drop( "NewBusinessEffectiveEndDate") # COMMAND ---------- max_key = spark.table(f"{catalog_name}.{db_name}.{table_name}" ).agg(spark_max(surrogate_key)).collect()[0][0] or 0 starting_key = max_key + 1 target_col_list = list(set(target_col_list) - {"DwCreatedYear", "DwCreatedMonth", "DwCreatedDay", "IsDataMigrated", "IsCurrent"} - set (surrogate_key.split(','))) type2_changes = type2_changes.select(target_col_list) if is_check_banfield_data: type2_changes = type2_changes.unionByName(incoming_new_df) patient_df.unpersist() new_records_df.unpersist() type2_changes = type2_changes.unionByName(new_records_source_df) type2_changes = type2_changes.withColumn("IsDataMigrated", when(lower(col( "ModifiedBy")).contains("migration") | lower(col("ModifiedBy")).contains( "migrated") | lower(col("CreatedBy")).contains("migration") | lower(col( "CreatedBy")).contains("migrated"),True).otherwise(False)) type2_changes = type2_changes.withColumn("BusinessEffectiveEndDate", lit( "9999-12-31").cast("date")).withColumn("IsCurrent", lit(True)).withColumn(surrogate_key, row_number().over(Window.orderBy("PatientId")) + starting_key - 1) type1_updates_columns = list(set(type1_changes.columns) - set(type2_columns )) type1_updates = type1_changes.select(*type1_updates_columns) expired_records.createOrReplaceTempView(f"temp_updates_type2_expired_{ db_name}_{table_name}") # This are the three final temp views that will be used in the merge statements or inserts. In some run for one of the views we don't getting data. type2_changes.createOrReplaceTempView(f"temp_inserts_new_records_{db_name}_{ table_name}") type1_updates.createOrReplaceTempView(f"temp_updates_type1_{db_name}_{ table_name}") # COMMAND ---------- # DBTITLE 1,Type1 column changes update existing_records_update = spark.sql(f"""MERGE INTO {catalog_name}.{db_name}. {table_name} AS tgt USING temp_updates_type1_{db_name}_{table_name} AS src ON tgt.PatientId = src.PatientId AND tgt.IsCurrent = true WHEN MATCHED THEN UPDATE SET tgt.col1 = src.col1, tgt.col2 = src.col2, tgt.col3 = src.col3, tgt.col4 = src.col4, . . . . tgt.RowHashType1 = src.RowHashType1""") print(f"Total no of records updated due to Type1 columns update: { existing_records_update.select('num_updated_rows').collect()[0][0]}") # COMMAND ---------- # DBTITLE 1,Update Expired Record update_expired_record = spark.sql(f"""MERGE INTO {catalog_name}.{db_name}.{ table_name} AS tgt USING temp_updates_type2_expired_{db_name}_{table_name} AS src ON tgt.PatientId = src.PatientId AND tgt.IsCurrent = true WHEN MATCHED THEN UPDATE SET tgt.IsCurrent = false, tgt.BusinessEffectiveEndDate = src.BusinessEffectiveEndDate, tgt.DwModifiedts = src.DwModifiedts, tgt.DwCreatedYear = year(src.DwModifiedts), tgt.DwCreatedMonth = month(src.DwModifiedts), tgt.DwCreatedDay = day(src.DwModifiedts)""") print(log_message=f"Total no of records marked IsCurrent as False due to type2 columns update: {update_expired_record.select('num_updated_rows' ).collect()[0][0]}") # COMMAND ---------- # DBTITLE 1,Insert new records as type2 value changed and first time data arrival new_records_insertion = spark.sql(f"""INSERT INTO {catalog_name}.{db_name}.{ table_name} ( col1() values( col1 ) FROM temp_inserts_new_records_{db_name}_{table_name} """) print(log_message=f"Total no of new records inserted: {new_records_insertion .select('num_inserted_rows').collect()[0][0]}") # COMMAND ---------- source_df.unpersist() source_with_target_df.unpersist()