shehabgamin commented on PR #16409:
URL: https://github.com/apache/datafusion/pull/16409#issuecomment-2972618052

   Not sure if it makes sense to commit the script I used, so I'll paste it 
here for now:
   ```
   """
   WARNING:
       - This script extracts only basic, straightforward tests.
       - It is not comprehensive and will not capture most function tests.
       - Intended as a quick-and-dirty tool for generating minimal Spark 
function tests.
       - Run this script from the root directory of the Sail project.
   """
   import glob
   import json
   import os
   import re
   
   from pyspark.sql import SparkSession
   
   # From project root in Sail: https://github.com/lakehq/sail
   FUNCTIONS_PATH = "crates/sail-spark-connect/tests/gold_data/function/"
   
   
   def extract_simple_function_arguments(query):
       """
       Extract arguments from simple function calls of pattern:
       SELECT SOME_FUNC(ARG0, ARG1, .... ARGN);
       Only accepts basic literal arguments - no arrays, nested functions, etc.
       Example queries NOT accepted:
           - query = "SELECT any(col) FROM VALUES (NULL), (true), (false) AS 
tab(col);"
           - query = "SELECT array_append(CAST(null as Array<Int>), 2);"
           - query = "SELECT array_append(array('b', 'd', 'c', 'a'), 'd');"
           - query = "SELECT * FROM explode(collection => array(10, 20));"
           - query = "SELECT cast('10' as int);"
       Example queries accepted:
           - query = "SELECT ceil(5);"
           - query = "SELECT ceil(3.1411, -3);"
           - query = "SELECT now();"
       """
       if any(f in query.lower() for f in ["cast", "map", "from", 
"raise_error", "regexp", "rlike", " in "]):
           return None
       pattern = r'SELECT\s+\w+\s*\(([^)]*)\)\s*;\s*'
       match = re.search(pattern, query, re.IGNORECASE | re.DOTALL)
       if not match:
           return None
       args_string = match.group(1).strip()
       if not args_string:  # Empty function call
           return []
       # Filter out complex arguments - reject if contains brackets, parens, 
etc...
       if any(char in args_string for char in ['[', ']', '(', ')']):
           return None
       arguments = re.split(r',(?=(?:[^"\']*["\'][^"\']*["\'])*[^"\']*$)', 
args_string)
       arguments = [arg.strip() for arg in arguments if arg.strip()]
       return arguments
   
   
   def extract_function_name(query):
       pattern = r'SELECT\s+(\w+)\s*\(([^)]*)\)\s*;\s*'
       match = re.search(pattern, query, re.IGNORECASE | re.DOTALL)
       if match:
           return match.group(1).strip()
       return None
   
   
   def create_typed_query(query, func_name, type_results):
       if not type_results:
           return query
       typed_args = []
       for key, spark_type in type_results.items():
           if key.startswith('typeof(') and key.endswith(')'):
               arg = key[7:-1]
               typed_args.append(f"{arg}::{spark_type}")
       typed_query = f"SELECT {func_name}({', '.join(typed_args)});"
       return [f"# Original Query: {query}", f"# PySpark 3.5.5 Result: 
{type_results}", typed_query]
   
   
   def main():
       spark = SparkSession.builder.getOrCreate()
       function_dict = {}
       json_files = glob.glob(os.path.join(FUNCTIONS_PATH, "*.json"))
       num_queries = 0
       for file_path in json_files:
           with open(file_path, "r") as f:
               data = json.load(f)
           directory_name = os.path.basename(file_path).removesuffix('.json')
           if directory_name not in function_dict:
               function_dict[directory_name] = {}
           for test in data["tests"]:
               if len(test["input"]["schema"]["fields"]) != 1:
                   # Skip generator tests with multiple fields
                   continue
               if "exception" in test:
                   # Skip tests that are expected to raise exceptions
                   continue
               query = test["input"]["query"].strip()
               arguments = extract_simple_function_arguments(query)
               if arguments is not None:
                   func_name = extract_function_name(query)
                   if func_name is not None:
                       func_call = re.sub('select', '', query, 
flags=re.IGNORECASE).strip().rstrip(';').strip()
                       if arguments:
                           typeof_parts = [f"typeof({arg})" for arg in 
arguments]
                           combined_query = f"SELECT {func_call}, 
typeof({func_call}), {', '.join(typeof_parts)};"
                       else:
                           combined_query = f"SELECT {func_call}, 
typeof({func_call});"
                       print(f"ORIGINAL QUERY: {query}\nRUNNING QUERY: 
{combined_query}")
                       try:
                           result = spark.sql(combined_query).collect()
                       except Exception as e:
                           if "CANNOT_PARSE_DATATYPE" in str(e):
                               print(f"Skipping query due to unsupported 
datatype: {e}")
                               continue
                           else:
                               raise
                       if len(result) != 1:
                           spark.stop()
                           raise ValueError(f"Unexpected result length: 
{len(result)} for query: {combined_query}")
                       result_row = result[0]
                       type_results = {}
                       for i in range(2, len(result_row)):
                           col_name = result_row.__fields__[i]
                           type_results[col_name.lower()] = result_row[i]
                       typed_query = create_typed_query(query, func_name, 
type_results)
                       if func_name.lower() not in 
function_dict[directory_name]:
                           function_dict[directory_name][func_name.lower()] = []
                       
function_dict[directory_name][func_name.lower()].append(typed_query)
                       num_queries += 1
       print(f"Processed {num_queries} queries from {len(json_files)} JSON 
files.")
       base_dir = os.path.join("tmp", "slt")
       for directory, functions in function_dict.items():
           dir_path = os.path.join(base_dir, directory)
           os.makedirs(dir_path, exist_ok=True)
           for func_name, queries in functions.items():
               file_path = os.path.join(dir_path, f"{func_name}.slt")
               with open(file_path, 'w') as f:
                   for query_data in queries:
                       f.write(f"#{query_data[0]}\n")
                       f.write(f"#{query_data[1]}\n")
                       f.write("#query\n")
                       f.write(f"#{query_data[2]}\n")
                       f.write("\n")
       spark.stop()
       return function_dict
   
   
   if __name__ == "__main__":
       main()
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to