Changeset: e870c4975155 for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB?cmd=changeset;node=e870c4975155
Modified Files:
        testing/sqltest.py
Branch: mtest
Log Message:

add monetdbe to sqltest


diffs (287 lines):

diff --git a/testing/sqltest.py b/testing/sqltest.py
--- a/testing/sqltest.py
+++ b/testing/sqltest.py
@@ -7,6 +7,7 @@ import os
 import sys
 import unittest
 import pymonetdb
+import monetdbe
 import difflib
 from abc import ABCMeta, abstractmethod
 import MonetDBtesting.process as process
@@ -86,7 +87,6 @@ class PyMonetDBConnectionContext(object)
         self.database = database
         self.language = language
         self.dbh = None
-        self.crs = None
         self.language = language
 
     def __enter__(self):
@@ -98,7 +98,6 @@ class PyMonetDBConnectionContext(object)
                                      port=self.port,
                                      database=self.database,
                                      autocommit=True)
-            self.crs = self.dbh.cursor()
         else:
             self.dbh = malmapi.Connection()
             self.dbh.connect(
@@ -108,22 +107,25 @@ class PyMonetDBConnectionContext(object)
                              port=self.port,
                              database=self.database,
                              language=self.language)
-            self.crs = MapiCursor(self.dbh)
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         self.close()
 
+    def cursor(self):
+        if self.language == 'sql':
+            return self.dbh.cursor()
+        else:
+            return MapiCursor(self.dbh)
+
     def close(self):
-        if self.crs:
-            self.crs.close()
-            self.crs = None
         if self.dbh:
             self.dbh.close()
             self.dbh = None
 
 class RunnableTestResult(metaclass=ABCMeta):
     """Abstract class for sql result"""
+    did_run = False
 
     @abstractmethod
     def run(self, query:str, *args, stdin=None):
@@ -172,7 +174,7 @@ class TestCaseResult(object):
             print('', file=err_file)
 
     def assertFailed(self, err_code=None, err_message=None):
-        """assert on query failed with optional err_code if provided"""
+        """assert on query failed"""
         if self.test_run_error is None:
             msg = "expected to fail but didn't"
             self.fail(msg)
@@ -205,7 +207,47 @@ class TestCaseResult(object):
         return self
 
     def assertResultHashTo(self, hash_value):
-        raise NotImplementedError()
+        raise NotImplementedError
+
+    def assertValue(self, row, col, val):
+        """assert on a value matched against row, col in the result"""
+        received = None
+        row = int(row)
+        col = int(col)
+        try:
+            received = self.data[row][col]
+        except IndexError:
+            pass
+        if type(val) is type(received):
+            if val != received:
+                msg = 'expected "{}", received "{}" in row={}, 
col={}'.format(val, received, row, col)
+                self.fail(msg, data=self.data)
+        else:
+            # handle type mismatch
+            msg = 'expected type {} and value "{}", received type {} and value 
"{}" in row={}, col={}'.format(type(val), str(val), type(received), 
str(received), row, col)
+            self.fail(msg, data=self.data)
+        return self
+
+    def assertDataResultMatch(self, data=[], index=None):
+        """Assert on a match of a subset of the result. When index is provided 
it
+        starts comparig from that row index onward.
+        """
+        def mapfn(next):
+            if type(next) is list:
+                return tuple(next)
+            return next
+        data = list(map(mapfn, data))
+        if index is None:
+            if len(data) > 0:
+                first = data[0]
+                for i, v in enumerate(self.data):
+                    if first == v:
+                        index = i
+                        break
+        if not sequence_match(data, self.data, index):
+            msg = '{}\nexpected to match query result starting at index={}, 
but it didn\'t'.format(piped_representation(data), index)
+            self.fail(msg, data=self.data)
+        return self
 
 class MclientTestResult(TestCaseResult, RunnableTestResult):
     """Holder of a sql execution result as returned from mclinet"""
