On Sun, Jun 14, 2009 at 05:59:58PM +0200, Petr Jelinek wrote:
> David Fetter wrote:
>> I was discussing this with Andrew Gierth in IRC, who thought that
>> putting RETURNING inside the WITH clause would be relatively easy, at
>> least for the parser and planner.  For the executor, he suggested that
>> one approach might be to make INSERT, UPDATE and DELETE into their own
>> nodes.
>
> David asked me to post his (and mine) experimental work in progress  
> patch for this here. The patch in the current state does not work. It  
> dies in executor on:
> ERROR:  attribute 1 has wrong type
> DETAIL:  Table has type tid, but query expects integer.
> Since I know nothing about postgres' executor I am only guessing it  
> thinks the query is SELECT instead of DELETE RETURNING.
> Also I think those query->commandType == CMD_SELECT ? query->targetList  
> : query->returningList in several places might not be the right way to 
> go.

I went another way in the attached patch, and thanks :)

> Anyway it's beginning and maybe somebody who knows what he is doing  
> could help or continue the work.

This patch fails regression tests and hangs or crashes when attempting
to do a writeable CTE.

Any help getting it into better shape would be greatly appreciated :)

Cheers,
David.
-- 
David Fetter <da...@fetter.org> http://fetter.org/
Phone: +1 415 235 3778  AIM: dfetter666  Yahoo!: dfetter
Skype: davidfetter      XMPP: david.fet...@gmail.com

Remember to vote!
Consider donating to Postgres: http://www.postgresql.org/about/donate
diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c
index 015dfdc..bcfaf06 100644
--- a/src/backend/nodes/nodeFuncs.c
+++ b/src/backend/nodes/nodeFuncs.c
@@ -2354,6 +2354,50 @@ bool
                                        return true;
                        }
                        break;
