On Tue, Apr 1, 2025 at 7:02 PM Peter Geoghegan <p...@bowt.ie> wrote:
> Though I think it should be "" safe even when "key->sk_attno >
> firstchangingattnum" "", to highlight that the rule here is
> significantly more permissive than even the nearby range skip array
> case, which is still safe when (key->sk_attno == firstchangingattnum).

Mark Dilger reported a bug in commit 8a510275 on Saturday, which I
fixed in commit b75fedca from Monday. Mark's repro was a little bit
complicated, though.

Attached is a Python script that performs fuzz testing of nbtree skips
scan. It is capable of quickly finding the same bug as the one that
Mark reported. The script generates random, complicated multi-column
index scans on a (a, b, c, d) index on a test table, and verifies that
each queries gives the same answer as an equivalent sequential scan
plan. This works quite well as a general smoke test. I find that if I
deliberately add somewhat plausible bugs to the code in
_bt_set_startikey, the fuzz testing script is usually able to identify
wrong answers to queries in under a minute.

I don't expect that this script will actually discover any real bugs
-- I ran it for long enough to get the sense that that was unlikely.
But it seemed like a worthwhile exercise.

-- 
Peter Geoghegan
#!/usr/bin/env python3
from itertools import product
import math
import psycopg2
import random
import time

def biased_random_int(min_val, max_val):
    """
    Generate a random integer between min_val and max_val (inclusive),
    with reduced probability near the boundaries.

    Args:
        min_val: Minimum possible value
        max_val: Maximum possible value

    Returns:
        int: A random integer with reduced probability near the bounds
    """
    # Use beta distribution to create a bell-shaped probability curve
    # Alpha = Beta = 2 gives a parabolic shape with lower probability at extremes
    alpha = 2
    beta = 2

    # Generate a random value between 0 and 1 with beta distribution
    random_val = random.betavariate(alpha, beta)

    # Scale to our range and round to integer
    scaled_val = min_val + random_val * (max_val - min_val)
    return round(scaled_val)

