Changeset: b588a6977aae for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/b588a6977aae
Modified Files:
        clients/odbc/driver/SQLConnect.c
Branch: odbc-tls
Log Message:

Refactor MNDBConnect

One change in behavior: in the old code, the dbname precedence was:
    1. 'dbname' parameter
    2. existing database name
    3. database name from data source

That seemed odd, so now it's
    1. 'dbname' parameter
    2. database name from data source
    3. existing database name


diffs (truncated from 367 to 300 lines):

diff --git a/clients/odbc/driver/SQLConnect.c b/clients/odbc/driver/SQLConnect.c
--- a/clients/odbc/driver/SQLConnect.c
+++ b/clients/odbc/driver/SQLConnect.c
@@ -82,6 +82,59 @@ get_serverinfo(ODBCDbc *dbc)
        mapi_close_handle(hdl);
 }
 
+// Return a newly allocated NUL-terminated config value from either the 
argument
+// or the data source. Return 'default_value' if no value can be found, NULL on
+// allocation error.
+//
+// If non-NULL, parameter 'argument' points to an argument that may or may not
+// be NUL-terminated. The length parameter 'argument_len' can either be the
+// length of the argument or one of the following special values:
+//
+//    SQL_NULL_DATA: consider the argument NULL
+//    SQL_NTS:       the argument is actually NUL-terminated
+//
+// Parameters 'dsn' and 'entry', if not NULL and not empty, indicate which data
+// source field to look up in "odbc.ini".
+static char*
+getConfig(
+       const void *argument, ssize_t argument_len,
+       const char *dsn, const char *entry,
+       const char *default_value)
+{
+       if (argument != NULL && argument_len != SQL_NULL_DATA) {
+               // argument is present..
+               if (argument_len == SQL_NTS) {
+                       // .. and it's already NUL-terminated
+                       return strdup((const char*)argument);
+               } else {
+                       // .. but we need to create a NUL-terminated copy
+                       char *value = malloc(argument_len + 1);
+                       if (value == NULL)
+                               return NULL;
+                       memmove(value, argument, argument_len);
+                       value[argument_len] = '\0';
+                       return value;
+               }
+       } else if (dsn && *dsn && entry && *entry) {
+               // look up in the data source
+               size_t size = 1024; // should be plenty
+               char *buffer = malloc(size);
+               if (buffer == NULL)
+                       return NULL;
+               int n = SQLGetPrivateProfileString(dsn, entry, "", buffer, 
size, "odbc.ini");
+               if (n > 0) {
+                       // found some
+                       return buffer;
+               } else {
+                       // found none
+                       free(buffer);
+                       return strdup(default_value);
+               }
+       } else {
+               return strdup(default_value);
+       }
+}
+
 SQLRETURN
 MNDBConnect(ODBCDbc *dbc,
            const SQLCHAR *ServerName,
@@ -95,38 +148,37 @@ MNDBConnect(ODBCDbc *dbc,
            const char *dbname,
            int mapToLongVarchar)
 {
-       SQLRETURN rc = SQL_SUCCESS;
+       // These will be passed to addDbcError if you 'goto failure'.
+       // If unset, 'goto failure' will assume an allocation error.
+       const char *error_state = NULL;
+       const char *error_explanation = NULL;
+
+       // These will be free'd / destroyed at the 'end' label at the bottom of 
this function
        char *dsn = NULL;
-       char uid[32];
-       char pwd[32];
-       char buf[256];
-       char db[32];
-       int n;
-       Mapi mid;
+       char *uid = NULL;
+       char *pwd = NULL;
+       char *db = NULL;
+       char *hostdup = NULL;
+       char *portdup = NULL;
+       Mapi mid = NULL;
+
+       // These do not need to be free'd
+       const char *mapiport_env;
 
        /* check connection state, should not be connected */
        if (dbc->Connected) {
-               /* Connection name in use */
-               addDbcError(dbc, "08002", NULL, 0);
-               return SQL_ERROR;
+               error_state = "08002";
+               goto failure;
        }
 
-       /* convert input string parameters to normal null terminated C strings 
*/
-       fixODBCstring(ServerName, NameLength1, SQLSMALLINT,
-                     addDbcError, dbc, return SQL_ERROR);
-       if (NameLength1 > 0) {
-               dsn = dupODBCstring(ServerName, (size_t) NameLength1);
-               if (dsn == NULL) {
-                       /* Memory allocation error */
-                       addDbcError(dbc, "HY001", NULL, 0);
-                       return SQL_ERROR;
-               }
-       }
+       dsn = getConfig(ServerName, NameLength1, NULL, NULL, "");
+       if (dsn == NULL)
+               goto failure;
 
 #ifdef ODBCDEBUG
        if ((ODBCdebug == NULL || *ODBCdebug == 0) && dsn && *dsn) {
                char logfile[2048];
-               n = SQLGetPrivateProfileString(dsn, "logfile", "",
+               int n = SQLGetPrivateProfileString(dsn, "logfile", "",
                                               logfile, sizeof(logfile),
                                               "odbc.ini");
                if (n > 0) {
@@ -151,141 +203,127 @@ MNDBConnect(ODBCDbc *dbc,
        }
 #endif
 
-       if (dsn && *dsn)
-               n = SQLGetPrivateProfileString(dsn, "uid", "monetdb",
-                                              uid, sizeof(uid), "odbc.ini");
-       else
-               n = 0;
-       fixODBCstring(UserName, NameLength2, SQLSMALLINT,
-                     addDbcError, dbc, if (dsn) free(dsn); return SQL_ERROR);
-       if (n == 0 && NameLength2 == 0) {
-               if (dsn)
-                       free(dsn);
-               /* Invalid authorization specification */
-               addDbcError(dbc, "28000", NULL, 0);
-               return SQL_ERROR;
-       }
-       if (NameLength2 > 0) {
-               if ((size_t)NameLength2 >= sizeof(uid))
-                       NameLength2 = sizeof(uid) - 1;
-               strncpy(uid, (char *) UserName, NameLength2);
-               uid[NameLength2] = 0;
+       uid = getConfig(UserName, NameLength2, dsn, "uid", "monetdb");
+       if (uid == NULL)
+               goto failure;
+       if (*uid == '\0') {
+               error_state = "28000";
+               error_explanation = "user name not set";
+               goto failure;
        }
-       if (dsn && *dsn)
-               n = SQLGetPrivateProfileString(dsn, "pwd", "monetdb",
-                                              pwd, sizeof(pwd), "odbc.ini");
-       else
-               n = 0;
-       fixODBCstring(Authentication, NameLength3, SQLSMALLINT,
-                     addDbcError, dbc, if (dsn) free(dsn); return SQL_ERROR);
-       if (n == 0 && NameLength3 == 0) {
-               if (dsn)
-                       free(dsn);
-               /* Invalid authorization specification */
-               addDbcError(dbc, "28000", NULL, 0);
-               return SQL_ERROR;
-       }
-       if (NameLength3 > 0) {
-               if ((size_t)NameLength3 >= sizeof(pwd))
-                       NameLength3 = sizeof(pwd) - 1;
-               strncpy(pwd, (char *) Authentication, NameLength3);
-               pwd[NameLength3] = 0;
+
+       pwd = getConfig(Authentication, NameLength3, dsn, "pwd", "monetdb");
+       if (pwd == NULL)
+               goto failure;
+       if (*pwd == '\0') {
+               error_state = "28000";
+               error_explanation = "password not set";
+               goto failure;
        }
 
-       if (dbname == NULL || *dbname == 0) {
-               dbname = dbc->dbname;
-       }
-       if (dbname == NULL || *dbname == 0) {
-               if (dsn && *dsn) {
-                       n = SQLGetPrivateProfileString(dsn, "database", "", db,
-                                                      sizeof(db), "odbc.ini");
-                       if (n > 0)
-                               dbname = db;
-               }
-       }
-       if (dbname && !*dbname)
-               dbname = NULL;
+       // In the old code, the dbname precedence was:
+       // 1. 'dbname' parameter
+       // 2. existing database name
+       // 3. database name from data source
+       //
+       // That seemed odd, so now it's
+       // 1. 'dbname' parameter
+       // 2. database name from data source
+       // 3. existing database name
+       db = getConfig(dbname, SQL_NTS, dsn, "database", dbc->dbname ? 
dbc->dbname : "");
+       if (db == NULL)
+               goto failure;
+
+       // In the old code we had Windows-specific code that
+       // ran _wgetenv(L"MAPIPORT").
+       // However, even on Windows getenv() is probably fine for a variable 
that's
+       // supposed to only hold digits.
+       mapiport_env = getenv("MAPIPORT");
 
-#ifdef NATIVE_WIN32
-       wchar_t *s;
-       if (port == 0 && (s = _wgetenv(L"MAPIPORT")) != NULL)
-               port = _wtoi(s);
-#else
-       char *s;
-       if (port == 0 && (s = getenv("MAPIPORT")) != NULL)
-               port = atoi(s);
-#endif
-       if (port == 0 && dsn && *dsn) {
-               n = SQLGetPrivateProfileString(dsn, "port", MAPI_PORT_STR,
-                                              buf, sizeof(buf), "odbc.ini");
-               if (n > 0)
-                       port = atoi(buf);
+       // Port precedence:
+       // 2. 'port' parameter
+       // 1. MAPIPORT env var
+       // 3. data source
+       // 4. MAPI_PORT_STR ("50000")
+       if (port == 0) {
+               portdup = getConfig(mapiport_env, SQL_NTS, dsn, "port", 
MAPI_PORT_STR);
+               if (portdup == NULL)
+                       goto failure;
+               char *end;
+               long longport = strtol(portdup, &end, 10);
+               if (*portdup == '\0' || *end != '\0' || longport < 1 || 
longport > 65535) {
+                       error_state = "HY009"; // invalid argument
+                       error_explanation = mapiport_env != NULL
+                               ? "invalid port setting in MAPIPORT environment 
variable"
+                               : "invalid port setting in data source";
+                       goto failure;
+               }
+               port = longport;
        }
-       if (port == 0)
-               port = MAPI_PORT;
 
-       if (host == NULL || *host == 0) {
-               host = "localhost";
-               if (dsn && *dsn) {
-                       n = SQLGetPrivateProfileString(dsn, "host", "localhost",
-                                                      buf, sizeof(buf),
-                                                      "odbc.ini");
-                       if (n > 0)
-                               host = buf;
-               }
-       }
+       hostdup = getConfig(host, SQL_NTS, dsn, "host", "localhost");
+       if (hostdup == NULL)
+               goto failure;
+
 
 #ifdef ODBCDEBUG
        ODBCLOG("SQLConnect: DSN=%s UID=%s PWD=%s host=%s port=%d 
database=%s\n",
-               dsn ? dsn : "(null)", uid, pwd, host, port,
-               dbname ? dbname : "(null)");
+               dsn, uid, pwd, hostdup, port, db);
 #endif
 
-       /* connect to a server on host via port */
-       /* FIXME: use dbname from ODBC connect string/options here */
-       mid = mapi_mapi(host, port, uid, pwd, "sql", dbname);
+       // Create mid and execute a bunch of commands before checking for 
errors.
+       mid = mapi_mapi(hostdup, port, uid, pwd, "sql", db);
        if (mid) {
                mapi_setAutocommit(mid, dbc->sql_attr_autocommit == 
SQL_AUTOCOMMIT_ON);
                mapi_set_size_header(mid, true);
                mapi_reconnect(mid);
        }
        if (mid == NULL || mapi_error(mid)) {
-               /* Client unable to establish connection */
-               addDbcError(dbc, "08001", mid ? mapi_error_str(mid) : NULL, 0);
-               rc = SQL_ERROR;
-               /* clean up */
-               if (mid)
-                       mapi_destroy(mid);
-               if (dsn != NULL)
-                       free(dsn);
-       } else {
-               /* store internal information and clean up buffers */
-               dbc->Connected = true;
-               dbc->mid = mid;
-               if (dbc->dsn != NULL)
-                       free(dbc->dsn);
-               dbc->dsn = dsn;
_______________________________________________
checkin-list mailing list -- checkin-list@monetdb.org
To unsubscribe send an email to checkin-list-le...@monetdb.org

Reply via email to