From da5be2463902c1b97e5d6299f61fc71d0b90126b Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Fri, 30 Dec 2022 10:35:36 +0800
Subject: [PATCH v5] Check lateral refs within PHVs for memoize cache keys

---
 .../postgres_fdw/expected/postgres_fdw.out    | 12 ++-
 src/backend/optimizer/path/joinpath.c         | 85 ++++++++++++++++++-
 src/test/regress/expected/memoize.out         | 68 +++++++++++++++
 src/test/regress/sql/memoize.sql              | 24 ++++++
 4 files changed, 181 insertions(+), 8 deletions(-)

diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out
index b5a38aeb21..c25fd9d076 100644
--- a/contrib/postgres_fdw/expected/postgres_fdw.out
+++ b/contrib/postgres_fdw/expected/postgres_fdw.out
@@ -3696,15 +3696,19 @@ ORDER BY ref_0."C 1";
          ->  Index Scan using t1_pkey on "S 1"."T 1" ref_0
                Output: ref_0."C 1", ref_0.c2, ref_0.c3, ref_0.c4, ref_0.c5, ref_0.c6, ref_0.c7, ref_0.c8
                Index Cond: (ref_0."C 1" < 10)
-         ->  Foreign Scan on public.ft1 ref_1
-               Output: ref_1.c3, ref_0.c2
-               Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
+         ->  Memoize
+               Output: ref_1.c3, (ref_0.c2)
+               Cache Key: ref_0.c2
+               Cache Mode: binary
+               ->  Foreign Scan on public.ft1 ref_1
+                     Output: ref_1.c3, ref_0.c2
+                     Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
    ->  Materialize
          Output: ref_3.c3
          ->  Foreign Scan on public.ft2 ref_3
                Output: ref_3.c3
                Remote SQL: SELECT c3 FROM "S 1"."T 1" WHERE ((c3 = '00001'))
-(15 rows)
+(19 rows)
 
 SELECT ref_0.c2, subq_1.*
 FROM
diff --git a/src/backend/optimizer/path/joinpath.c b/src/backend/optimizer/path/joinpath.c
index 6aca66f196..e4dae8e534 100644
--- a/src/backend/optimizer/path/joinpath.c
+++ b/src/backend/optimizer/path/joinpath.c
@@ -23,6 +23,7 @@
 #include "optimizer/optimizer.h"
 #include "optimizer/pathnode.h"
 #include "optimizer/paths.h"
+#include "optimizer/placeholder.h"
 #include "optimizer/planmain.h"
 #include "utils/typcache.h"
 
