shehabgamin commented on PR #16409: URL: https://github.com/apache/datafusion/pull/16409#issuecomment-2972618052
Not sure if it makes sense to commit the script I used, so I'll paste it here for now: ``` """ WARNING: - This script extracts only basic, straightforward tests. - It is not comprehensive and will not capture most function tests. - Intended as a quick-and-dirty tool for generating minimal Spark function tests. - Run this script from the root directory of the Sail project. """ import glob import json import os import re from pyspark.sql import SparkSession # From project root in Sail: https://github.com/lakehq/sail FUNCTIONS_PATH = "crates/sail-spark-connect/tests/gold_data/function/" def extract_simple_function_arguments(query): """ Extract arguments from simple function calls of pattern: SELECT SOME_FUNC(ARG0, ARG1, .... ARGN); Only accepts basic literal arguments - no arrays, nested functions, etc. Example queries NOT accepted: - query = "SELECT any(col) FROM VALUES (NULL), (true), (false) AS tab(col);" - query = "SELECT array_append(CAST(null as Array<Int>), 2);" - query = "SELECT array_append(array('b', 'd', 'c', 'a'), 'd');" - query = "SELECT * FROM explode(collection => array(10, 20));" - query = "SELECT cast('10' as int);" Example queries accepted: - query = "SELECT ceil(5);" - query = "SELECT ceil(3.1411, -3);" - query = "SELECT now();" """ if any(f in query.lower() for f in ["cast", "map", "from", "raise_error", "regexp", "rlike", " in "]): return None pattern = r'SELECT\s+\w+\s*\(([^)]*)\)\s*;\s*' match = re.search(pattern, query, re.IGNORECASE | re.DOTALL) if not match: return None args_string = match.group(1).strip() if not args_string: # Empty function call return [] # Filter out complex arguments - reject if contains brackets, parens, etc... if any(char in args_string for char in ['[', ']', '(', ')']): return None arguments = re.split(r',(?=(?:[^"\']*["\'][^"\']*["\'])*[^"\']*$)', args_string) arguments = [arg.strip() for arg in arguments if arg.strip()] return arguments def extract_function_name(query): pattern = r'SELECT\s+(\w+)\s*\(([^)]*)\)\s*;\s*' match = re.search(pattern, query, re.IGNORECASE | re.DOTALL) if match: return match.group(1).strip() return None def create_typed_query(query, func_name, type_results): if not type_results: return query typed_args = [] for key, spark_type in type_results.items(): if key.startswith('typeof(') and key.endswith(')'): arg = key[7:-1] typed_args.append(f"{arg}::{spark_type}") typed_query = f"SELECT {func_name}({', '.join(typed_args)});" return [f"# Original Query: {query}", f"# PySpark 3.5.5 Result: {type_results}", typed_query] def main(): spark = SparkSession.builder.getOrCreate() function_dict = {} json_files = glob.glob(os.path.join(FUNCTIONS_PATH, "*.json")) num_queries = 0 for file_path in json_files: with open(file_path, "r") as f: data = json.load(f) directory_name = os.path.basename(file_path).removesuffix('.json') if directory_name not in function_dict: function_dict[directory_name] = {} for test in data["tests"]: if len(test["input"]["schema"]["fields"]) != 1: # Skip generator tests with multiple fields continue if "exception" in test: # Skip tests that are expected to raise exceptions continue query = test["input"]["query"].strip() arguments = extract_simple_function_arguments(query) if arguments is not None: func_name = extract_function_name(query) if func_name is not None: func_call = re.sub('select', '', query, flags=re.IGNORECASE).strip().rstrip(';').strip() if arguments: typeof_parts = [f"typeof({arg})" for arg in arguments] combined_query = f"SELECT {func_call}, typeof({func_call}), {', '.join(typeof_parts)};" else: combined_query = f"SELECT {func_call}, typeof({func_call});" print(f"ORIGINAL QUERY: {query}\nRUNNING QUERY: {combined_query}") try: result = spark.sql(combined_query).collect() except Exception as e: if "CANNOT_PARSE_DATATYPE" in str(e): print(f"Skipping query due to unsupported datatype: {e}") continue else: raise if len(result) != 1: spark.stop() raise ValueError(f"Unexpected result length: {len(result)} for query: {combined_query}") result_row = result[0] type_results = {} for i in range(2, len(result_row)): col_name = result_row.__fields__[i] type_results[col_name.lower()] = result_row[i] typed_query = create_typed_query(query, func_name, type_results) if func_name.lower() not in function_dict[directory_name]: function_dict[directory_name][func_name.lower()] = [] function_dict[directory_name][func_name.lower()].append(typed_query) num_queries += 1 print(f"Processed {num_queries} queries from {len(json_files)} JSON files.") base_dir = os.path.join("tmp", "slt") for directory, functions in function_dict.items(): dir_path = os.path.join(base_dir, directory) os.makedirs(dir_path, exist_ok=True) for func_name, queries in functions.items(): file_path = os.path.join(dir_path, f"{func_name}.slt") with open(file_path, 'w') as f: for query_data in queries: f.write(f"#{query_data[0]}\n") f.write(f"#{query_data[1]}\n") f.write("#query\n") f.write(f"#{query_data[2]}\n") f.write("\n") spark.stop() return function_dict if __name__ == "__main__": main() ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org