This is an automated email from the ASF dual-hosted git repository.

joemcdonnell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit 3781132ef6d339a2b8e2b4444c0b79bc79f84a5c
Author: gaurav1086 <[email protected]>
AuthorDate: Thu Jan 30 12:09:51 2025 -0800

    IMPALA-13675: OAuth AuthN Support for Impala Shell
    
    This patch adds the support to fetch access tokens
    from the OAuth Server using the OAuth client_id and
    client_secret if the access token is not provided.
    It covers the flow: client_credentials.
    The client_secret can either be passed as a file or
    be prompted to enter.
    
    Added a test param for impala shell oauth_mock_response_cmd
    to mock oauth server response only to be used for testing.
    Also suppressed existing option hs2_x_forward from the
    impala --help output.
    
    Testing(okta oauth server):
    - Added custom_cluster tests in test_shell_jwt_auth.py:
        test_oauth_auth_with_clientid_and_secret_success
        test_oauth_auth_with_clientid_and_secret_failure
    - Tested manually by providing --user <user> and
      --oauth_client_secret_cmd="cat password_file.txt"
    - Tested manually by providing --user <user> and no
      --oauth_client_secret_cmd, thereby prompting the user
      to enter the client_secret.
    
    Example command: impala-shell.sh -a
    --auth_creds_ok_in_clear --protocol="hs2-http"
    --oauth_client_id="client_id"
    --oauth_client_secret_cmd="cat client_secret.txt"
    --oauth_server="dev.us.auth01.com"
    --oauth_endpoint="/oauth/token"
    
    Change-Id: I84e26d54f6a53696660728efb239ffd43de4c55d
    Reviewed-on: http://gerrit.cloudera.org:8080/22424
    Reviewed-by: Impala Public Jenkins <[email protected]>
    Tested-by: Impala Public Jenkins <[email protected]>
---
 bin/rat_exclude_files.txt                      |  3 +
 shell/impala_shell/impala_shell.py             | 87 +++++++++++++++++++++++--
 shell/impala_shell/option_parser.py            | 20 +++++-
 testdata/jwt/oauth_client_secret               |  1 +
 testdata/jwt/okta_oauth_payload_invalid        |  1 +
 testdata/jwt/okta_oauth_payload_valid          |  1 +
 tests/custom_cluster/test_shell_jwt_auth.py    |  5 ++
 tests/custom_cluster/test_shell_oauth_auth.py  | 90 +++++++++++++++++++++++++-
 tests/shell/test_shell_commandline_jwt_auth.py |  8 +--
 9 files changed, 204 insertions(+), 12 deletions(-)

