Dear all,

This is my first patch of my GSoC project, query tool automatic mode
detection.

In this patch, the initial (basic) version of the project is implemented.
In this version, query resultsets are updatable if and only if:
- All the columns belong to a single table
- No duplicate columns are available
- All the primary keys of the table are available

Inserts, updates and deletes work automatically when the resultset is
updatable.

The 'save' button in the query tool works automatically to save the changes
in the resultset if the query is the updatable, and saves the query to a
file otherwise. The 'save as' button stays as is.

I will work on improving and adding features to this version throughout my
work during the summer according to what has the highest priorities
(supporting duplicate columns or columns produced by functions or
aggregations as read-only columns in the results seems like a good next
move).

Please give me your feedback of the changes I made, and any hints or
comments that will improve my code in any aspect.

I also have a couple of questions,
- Should the save button in the query tool work the way I am using it now?
or should there be a new dedicated button for saving the query to a file?

- What documentations or unit tests should I write? any guidelines here
would be appreciated.

Thanks a lot!


-- 
*Yosry Muhammad Yosry*

Computer Engineering student,
The Faculty of Engineering,
Cairo University (2021).
Class representative of CMP 2021.
https://www.linkedin.com/in/yosrym93/
diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py
index 16f7133f..22050451 100644
--- a/web/pgadmin/tools/sqleditor/__init__.py
+++ b/web/pgadmin/tools/sqleditor/__init__.py
@@ -384,6 +384,8 @@ def poll(trans_id):
     rset = None
     has_oids = False
     oids = None
+    additional_messages = None
+    notifies = None
 
     # Check the transaction and connection status
     status, error_msg, conn, trans_obj, session_obj = \
@@ -422,6 +424,22 @@ def poll(trans_id):
 
             st, result = conn.async_fetchmany_2darray(ON_DEMAND_RECORD_COUNT)
 
+            # There may be additional messages even if result is present
+            # eg: Function can provide result as well as RAISE messages
+            messages = conn.messages()
+            if messages:
+                additional_messages = ''.join(messages)
+            notifies = conn.get_notifies()
+
+            # Procedure/Function output may comes in the form of Notices
+            # from the database server, so we need to append those outputs
+            # with the original result.
+            if result is None:
+                result = conn.status_message()
+                if (result != 'SELECT 1' or result != 'SELECT 0') and \
+                    result is not None and additional_messages:
+                    result = additional_messages + result
+
             if st:
                 if 'primary_keys' in session_obj:
                     primary_keys = session_obj['primary_keys']
@@ -438,10 +456,22 @@ def poll(trans_id):
                 )
                 session_obj['client_primary_key'] = client_primary_key
 
-                if columns_info is not None:
+                # If trans_obj is a QueryToolCommand then check for updatable
+                # resultsets and primary keys
+                if isinstance(trans_obj, QueryToolCommand):
+                    trans_obj.check_for_updatable_resultset_and_primary_keys()
+                    pk_names, primary_keys = trans_obj.get_primary_keys()
+                    # If primary_keys exist, add them to the session_obj to
+                    # allow for saving any changes to the data
+                    if primary_keys is not None:
+                        session_obj['primary_keys'] = primary_keys
 
