Changeset: a32c685fa62b for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/a32c685fa62b
Modified Files:
        sql/backends/monet5/UDF/capi/Tests/capi11.test
        sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_18.test
        sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_19.test
        sql/common/sql_types.c
        sql/server/rel_psm.c
        sql/server/rel_schema.c
        sql/server/rel_select.c
        sql/server/sql_parser.y
        sql/test/SQLancer/Tests/sqlancer08.test
Branch: Aug2024
Log Message:

handle generic type decimal (ie without specification/scale) later in the
process. For creation of columns we still use the old default of 18,3.
For functions we resolve based on the input type.
For all *api sql extentions, we now enforce use of fully specified decimals.


diffs (truncated from 327 to 300 lines):

diff --git a/sql/backends/monet5/UDF/capi/Tests/capi11.test 
b/sql/backends/monet5/UDF/capi/Tests/capi11.test
--- a/sql/backends/monet5/UDF/capi/Tests/capi11.test
+++ b/sql/backends/monet5/UDF/capi/Tests/capi11.test
@@ -1,7 +1,4 @@
-statement ok
-START TRANSACTION
-
-statement ok
+statement error 42000!CREATE FUNCTION: the function 'capi11' uses a generic 
DECIMAL type, UDFs require precision and scale
 CREATE FUNCTION capi11(inp DECIMAL) RETURNS DECIMAL(11,1) LANGUAGE C {
     size_t i;
     result->initialize(result, inp.count);
@@ -14,6 +11,33 @@ CREATE FUNCTION capi11(inp DECIMAL) RETU
     }
 }
 
+statement error 42000!CREATE FUNCTION: the function '_dbl2dec' returns a 
generic DECIMAL type, UDFs require precision and scale
+CREATE FUNCTION _dbl2dec(inp DOUBLE) RETURNS DECIMAL LANGUAGE C {
+    size_t i;
+    result->initialize(result, inp.count);
+    for(i = 0; i < inp.count; i++) {
+        result->data[i] = inp.data[i] * result->scale;
+    }
+}
+
+
+statement ok
+START TRANSACTION
+
+statement ok
+CREATE FUNCTION capi11(inp DECIMAL(18,3)) RETURNS DECIMAL(11,1) LANGUAGE C {
+    size_t i;
+    result->initialize(result, inp.count);
+    for(i = 0; i < inp.count; i++) {
+        if (inp.data[i] == inp.null_value) {
+            result->data[i] = result->null_value;
+        } else {
+            result->data[i] = (inp.data[i] / inp.scale) * result->scale;
+        }
+    }
+}
+
+
 statement ok
 CREATE TABLE decimals(d DECIMAL(18,3))
 
@@ -32,7 +56,7 @@ statement ok
 DROP FUNCTION capi11
 
 statement ok