@@ -437,10 +438,11 @@ have_unsafe_outer_join_ref(PlannerInfo *root,
 static bool
 paraminfo_get_equal_hashops(PlannerInfo *root, ParamPathInfo *param_info,
 							RelOptInfo *outerrel, RelOptInfo *innerrel,
-							List **param_exprs, List **operators,
-							bool *binary_mode)
+							List *ph_lateral_vars, List **param_exprs,
+							List **operators, bool *binary_mode)
 
 {
+	List	   *lateral_vars;
 	ListCell   *lc;
 
 	*param_exprs = NIL;
@@ -520,7 +522,8 @@ paraminfo_get_equal_hashops(PlannerInfo *root, ParamPathInfo *param_info,
 	}
 
 	/* Now add any lateral vars to the cache key too */
-	foreach(lc, innerrel->lateral_vars)
+	lateral_vars = list_concat(ph_lateral_vars, innerrel->lateral_vars);
+	foreach(lc, lateral_vars)
 	{
 		Node	   *expr = (Node *) lfirst(lc);
 		TypeCacheEntry *typentry;
@@ -571,6 +574,71 @@ paraminfo_get_equal_hashops(PlannerInfo *root, ParamPathInfo *param_info,
 	return true;
 }
 
+/*
+ * extract_lateral_vars_from_PHVs
+ *	  Extract lateral references within PlaceHolderVars that are due to be
+ *	  evaluated at 'innerrelids'.
+ */
+static List *
+extract_lateral_vars_from_PHVs(PlannerInfo *root, Relids innerrelids)
+{
+	List	   *ph_lateral_vars = NIL;
+	ListCell   *lc;
+
+	/* Nothing would be found if the query contains no LATERAL RTEs */
+	if (!root->hasLateralRTEs)
+		return NIL;
+
+	foreach(lc, root->placeholder_list)
+	{
+		PlaceHolderInfo *phinfo = (PlaceHolderInfo *) lfirst(lc);
+		List	   *vars;
+		ListCell   *cell;
+
+		/* PHV is uninteresting if no lateral refs */
+		if (phinfo->ph_lateral == NULL)
+			continue;
+
+		/* PHV is uninteresting if not due to be evaluated at innerrelids */
+		if (!bms_equal(phinfo->ph_eval_at, innerrelids))
+			continue;
+
+		/* Fetch Vars and PHVs of lateral references within PlaceHolderVars */
+		vars = pull_vars_of_level((Node *) phinfo->ph_var->phexpr, 0);
+		foreach(cell, vars)
+		{
+			Node	   *node = (Node *) lfirst(cell);
+
+			node = copyObject(node);
+			if (IsA(node, Var))
+			{
+				Var		   *var = (Var *) node;
+
+				Assert(var->varlevelsup == 0);
+
+				if (bms_is_member(var->varno, phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else if (IsA(node, PlaceHolderVar))
+			{
+				PlaceHolderVar *phv = (PlaceHolderVar *) node;
+
+				Assert(phv->phlevelsup == 0);
+
+				if (bms_is_subset(find_placeholder_info(root, phv)->ph_eval_at,
+								  phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else
+				Assert(false);
+		}
+
+		list_free(vars);
+	}
+
+	return ph_lateral_vars;
+}
+
 /*
  * get_memoize_path
  *		If possible, make and return a Memoize path atop of 'inner_path'.
@@ -586,6 +654,7 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	List	   *hash_operators;
 	ListCell   *lc;
 	bool		binary_mode;
+	List	   *ph_lateral_vars;
 
 	/* Obviously not if it's disabled */
 	if (!enable_memoize)
@@ -600,6 +669,12 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	if (outer_path->parent->rows < 2)
 		return NULL;
 
+	/*
+	 * Extract lateral Vars within PlaceHolderVars that are due to be evaluated
+	 * at innerrel.  These lateral Vars could be used as memoize cache keys.
+	 */
+	ph_lateral_vars = extract_lateral_vars_from_PHVs(root, innerrel->relids);
+
 	/*
 	 * We can only have a memoize node when there's some kind of cache key,
 	 * either parameterized path clauses or lateral Vars.  No cache key sounds
@@ -607,7 +682,8 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 	 */
 	if ((inner_path->param_info == NULL ||
 		 inner_path->param_info->ppi_clauses == NIL) &&
-		innerrel->lateral_vars == NIL)
+		innerrel->lateral_vars == NIL &&
+		ph_lateral_vars == NIL)
 		return NULL;
 
 	/*
@@ -694,6 +770,7 @@ get_memoize_path(PlannerInfo *root, RelOptInfo *innerrel,
 									outerrel->top_parent ?
 									outerrel->top_parent : outerrel,
 									innerrel,
+									ph_lateral_vars,
 									&param_exprs,
 									&hash_operators,
 									&binary_mode))
diff --git a/src/test/regress/expected/memoize.out b/src/test/regress/expected/memoize.out
index cf6886a288..933d7cf448 100644
--- a/src/test/regress/expected/memoize.out
+++ b/src/test/regress/expected/memoize.out
@@ -129,6 +129,74 @@ WHERE t1.unique1 < 10;
     20 | 0.50000000000000000000
 (1 row)
 
+-- Try with LATERAL references within PlaceHolderVars at a baserel
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+                                      explain_memoize                                      
+-------------------------------------------------------------------------------------------
+ Aggregate (actual rows=1 loops=N)
+   ->  Nested Loop (actual rows=1000 loops=N)
+         ->  Seq Scan on tenk1 t1 (actual rows=1000 loops=N)
+               Filter: (unique1 < 1000)
+               Rows Removed by Filter: 9000
+         ->  Memoize (actual rows=1 loops=N)
+               Cache Key: t1.twenty
+               Cache Mode: binary
+               Hits: 980  Misses: 20  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Index Only Scan using tenk1_unique1 on tenk1 t2 (actual rows=1 loops=N)
+                     Filter: (t1.twenty = unique1)
+                     Rows Removed by Filter: 9999
+                     Heap Fetches: N
+(13 rows)
+
+-- And check we get the expected results.
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+ count |        avg         
+-------+--------------------
+  1000 | 9.5000000000000000
+(1 row)
+
+-- Try with LATERAL references within PlaceHolderVars at a joinrel
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2, tenk1 t3
+         WHERE t3.unique1 = 1) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+                                           explain_memoize                                           
+-----------------------------------------------------------------------------------------------------
+ Aggregate (actual rows=1 loops=N)
+   ->  Nested Loop (actual rows=1000 loops=N)
+         ->  Seq Scan on tenk1 t1 (actual rows=1000 loops=N)
+               Filter: (unique1 < 1000)
+               Rows Removed by Filter: 9000
+         ->  Memoize (actual rows=1 loops=N)
+               Cache Key: t1.twenty
+               Cache Mode: binary
+               Hits: 980  Misses: 20  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Nested Loop (actual rows=1 loops=N)
+                     Join Filter: (t1.twenty = t2.unique1)
+                     Rows Removed by Join Filter: 9999
+                     ->  Index Only Scan using tenk1_unique1 on tenk1 t3 (actual rows=1 loops=N)
+                           Index Cond: (unique1 = 1)
+                           Heap Fetches: N
+                     ->  Index Only Scan using tenk1_unique1 on tenk1 t2 (actual rows=10000 loops=N)
+                           Heap Fetches: N
+(17 rows)
+
+-- And check we get the expected results.
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2, tenk1 t3
+         WHERE t3.unique1 = 1) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+ count |        avg         
+-------+--------------------
+  1000 | 9.5000000000000000
+(1 row)
+
 -- Reduce work_mem and hash_mem_multiplier so that we see some cache evictions
 SET work_mem TO '64kB';
 SET hash_mem_multiplier TO 1.0;
diff --git a/src/test/regress/sql/memoize.sql b/src/test/regress/sql/memoize.sql
index 1f4ab0ba3b..1ec96a520c 100644
--- a/src/test/regress/sql/memoize.sql
+++ b/src/test/regress/sql/memoize.sql
@@ -74,6 +74,30 @@ LATERAL (
 ON t1.two = t2.two
 WHERE t1.unique1 < 10;
 
+-- Try with LATERAL references within PlaceHolderVars at a baserel
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+
+-- And check we get the expected results.
+SELECT COUNT(*), AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+
+-- Try with LATERAL references within PlaceHolderVars at a joinrel
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2, tenk1 t3
+         WHERE t3.unique1 = 1) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+
+-- And check we get the expected results.
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2, tenk1 t3
+         WHERE t3.unique1 = 1) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+
 -- Reduce work_mem and hash_mem_multiplier so that we see some cache evictions
 SET work_mem TO '64kB';
 SET hash_mem_multiplier TO 1.0;
-- 
2.31.0