-                    command_obj = pickle.loads(session_obj['command_obj'])
-                    if hasattr(command_obj, 'obj_id'):
+                if columns_info is not None:
+                    # If it is a QueryToolCommand that has obj_id attribute
+                    # then it should also be editable
+                    if hasattr(trans_obj, 'obj_id') and \
+                        (not isinstance(trans_obj, QueryToolCommand) or
+                         trans_obj.can_edit()):
                         # Get the template path for the column
                         template_path = 'columns/sql/#{0}#'.format(
                             conn.manager.version
@@ -449,7 +479,7 @@ def poll(trans_id):
 
                         SQL = render_template(
                             "/".join([template_path, 'nodes.sql']),
-                            tid=command_obj.obj_id,
+                            tid=trans_obj.obj_id,
                             has_oids=True
                         )
                         # rows with attribute not_null
@@ -520,26 +550,8 @@ def poll(trans_id):
         status = 'NotConnected'
         result = error_msg
 
-    # There may be additional messages even if result is present
-    # eg: Function can provide result as well as RAISE messages
-    additional_messages = None
-    notifies = None
-    if status == 'Success':
-        messages = conn.messages()
-        if messages:
-            additional_messages = ''.join(messages)
-        notifies = conn.get_notifies()
-
-    # Procedure/Function output may comes in the form of Notices from the
-    # database server, so we need to append those outputs with the
-    # original result.
-    if status == 'Success' and result is None:
-        result = conn.status_message()
-        if (result != 'SELECT 1' or result != 'SELECT 0') and \
-           result is not None and additional_messages:
-            result = additional_messages + result
-
     transaction_status = conn.transaction_status()
+
     return make_json_response(
         data={
             'status': status, 'result': result,
@@ -750,7 +762,6 @@ def save(trans_id):
                 return make_json_response(
                     data={'status': status, 'result': u"{}".format(msg)}
                 )
-
         status, res, query_res, _rowid = trans_obj.save(
             changed_data,
             session_obj['columns_info'],
diff --git a/web/pgadmin/tools/sqleditor/command.py b/web/pgadmin/tools/sqleditor/command.py
index d4b0700f..bccb6b38 100644
--- a/web/pgadmin/tools/sqleditor/command.py
+++ b/web/pgadmin/tools/sqleditor/command.py
@@ -19,6 +19,10 @@ from flask import render_template
 from flask_babelex import gettext
 from pgadmin.utils.ajax import forbidden
 from pgadmin.utils.driver import get_driver
+from pgadmin.tools.sqleditor.utils.constant_definition import ASYNC_OK
+from pgadmin.tools.sqleditor.utils.is_query_resultset_updatable \
+    import is_query_resultset_updatable
+from pgadmin.tools.sqleditor.utils.save_changed_data import save_changed_data
 
 from config import PG_DEFAULT_DRIVER
 
@@ -668,244 +672,11 @@ class TableCommand(GridCommand):
         else:
             conn = default_conn
 
-        status = False
-        res = None
-        query_res = dict()
-        count = 0
-        list_of_rowid = []
-        operations = ('added', 'updated', 'deleted')
-        list_of_sql = {}
-        _rowid = None
-
-        if conn.connected():
-
-            # Start the transaction
-            conn.execute_void('BEGIN;')
-
-            # Iterate total number of records to be updated/inserted
-            for of_type in changed_data:
-                # No need to go further if its not add/update/delete operation
-                if of_type not in operations:
-                    continue
-                # if no data to be save then continue
-                if len(changed_data[of_type]) < 1:
-                    continue
-
-                column_type = {}
-                column_data = {}
-                for each_col in columns_info:
-                    if (
-                        columns_info[each_col]['not_null'] and
-                        not columns_info[each_col]['has_default_val']
-                    ):
-                        column_data[each_col] = None
-                        column_type[each_col] =\
-                            columns_info[each_col]['type_name']
-                    else:
-                        column_type[each_col] = \
-                            columns_info[each_col]['type_name']
-
-                # For newly added rows
-                if of_type == 'added':
-                    # Python dict does not honour the inserted item order
-                    # So to insert data in the order, we need to make ordered
-                    # list of added index We don't need this mechanism in
-                    # updated/deleted rows as it does not matter in
-                    # those operations
-                    added_index = OrderedDict(
-                        sorted(
-                            changed_data['added_index'].items(),
-                            key=lambda x: int(x[0])
-                        )
-                    )
-                    list_of_sql[of_type] = []
-
-                    # When new rows are added, only changed columns data is
-                    # sent from client side. But if column is not_null and has
-                    # no_default_value, set column to blank, instead
-                    # of not null which is set by default.
-                    column_data = {}
-                    pk_names, primary_keys = self.get_primary_keys()
-                    has_oids = 'oid' in column_type
-
-                    for each_row in added_index:
-                        # Get the row index to match with the added rows
-                        # dict key
-                        tmp_row_index = added_index[each_row]
-                        data = changed_data[of_type][tmp_row_index]['data']
-                        # Remove our unique tracking key
-                        data.pop(client_primary_key, None)
-                        data.pop('is_row_copied', None)
-                        list_of_rowid.append(data.get(client_primary_key))
-
-                        # Update columns value with columns having
-                        # not_null=False and has no default value
-                        column_data.update(data)
-
-                        sql = render_template(
-                            "/".join([self.sql_path, 'insert.sql']),
-                            data_to_be_saved=column_data,
-                            primary_keys=None,
-                            object_name=self.object_name,
-                            nsp_name=self.nsp_name,
-                            data_type=column_type,
-                            pk_names=pk_names,
-                            has_oids=has_oids
-                        )
-
-                        select_sql = render_template(
-                            "/".join([self.sql_path, 'select.sql']),
-                            object_name=self.object_name,
-                            nsp_name=self.nsp_name,
-                            primary_keys=primary_keys,
-                            has_oids=has_oids
-                        )
-
-                        list_of_sql[of_type].append({
-                            'sql': sql, 'data': data,
-                            'client_row': tmp_row_index,
-                            'select_sql': select_sql
-                        })
-                        # Reset column data
-                        column_data = {}
-
-                # For updated rows
-                elif of_type == 'updated':
-                    list_of_sql[of_type] = []
-                    for each_row in changed_data[of_type]:
-                        data = changed_data[of_type][each_row]['data']
-                        pk = changed_data[of_type][each_row]['primary_keys']
-                        sql = render_template(
-                            "/".join([self.sql_path, 'update.sql']),
-                            data_to_be_saved=data,
-                            primary_keys=pk,
-                            object_name=self.object_name,
-                            nsp_name=self.nsp_name,
-                            data_type=column_type
-                        )
-                        list_of_sql[of_type].append({'sql': sql, 'data': data})
-                        list_of_rowid.append(data.get(client_primary_key))
-
-                # For deleted rows
-                elif of_type == 'deleted':
-                    list_of_sql[of_type] = []
-                    is_first = True
-                    rows_to_delete = []
-                    keys = None
-                    no_of_keys = None
-                    for each_row in changed_data[of_type]:
-                        rows_to_delete.append(changed_data[of_type][each_row])
-                        # Fetch the keys for SQL generation
-                        if is_first:
-                            # We need to covert dict_keys to normal list in
-                            # Python3
-                            # In Python2, it's already a list & We will also
-                            # fetch column names using index
-                            keys = list(
-                                changed_data[of_type][each_row].keys()
-                            )
-                            no_of_keys = len(keys)
-                            is_first = False
-                    # Map index with column name for each row
-                    for row in rows_to_delete:
-                        for k, v in row.items():
-                            # Set primary key with label & delete index based
-                            # mapped key
-                            try:
-                                row[changed_data['columns']
-                                    [int(k)]['name']] = v
-                            except ValueError:
-                                continue
-                            del row[k]
-
-                    sql = render_template(
-                        "/".join([self.sql_path, 'delete.sql']),
-                        data=rows_to_delete,
-                        primary_key_labels=keys,
-                        no_of_keys=no_of_keys,
-                        object_name=self.object_name,
-                        nsp_name=self.nsp_name
-                    )
-                    list_of_sql[of_type].append({'sql': sql, 'data': {}})
-
-            for opr, sqls in list_of_sql.items():
-                for item in sqls:
-                    if item['sql']:
-                        row_added = None
-
-                        # Fetch oids/primary keys
-                        if 'select_sql' in item and item['select_sql']:
-                            status, res = conn.execute_dict(
-                                item['sql'], item['data'])
-                        else:
-                            status, res = conn.execute_void(
-                                item['sql'], item['data'])
-
-                        if not status:
-                            conn.execute_void('ROLLBACK;')
-                            # If we roll backed every thing then update the
-                            # message for each sql query.
-                            for val in query_res:
-                                if query_res[val]['status']:
-                                    query_res[val]['result'] = \
-                                        'Transaction ROLLBACK'
-
-                            # If list is empty set rowid to 1
-                            try:
-                                if list_of_rowid:
-                                    _rowid = list_of_rowid[count]
-                                else:
-                                    _rowid = 1
-                            except Exception:
-                                _rowid = 0
-
-                            return status, res, query_res, _rowid
-
-                        # Select added row from the table
-                        if 'select_sql' in item:
-                            status, sel_res = conn.execute_dict(
-                                item['select_sql'], res['rows'][0])
-
-                            if not status:
-                                conn.execute_void('ROLLBACK;')
-                                # If we roll backed every thing then update
-                                # the message for each sql query.
-                                for val in query_res:
-                                    if query_res[val]['status']:
-                                        query_res[val]['result'] = \
-                                            'Transaction ROLLBACK'
-
-                                # If list is empty set rowid to 1
-                                try:
-                                    if list_of_rowid:
-                                        _rowid = list_of_rowid[count]
-                                    else:
-                                        _rowid = 1
-                                except Exception:
-                                    _rowid = 0
-
-                                return status, sel_res, query_res, _rowid
-
-                            if 'rows' in sel_res and len(sel_res['rows']) > 0:
-                                row_added = {
-                                    item['client_row']: sel_res['rows'][0]}
-
-                        rows_affected = conn.rows_affected()
-
-                        # store the result of each query in dictionary
-                        query_res[count] = {
-                            'status': status,
-                            'result': None if row_added else res,
-                            'sql': sql, 'rows_affected': rows_affected,
-                            'row_added': row_added
-                        }
-
-                        count += 1
-
-            # Commit the transaction if there is no error found
-            conn.execute_void('COMMIT;')
-
-        return status, res, query_res, _rowid
+        return save_changed_data(changed_data=changed_data,
+                                 columns_info=columns_info,
+                                 command_obj=self,
+                                 client_primary_key=client_primary_key,
+                                 conn=conn)
 
 
 class ViewCommand(GridCommand):
@@ -1089,18 +860,87 @@ class QueryToolCommand(BaseCommand, FetchedRowTracker):
         self.auto_rollback = False
         self.auto_commit = True
 
+        # Attributes needed to be able to edit updatable resultselts
+        self.is_updatable_resultset = False
+        self.primary_keys = None
+        self.pk_names = None
+
     def get_sql(self, default_conn=None):
         return None
 
     def get_all_columns_with_order(self, default_conn=None):
         return None
 
+    def get_primary_keys(self):
+        return self.pk_names, self.primary_keys
+
     def can_edit(self):
-        return False
+        return self.is_updatable_resultset
 
     def can_filter(self):
         return False
 
+    def check_for_updatable_resultset_and_primary_keys(self):
+        """
+            This function is used to check whether the last successful query
+            produced updatable results and sets the necessary flags and
+            attributes accordingly
+        """
+        # Fetch the connection object
+        driver = get_driver(PG_DEFAULT_DRIVER)
+        manager = driver.connection_manager(self.sid)
+        conn = manager.connection(did=self.did, conn_id=self.conn_id)
+
+        # Check that the query results are ready first
+        status, result = conn.poll(
+            formatted_exception_msg=True, no_result=True)
+        if status != ASYNC_OK:
+            return
+
+        # Get the path to the sql templates
+        sql_path = 'sqleditor/sql/#{0}#'.format(manager.version)
+
+        self.is_updatable_resultset, self.primary_keys, pk_names, table_oid = \
+            is_query_resultset_updatable(conn, sql_path)
+
+        # Create pk_names attribute in the required format
+        if pk_names is not None:
+            self.pk_names = ''
+
+            for pk_name in pk_names:
+                self.pk_names += driver.qtIdent(conn, pk_name) + ','
+
+            if self.pk_names != '':
+                # Remove last character from the string
+                self.pk_names = self.pk_names[:-1]
+
+        # Add attributes required to be able to update table data
+        if self.is_updatable_resultset:
+            self.__set_updatable_resultset_attributes(sql_path=sql_path,
+                                                      table_oid=table_oid,
+                                                      conn=conn)
+
+    def save(self,
+             changed_data,
+             columns_info,
+             client_primary_key='__temp_PK',
+             default_conn=None):
+        if not self.is_updatable_resultset:
+            return False, gettext('The resultset is not updatable.'), None, None
+        else:
+            driver = get_driver(PG_DEFAULT_DRIVER)
+            if default_conn is None:
+                manager = driver.connection_manager(self.sid)
+                conn = manager.connection(did=self.did, conn_id=self.conn_id)
+            else:
+                conn = default_conn
+
+            return save_changed_data(changed_data=changed_data,
+                                     columns_info=columns_info,
+                                     conn=conn,
+                                     command_obj=self,
+                                     client_primary_key=client_primary_key)
+
     def set_connection_id(self, conn_id):
         self.conn_id = conn_id
 
@@ -1109,3 +949,29 @@ class QueryToolCommand(BaseCommand, FetchedRowTracker):
 
     def set_auto_commit(self, auto_commit):
         self.auto_commit = auto_commit
+
+    def __set_updatable_resultset_attributes(self, sql_path,
+                                             table_oid, conn):
+        # Set template path for sql scripts and the table object id
+        self.sql_path = sql_path
+        self.obj_id = table_oid
+
+        if conn.connected():
+            # Fetch the Namespace Name and object Name
+            query = render_template(
+                "/".join([self.sql_path, 'objectname.sql']),
+                obj_id=self.obj_id
+            )
+
+            status, result = conn.execute_dict(query)
+            if not status:
+                raise Exception(result)
+
+            self.nsp_name = result['rows'][0]['nspname']
+            self.object_name = result['rows'][0]['relname']
+        else:
+            raise Exception(gettext(
+                'Not connected to server or connection with the server '
+                'has been closed.')
+            )
+
diff --git a/web/pgadmin/tools/sqleditor/static/js/sqleditor.js b/web/pgadmin/tools/sqleditor/static/js/sqleditor.js
index 6af098b4..3bccd447 100644
--- a/web/pgadmin/tools/sqleditor/static/js/sqleditor.js
+++ b/web/pgadmin/tools/sqleditor/static/js/sqleditor.js
@@ -2376,6 +2376,18 @@ define('tools.querytool', [
         else
           self.can_edit = true;
 
+        /* If the query results are updatable then keep track of newly added
+         * rows
+         */
+        if (self.is_query_tool && self.can_edit) {
+          // keep track of newly added rows
+          self.rows_to_disable = new Array();
+          // Temporarily hold new rows added
+          self.temp_new_rows = new Array();
+          self.has_more_rows = false;
+          self.fetching_rows = false;
+        }
+
         /* If user can filter the data then we should enabled
          * Filter and Limit buttons.
          */
@@ -2818,12 +2830,15 @@ define('tools.querytool', [
        * the ajax call to save the data into the database server.
        * and will open save file dialog conditionally.
        */
-      _save: function(view, controller, save_as) {
+      _save: function(view, controller, save_as=false) {
         var self = this,
           save_data = true;
 
-        // Open save file dialog if query tool
-        if (self.is_query_tool) {
+        // Open save file dialog if query tool and:
+        // - results are not editable
+        // or
+        // - 'save as' is pressed instead of 'save'
+        if (self.is_query_tool && (!self.can_edit || save_as)) {
           var current_file = self.gridView.current_file;
           if (!_.isUndefined(current_file) && !save_as) {
             self._save_file_handler(current_file);
diff --git a/web/pgadmin/tools/sqleditor/templates/sqleditor/sql/default/primary_keys.sql b/web/pgadmin/tools/sqleditor/templates/sqleditor/sql/default/primary_keys.sql
index 60d0e56f..a96c928f 100644
--- a/web/pgadmin/tools/sqleditor/templates/sqleditor/sql/default/primary_keys.sql
+++ b/web/pgadmin/tools/sqleditor/templates/sqleditor/sql/default/primary_keys.sql
@@ -1,8 +1,8 @@
 {# ============= Fetch the primary keys for given object id ============= #}
 {% if obj_id %}
-SELECT at.attname, ty.typname
+SELECT at.attname, at.attnum, ty.typname
 FROM pg_attribute at LEFT JOIN pg_type ty ON (ty.oid = at.atttypid)
 WHERE attrelid={{obj_id}}::oid AND attnum = ANY (
     (SELECT con.conkey FROM pg_class rel LEFT OUTER JOIN pg_constraint con ON con.conrelid=rel.oid
     AND con.contype='p' WHERE rel.relkind IN ('r','s','t') AND rel.oid = {{obj_id}}::oid)::oid[])
-{% endif %}
\ No newline at end of file
+{% endif %}
diff --git a/web/pgadmin/tools/sqleditor/utils/is_query_resultset_updatable.py b/web/pgadmin/tools/sqleditor/utils/is_query_resultset_updatable.py
new file mode 100644
index 00000000..ed60f1e9
--- /dev/null
+++ b/web/pgadmin/tools/sqleditor/utils/is_query_resultset_updatable.py
@@ -0,0 +1,79 @@
+##########################################################################
+#
+# pgAdmin 4 - PostgreSQL Tools
+#
+# Copyright (C) 2013 - 2019, The pgAdmin Development Team
+# This software is released under the PostgreSQL Licence
+#
+##########################################################################
+
+"""
+    Check if the result-set of a query is updatable, A resultset is
+    updatable (as of this version) if:
+        - All columns belong to the same table.
+        - All the primary key columns of the table are present in the resultset
+        - No duplicate columns
+"""
+from flask import render_template
+try:
+    from collections import OrderedDict
+except ImportError:
+    from ordereddict import OrderedDict
+
+
+def is_query_resultset_updatable(conn, sql_path):
+    """
+        This function is used to check whether the last successful query
+        produced updatable results.
+
+        Args:
+            conn: Connection object.
+            sql_path: the path to the sql templates.
+    """
+    columns_info = conn.get_column_info()
+    # Fetch the column info
+    if len(columns_info) < 1:
+        return False, None, None, None
+
+    # First check that all the columns belong to a single table
+    table_oid = columns_info[0]['table_oid']
+    column_numbers = []
+    for column in columns_info:
+        if column['table_oid'] != table_oid:
+            return False, None, None, None
+        else:
+            column_numbers.append(column['table_column'])
+
+    # Check for duplicate columns
+    is_duplicate_columns = len(column_numbers) != len(set(column_numbers))
+    if is_duplicate_columns:
+        return False, None, None, None
+
+    if conn.connected():
+        # Then check that all the primary keys of the table are present
+        query = render_template(
+            "/".join([sql_path, 'primary_keys.sql']),
+            obj_id=table_oid
+        )
+        status, result = conn.execute_dict(query)
+        if not status:
+            return False, None, None, None
+
+        primary_keys_column_numbers = []
+        primary_keys = OrderedDict()
+        pk_names = []
+
+        for row in result['rows']:
+            primary_keys[row['attname']] = row['typname']
+            primary_keys_column_numbers.append(row['attnum'])
+            pk_names.append(row['attname'])
+
+        all_primary_keys_exist = all(elem in column_numbers
+                                     for elem in primary_keys_column_numbers)
+    else:
+        return False, None, None, None
+
+    if all_primary_keys_exist:
+        return True, primary_keys, pk_names, table_oid
+    else:
+        return False, None, None, None
diff --git a/web/pgadmin/tools/sqleditor/utils/save_changed_data.py b/web/pgadmin/tools/sqleditor/utils/save_changed_data.py
new file mode 100644
index 00000000..f22c5da3
--- /dev/null
+++ b/web/pgadmin/tools/sqleditor/utils/save_changed_data.py
@@ -0,0 +1,268 @@
+##########################################################################
+#
+# pgAdmin 4 - PostgreSQL Tools
+#
+# Copyright (C) 2013 - 2019, The pgAdmin Development Team
+# This software is released under the PostgreSQL Licence
+#
+##########################################################################
+
+from flask import render_template
+try:
+    from collections import OrderedDict
+except ImportError:
+    from ordereddict import OrderedDict
+
+
+def save_changed_data(changed_data, columns_info, conn, command_obj,
+                      client_primary_key):
+    """
+    This function is used to save the data into the database.
+    Depending on condition it will either update or insert the
+    new row into the database.
+
+    Args:
+        changed_data: Contains data to be saved
+        command_obj: The transaction object (command_obj or trans_obj)
+        conn: The connection object
+        columns_info:
+        client_primary_key:
+    """
+    status = False
+    res = None
+    query_res = dict()
+    count = 0
+    list_of_rowid = []
+    operations = ('added', 'updated', 'deleted')
+    list_of_sql = {}
+    _rowid = None
+
+    if conn.connected():
+
+        # Start the transaction
+        conn.execute_void('BEGIN;')
+
+        # Iterate total number of records to be updated/inserted
+        for of_type in changed_data:
+            # No need to go further if its not add/update/delete operation
+            if of_type not in operations:
+                continue
+            # if no data to be save then continue
+            if len(changed_data[of_type]) < 1:
+                continue
+
+            column_type = {}
+            column_data = {}
+            for each_col in columns_info:
+                if (
+                    columns_info[each_col]['not_null'] and
+                    not columns_info[each_col]['has_default_val']
+                ):
+                    column_data[each_col] = None
+                    column_type[each_col] = \
+                        columns_info[each_col]['type_name']
+                else:
+                    column_type[each_col] = \
+                        columns_info[each_col]['type_name']
+
+            # For newly added rows
+            if of_type == 'added':
+                # Python dict does not honour the inserted item order
+                # So to insert data in the order, we need to make ordered
+                # list of added index We don't need this mechanism in
+                # updated/deleted rows as it does not matter in
+                # those operations
+                added_index = OrderedDict(
+                    sorted(
+                        changed_data['added_index'].items(),
+                        key=lambda x: int(x[0])
+                    )
+                )
+                list_of_sql[of_type] = []
+
+                # When new rows are added, only changed columns data is
+                # sent from client side. But if column is not_null and has
+                # no_default_value, set column to blank, instead
+                # of not null which is set by default.
+                column_data = {}
+                pk_names, primary_keys = command_obj.get_primary_keys()
+                has_oids = 'oid' in column_type
+
+                for each_row in added_index:
+                    # Get the row index to match with the added rows
+                    # dict key
+                    tmp_row_index = added_index[each_row]
+                    data = changed_data[of_type][tmp_row_index]['data']
+                    # Remove our unique tracking key
+                    data.pop(client_primary_key, None)
+                    data.pop('is_row_copied', None)
+                    list_of_rowid.append(data.get(client_primary_key))
+
+                    # Update columns value with columns having
+                    # not_null=False and has no default value
+                    column_data.update(data)
+
+                    sql = render_template(
+                        "/".join([command_obj.sql_path, 'insert.sql']),
+                        data_to_be_saved=column_data,
+                        primary_keys=None,
+                        object_name=command_obj.object_name,
+                        nsp_name=command_obj.nsp_name,
+                        data_type=column_type,
+                        pk_names=pk_names,
+                        has_oids=has_oids
+                    )
+
+                    select_sql = render_template(
+                        "/".join([command_obj.sql_path, 'select.sql']),
+                        object_name=command_obj.object_name,
+                        nsp_name=command_obj.nsp_name,
+                        primary_keys=primary_keys,
+                        has_oids=has_oids
+                    )
+
+                    list_of_sql[of_type].append({
+                        'sql': sql, 'data': data,
+                        'client_row': tmp_row_index,
+                        'select_sql': select_sql
+                    })
+                    # Reset column data
+                    column_data = {}
+
+            # For updated rows
+            elif of_type == 'updated':
+                list_of_sql[of_type] = []
+                for each_row in changed_data[of_type]:
+                    data = changed_data[of_type][each_row]['data']
+                    pk = changed_data[of_type][each_row]['primary_keys']
+                    sql = render_template(
+                        "/".join([command_obj.sql_path, 'update.sql']),
+                        data_to_be_saved=data,
+                        primary_keys=pk,
+                        object_name=command_obj.object_name,
+                        nsp_name=command_obj.nsp_name,
+                        data_type=column_type
+                    )
+                    list_of_sql[of_type].append({'sql': sql, 'data': data})
+                    list_of_rowid.append(data.get(client_primary_key))
+
+            # For deleted rows
+            elif of_type == 'deleted':
+                list_of_sql[of_type] = []
+                is_first = True
+                rows_to_delete = []
+                keys = None
+                no_of_keys = None
+                for each_row in changed_data[of_type]:
+                    rows_to_delete.append(changed_data[of_type][each_row])
+                    # Fetch the keys for SQL generation
+                    if is_first:
+                        # We need to covert dict_keys to normal list in
+                        # Python3
+                        # In Python2, it's already a list & We will also
+                        # fetch column names using index
+                        keys = list(
+                            changed_data[of_type][each_row].keys()
+                        )
+                        no_of_keys = len(keys)
+                        is_first = False
+                # Map index with column name for each row
+                for row in rows_to_delete:
+                    for k, v in row.items():
+                        # Set primary key with label & delete index based
+                        # mapped key
+                        try:
+                            row[changed_data['columns']
+                                            [int(k)]['name']] = v
+                        except ValueError:
+                            continue
+                        del row[k]
+
+                sql = render_template(
+                    "/".join([command_obj.sql_path, 'delete.sql']),
+                    data=rows_to_delete,
+                    primary_key_labels=keys,
+                    no_of_keys=no_of_keys,
+                    object_name=command_obj.object_name,
+                    nsp_name=command_obj.nsp_name
+                )
+                list_of_sql[of_type].append({'sql': sql, 'data': {}})
+
+        for opr, sqls in list_of_sql.items():
+            for item in sqls:
+                if item['sql']:
+                    row_added = None
+
+                    # Fetch oids/primary keys
+                    if 'select_sql' in item and item['select_sql']:
+                        status, res = conn.execute_dict(
+                            item['sql'], item['data'])
+                    else:
+                        status, res = conn.execute_void(
+                            item['sql'], item['data'])
+
+                    if not status:
+                        conn.execute_void('ROLLBACK;')
+                        # If we roll backed every thing then update the
+                        # message for each sql query.
+                        for val in query_res:
+                            if query_res[val]['status']:
+                                query_res[val]['result'] = \
+                                    'Transaction ROLLBACK'
+
+                        # If list is empty set rowid to 1
+                        try:
+                            if list_of_rowid:
+                                _rowid = list_of_rowid[count]
+                            else:
+                                _rowid = 1
+                        except Exception:
+                            _rowid = 0
+
+                        return status, res, query_res, _rowid
+
+                    # Select added row from the table
+                    if 'select_sql' in item:
+                        status, sel_res = conn.execute_dict(
+                            item['select_sql'], res['rows'][0])
+
+                        if not status:
+                            conn.execute_void('ROLLBACK;')
+                            # If we roll backed every thing then update
+                            # the message for each sql query.
+                            for val in query_res:
+                                if query_res[val]['status']:
+                                    query_res[val]['result'] = \
+                                        'Transaction ROLLBACK'
+
+                            # If list is empty set rowid to 1
+                            try:
+                                if list_of_rowid:
+                                    _rowid = list_of_rowid[count]
+                                else:
+                                    _rowid = 1
+                            except Exception:
+                                _rowid = 0
+
+                            return status, sel_res, query_res, _rowid
+
+                        if 'rows' in sel_res and len(sel_res['rows']) > 0:
+                            row_added = {
+                                item['client_row']: sel_res['rows'][0]}
+
+                    rows_affected = conn.rows_affected()
+
+                    # store the result of each query in dictionary
+                    query_res[count] = {
+                        'status': status,
+                        'result': None if row_added else res,
+                        'sql': sql, 'rows_affected': rows_affected,
+                        'row_added': row_added
+                    }
+
+                    count += 1
+
+        # Commit the transaction if there is no error found
+        conn.execute_void('COMMIT;')
+
+    return status, res, query_res, _rowid
diff --git a/web/pgadmin/tools/sqleditor/utils/start_running_query.py b/web/pgadmin/tools/sqleditor/utils/start_running_query.py
index a5399774..ece11f9c 100644
--- a/web/pgadmin/tools/sqleditor/utils/start_running_query.py
+++ b/web/pgadmin/tools/sqleditor/utils/start_running_query.py
@@ -45,6 +45,9 @@ class StartRunningQuery:
         if type(session_obj) is Response:
             return session_obj
 
+        # Remove any existing primary keys in session_obj
+        session_obj.pop('primary_keys', None)
+
         transaction_object = pickle.loads(session_obj['command_obj'])
         can_edit = False
         can_filter = False

Reply via email to