-CREATE FUNCTION _dec2dbl(inp DECIMAL) RETURNS DOUBLE LANGUAGE C {
+CREATE FUNCTION _dec2dbl(inp DECIMAL(18,3)) RETURNS DOUBLE LANGUAGE C {
     size_t i;
     result->initialize(result, inp.count);
     for(i = 0; i < inp.count; i++) {
@@ -59,7 +83,7 @@ statement ok
 DROP FUNCTION _dec2dbl
 
 statement ok
-CREATE FUNCTION _dbl2dec(inp DOUBLE) RETURNS DECIMAL LANGUAGE C {
+CREATE FUNCTION _dbl2dec(inp DOUBLE) RETURNS DECIMAL(18,3) LANGUAGE C {
     size_t i;
     result->initialize(result, inp.count);
     for(i = 0; i < inp.count; i++) {
diff --git a/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_18.test 
b/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_18.test
--- a/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_18.test
+++ b/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_18.test
@@ -1,3 +1,15 @@
+statement error 42000!CREATE FUNCTION: the function 'pyapi_decimal' uses a 
generic DECIMAL type, UDFs require precision and scale
+CREATE FUNCTION pyapi_decimal(d DECIMAL) RETURNS DOUBLE LANGUAGE PYTHON { 
return d; }
+
+statement error 42000!CREATE UNION FUNCTION: the function 'pyapi_ret_decimal' 
returns a generic DECIMAL type, UDFs require precision and scale
+CREATE FUNCTION pyapi_ret_decimal() RETURNS TABLE(d DECIMAL)
+LANGUAGE PYTHON
+{
+    result = dict()
+    result['d'] = 100.33
+    return result
+}
+
 statement ok
 START TRANSACTION
 
@@ -65,7 +77,7 @@ statement ok rowcount 1
 INSERT INTO decimal_table VALUES (123.4567)
 
 statement ok
-CREATE FUNCTION pyapi_decimal(d DECIMAL) RETURNS DOUBLE LANGUAGE PYTHON { 
return d; }
+CREATE FUNCTION pyapi_decimal(d DECIMAL(18, 3)) RETURNS DOUBLE LANGUAGE PYTHON 
{ return d; }
 
 query R rowsort
 SELECT pyapi_decimal(d) FROM decimal_table
@@ -118,7 +130,7 @@ SELECT * FROM pyapi_ret_timestamp()
 2000-01-01 12:00:00.000001
 
 statement ok
-CREATE FUNCTION pyapi_ret_decimal() RETURNS TABLE(d DECIMAL)
+CREATE FUNCTION pyapi_ret_decimal() RETURNS TABLE(d DECIMAL(18, 3))
 LANGUAGE PYTHON
 {
     result = dict()
@@ -144,14 +156,14 @@ statement ok
 DROP FUNCTION pyapi_ret_decimal
 
 statement ok
-CREATE FUNCTION pyapi_ret_decimal() RETURNS TABLE(d DECIMAL)
+CREATE FUNCTION pyapi_ret_decimal() RETURNS TABLE(d DECIMAL(18, 3))
 LANGUAGE PYTHON
 {
     return numpy.arange(100001) # return 100k decimal values
 }
 
 statement ok
-CREATE FUNCTION pyapi_inp_decimal(d DECIMAL) RETURNS DOUBLE
+CREATE FUNCTION pyapi_inp_decimal(d DECIMAL(18, 3)) RETURNS DOUBLE
 LANGUAGE PYTHON
 {
     return numpy.mean(d) # average 100k decimal values
diff --git a/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_19.test 
b/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_19.test
--- a/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_19.test
+++ b/sql/backends/monet5/UDF/pyapi3/Tests/pyapi3_19.test
@@ -121,7 +121,7 @@ statement ok
 START TRANSACTION
 
 statement ok
-CREATE FUNCTION pyapi19_create_table() returns table (i integer, j integer, k 
double, l float, m smallint, n bigint, o STRING, p DECIMAL) LANGUAGE P
+CREATE FUNCTION pyapi19_create_table() returns table (i integer, j integer, k 
double, l float, m smallint, n bigint, o STRING, p DECIMAL(18, 3)) LANGUAGE P
 {
     result = dict();
     result['i'] = numpy.arange(100000, 0, -1);
@@ -136,7 +136,7 @@ CREATE FUNCTION pyapi19_create_table() r
 }
 
 statement ok
-CREATE FUNCTION pyapi19_load_table() returns table (i integer, j integer, k 
double, l float, m smallint, n bigint, o STRING, p DECIMAL) LANGUAGE PYTHON_MAP
+CREATE FUNCTION pyapi19_load_table() returns table (i integer, j integer, k 
double, l float, m smallint, n bigint, o STRING, p DECIMAL(18, 3)) LANGUAGE 
PYTHON_MAP
 {
     res = _conn.execute('SELECT * FROM pyapi19_integers;')
     return res
diff --git a/sql/common/sql_types.c b/sql/common/sql_types.c
--- a/sql/common/sql_types.c
+++ b/sql/common/sql_types.c
@@ -244,6 +244,8 @@ sql_init_subtype(sql_subtype *res, sql_t
        if (t->digits && res->digits > t->digits)
                res->digits = t->digits;
        res->scale = scale;
+       if (!digits && !scale && t->eclass == EC_DEC)
+               res->scale = res->digits = 0;
 }
 
 sql_subtype *
@@ -255,6 +257,15 @@ sql_create_subtype(allocator *sa, sql_ty
        return res;
 }
 
+static sql_subtype *
+create_subtype(allocator *sa, sql_type *t)
+{
+       sql_subtype *res = SA_ZNEW(sa, sql_subtype);
+
+       sql_init_subtype(res, t, t->digits, 0);
+       return res;
+}
+
 static bool
 localtypes_cmp(int nlt, int olt)
 {
@@ -799,10 +810,10 @@ sql_create_func_(allocator *sa, const ch
 
        for (int i = 0; i < nargs; i++) {
                sql_type *tpe = va_arg(valist, sql_type*);
-               list_append(ops, create_arg(sa, NULL, sql_create_subtype(sa, 
tpe, 0, 0), ARG_IN));
+               list_append(ops, create_arg(sa, NULL, create_subtype(sa, tpe), 
ARG_IN));
        }
        if (res)
-               fres = create_arg(sa, NULL, sql_create_subtype(sa, res, 0, 0), 
ARG_OUT);
+               fres = create_arg(sa, NULL, create_subtype(sa, res), ARG_OUT);
        base_init(sa, &t->base, local_id++, false, name);
 
        t->imp = sa_strdup(sa, imp);
diff --git a/sql/server/rel_psm.c b/sql/server/rel_psm.c
--- a/sql/server/rel_psm.c
+++ b/sql/server/rel_psm.c
@@ -853,6 +853,35 @@ rel_create_function(allocator *sa, const
        return rel;
 }
 
+static bool
+has_generic_decimal(list *types)
+{
+       if (!list_empty(types)) {
+               for(node *n = types->h; n; n = n->next) {
+                       sql_subtype *st = n->data;
+
+                       if (st->type->eclass == EC_DEC && !st->digits && 
!st->scale)
+                               return true;
+               }
+       }
+       return false;
+}
+
+static bool
+has_generic_decimal_result(list *types)
+{
+       if (!list_empty(types)) {
+               for(node *n = types->h; n; n = n->next) {
+                       sql_arg *a = n->data;
+
+                       if (a->type.type->eclass == EC_DEC && !a->type.digits 
&& !a->type.scale)
+                               return true;
+               }
+       }
+       return false;
+}
+
+
 static sql_rel *
 rel_create_func(sql_query *query, dlist *qname, dlist *params, symbol *res, 
dlist *ext_name, dlist *body, sql_ftype type, sql_flang lang, int replace)
 {
@@ -953,6 +982,8 @@ rel_create_func(sql_query *query, dlist 
                sql->session->status = 0; /* if the function was not found 
clean the error */
                sql->errstr[0] = '\0';
        }
+       if (lang > FUNC_LANG_SQL && has_generic_decimal(type_list))
+               return sql_error(sql, 02, SQLSTATE(42000) "CREATE %s: the 
function '%s' uses a generic DECIMAL type, UDFs require precision and scale", 
F, fname);
 
        list_destroy(type_list);
 
@@ -976,6 +1007,8 @@ rel_create_func(sql_query *query, dlist 
        if (res && !(restype = result_type(sql, res)))
                return sql_error(sql, 01, SQLSTATE(42000) "CREATE %s: failed to 
get restype", F);
 
+       if (lang > FUNC_LANG_SQL && has_generic_decimal_result(restype))
+               return sql_error(sql, 02, SQLSTATE(42000) "CREATE %s: the 
function '%s' returns a generic DECIMAL type, UDFs require precision and 
scale", F, fname);
        if (body && LANG_EXT(lang)) {
                const char *lang_body = body->h->data.sval, *mod = "unknown", 
*slang = "Unknown", *imp = "Unknown";
                switch (lang) {
diff --git a/sql/server/rel_schema.c b/sql/server/rel_schema.c
--- a/sql/server/rel_schema.c
+++ b/sql/server/rel_schema.c
@@ -1094,6 +1094,9 @@ create_column(sql_query *query, symbol *
        if (l->h->next->next)
                opt_list = l->h->next->next->data.lval;
 
+       if (ctype && ctype->type->eclass == EC_DEC && !ctype->digits && 
!ctype->scale) /* default 18,3 */
+               ctype = sql_bind_subtype(query->sql->sa, "decimal", 18, 3);
+
        if (cname && ctype) {
                sql_column *cs = NULL;
 
diff --git a/sql/server/rel_select.c b/sql/server/rel_select.c
--- a/sql/server/rel_select.c
+++ b/sql/server/rel_select.c
@@ -4073,10 +4073,12 @@ rel_cast(sql_query *query, sql_rel **rel
                sql_subtype *et = exp_subtype(e);
                if (et->type->eclass == EC_NUM) {
                        unsigned int min_precision = atom_num_digits(e->l);
+                       if (!tpe->digits && !tpe->scale)
+                               tpe->digits = min_precision;
                        if (min_precision > tpe->digits)
                                return sql_error(sql, 02, SQLSTATE(42000) 
"Precision (%d) should be at least (%d)", tpe->digits, min_precision);
-                       tpe = sql_bind_subtype(sql->sa, "decimal", 
min_precision, et->scale);
-               } else if (EC_VARCHAR(et->type->eclass)) {
+                       tpe = sql_bind_subtype(sql->sa, "decimal", tpe->digits, 
et->scale);
+               } else if (EC_VARCHAR(et->type->eclass) && !tpe->digits && 
!tpe->scale) {
                        char *s = E_ATOM_STRING(e);
                        unsigned int min_precision = 0, min_scale = 0;
                        bool dot_seen = false;
@@ -4091,6 +4093,12 @@ rel_cast(sql_query *query, sql_rel **rel
                        }
                        tpe = sql_bind_subtype(sql->sa, "decimal", 
min_precision, min_scale);
                }
+       } else if (tpe->type->eclass == EC_DEC && !tpe->digits && !tpe->scale) {
+               sql_subtype *et = exp_subtype(e);
+               if (et->type->eclass == EC_NUM)
+                       tpe = sql_bind_subtype(sql->sa, "decimal", et->digits, 
0);
+               else /* fallback */
+                       tpe = sql_bind_subtype(sql->sa, "decimal", 18, 3);
        }
 
        if (e)
diff --git a/sql/server/sql_parser.y b/sql/server/sql_parser.y
--- a/sql/server/sql_parser.y
+++ b/sql/server/sql_parser.y
@@ -5752,7 +5752,7 @@ data_type:
  |  BIGINT             { sql_find_subtype(&$$, "bigint", 0, 0); }
  |  HUGEINT            { sql_find_subtype(&$$, "hugeint", 0, 0); }
 
- |  sqlDECIMAL         { sql_find_subtype(&$$, "decimal", 18, 3); }
+ |  sqlDECIMAL         { sql_find_subtype(&$$, "decimal", 0, 0); }
  |  sqlDECIMAL '(' nonzero ')'
                        {
                          int d = $3;
@@ -7019,7 +7019,7 @@ odbc_data_type:
     | SQL_DATE
        { sql_find_subtype(&$$, "date", 0, 0); }
     | SQL_DECIMAL
-       { sql_find_subtype(&$$, "decimal", 18, 3); }
+       { sql_find_subtype(&$$, "decimal", 0, 0); }
     | SQL_DOUBLE
_______________________________________________
checkin-list mailing list -- checkin-list@monetdb.org
To unsubscribe send an email to checkin-list-le...@monetdb.org

Reply via email to