@@ -289,6 +331,11 @@ class MclientTestResult(TestCaseResult, 
             self.fail(msg)
         return self
 
+    def assertDataResultMatch(self, data=[], index=None):
+        raise NotImplementedError
+
+    def assertValue(self, row, col, val):
+        raise NotImplementedError
 
 class PyMonetDBTestResult(TestCaseResult, RunnableTestResult):
     """Holder of sql execution information. Managed by SQLTestCase."""
@@ -323,56 +370,44 @@ class PyMonetDBTestResult(TestCaseResult
                 self.query = query
                 try:
                     with self.test_case.conn_ctx as ctx:
-                        ctx.crs.execute(query)
-                        self.rowcount = ctx.crs.rowcount
-                        self.rows = ctx.crs._rows
-                        if ctx.crs.description:
-                            self.data = ctx.crs.fetchall()
-                            self.description = ctx.crs.description
+                        crs = ctx.cursor()
+                        crs.execute(query)
+                        self.rowcount = crs.rowcount
+                        self.rows = crs._rows
+                        if crs.description:
+                            self.data = crs.fetchall()
+                            self.description = crs.description
                 except (pymonetdb.Error, ValueError) as e:
                     self.test_run_error = e
                     self.err_code, self.err_message = 
self._parse_error(e.args[0])
             self.did_run = True
         return self
 
-    def assertValue(self, row, col, val):
-        """assert on a value matched against row, col in the result"""
-        received = None
-        row = int(row)
-        col = int(col)
-        try:
-            received = self.data[row][col]
-        except IndexError:
-            pass
-        if type(val) is type(received):
-            if val != received:
-                msg = 'expected "{}", received "{}" in row={}, 
col={}'.format(val, received, row, col)
-                self.fail(msg, data=self.data)
-        else:
-            # handle type mismatch
-            msg = 'expected type {} and value "{}", received type {} and value 
"{}" in row={}, col={}'.format(type(val), str(val), type(received), 
str(received), row, col)
-            self.fail(msg, data=self.data)
-        return self
+
+class MonetDBeTestResult(TestCaseResult, RunnableTestResult):
+    def __init__(self, test_case):
+        super().__init__(test_case)
+        self.did_run = False
+
+    def _parse_error(self, err: str):
+        pass
 
-    def assertDataResultMatch(self, data=[], index=None):
-        """Assert on a match of a subset of the result. When index is provided 
it
-        starts comparig from that row index onward.
-        """
-        def mapfn(next):
-            if type(next) is list:
-                return tuple(next)
-            return next
-        data = list(map(mapfn, data))
-        if index is None:
-            if len(data) > 0:
-                first = data[0]
-                for i, v in enumerate(self.data):
-                    if first == v:
-                        index = i
-                        break
-        if not sequence_match(data, self.data, index):
-            msg = '{}\nexpected to match query result starting at index={}, 
but it didn\'t'.format(piped_representation(data), index)
-            self.fail(msg, data=self.data)
+    def run(self, query:str, *args, stdin=None):
+        if self.did_run is False:
+            if query:
+                self.query = query
+                try:
+                    conn = self.test_case.conn_ctx
+                    crs = conn.cursor()
+                    crs.execute(query)
+                    self.rowcount = int(crs.rowcount)
+                    if crs.description:
+                        self.description = crs.description
+                        self.data = crs.fetchall()
+                except (monetdbe.Error, ValueError) as e:
+                    self.test_run_error = e
+                    # TODO parse error
+            self.did_run = True
         return self
 
 class SQLDump():
@@ -403,6 +438,7 @@ class SQLTestCase():
         self.err_file = err_file
         self.test_results = []
         self._conn_ctx = None
+        self.in_memory = False
 
     def __enter__(self):
         return self
@@ -411,6 +447,8 @@ class SQLTestCase():
         self.exit()
 
     def exit(self):
+        if self._conn_ctx:
+            self._conn_ctx.close()
         self._conn_ctx = None
         for res in self.test_results:
             if len(res.assertion_errors) > 0:
@@ -424,10 +462,15 @@ class SQLTestCase():
         print(msg, file=self.err_file)
         print('', file=self.err_file)
 
-
     def connect(self,
-            username='monetdb', password='monetdb',
-            hostname='localhost', port=MAPIPORT, database=TSTDB, 
language='sql'):
+            username='monetdb', password='monetdb', port=MAPIPORT,
+            hostname='localhost', database=TSTDB, language='sql'):
+        if database == ':memory:':
+            self.in_memory = True
+            # TODO add username, password, port when supported from monetdbe
+            self._conn_ctx = monetdbe.connect(':memory:', autocommit=True)
+        else:
+            self.in_memory = False
             self._conn_ctx = PyMonetDBConnectionContext(
                                  username=username,
                                  password=password,
@@ -435,9 +478,11 @@ class SQLTestCase():
                                  port=port,
                                  database=database,
                                  language=language)
-            return self._conn_ctx
+        return self._conn_ctx
 
     def default_conn_ctx(self):
+        if self.in_memory:
+            return  monetdbe.connect(':memory:', autocommit=True)
         return PyMonetDBConnectionContext()
 
     @property
@@ -447,6 +492,8 @@ class SQLTestCase():
     def execute(self, query:str, *args, client='pymonetdb', stdin=None):
         if client == 'mclient':
             res = MclientTestResult(self)
+        elif self.in_memory:
+            res = MonetDBeTestResult(self)
         else:
             res = PyMonetDBTestResult(self)
         res.run(query, *args, stdin=stdin)
@@ -471,9 +518,12 @@ class SQLTestCase():
         return res
 
     def drop(self):
+        if self.in_memory:
+            # TODO
+            return
         try:
             with self.conn_ctx as ctx:
-                crs = ctx.crs
+                crs = ctx.cursor()
                 crs.execute('select s.name, t.name, tt.table_type_name from 
sys.tables t, sys.schemas s, sys.table_types tt where not t.system and 
t.schema_id = s.id and t.type = tt.table_type_id')
                 for row in crs.fetchall():
                     crs.execute('drop {} "{}"."{}" cascade'.format(row[2], 
row[0], row[1]))
_______________________________________________
checkin-list mailing list
checkin-list@monetdb.org
https://www.monetdb.org/mailman/listinfo/checkin-list

Reply via email to