diff --git a/bin/rat_exclude_files.txt b/bin/rat_exclude_files.txt
index dc4ff98a0..c9652ed7b 100644
--- a/bin/rat_exclude_files.txt
+++ b/bin/rat_exclude_files.txt
@@ -192,6 +192,9 @@ testdata/jwt/*.json
 testdata/jwt/jwt_expired
 testdata/jwt/jwt_signed
 testdata/jwt/jwt_signed_untrusted
+testdata/jwt/oauth_client_secret
+testdata/jwt/okta_oauth_payload_valid
+testdata/jwt/okta_oauth_payload_invalid
 testdata/tzdb/2017c.zip
 testdata/tzdb/2017c-corrupt.zip
 testdata/tzdb_tiny/*
diff --git a/shell/impala_shell/impala_shell.py 
b/shell/impala_shell/impala_shell.py
index 04729211e..6de925ccb 100755
--- a/shell/impala_shell/impala_shell.py
+++ b/shell/impala_shell/impala_shell.py
@@ -36,6 +36,8 @@ import sys
 import textwrap
 import time
 import traceback
+import json
+from six.moves import http_client
 
 import prettytable
 import sqlparse
@@ -229,6 +231,13 @@ class ImpalaShell(cmd.Cmd, object):
     self.jwt = options.jwt
     self.use_oauth = options.use_oauth
     self.oauth = options.oauth
+    self.oauth_server = options.oauth_server
+    self.oauth_client_id = options.oauth_client_id
+    self.oauth_client_secret_cmd = options.oauth_client_secret_cmd
+    self.oauth_client_secret = options.oauth_client_secret
+    self.oauth_endpoint = options.oauth_endpoint
+    self.oauth_mock_response_cmd = options.oauth_mock_response_cmd
+    self.oauth_mock_response = options.oauth_mock_response
     # When running tests in strict mode, the server uses the ldap
     # protocol but can allow any password.
     if options.use_ldap_test_password:
@@ -1008,7 +1017,9 @@ class ImpalaShell(cmd.Cmd, object):
       self.jwt = getpass.getpass("Enter JWT: ")
 
     if self.use_oauth and self.oauth is None:
-      self.oauth = getpass.getpass("Enter OAUTH: ")
+      self.oauth_get_access_token()
+      if self.oauth is None and self.oauth_mock_response is None:
+        self.oauth = getpass.getpass("Enter OAUTH: ")
 
     if not args: args = socket.getfqdn()
     tokens = args.split(" ")
@@ -1874,6 +1885,60 @@ class ImpalaShell(cmd.Cmd, object):
         if not self.ignore_query_failure: return False
     return True
 
+  def oauth_get_access_token(self):
+    """Fetches OAuth access token from the OAuth Auth Server
+    using client_id and client_secret. If however, the
+    oauth_mock_response is set, then it returns the contents
+    of the file passed as parameter for oauth_mock_response_cmd.
+    """
+    json_body = dict()
+    if self.use_oauth and self.oauth is None:
+      if self.oauth_server is None:
+        print("Error: OAuth Server is empty")
+        sys.exit(1)
+      if self.oauth_endpoint is None:
+        print("Error: OAuth endpoint is empty")
+        sys.exit(1)
+      if self.oauth_client_id is None:
+        print("Error: OAuth Client id is empty")
+        sys.exit(1)
+
+      # If client secret is not provided through a command, then request it 
from user
+      if self.oauth_client_secret is None:
+        self.oauth_client_secret = getpass.getpass("Enter OAuth Client Secret: 
")
+
+      # Fetch the access tokens first using client id and client secret
+      if self.oauth_mock_response is None:
+        try:
+          # Retreive the oauth access token from the OAuth Server
+          conn = http_client.HTTPSConnection(self.oauth_server, timeout=10)
+          payload = "{\"client_id\":\"" + self.oauth_client_id + \
+            "\",\"client_secret\":\"" + self.oauth_client_secret + \
+            "\",\"audience\":\"https://"; + self.oauth_server + \
+            "/api/v2/\",\"grant_type\":\"client_credentials\"}"
+          headers = {'content-type': "application/json", "charset": "utf-8"}
+          conn.request("POST", self.oauth_endpoint, payload.encode('utf-8'), 
headers)
+          res = conn.getresponse()
+          if (res.status != 200):
+            print("HTTP error: ", res.status, res.read().decode("utf-8"))
+            sys.exit(1)
+          data = res.read()
+          json_body = json.loads(data.decode("utf-8"))
+        except Exception as e:
+          print("Error getting OAuth access tokens", e)
+          sys.exit(1)
+        finally:
+          if conn:
+            conn.close()
+      else:
+        # Fetch mock response
+        json_body = json.loads(self.oauth_mock_response)
+
+      if "access_token" in json_body.keys():
+        self.oauth = json_body["access_token"]
+      else:
+        print("Error: OAuth access token not found in json payload")
+        sys.exit(1)
 
 TIPS = [
   "Press TAB twice to see a list of available commands.",
@@ -2203,7 +2268,7 @@ def impala_shell_main():
     auth_method_count += 1
 
   if auth_method_count > 1:
-    print("Please specify at most one authentication mechanism (-k, -l, or 
-j)",
+    print("Please specify at most one authentication mechanism (-k, -l, -j, or 
-a)",
           file=sys.stderr)
     raise FatalShellException()
 
@@ -2256,6 +2321,11 @@ def impala_shell_main():
           file=sys.stderr)
     raise FatalShellException()
 
+  if not options.use_oauth and options.oauth_client_secret_cmd:
+    print("Option --oauth_client_secret_cmd requires using OAUTH 
authentication "
+          "mechanism (-a)", file=sys.stderr)
+    raise FatalShellException()
+
   if options.hs2_fp_format:
     try:
       _validate_hs2_fp_format_specification(options.hs2_fp_format)
@@ -2313,8 +2383,17 @@ def impala_shell_main():
     options.jwt = read_password_cmd(options.jwt_cmd, "JWT", True)
 
   options.oauth = None
-  if options.use_oauth and options.oauth_cmd:
-    options.oauth = read_password_cmd(options.oauth_cmd, "OAUTH", True)
+  options.oauth_client_secret = None
+  options.oauth_mock_response = None
+  if options.use_oauth:
+    if options.oauth_cmd:
+      options.oauth = read_password_cmd(options.oauth_cmd, "OAUTH", True)
+    elif options.oauth_client_secret_cmd:
+      options.oauth_client_secret = 
read_password_cmd(options.oauth_client_secret_cmd,
+          "OAuth client secret", True)
+    if options.oauth_mock_response_cmd:
+      options.oauth_mock_response = 
read_password_cmd(options.oauth_mock_response_cmd,
+          "OAuth mock response", True)
 
   if options.ssl:
     if options.ca_cert is None:
diff --git a/shell/impala_shell/option_parser.py 
b/shell/impala_shell/option_parser.py
index a06cb55c8..fe5f5ab78 100644
--- a/shell/impala_shell/option_parser.py
+++ b/shell/impala_shell/option_parser.py
@@ -235,7 +235,7 @@ def get_option_parser(defaults):
   parser.add_option("-a", "--oauth", dest="use_oauth",
                     action="store_true",
                     help="Use OAuth to authenticate with Impala. Impala must 
be"
-                    "configured to allow Oauth authentication. \t\t")
+                    "configured to allow OAuth authentication. \t\t")
   parser.add_option("-u", "--user", dest="user",
                     help="User to authenticate with.")
   parser.add_option("--ssl", dest="ssl",
@@ -275,10 +275,23 @@ def get_option_parser(defaults):
                     "unencrypted, and may be vulnerable to attack.")
   parser.add_option("--ldap_password_cmd", dest="ldap_password_cmd",
                     help="Shell command to run to retrieve the LDAP password")
+  parser.add_option("--oauth_client_id", dest="oauth_client_id",
+                    help="User to authenticate with OAuth auth server")
+  parser.add_option("--oauth_client_secret_cmd", 
dest="oauth_client_secret_cmd",
+                    help="Shell command to run to retrieve OAuth client 
secret")
   parser.add_option("--jwt_cmd", dest="jwt_cmd",
                     help="Shell command to run to retrieve the JWT")
   parser.add_option("--oauth_cmd", dest="oauth_cmd",
                     help="Shell command to run to retrieve the Oauth Token")
+  parser.add_option("--oauth_server", dest="oauth_server",
+                    help="OAuth Server url to get access and refresh tokens. 
Impala must"
+                    "be configured to allow OAuth authentication")
+  parser.add_option("--oauth_endpoint", dest="oauth_endpoint",
+                    help="OAuth Server endpoint to get access and refresh 
tokens. Impala"
+                    "must be configured to allow OAuth authentication")
+  # This option is used to create mock oauth auth server response for testing.
+  parser.add_option("--oauth_mock_response_cmd", 
dest="oauth_mock_response_cmd",
+                    help=SUPPRESS_HELP)
   parser.add_option("--var", dest="keyval", action="append",
                     help="Defines a variable to be used within the Impala 
session."
                          " Can be used multiple times to set different 
variables."
@@ -358,10 +371,11 @@ def get_option_parser(defaults):
                     "values when using the HS2 protocol. The default behaviour 
makes the "
                     "values handled by Python's str() built-in method. Use 
'16G' to "
                     "match the Beeswax protocol's floating-point output 
format.")
+  # When using the hs2-http protocol, set this value in the X-Forwarded-For 
header.
+  # This is primarily for testing purposes.
   parser.add_option("--hs2_x_forward", type="str",
                     dest="hs2_x_forward", default=None,
-                    help="When using the hs2-http protocol, set this value in 
the "
-                    "X-Forwarded-For header. This is primarily for testing 
purposes.")
+                    help=SUPPRESS_HELP)
   parser.add_option("--beeswax_compat_num_rows", 
dest="beeswax_compat_num_rows",
                     action="store_true",
                     help="If specified, always print num rows report at the 
end of query "
diff --git a/testdata/jwt/oauth_client_secret b/testdata/jwt/oauth_client_secret
new file mode 100644
index 000000000..34779ff15
--- /dev/null
+++ b/testdata/jwt/oauth_client_secret
@@ -0,0 +1 @@
+wZOQjHUCNyAEPkEhpQ0MTzcEnrX9fqQv
diff --git a/testdata/jwt/okta_oauth_payload_invalid 
b/testdata/jwt/okta_oauth_payload_invalid
new file mode 100644
index 000000000..fa3c05a59
--- /dev/null
+++ b/testdata/jwt/okta_oauth_payload_invalid
@@ -0,0 +1 @@
+{"missing_access_token": "no token", "expires_in": "86400", "token_type": 
"Bearer"}
diff --git a/testdata/jwt/okta_oauth_payload_valid 
b/testdata/jwt/okta_oauth_payload_valid
new file mode 100644
index 000000000..1297776b8
--- /dev/null
+++ b/testdata/jwt/okta_oauth_payload_valid
@@ -0,0 +1 @@
+{"access_token": 
"eyJhbGciOiJSUzI1NiIsImtpZCI6IjIwMjMwNTA5LTE2MDQxNSIsInR5cGUiOiJKV1QifQ.eyJhdWQiOiJpbXBhbGEtdGVzdHMiLCJleHAiOjE5OTkwMDgyNTUsImlhdCI6MTY4MzY0ODI1NSwiaXNzIjoiZmlsZTovL3Rlc3RzL3V0aWwvand0L2p3dF91dGlsLnB5Iiwia2lkIjoiMjAyMzA1MDktMTYwNDE1Iiwic3ViIjoidGVzdC11c2VyIn0.dWMOkcBrwRansZrCZrlbYzr9alIQ23qlnw4t8Kx_v87CBB90qtmTV88nZAh4APtTE8IUnP0e45R2XyDoH3a8UVrrSOkEzI47wJ0I3GqSc_R_MsGoeGlKreZmcjGhY_ceOo7RWYaBdzsAZe1YXcKJbq2sQJ3issfjBa_fWt0Qhy0DvzssUf3V-g5nQUM3W3pOULiFtMhA8YmIdheHalRz3D_
 [...]
diff --git a/tests/custom_cluster/test_shell_jwt_auth.py 
b/tests/custom_cluster/test_shell_jwt_auth.py
index 0bfd25479..f37a9ab7c 100644
--- a/tests/custom_cluster/test_shell_jwt_auth.py
+++ b/tests/custom_cluster/test_shell_jwt_auth.py
@@ -47,6 +47,11 @@ class TestImpalaShellJWTAuth(CustomClusterTestSuite):
                   "-jwt_token_auth=true -jwt_allow_without_tls=true "
                   .format(JWKS_JSON_PATH))
 
+  IMPALAD_OAUTH_ARGS = ("-v 2 -oauth_jwks_file_path={0} -oauth_token_auth=true 
"
+                        "-oauth_jwt_custom_claim_username=sub "
+                        "-oauth_allow_without_tls=true "
+                  .format(JWKS_JSON_PATH))
+
   # Name of the Impala metric containing the total count of hs2-http 
connections opened.
   HS2_HTTP_CONNS = 
"impala.thrift-server.hiveserver2-http-frontend.total-connections"
 
diff --git a/tests/custom_cluster/test_shell_oauth_auth.py 
b/tests/custom_cluster/test_shell_oauth_auth.py
index bb559a43c..4f6f788e7 100644
--- a/tests/custom_cluster/test_shell_oauth_auth.py
+++ b/tests/custom_cluster/test_shell_oauth_auth.py
@@ -41,11 +41,19 @@ class TestImpalaShellOAuthAuth(CustomClusterTestSuite):
   OAUTH_SIGNED_PATH = os.path.join(JWKS_JWTS_DIR, 'jwt_signed')
   OAUTH_EXPIRED_PATH = os.path.join(JWKS_JWTS_DIR, 'jwt_expired')
   OAUTH_INVALID_JWK = os.path.join(JWKS_JWTS_DIR, 'jwt_signed_untrusted')
+  OAUTH_CLIENT_SECRET = os.path.join(JWKS_JWTS_DIR, 'oauth_client_secret')
+  OAUTH_VALID_PAYLOAD = os.path.join(JWKS_JWTS_DIR, 'okta_oauth_payload_valid')
+  OAUTH_INVALID_PAYLOAD = os.path.join(JWKS_JWTS_DIR, 
'okta_oauth_payload_invalid')
 
   IMPALAD_ARGS = ("-v 2 -oauth_jwks_file_path={0} 
-oauth_jwt_custom_claim_username=sub "
     "-oauth_token_auth=true -oauth_allow_without_tls=true "
     .format(JWKS_JSON_PATH))
 
+  IMPALAD_OAUTH_ARGS = ("-v 2 -oauth_jwks_file_path={0} -oauth_token_auth=true 
"
+                        "-oauth_jwt_custom_claim_username=sub "
+                        "-oauth_allow_without_tls=true "
+                        .format(JWKS_JSON_PATH))
+
   # Name of the Impala metric containing the total count of hs2-http 
connections opened.
   HS2_HTTP_CONNS = 
"impala.thrift-server.hiveserver2-http-frontend.total-connections"
 
@@ -194,8 +202,88 @@ class TestImpalaShellOAuthAuth(CustomClusterTestSuite):
     assert "HTTP code 401: Unauthorized" in result.stderr
     assert "Not connected to Impala, could not execute queries." in 
result.stderr
 
+  @pytest.mark.execute_serially
+  @CustomClusterTestSuite.with_args(
+    impalad_args=IMPALAD_OAUTH_ARGS,
+    impala_log_dir="{oauth_auth_success}",
+    tmp_dir_placeholders=["oauth_auth_success"],
+    disable_log_buffering=True,
+    cluster_size=1)
+  def test_oauth_auth_with_clientid_and_secret_success(self, vector):
+    """Asserts the Impala shell can authenticate to Impala using OAuth 
authentication.
+    Also executes a query to ensure the authentication was successful."""
+    # Run a query and wait for it to complete.
+    args = ['--protocol', vector.get_value('protocol'), '-a',
+            '--oauth_client_id', 'oauth-test-user', 
'--oauth_client_secret_cmd',
+            'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_CLIENT_SECRET),
+            '--oauth_server', 'localhost:8000',
+            '--oauth_endpoint', '/oauth/token',
+            '--oauth_mock_response_cmd',
+            'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_VALID_PAYLOAD),
+            '-q', 'select version()', '--auth_creds_ok_in_clear']
+    result = run_impala_shell_cmd(vector, args)
+
+    # Shut down cluster to ensure logs flush to disk.
+    self._stop_impala_cluster()
+
+    # Ensure OAuth auth was enabled by checking the coordinator startup flags 
logged
+    # in the coordinator's INFO logfile
+    self.assert_impalad_log_contains("INFO",
+        '--oauth_jwks_file_path={0}'.format(self.JWKS_JSON_PATH), 
expected_count=1)
+    # Ensure OAuth auth was successful by checking impala coordinator logs
+    self.assert_impalad_log_contains("INFO",
+        'effective username: test-user', expected_count=1)
+    self.assert_impalad_log_contains("INFO",
+        r'connected_user \(string\) = "test-user"', expected_count=1)
+
+    # Ensure the query ran successfully.
+    assert "version()" in result.stdout
+    assert "impalad version" in result.stdout
+
+  @pytest.mark.execute_serially
+  @CustomClusterTestSuite.with_args(
+    impalad_args=IMPALAD_OAUTH_ARGS,
+    impala_log_dir="{oauth_auth_failure}",
+    tmp_dir_placeholders=["oauth_auth_failure"],
+    disable_log_buffering=True,
+    cluster_size=1)
+  def test_oauth_auth_with_clientid_and_secret_failure(self, vector):
+    """Asserts the Impala shell fails to authenticate with Impala if it can't
+    retrieve the OAuth access token from the OAuth Server."""
+    # Run a query and wait for it to complete.
+    args = ['--protocol', vector.get_value('protocol'), '-a',
+            '--oauth_client_id', 'oauth-test-user', 
'--oauth_client_secret_cmd',
+            'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_CLIENT_SECRET),
+            '--oauth_server', 'localhost:8000',
+            '--oauth_endpoint', '/oauth/token',
+            '--oauth_mock_response_cmd',
+            'cat {0}'.format(TestImpalaShellOAuthAuth.OAUTH_INVALID_PAYLOAD),
+            '-q', 'select version()', '--auth_creds_ok_in_clear']
+    result = run_impala_shell_cmd(vector, args, expect_success=False)
+
+    # Since Impala failed to authenticate, both success and fail count
+    # will be zero.
+    self.__assert_success_fail_metric(success_count=0, fail_count=0)
+
+    # Shut down cluster to ensure logs flush to disk.
+    self._stop_impala_cluster()
+
+    # Ensure OAuth auth was enabled by checking the coordinator startup flags 
logged
+    # in the coordinator's INFO logfile
+    self.assert_impalad_log_contains("INFO",
+        '--oauth_jwks_file_path={0}'.format(self.JWKS_JSON_PATH), 
expected_count=1)
+    # Ensure OAuth auth was not received in the logs
+    self.assert_impalad_log_contains("INFO",
+        'effective username: test-user', expected_count=0)
+    self.assert_impalad_log_contains("INFO",
+        r'connected_user \(string\) = "test-user"', expected_count=0)
+
+    # Ensure the query did not run successfully.
+    assert "version()" not in result.stdout
+    assert "impalad version" not in result.stdout
+
   def __assert_success_fail_metric(self, success_count=0, fail_count=0):
-    """Impala emits metrics that count the number of successful and failed 
OAUth
+    """Impala emits metrics that count the number of successful and failed 
OAuth
     authentications. This function asserts the OAuth auth success/fail 
counters from the
     coordinator match the expected values."""
     actual = self.cluster.get_first_impalad().service.get_metric_values([
diff --git a/tests/shell/test_shell_commandline_jwt_auth.py 
b/tests/shell/test_shell_commandline_jwt_auth.py
index d21fe10a2..5dc2412a9 100644
--- a/tests/shell/test_shell_commandline_jwt_auth.py
+++ b/tests/shell/test_shell_commandline_jwt_auth.py
@@ -78,23 +78,23 @@ class TestImpalaShellJwtAuth(ImpalaTestSuite):
   def test_multiple_auth_ldap_jwt(self, vector):
     """Asserts that ldap and jwt auth cannot both be enabled."""
     result = run_impala_shell_cmd(vector, ['-l', '-j'], expect_success=False)
-    assert "Please specify at most one authentication mechanism (-k, -l, or 
-j)" \
+    assert "Please specify at most one authentication mechanism (-k, -l, -j, 
or -a)" \
            in result.stderr
 
   def test_multiple_auth_ldap_kerberos(self, vector):
     """Asserts that ldap and kerberos auth cannot both be enabled."""
     result = run_impala_shell_cmd(vector, ['-l', '-k'], expect_success=False)
-    assert "Please specify at most one authentication mechanism (-k, -l, or 
-j)" \
+    assert "Please specify at most one authentication mechanism (-k, -l, -j, 
or -a)" \
            in result.stderr
 
   def test_multiple_auth_jwt_kerberos(self, vector):
     """Asserts that jwt and kerberos auth cannot both be enabled."""
     result = run_impala_shell_cmd(vector, ['-j', '-k'], expect_success=False)
-    assert "Please specify at most one authentication mechanism (-k, -l, or 
-j)" \
+    assert "Please specify at most one authentication mechanism (-k, -l, -j, 
or -a)" \
            in result.stderr
 
   def test_multiple_auth_ldap_jwt_kerberos(self, vector):
     """Asserts ldap, jwt, and kerberos auth cannot all be enabled."""
     result = run_impala_shell_cmd(vector, ['-l', '-j', '-k'], 
expect_success=False)
-    assert "Please specify at most one authentication mechanism (-k, -l, or 
-j)" \
+    assert "Please specify at most one authentication mechanism (-k, -l, -j, 
or -a)" \
            in result.stderr

Reply via email to