class PostgreSQLSkipScanTester:
    def __init__(self, conn_params, table_name, num_rows, num_samples, num_tests):
        self.conn_params = conn_params
        self.table_name = table_name
        self.num_rows = num_rows
        self.num_samples = num_samples
        self.num_tests = num_tests
        self.columns = ['a', 'b', 'c', 'd']
        self.equality_operators = ['=', 'IN', 'IS NULL']
        self.inequality_operators = ['<', '<=', '>=', '>', 'IS NOT NULL']
        self.conn = None

    def connect(self):
        """Establish connection to PostgreSQL database"""
        try:
            self.conn = psycopg2.connect(**self.conn_params)
            print("Successfully connected")
        except Exception as e:
            print(f"Connection error: {e}")
            raise

    def setup_test_environment(self):
        """Create test table and populate with random data"""
        try:
            cursor = self.conn.cursor()

            # Create test table
            cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
            cursor.execute(f"""
                create table {self.table_name} (
                    id serial primary key,
                    a integer,
                    b integer,
                    c integer,
                    d integer)
            """)

            # Create the composite index first (makes suffix truncation
            # effective)
            cursor.execute(f"""
                CREATE INDEX idx_{self.table_name}_abcd
                ON {self.table_name} (a, b, c, d)
            """)
            # Insert random data (with some NULL values)
            for _ in range(self.num_rows):
                # Randomly decide whether to include NULL values for each column
                a_val = None if random.random() < 0.05 else random.randint(1, 20)
                b_val = None if random.random() < 0.05 else random.randint(1, 20)
                c_val = None if random.random() < 0.05 else random.randint(1, 100)
                d_val = None if random.random() < 0.05 else random.randint(1, 10_000)

                cursor.execute(f"""
                    INSERT INTO {self.table_name} (a, b, c, d)
                    VALUES (%s, %s, %s, %s)
                """, (a_val, b_val, c_val, d_val))

            self.conn.commit()

            # VACUUM
            self.conn.set_session(autocommit=True) # So that we can run VACUUM, etc
            cursor.execute(f"""
                vacuum analyze {self.table_name}
            """)

            print(f"Created test table {self.table_name} with {self.num_rows} rows")
            cursor.execute(f"""
                set max_parallel_workers_per_gather to 0;
            """)

        except Exception as e:
            self.conn.rollback()
            print(f"Setup error: {e}")
            raise

    def generate_cond(self, column, operator_type=None,
                      equality_weights = [0.78, 0.20, 0.02], # Use IS NULL much less often
                      inequality_weights = [0.20, 0.20, 0.20, 0.17, 0.03]):   # Use IS NOT NULL much less often
        """
        Generate a single condition for a column using the specified operator type.

        Args:
            column: The column name to generate condition for
            operator_type: 'equality', 'inequality', or None (random)

        Returns:
            A condition string like "a > 5" or "b IS NULL"
        """
        if operator_type == 'equality':
            op = random.choices(self.equality_operators, weights=equality_weights)[0]
        elif operator_type == 'inequality':
            op = random.choices(self.inequality_operators, weights=inequality_weights)[0]
        else:
            # Random operator from all operators
            all_operators = self.equality_operators + self.inequality_operators
            op = random.choice(all_operators)

        if op == 'IS NULL' or op == 'IS NOT NULL':
            return f"{column} {op}"
        elif op == 'IN':
            nelements = random.randint(2, 20)
            elements = set()
            for i in range(nelements):
                # Generate a value appropriate for the column
                if column == 'a':
                    elements.add(biased_random_int(-1, 21))
                elif column == 'b':
                    elements.add(biased_random_int(-1, 21))
                elif column == 'c':
                    elements.add(biased_random_int(-1, 101))
                else:  # column == 'd'
                    elements.add(biased_random_int(-1, 10_001))

            values_str = ", ".join(str(x) for x in sorted(elements))
            return f"{column} IN ({values_str})"
        else:
            # Generate a value appropriate for the column
            if column == 'a':
                value = biased_random_int(-1, 21)
            elif column == 'b':
                value = biased_random_int(-1, 21)
            elif column == 'c':
                value = biased_random_int(-1, 101)
            else:  # column == 'd'
                value = biased_random_int(-1, 10_001)

            return f"{column} {op} {value}"

    def find_matching_operator_indices(self, all_conditions):

        matching_indices = set()

        for index, operator in enumerate(self.inequality_operators):
            for dynamic_string in all_conditions:
                if operator in dynamic_string:
                    matching_indices.add(index)
                    if index == 0: # if < add <= to avoid redundancies
                        matching_indices.add(1)
                    if index == 1: # if <= add < to avoid redundancies
                        matching_indices.add(0)
                    if index == 2: # if >= add > to avoid redundancies
                        matching_indices.add(3)
                    if index == 3: # if > add >= to avoid redundancies
                        matching_indices.add(2)
                    break  # Found one match for this operator, no need to check other strings

        return matching_indices

    def sort_constraints(self, constraints):
        """
        Sort constraints with column name treated most significant, followed by operator order.
        Lower bound operators (>, >=) sort before upper bound operators (<, <=).  This presents
        things in a consistent order, that seems more readable.

        Args:
            constraints: List of constraint strings (e.g., ["a <= 5", "a > 8", "a >= 4", "b < 5"])

        Returns:
            list: Sorted list of constraints
        """
        # Define operator precedence (lower values = higher precedence)
        operator_priority = {
            ">": 0,   # Highest priority for lower bounds
            ">=": 1,
            "<": 2,   # Lower priority for upper bounds
            "<=": 3
        }

        def get_sort_key(constraint):
            # Parse the constraint into column and operator parts
            parts = constraint.split()
            if len(parts) >= 2:
                column = parts[0]
                operator = parts[1]

                # Return a tuple for sorting (column name, operator priority)
                return (column, operator_priority.get(operator, 999))
            return (constraint, 999)  # Fallback for unparseable constraints

        # Sort the constraints using the custom sort key
        return sorted(constraints, key=get_sort_key)

    def resolve_contradictions(self, conditions):
        """
        Allowing contradictory quals seems to be a poor use of available test cycles.
        Resolves contradictions in a list of conditions by swapping integer constants
        when needed, specifically for conditions on the same column that are impossible
        to satisfy simultaneously.

        Args:
            conditions: List of condition strings (e.g., ["b > 16", "b < 9", "d IS NOT NULL"])

        Returns:
            list: Modified list of conditions with contradictions resolved via
            constant swapping
        """
        # Parse conditions into a more usable format
        parsed_conditions = []
        for condition in conditions:
            parts = condition.split()
            # Only process numeric comparisons
            if len(parts) == 3 and parts[1] in ['>', '<', '>=', '<=']:
                try:
                    column = parts[0]
                    operator = parts[1]
                    value = int(parts[2])
                    parsed_conditions.append((column, operator, value, condition))
                except ValueError:
                    # If value isn't an integer, just keep the original condition
                    parsed_conditions.append((None, None, None, condition))
            else:
                # For non-comparison conditions like "IS NOT NULL"
                parsed_conditions.append((None, None, None, condition))

        # Group conditions by column
        column_conditions = {}
        for col, op, val, cond in parsed_conditions:
            if col is not None:
                if col not in column_conditions:
                    column_conditions[col] = []
                column_conditions[col].append((op, val, cond))

        # Check and resolve contradictions
        modified_conditions = conditions.copy()

        for column, col_conditions in column_conditions.items():
            lower_bounds = []  # > and >=
            upper_bounds = []  # < and <=

            # Separate into lower and upper bounds
            for op, val, cond in col_conditions:
                if op in ['>', '>=']:
                    lower_bounds.append((op, val, cond))
                elif op in ['<', '<=']:
                    upper_bounds.append((op, val, cond))

            # Check for contradictions between lower and upper bounds
            for lower_op, lower_val, lower_cond in lower_bounds:
                for upper_op, upper_val, upper_cond in upper_bounds:
                    # Contradiction if lower bound >= upper bound
                    is_contradiction = False

                    if lower_op == '>' and upper_op == '<' and lower_val >= upper_val:
                        is_contradiction = True
                    elif lower_op == '>' and upper_op == '<=' and lower_val >= upper_val:
                        is_contradiction = True
                    elif lower_op == '>=' and upper_op == '<' and lower_val >= upper_val:
                        is_contradiction = True
                    elif lower_op == '>=' and upper_op == '<=' and lower_val > upper_val:
                        is_contradiction = True

                    # If contradiction, swap the values
                    if is_contradiction:
                        # Create new conditions with swapped values
                        new_lower_cond = f"{column} {lower_op} {upper_val}"
                        new_upper_cond = f"{column} {upper_op} {lower_val}"

                        # Replace the old conditions with new ones
                        modified_conditions[modified_conditions.index(lower_cond)] = new_lower_cond
                        modified_conditions[modified_conditions.index(upper_cond)] = new_upper_cond

        return modified_conditions

    def generate_random_query(self):
        """
        Generate a random query with conditions on columns.
        For columns without equality conditions, sometimes generate multiple inequality conditions.
        """
        # Decide which columns will have conditions (at least one, at most four)
        assert(len(self.columns) == 4)
        weights = [0.05, 0.05, 0.05, 0.85]  # Probabilities for 1, 2, 3, or 4 columns
        num_columns_with_conditions = random.choices([1, 2, 3, 4], weights=weights)[0]
        columns_with_conditions = random.sample(self.columns, num_columns_with_conditions)

        # Avoid just having one column with conditions when that column is "a";
        # make it on "b", instead
        if num_columns_with_conditions == 1 and columns_with_conditions[0] == 'a':
            columns_with_conditions[0] = 'b'

        # Track which columns have equality/IS NULL conditions
        columns_with_equality = set()
        all_conditions = []

        # First pass: decide on equality vs inequality for each column
        for col in columns_with_conditions:
            # 50% chance of having an equality condition
            if random.random() < 0.5:
                condition = self.generate_cond(col, 'equality')
                all_conditions.append(condition)
                columns_with_equality.add(col)
            else:
                # Will handle inequalities in the next pass
                pass

        # Second pass: handle inequality conditions, possibly multiple per column
        for col in columns_with_conditions:
            if col not in columns_with_equality:
                # This column doesn't have an equality condition, so it can have multiple inequalities

                # Determine how many inequality conditions to add (1-3)
                weights = [0.25, 0.7, 0.05]  # Probabilities for 1, 2, or 3 inequality conditions
                num_conditions = random.choices([1, 2, 3], weights=weights)[0]
                # inequality_weights is self.inequality_operators-offset-wise list:
                inequality_weights = [0.20, 0.20, 0.20, 0.17, 0.03]

                for _ in range(num_conditions):
                    zero_weights = self.find_matching_operator_indices(all_conditions)
                    if zero_weights == set([0, 1, 2, 3, 4]):
                        break
                    for index, value in enumerate(zero_weights):
                        inequality_weights[value] = 0
                    condition = self.generate_cond(col, 'inequality',
                                                               inequality_weights=inequality_weights)
                    all_conditions.append(condition)

        # If we somehow ended up with no conditions (unlikely), add one
        if not all_conditions:
            col = random.choice(self.columns)
            all_conditions.append(self.generate_cond(col))

        all_conditions = self.sort_constraints(all_conditions)
        all_conditions = self.resolve_contradictions(all_conditions)
        return " AND ".join(all_conditions)

    def execute_test_query(self, where_clause):
        """Execute a test query with both sequential scan and index scan"""
        cursor = self.conn.cursor()

        # Force sequential scan
        cursor.execute("SET enable_indexscan = off; SET enable_bitmapscan = off;")
        seq_query = f"EXPLAIN ANALYZE SELECT * FROM {self.table_name} WHERE {where_clause}"
        cursor.execute(seq_query)
        seq_plan = cursor.fetchall()

        # Get sequential scan results
        cursor.execute(f"SELECT * FROM {self.table_name} WHERE {where_clause}")
        seq_results = cursor.fetchall()

        # Force index scan
        cursor.execute("SET enable_indexscan = on; SET enable_seqscan = off; SET enable_bitmapscan = off;")
        idx_query = f"EXPLAIN ANALYZE SELECT * FROM {self.table_name} WHERE {where_clause}"
        cursor.execute(idx_query)
        idx_plan = cursor.fetchall()

        # Get index scan results
        cursor.execute(f"SELECT * FROM {self.table_name} WHERE {where_clause}")
        idx_results = cursor.fetchall()

        # Reset scan settings
        cursor.execute("RESET enable_indexscan; RESET enable_seqscan; RESET enable_bitmapscan;")

        return {
            'where_clause': where_clause,
            'seq_plan': seq_plan,
            'idx_plan': idx_plan,
            'seq_results': seq_results,
            'idx_results': idx_results,
            'results_match': sorted(seq_results) == sorted(idx_results),
            'seq_count': len(seq_results),
            'idx_count': len(idx_results)
        }

    def verify_scan_results(self, test_result):
        """Verify that sequential scan and index scan results match"""
        if not test_result['results_match']:
            print("\n❌ TEST FAILED: Results do not match!")
            print(f"Query: SELECT * FROM {self.table_name} WHERE {test_result['where_clause']}")
            print(f"Sequential scan found {test_result['seq_count']} rows")
            print(f"Index scan found {test_result['idx_count']} rows")
            return False
        return True

    def run_fuzzing_queries(self):
        """Run a batch of random test queries and verify results"""
        print(f"\nRunning {self.num_tests} random test queries...")

        start_time = time.time()
        failures = 0
        multiple_inequality_count = 0

        for i in range(1, self.num_tests + 1):
            where_clause = self.generate_random_query()

            # Count queries with multiple inequalities on the same column (fixed)
            multiple_inequalities = False
            for column in self.columns:
                # Count occurrences of inequality operators for this column
                column_conditions = sum(1 for op in ['<', '<=', '>=', '>', 'IS NOT NULL']
                                      if f"{column} {op}" in where_clause)
                if column_conditions > 1:
                    multiple_inequalities = True
                    break

            if multiple_inequalities:
                multiple_inequality_count += 1

            test_result = self.execute_test_query(where_clause)

            if not self.verify_scan_results(test_result):
                failures += 1

            if i % 10 == 0:
                print(f"Completed {i} tests. Failures: {failures}")

        end_time = time.time()
        duration = end_time - start_time

        print(f"\nCompleted {self.num_tests} tests in {duration:.2f} seconds")
        print(f"Queries with multiple inequalities on the same column: {multiple_inequality_count}")
        print(f"Total failures: {failures}")

        if failures == 0:
            print("✅ All tests passed!")
        else:
            print(f"❌ {failures} tests failed!")

        return failures == 0

    def dump_plan_samples(self):
        """Analyze and print execution plans for a few sample queries"""
        print(f"\nAnalyzing execution plans for {self.num_samples} sample queries...")

        for i in range(self.num_samples):
            where_clause = self.generate_random_query()
            test_result = self.execute_test_query(where_clause)

            print(f"\nQuery {i+1}: SELECT * FROM {self.table_name} WHERE {where_clause}")
            print("\nSequential scan plan:")
            for line in test_result['seq_plan']:
                print(line[0])

            print("\nIndex scan plan:")
            for line in test_result['idx_plan']:
                print(line[0])

            print(f"\nResults match: {test_result['results_match']}")
            print(f"Row count: {test_result['seq_count']}")
            print("-" * 80)

    def cleanup(self):
        """Close connection"""
        if self.conn:
            try:
                cursor = self.conn.cursor()
                self.conn.commit()
                self.conn.close()
                print(f"Closed connection")
            except Exception as e:
                print(f"Cleanup error: {e}")

    def run_all_tests(self):
        """Run all test types"""
        try:
            self.connect()
            self.setup_test_environment()

            # Analyze some sample execution plans first, to preview the work
            # that run_fuzzing_queries() will do (runs quickly)
            self.dump_plan_samples()

            # The real work happens in run_fuzzing_queries() (takes a while)
            return self.run_fuzzing_queries()

        except Exception as e:
            print(f"Test error: {e}")
            return False
        finally:
            self.cleanup()


if __name__ == "__main__":
    # Connection parameters - adjust as needed
    conn_params = {
        "host": "localhost",
        "database": "regression",
        "user": "pg",
    }

    # Create and run the tester
    tester = PostgreSQLSkipScanTester(
        conn_params=conn_params,
        table_name="skip_scan_test",
        num_rows=100_000, # rows in `table_name` table
        num_samples=10, # Number of plan samples to dump (previews test query structure)
        num_tests=5_000 # Number of test queries
    )

    success = tester.run_all_tests()

    if success:
        print("\n✅ All tests completed successfully")
    else:
        print("\n❌ Test failures detected")

Reply via email to