+               case T_InsertStmt:
+                       {
+                               InsertStmt *stmt = (InsertStmt *) node;
+
+                               if (walker(stmt->relation, context))
+                                       return true;
+                               if (walker(stmt->cols, context))
+                                       return true;
+                               if (walker(stmt->selectStmt, context))
+                                       return true;
+                               if (walker(stmt->returningList, context))
+                                       return true;
+                       }
+                       break;
+               case T_UpdateStmt:
+                       {
+                               UpdateStmt *stmt = (UpdateStmt *) node;
+
+                               if (walker(stmt->relation, context))
+                                       return true;
+                               if (walker(stmt->targetList, context))
+                                       return true;
+                               if (walker(stmt->whereClause, context))
+                                       return true;
+                               if (walker(stmt->fromClause, context))
+                                       return true;
+                               if (walker(stmt->returningList, context))
+                                       return true;
+                       }
+                       break;
+               case T_DeleteStmt:
+                       {
+                               DeleteStmt *stmt = (DeleteStmt *) node;
+
+                               if (walker(stmt->relation, context))
+                                       return true;
+                               if (walker(stmt->usingClause, context))
+                                       return true;
+                               if (walker(stmt->whereClause, context))
+                                       return true;
+                               if (walker(stmt->returningList, context))
+                                       return true;
+                       }
+                       break;
                case T_A_Expr:
                        {
                                A_Expr     *expr = (A_Expr *) node;
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 9a45355..9e66536 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -7028,7 +7028,8 @@ cte_list:
                | cte_list ',' common_table_expr                { $$ = 
lappend($1, $3); }
                ;
 
-common_table_expr:  name opt_name_list AS select_with_parens
+common_table_expr:
+        name opt_name_list AS select_with_parens
                        {
                                CommonTableExpr *n = makeNode(CommonTableExpr);
                                n->ctename = $1;
@@ -7037,6 +7038,33 @@ common_table_expr:  name opt_name_list AS 
select_with_parens
                                n->location = @1;
                                $$ = (Node *) n;
                        }
+        | name opt_name_list AS '(' InsertStmt ')'
+                       {
+                               CommonTableExpr *n = makeNode(CommonTableExpr);
+                               n->ctename = $1;
+                               n->aliascolnames = $2;
+                               n->ctequery = $5;
+                               n->location = @1;
+                               $$ = (Node *) n;
+                       }
+        | name opt_name_list AS '(' UpdateStmt ')'
+                       {
+                               CommonTableExpr *n = makeNode(CommonTableExpr);
+                               n->ctename = $1;
+                               n->aliascolnames = $2;
+                               n->ctequery = $5;
+                               n->location = @1;
+                               $$ = (Node *) n;
+                       }
+        | name opt_name_list AS '(' DeleteStmt ')'
+                       {
+                               CommonTableExpr *n = makeNode(CommonTableExpr);
+                               n->ctename = $1;
+                               n->aliascolnames = $2;
+                               n->ctequery = $5;
+                               n->location = @1;
+                               $$ = (Node *) n;
+                       }
                ;
 
 into_clause:
diff --git a/src/backend/parser/parse_cte.c b/src/backend/parser/parse_cte.c
index 988e8eb..2347b28 100644
--- a/src/backend/parser/parse_cte.c
+++ b/src/backend/parser/parse_cte.c
@@ -246,23 +246,40 @@ transformWithClause(ParseState *pstate, WithClause 
*withClause)
 static void
 analyzeCTE(ParseState *pstate, CommonTableExpr *cte)
 {
-       Query      *query;
+       Query           *query;
+       List            *ctelist;
 
        /* Analysis not done already */
-       Assert(IsA(cte->ctequery, SelectStmt));
+       /* This needs to be one of SelectStmt, InsertStmt, UpdateStmt, 
DeleteStmt instead of:
+        * Assert(IsA(cte->ctequery, SelectStmt)); */
 
        query = parse_sub_analyze(cte->ctequery, pstate);
        cte->ctequery = (Node *) query;
 
+       if (query->commandType == CMD_SELECT)
+               ctelist = query->targetList;
+       else
+       {
+               ctelist = query->returningList;
+       }
+
        /*
         * Check that we got something reasonable.      Many of these 
conditions are
         * impossible given restrictions of the grammar, but check 'em anyway.
-        * (These are the same checks as in transformRangeSubselect.)
+        * (In addition to the same checks as in transformRangeSubselect,
+        * this adds checks for (INSERT|UPDATE|DELETE)...RETURNING.)
         */
        if (!IsA(query, Query) ||
                query->commandType != CMD_SELECT ||
-               query->utilityStmt != NULL)
-               elog(ERROR, "unexpected non-SELECT command in subquery in 
WITH");
+               query->utilityStmt != NULL ||
+               ((query->commandType == CMD_INSERT ||
+                 query->commandType == CMD_UPDATE ||
+                 query->commandType == CMD_DELETE) &&
+                query->returningList == NULL))
+               ereport(ERROR,
+                               (errcode(ERRCODE_SYNTAX_ERROR),
+                                errmsg("unexpected non-row-returning command 
in subquery in WITH"),
+                                parser_errposition(pstate, 0)));
        if (query->intoClause)
                ereport(ERROR,
                                (errcode(ERRCODE_SYNTAX_ERROR),
@@ -273,7 +290,7 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte)
        if (!cte->cterecursive)
        {
                /* Compute the output column names/types if not done yet */
-               analyzeCTETargetList(pstate, cte, query->targetList);
+               analyzeCTETargetList(pstate, cte, ctelist);
        }
        else
        {
@@ -291,7 +308,7 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte)
                lctyp = list_head(cte->ctecoltypes);
                lctypmod = list_head(cte->ctecoltypmods);
                varattno = 0;
-               foreach(lctlist, query->targetList)
+               foreach(lctlist, ctelist)
                {
                        TargetEntry *te = (TargetEntry *) lfirst(lctlist);
                        Node       *texpr;
diff --git a/src/backend/parser/parse_target.c 
b/src/backend/parser/parse_target.c
index 08b8edb..9af7d91 100644
--- a/src/backend/parser/parse_target.c
+++ b/src/backend/parser/parse_target.c
@@ -310,10 +310,12 @@ markTargetListOrigin(ParseState *pstate, TargetEntry *tle,
                        {
                                CommonTableExpr *cte = GetCTEForRTE(pstate, 
rte, netlevelsup);
                                TargetEntry *ste;
+                               Query           *query;
 
                                /* should be analyzed by now */
                                Assert(IsA(cte->ctequery, Query));
-                               ste = get_tle_by_resno(((Query *) 
cte->ctequery)->targetList,
+                               query = (Query *) cte->ctequery;
+                               ste = get_tle_by_resno((query->commandType == 
CMD_SELECT) ? query->targetList : query->returningList,
                                                                           
attnum);
                                if (ste == NULL || ste->resjunk)
                                        elog(ERROR, "subquery %s does not have 
attribute %d",
@@ -1233,11 +1235,19 @@ expandRecordVariable(ParseState *pstate, Var *var, int 
levelsup)
                        {
                                CommonTableExpr *cte = GetCTEForRTE(pstate, 
rte, netlevelsup);
                                TargetEntry *ste;
+                               Query           *query;
+                               List            *ctelist;
 
                                /* should be analyzed by now */
                                Assert(IsA(cte->ctequery, Query));
-                               ste = get_tle_by_resno(((Query *) 
cte->ctequery)->targetList,
-                                                                          
attnum);
+                               query = (Query *) cte->ctequery;
+                               if (query->commandType == CMD_SELECT)
+                                       ctelist = query->targetList;
+                               else
+                               {
+                                       ctelist = query->returningList;
+                               }
+                               ste = get_tle_by_resno(ctelist, attnum);
                                if (ste == NULL || ste->resjunk)
                                        elog(ERROR, "subquery %s does not have 
attribute %d",
                                                 rte->eref->aliasname, attnum);
diff --git a/src/backend/utils/adt/ruleutils.c 
b/src/backend/utils/adt/ruleutils.c
index d302fb8..68c98d4 100644
--- a/src/backend/utils/adt/ruleutils.c
+++ b/src/backend/utils/adt/ruleutils.c
@@ -3800,9 +3800,17 @@ get_name_for_var_field(Var *var, int fieldno,
                                }
                                if (lc != NULL)
                                {
-                                       Query      *ctequery = (Query *) 
cte->ctequery;
-                                       TargetEntry *ste = 
get_tle_by_resno(ctequery->targetList,
-                                                                               
                                attnum);
+                                       Query           *ctequery = (Query *) 
cte->ctequery;
+                                       List            *ctelist;
+
+                                       if (ctequery->commandType == 
CMD_SELECT) 
+                                               ctelist = ctequery->targetList;
+                                       else
+                                       {
+                                               ctelist = 
ctequery->returningList;
+                                       }
+
+                                       TargetEntry *ste = 
get_tle_by_resno(ctelist, attnum);
 
                                        if (ste == NULL || ste->resjunk)
                                                elog(ERROR, "subquery %s does 
not have attribute %d",
diff --git a/src/test/regress/expected/with.out 
b/src/test/regress/expected/with.out
index 4a2f18c..cb603ca 100644
--- a/src/test/regress/expected/with.out
+++ b/src/test/regress/expected/with.out
@@ -912,3 +912,23 @@ ERROR:  recursive query "foo" column 1 has type 
numeric(3,0) in non-recursive te
 LINE 2:    (SELECT i::numeric(3,0) FROM (VALUES(1),(2)) t(i)
                    ^
 HINT:  Cast the output of the non-recursive term to the correct type.
+
+-- DELETE inside the CTE
+CREATE TEMPORARY TABLE t(i INTEGER);
+INSERT INTO t(i) SELECT * FROM generate_series(1,10);
+
+WITH RECURSIVE foo(i) AS (
+    DELETE FROM t RETURNING i
+)
+SELECT i FROM foo ORDER BY i;
+  1
+  2
+  3
+  4
+  5
+  6
+  7
+  8
+  9
+ 10
+(10 rows)
diff --git a/src/test/regress/sql/with.sql b/src/test/regress/sql/with.sql
index c736441..eb83aab 100644
--- a/src/test/regress/sql/with.sql
+++ b/src/test/regress/sql/with.sql
@@ -469,3 +469,12 @@ WITH RECURSIVE foo(i) AS
    UNION ALL
    SELECT (i+1)::numeric(10,0) FROM foo WHERE i < 10)
 SELECT * FROM foo;
+
+-- DELETE inside the CTE
+CREATE TEMPORARY TABLE t(i INTEGER);
+INSERT INTO t(i) SELECT * FROM generate_series(1,10);
+
+WITH RECURSIVE foo(i) AS (
+    DELETE FROM t RETURNING i
+)
+SELECT i FROM foo ORDER BY i;
-- 
Sent via pgsql-hackers mailing list (pgsql-hackers@postgresql.org)
To make changes to your subscription:
http://www.postgresql.org/mailpref/pgsql-hackers

Reply via email to