This is an automated email from the ASF dual-hosted git repository.
arawat pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git
The following commit(s) were added to refs/heads/master by this push:
new c044bdd49 IMPALA-14083: Connected user and session user mismatch when
cookie based authentication is used with SPNEGO
c044bdd49 is described below
commit c044bdd49d20a83de3c04f6eeb9f2477eeee4815
Author: Abhishek Rawat <[email protected]>
AuthorDate: Thu Jun 5 16:59:11 2025 -0700
IMPALA-14083: Connected user and session user mismatch when cookie based
authentication is used with SPNEGO
IMPALA-11298 allowed comparing short user name for connected user and
session user to support proxy clients like Hue which could potentially
use different physical hosts for queries/requests from the same session.
When cookie based authentication is used, the 'kerberos_user_short' is
not set on the ConnectionContext and as a result 'connected_user_short'
is not set in SessionState. This can cause a mismatch when comparing
short user names from ConnectionContext and SessionState. This happens
because the original connection authenticated using SPNEGO will have
'kerberos_user_short' in the ConnectionContext, while the other
connections authenticated using cookies won't have 'kerberos_user_short'
set in the ConnectionContext.
This patch addresses this issue by setting 'kerberos_user_short' in
ConnectionContext, when using auth cookies generated post SPNEGO. This
information is retrieved from 'impala.auth' cookie itself, which now
also stores the 'a=<AUTH_MECHANISM>' in the cookie's value.
Testing:
- Added a SpnegoAuthTest which simulates 'knox' like proxy client and
uses SPNEGO to connect to Impala and also uses authentication cookies.
The test runs concurrent sql clients similar to real world scenarios.
Without the fix the test fails with error:
The user authorized on the connection '<username>' does not match the
session username ''
Change-Id: Id7223e449c32484bfd2295f7a9e728b7c02637e9
Reviewed-on: http://gerrit.cloudera.org:8080/22986
Tested-by: Impala Public Jenkins <[email protected]>
Reviewed-by: Jason Fehr <[email protected]>
---
be/src/rpc/authentication-util.cc | 15 +-
be/src/rpc/authentication-util.h | 23 +-
be/src/rpc/authentication.cc | 31 +-
be/src/util/webserver-test.cc | 32 +-
be/src/util/webserver.cc | 23 +-
be/src/util/webserver.h | 3 +-
.../customcluster/KerberosKdcEnvironment.java | 4 +
.../impala/customcluster/SpnegoAuthTest.java | 415 ++++++++++++++++++++
.../impala/customcluster/SpnegoTokenGenerator.java | 82 ++++
.../customcluster/THttpClientWithHeaders.java | 427 +++++++++++++++++++++
10 files changed, 1023 insertions(+), 32 deletions(-)
diff --git a/be/src/rpc/authentication-util.cc
b/be/src/rpc/authentication-util.cc
index 05ddc0982..85382891e 100644
--- a/be/src/rpc/authentication-util.cc
+++ b/be/src/rpc/authentication-util.cc
@@ -51,6 +51,7 @@ static const string COOKIE_SEPARATOR = "&";
static const string USERNAME_KEY = "u=";
static const string TIMESTAMP_KEY = "t=";
static const string RAND_KEY = "r=";
+static const string AUTH_MECH_KEY = "a=";
// Cookies generated and processed by the HTTP server will be of the form:
// COOKIE_NAME=<cookie>
@@ -68,7 +69,7 @@ static const int MAX_COOKIES_TO_CHECK = 5;
Status AuthenticateCookie(
const AuthenticationHash& hash, const string& cookie_header,
- string* username, string* rand) {
+ string* username, string* authMech, string* rand) {
// The 'Cookie' header allows sending multiple name/value pairs separated by
';'.
vector<string> cookies = strings::Split(cookie_header, ";");
if (cookies.size() > MAX_COOKIES_TO_CHECK) {
@@ -107,9 +108,9 @@ Status AuthenticateCookie(
return Status("The signature is incorrect.");
}
- // Split the cookie value into username, timestamp, and random number.
+ // Split the cookie value into username, timestamp, random number and auth
mechanism.
vector<string> cookie_value_split = Split(cookie_value, COOKIE_SEPARATOR);
- if (cookie_value_split.size() != 3) {
+ if (cookie_value_split.size() != 4) {
return Status("The cookie value has an invalid format.");
}
string timestamp;
@@ -133,6 +134,9 @@ Status AuthenticateCookie(
return Status("The cookie rand value has an invalid format.");
}
}
+ if (!TryStripPrefixString(cookie_value_split[3], AUTH_MECH_KEY,
authMech)) {
+ return Status("The cookie authMech value has an invalid format.");
+ }
// We've successfully authenticated.
return Status::OK();
} else {
@@ -144,7 +148,7 @@ Status AuthenticateCookie(
}
string GenerateCookie(const string& username, const AuthenticationHash& hash,
- std::string* srand) {
+ const std::string& authMech, std::string* srand) {
// Its okay to use rand() here even though its a weak RNG because being able
to guess
// the random numbers generated won't help an attacker. The important thing
is that
// we're using a strong RNG to create the key and a strong HMAC function.
@@ -154,7 +158,8 @@ string GenerateCookie(const string& username, const
AuthenticationHash& hash,
*srand = cookie_rand_s;
}
string cookie_value = StrCat(USERNAME_KEY, username, COOKIE_SEPARATOR,
TIMESTAMP_KEY,
- MonotonicMillis(), COOKIE_SEPARATOR, RAND_KEY, cookie_rand_s);
+ MonotonicMillis(), COOKIE_SEPARATOR, RAND_KEY, cookie_rand_s,
COOKIE_SEPARATOR,
+ AUTH_MECH_KEY, authMech);
uint8_t signature[AuthenticationHash::HashLen()];
Status compute_status =
hash.Compute(reinterpret_cast<const uint8_t*>(cookie_value.data()),
diff --git a/be/src/rpc/authentication-util.h b/be/src/rpc/authentication-util.h
index eec497612..37a62b3cb 100644
--- a/be/src/rpc/authentication-util.h
+++ b/be/src/rpc/authentication-util.h
@@ -23,17 +23,30 @@ namespace impala {
class AuthenticationHash;
+// HTTP Auth Mechanisms used for generating Auth Cookies
+inline const std::string HTTP_AUTH_MECH_HTPASSWD = "HTPASSWD";
+inline const std::string HTTP_AUTH_MECH_LDAP = "LDAP";
+inline const std::string HTTP_AUTH_MECH_TRUSTED_DOMAIN = "TRUSTED_DOMAIN";
+inline const std::string HTTP_AUTH_MECH_TRUSTED_HEADER = "TRUSTED_HEADER";
+inline const std::string HTTP_AUTH_MECH_SPNEGO = "SPNEGO";
+inline const std::string HTTP_AUTH_MECH_SAML = "SAML";
+inline const std::string HTTP_AUTH_MECH_JWT = "JWT";
+inline const std::string HTTP_AUTH_MECH_OAUTH = "OAUTH";
+
// Takes a single 'key=value' pair from a 'Cookie' header and attempts to
verify its
// signature with 'hash'. If verification is successful and the cookie is
still valid,
-// sets 'username' and 'rand' (if specified) to the corresponding values and
returns OK.
+// sets 'username', 'authMech' and 'rand' (if specified) to the corresponding
values
+// and returns OK.
Status AuthenticateCookie(
const AuthenticationHash& hash, const std::string& cookie_header,
- std::string* username, std::string* rand = nullptr);
+ std::string* username, std::string* authMech, std::string* rand = nullptr);
-// Generates and returns a cookie containing the username set on
'connection_context' and
-// a signature generated with 'hash'. If specified, sets 'rand' to the 'r='
cookie value.
+// Generates and returns a cookie containing the username set on
'connection_context',
+// a signature generated with 'hash' and the authentication mechanism used for
+// authenticating the given user with username. If specified, sets 'rand' to
the 'r='
+// cookie value.
std::string GenerateCookie(const std::string& username, const
AuthenticationHash& hash,
- std::string* rand = nullptr);
+ const std::string& authMech, std::string* rand = nullptr);
// Returns a empty cookie. Returned in a 'Set-Cookie' when cookie auth fails
to indicate
// to the client that the cookie should be deleted.
diff --git a/be/src/rpc/authentication.cc b/be/src/rpc/authentication.cc
index 7f1739385..8a879d57f 100644
--- a/be/src/rpc/authentication.cc
+++ b/be/src/rpc/authentication.cc
@@ -662,9 +662,18 @@ static int SaslGetPath(void* context, const char** path) {
bool CookieAuth(ThriftServer::ConnectionContext* connection_context,
const AuthenticationHash& hash, const std::string& cookie_header) {
string username;
- Status cookie_status = AuthenticateCookie(hash, cookie_header, &username);
+ string authMech;
+ Status cookie_status = AuthenticateCookie(hash, cookie_header, &username,
&authMech);
if (cookie_status.ok()) {
connection_context->username = username;
+ if (authMech == HTTP_AUTH_MECH_SPNEGO) {
+ connection_context->kerberos_user_principal = username;
+ connection_context->kerberos_user_short =
+ GetShortUsernameFromKerberosPrincipal(username);
+ VLOG(2) << "Connection authenticated with "
+ << "short username \"" <<
connection_context->kerberos_user_short << "\" "
+ << "parsed from principal \"" << username << "\" ";
+ }
return true;
}
@@ -723,7 +732,9 @@ static bool
TrustedDomainCheck(ThriftServer::ConnectionContext* connection_conte
}
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0",
GenerateCookie(connection_context->username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(connection_context->username, hash,
+ HTTP_AUTH_MECH_TRUSTED_DOMAIN)));
return true;
}
@@ -734,7 +745,9 @@ static bool
HandleTrustedAuthHeader(ThriftServer::ConnectionContext* connection_
}
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0",
GenerateCookie(connection_context->username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(connection_context->username, hash,
+ HTTP_AUTH_MECH_TRUSTED_HEADER)));
return true;
}
@@ -755,7 +768,8 @@ bool BasicAuth(ThriftServer::ConnectionContext*
connection_context,
connection_context->username = username;
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(username, hash, HTTP_AUTH_MECH_LDAP)));
if (!FLAGS_test_cookie.empty()) {
connection_context->return_headers.push_back(
Substitute("Set-Cookie: $0", FLAGS_test_cookie));
@@ -803,7 +817,7 @@ error_description=\"$0 \"", status.GetDetail()));
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0", GenerateCookie(username, hash,
HTTP_AUTH_MECH_JWT)));
return true;
}
@@ -845,7 +859,7 @@ error_description=\"$0 \"", status.GetDetail()));
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0", GenerateCookie(username, hash,
HTTP_AUTH_MECH_OAUTH)));
return true;
}
@@ -913,7 +927,8 @@ bool NegotiateAuth(ThriftServer::ConnectionContext*
connection_context,
connection_context->kerberos_user_short = short_user;
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0",
+ GenerateCookie(username, hash, HTTP_AUTH_MECH_SPNEGO)));
}
}
} else {
@@ -1035,7 +1050,7 @@ bool ValidateSaml2Bearer(ThriftServer::ConnectionContext*
connection_context,
connection_context->username = username;
// Create a cookie to return.
connection_context->return_headers.push_back(
- Substitute("Set-Cookie: $0", GenerateCookie(username, hash)));
+ Substitute("Set-Cookie: $0", GenerateCookie(username, hash,
HTTP_AUTH_MECH_SAML)));
return true;
}
diff --git a/be/src/util/webserver-test.cc b/be/src/util/webserver-test.cc
index 7123dba12..ad8d44c8e 100644
--- a/be/src/util/webserver-test.cc
+++ b/be/src/util/webserver-test.cc
@@ -460,20 +460,37 @@ public:
const filesystem::path& path() { return path_; }
string token() {
const char* rand_key = "&r=";
+ const char* auth_mech_key = "&a=";
string rand, line;
ifstream cookie_file(path_.string());
while (cookie_file) {
getline(cookie_file, line);
size_t rand_idx = line.rfind(rand_key);
- if (rand_idx != string::npos) {
- // Relies on the random value being the last element in the cookie.
- rand = line.substr(rand_idx + strlen(rand_key));
+ size_t auth_mech_idx = line.rfind(auth_mech_key);
+ if (rand_idx != string::npos && auth_mech_idx != string::npos) {
+ // Relies on the random value being followed by auth mech in the
cookie.
+ size_t rand_val_idx = rand_idx + strlen(rand_key);
+ rand = line.substr(rand_val_idx, auth_mech_idx - rand_val_idx);
break;
}
}
return rand;
}
-
+ string auth_mech() {
+ const char* auth_mech_key = "&a=";
+ string authmech, line;
+ ifstream cookie_file(path_.string());
+ while (cookie_file) {
+ getline(cookie_file, line);
+ size_t auth_mech_idx = line.rfind(auth_mech_key);
+ if (auth_mech_idx != string::npos) {
+ // Relies on auth mech being the last element in the cookie.
+ authmech = line.substr(auth_mech_idx + strlen(auth_mech_key));
+ break;
+ }
+ }
+ return authmech;
+ }
private:
const filesystem::path dir_, path_;
};
@@ -531,6 +548,8 @@ TEST(Webserver, TestGetWithSpnego) {
// curl does not do the initial attempt without authentication, so there
is no
// additional failed auth attempt.
CheckAuthMetrics(&metrics, 1, (curl_7_64_or_above ? 1 : 2), 1, 0);
+ // Validate authentication mechanism stored in the cookie
+ ASSERT_EQ(cookie.auth_mech(), "SPNEGO");
webserver.Stop();
MetricGroup metrics2("webserver-test");
@@ -578,6 +597,8 @@ TEST(Webserver, TestPostWithSpnego) {
CookieJar cookie;
// GET with SPNEGO succeeds and returns a cookie.
ASSERT_EQ(system(curl("--negotiate -u : -c " +
cookie.path().string()).c_str()), 0);
+ // Validate authentication mechanism stored in the cookie.
+ ASSERT_EQ(cookie.auth_mech(), "SPNEGO");
// Verify we got a cookie and can read the random token.
string token = cookie.token();
ASSERT_FALSE(token.empty());
@@ -629,13 +650,14 @@ TEST(Webserver, StartWithPasswordFileTest) {
// GET with user and password succeeds and returns a cookie.
ASSERT_EQ(system(curl(Substitute("--digest -u test:test -c $0",
cookie.path().string())).c_str()), 0);
+ // Validate authentication mechanism stored in the cookie
+ ASSERT_EQ(cookie.auth_mech(), "HTPASSWD");
// Verify we got a cookie and can read the random token.
string token = cookie.token();
ASSERT_FALSE(token.empty());
// Post with the cookie fails due to CSRF protection.
ASSERT_EQ(curl_status_code(Substitute("--digest -u test:test -b $0 -d ''",
cookie.path().string()).c_str()), "403");
-
// Include the cookie's random token as csrf_token and request should
succeed.
ASSERT_EQ(system(curl(Substitute("--digest -u test:test -b $0 -d
'csrf_token=$1'",
cookie.path().string(), token)).c_str()), 0);
diff --git a/be/src/util/webserver.cc b/be/src/util/webserver.cc
index 13923a103..70b190d1a 100644
--- a/be/src/util/webserver.cc
+++ b/be/src/util/webserver.cc
@@ -738,9 +738,11 @@ sq_callback_result_t
Webserver::BeginRequestCallback(struct sq_connection* conne
if (!authenticated && use_cookies_) {
const char* cookie_header = sq_get_header(connection, "Cookie");
string username;
+ string auth_mech;
if (cookie_header != nullptr) {
Status cookie_status =
- AuthenticateCookie(hash_, cookie_header, &username,
&cookie_rand_value);
+ AuthenticateCookie(hash_, cookie_header, &username, &auth_mech,
+ &cookie_rand_value);
if (cookie_status.ok()) {
authenticated = true;
cookie_authenticated = true;
@@ -760,7 +762,8 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct
sq_connection* conne
// as browsers automatically include HTPASSWD credentials in requests, so
add and use
// cookies to avoid requiring the custom header.
authenticated = true;
- AddCookie(request_info->remote_user, &response_headers,
&cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
HTTP_AUTH_MECH_HTPASSWD,
+ &cookie_rand_value);
}
// Connections originating from trusted domains should not require
authentication.
@@ -788,7 +791,8 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct
sq_connection* conne
if (TrustedDomainCheck(origin, connection, request_info)) {
total_trusted_domain_check_success_->Increment(1);
authenticated = true;
- AddCookie(request_info->remote_user, &response_headers,
&cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_TRUSTED_DOMAIN, &cookie_rand_value);
}
}
}
@@ -801,7 +805,8 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct
sq_connection* conne
if (GetUsernameFromAuthHeader(connection, request_info, err_msg)) {
total_trusted_auth_header_check_success_->Increment(1);
authenticated = true;
- AddCookie(request_info->remote_user, &response_headers,
&cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_TRUSTED_HEADER, &cookie_rand_value);
} else {
LOG(ERROR) << "Found trusted auth header but " << err_msg;
}
@@ -814,7 +819,8 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct
sq_connection* conne
HandleSpnego(connection, request_info, &response_headers);
if (spnego_result == SQ_CONTINUE_HANDLING) {
// Spnego negotiation was successful.
- AddCookie(request_info->remote_user, &response_headers,
&cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_SPNEGO, &cookie_rand_value);
} else {
// Spnego negotiation is incomplete or failed, stop processing the
request.
return spnego_result;
@@ -825,7 +831,8 @@ sq_callback_result_t Webserver::BeginRequestCallback(struct
sq_connection* conne
if (basic_status.ok()) {
// Basic auth was successful.
total_basic_auth_success_->Increment(1);
- AddCookie(request_info->remote_user, &response_headers,
&cookie_rand_value);
+ AddCookie(request_info->remote_user, &response_headers,
+ HTTP_AUTH_MECH_LDAP, &cookie_rand_value);
} else {
total_basic_auth_failure_->Increment(1);
if (!sq_get_header(connection, "Authorization")) {
@@ -1158,7 +1165,7 @@ Status Webserver::HandleBasic(struct sq_connection*
connection,
}
void Webserver::AddCookie(const char* user, vector<string>* response_headers,
- string* cookie_rand_value) {
+ const string& authMech, string* cookie_rand_value) {
if (use_cookies_) {
// If cookie auth failed and we generated a 'delete cookie' header, remove
it.
auto eq = [](const string& header) { return header.rfind("Set-Cookie", 0)
== 0; };
@@ -1168,7 +1175,7 @@ void Webserver::AddCookie(const char* user,
vector<string>* response_headers,
}
// Generate a cookie to return.
response_headers->push_back(Substitute("Set-Cookie: $0",
- GenerateCookie(user, hash_, cookie_rand_value)));
+ GenerateCookie(user, hash_, authMech, cookie_rand_value)));
}
}
diff --git a/be/src/util/webserver.h b/be/src/util/webserver.h
index bfa71f171..eb34c317a 100644
--- a/be/src/util/webserver.h
+++ b/be/src/util/webserver.h
@@ -221,7 +221,8 @@ class Webserver {
// Adds a 'Set-Cookie' header to 'response_headers', if cookie support is
enabled.
// Returns the random value portion of the cookie in 'rand' for use in CSRF
prevention.
- void AddCookie(const char* user, vector<string>* response_headers, string*
rand);
+ void AddCookie(const char* user, vector<string>* response_headers,
+ const string& authMech, string* rand);
// Get username from Authorization header.
bool GetUsernameFromAuthHeader(struct sq_connection* connection,
diff --git
a/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
b/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
index e43d6ed69..d62202862 100644
---
a/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
+++
b/fe/src/test/java/org/apache/impala/customcluster/KerberosKdcEnvironment.java
@@ -53,6 +53,10 @@ class KerberosKdcEnvironment extends ExternalResource {
this.testFolder = testFolder;
}
+ public String getTestFolderPath() throws IOException {
+ return testFolder.getRoot().getCanonicalPath();
+ }
+
@Override
protected void before() throws Throwable {
testFolder.create();
diff --git
a/fe/src/test/java/org/apache/impala/customcluster/SpnegoAuthTest.java
b/fe/src/test/java/org/apache/impala/customcluster/SpnegoAuthTest.java
new file mode 100644
index 000000000..09aec13db
--- /dev/null
+++ b/fe/src/test/java/org/apache/impala/customcluster/SpnegoAuthTest.java
@@ -0,0 +1,415 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.impala.customcluster;
+
+import static org.apache.impala.testutil.LdapUtil.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.directory.server.core.annotations.CreateDS;
+import org.apache.directory.server.core.annotations.CreatePartition;
+import org.apache.directory.server.annotations.CreateLdapServer;
+import org.apache.directory.server.annotations.CreateTransport;
+import org.apache.directory.server.core.annotations.ApplyLdifFiles;
+import org.apache.directory.server.core.integ.CreateLdapServerRule;
+import org.apache.hive.service.rpc.thrift.*;
+import org.apache.impala.testutil.WebClient;
+import org.apache.thrift.transport.THttpClient;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.ietf.jgss.*;
+import org.junit.ClassRule;
+import org.junit.rules.TemporaryFolder;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@CreateDS(name = "myDS",
+ partitions = { @CreatePartition(name = "test", suffix = "dc=myorg,dc=com")
})
+@CreateLdapServer(
+ transports = { @CreateTransport(protocol = "LDAP", address = "localhost")
})
+@ApplyLdifFiles({"users.ldif"})
+/**
+ * Tests that hiveserver2 operations over the http interface work as expected
when
+ * SPNEGO authentication is being used.
+ */
+public class SpnegoAuthTest {
+ private static final Logger LOG =
LoggerFactory.getLogger(SpnegoAuthTest.class);
+
+ @ClassRule
+ public static CreateLdapServerRule serverRule = new CreateLdapServerRule();
+ @ClassRule
+ public static KerberosKdcEnvironment kerberosKdcEnvironment =
+ new KerberosKdcEnvironment(new TemporaryFolder());
+
+ WebClient client_ = new WebClient();
+
+ protected Map<String, String> getLdapFlags() {
+ String ldapUri = String.format("ldap://localhost:%s",
+ serverRule.getLdapServer().getPort());
+ String passwordCommand = String.format("'echo -n %s'", TEST_PASSWORD_1);
+ return ImmutableMap.<String, String>builder()
+ .put("enable_ldap_auth", "true")
+ .put("ldap_uri", ldapUri)
+ .put("ldap_bind_pattern", "cn=#UID,ou=Users,dc=myorg,dc=com")
+ .put("ldap_passwords_in_clear_ok", "true")
+ .put("ldap_bind_dn", TEST_USER_DN_1)
+ .put("ldap_bind_password_cmd", passwordCommand)
+ .build();
+ }
+
+ protected int startImpalaCluster(String args) throws IOException,
InterruptedException {
+ return kerberosKdcEnvironment.startImpalaClusterWithArgs(args);
+ }
+
+ public static String flagsToArgs(Map<String, String> flags) {
+ return flags.entrySet().stream()
+ .map(entry -> "--" + entry.getKey() + "=" + entry.getValue() + " ")
+ .collect(Collectors.joining());
+ }
+
+ @SafeVarargs
+ public static Map<String, String> mergeFlags(Map<String, String>... flags) {
+ return Arrays.stream(flags)
+ .filter(Objects::nonNull)
+ .flatMap(map -> map.entrySet().stream())
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ }
+
+ static void verifySuccess(TStatus status) throws Exception {
+ if (status.getStatusCode() == TStatusCode.SUCCESS_STATUS
+ || status.getStatusCode() == TStatusCode.SUCCESS_WITH_INFO_STATUS) {
+ return;
+ }
+ throw new Exception(status.toString());
+ }
+
+ /**
+ * Executes 'query', fetches the results and closes the 'query'. Expects
there to be
+ * exactly one string returned, which be be equal to 'expectedResult'.
+ */
+ static void execAndFetch(TCLIService.Iface client,
+ TSessionHandle sessionHandle, String query, String expectedResult)
+ throws Exception {
+ TOperationHandle handle = null;
+ try {
+ TExecuteStatementReq execReq = new TExecuteStatementReq(sessionHandle,
query);
+ TExecuteStatementResp execResp = client.ExecuteStatement(execReq);
+ verifySuccess(execResp.getStatus());
+ handle = execResp.getOperationHandle();
+
+ TFetchResultsReq fetchReq = new TFetchResultsReq(
+ handle, TFetchOrientation.FETCH_NEXT, 1000);
+ TFetchResultsResp fetchResp = client.FetchResults(fetchReq);
+ verifySuccess(fetchResp.getStatus());
+ List<TColumn> columns = fetchResp.getResults().getColumns();
+ assertEquals(1, columns.size());
+ if (expectedResult != null) {
+ assertEquals(expectedResult,
columns.get(0).getStringVal().getValues().get(0));
+ }
+ } finally {
+ if (handle != null) {
+ TCloseOperationReq closeReq = new TCloseOperationReq(handle);
+ TCloseOperationResp closeResp = client.CloseOperation(closeReq);
+ verifySuccess(closeResp.getStatus());
+ }
+ }
+ }
+
+ private void verifyNegotiateAuthMetrics(
+ long expectedBasicAuthSuccess, long expectedBasicAuthFailure) throws
Exception {
+ long actualBasicAuthSuccess = (long) client_.getMetric(
+
"impala.thrift-server.hiveserver2-http-frontend.total-negotiate-auth-success");
+ assertEquals(expectedBasicAuthSuccess, actualBasicAuthSuccess);
+ long actualBasicAuthFailure = (long) client_.getMetric(
+
"impala.thrift-server.hiveserver2-http-frontend.total-negotiate-auth-failure");
+ assertEquals(expectedBasicAuthFailure, actualBasicAuthFailure);
+ }
+
+ private void verifyCookieAuthMetrics(
+ long expectedCookieAuthSuccess, long expectedCookieAuthFailure) throws
Exception {
+ long actualCookieAuthSuccess = (long) client_.getMetric(
+
"impala.thrift-server.hiveserver2-http-frontend.total-cookie-auth-success");
+ assertEquals(expectedCookieAuthSuccess, actualCookieAuthSuccess);
+ long actualCookieAuthFailure = (long) client_.getMetric(
+
"impala.thrift-server.hiveserver2-http-frontend.total-cookie-auth-failure");
+ assertEquals(expectedCookieAuthFailure, actualCookieAuthFailure);
+ }
+
+ @Test
+ /**
+ * Tests Authentication flow using a proxy client such as Knox, which uses
SPNEGO Auth
+ * to connect to Impala and impersonates other users. Initial Authentication
is done
+ * through SPNEGO and follow on requests are authenticated using Auth
cookies. The test
+ * uses multiple clients sharing the same Auth cookie similar to what a
proxy client
+ * would do and as a result adds coverage for interesting scenarios where
OpenSession
+ * RPC could also use Auth Cookies.
+ */
+ public void testImpersonation() throws Exception, Throwable {
+ Map<String, String> flags = mergeFlags(
+ // enable Kerberos authentication
+ kerberosKdcEnvironment.getKerberosAuthFlags(),
+ getLdapFlags(),
+ // custom LDAP filters
+ ImmutableMap.of(
+ "ldap_group_filter", String.format("%s,another-group",
TEST_USER_GROUP),
+ "ldap_user_filter", String.format("%s,%s,another-user",
+ TEST_USER_1, TEST_USER_3),
+ "ldap_group_dn_pattern", GROUP_DN_PATTERN,
+ "ldap_group_membership_key", "uniqueMember",
+ "ldap_group_class_key", "groupOfUniqueNames",
+ "allow_custom_ldap_filters_with_kerberos_auth", "true",
+ // set proxy user: allow TEST_USER_4 to act as a proxy user for
others
+ "authorized_proxy_user_config", String.format("%s=*", TEST_USER_4)
+ )
+ );
+ // Start Impala with configured flags.
+ int ret = startImpalaCluster(flagsToArgs(flags));
+ assertEquals(0, ret); // cluster should start up
+
+ // Open a session and authenticate using SPNEGO.
+ THttpClientWithHeaders transport =
+ new THttpClientWithHeaders("http://localhost:28000");
+ Map<String, String> headers = new HashMap<String, String>();
+ // Authenticate as the proxy user 'Test4Ldap'
+ headers.put("Authorization", "Negotiate " + getSpnegoToken(TEST_USER_4));
+ transport.setCustomHeaders(headers);
+ transport.open();
+ TCLIService.Iface client = new TCLIService.Client(new
TBinaryProtocol(transport));
+
+ // Open a session without specifying a 'doas', should fail as the proxy
user won't
+ // pass the filters.
+ TOpenSessionReq openReq = new TOpenSessionReq();
+ TOpenSessionResp openResp = client.OpenSession(openReq);
+ assertEquals(TStatusCode.ERROR_STATUS,
openResp.getStatus().getStatusCode());
+ int negotiateAuthFailureCount = 0;
+ int negotiateAuthSuccessCount = 1;
+ verifyNegotiateAuthMetrics(negotiateAuthSuccessCount,
negotiateAuthFailureCount);
+ int cookieAuthFailureCount = 0;
+ int cookieAuthSuccessCount = 0;
+ verifyCookieAuthMetrics(cookieAuthSuccessCount, cookieAuthFailureCount);
+
+ // SPNEGO doesn't like replay tokens, so use new tokens.
+ headers.remove("Authorization");
+ headers.put("Authorization", "Negotiate " + getSpnegoToken(TEST_USER_4));
+ // Open a session with a 'doas' that will pass both filters, should
succeed.
+ Map<String, String> config = new HashMap<String, String>();
+ config.put("impala.doas.user", TEST_USER_1);
+ openReq.setConfiguration(config);
+ openResp = client.OpenSession(openReq);
+ assertEquals(TStatusCode.SUCCESS_STATUS,
openResp.getStatus().getStatusCode());
+ negotiateAuthSuccessCount++;
+ verifyNegotiateAuthMetrics(negotiateAuthSuccessCount,
negotiateAuthFailureCount);
+ verifyCookieAuthMetrics(cookieAuthSuccessCount, cookieAuthFailureCount);
+
+ // Use Auth Cookie for the remaining sessions and connections.
+ Map<String, List<String>> responseHeaders = transport.getResponseHeaders();
+ List<String> cookies = responseHeaders.get("Set-Cookie");
+ if (cookies != null) {
+ for (String cookie : cookies) {
+ String authMech = extractCookieAuthMech(cookie);
+ assertNotNull(authMech);
+ assertEquals("SPNEGO", authMech);
+ headers.put("Cookie", cookie);
+ }
+ } else {
+ fail("'Set-Cookie' cookie not returned from Impala");
+ }
+
+ // Simulate 4 concurrent clients, with each running 100 exec and fetch
RPCs.
+ final int numClients = 4;
+ final int numQueries = 100;
+ ExecutorService executor = Executors.newFixedThreadPool(numClients);
+ List<Future<Void>> futures = new ArrayList<>();
+ for (int i = 0; i < numClients; i++) {
+ final int clientId = i;
+ Future<Void> future = executor.submit(() -> {
+ simulateClient(headers, config, clientId, numQueries);
+ return null;
+ });
+ futures.add(future);
+ }
+
+ executor.shutdown();
+ executor.awaitTermination(5, TimeUnit.MINUTES);
+
+ // Check for exceptions from client threads
+ for (int i = 0; i < futures.size(); i++) {
+ try {
+ futures.get(i).get();
+ } catch (ExecutionException e) {
+ Throwable cause = e.getCause();
+ System.err.println("Client " + i + " failed: " + cause.getMessage());
+ cause.printStackTrace();
+ fail("Client " + i + " failed: " + cause.getMessage());
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ System.err.println("Main thread interrupted.");
+ }
+ }
+ // Each client uses one OpenSession RPC using cookie based authentication.
+ // Each query runs one Exec, one Fetch and one Close RPC, using 3 cookie
based
+ // authentications per query.
+ cookieAuthSuccessCount += numClients * (1 + numQueries * 3);
+ verifyCookieAuthMetrics(cookieAuthSuccessCount, cookieAuthFailureCount);
+ verifyNegotiateAuthMetrics(negotiateAuthSuccessCount,
negotiateAuthFailureCount);
+ }
+
+ /**
+ * Generates and returns Base64 encoded SPNEGO token for the input user.
+ */
+ private static String getSpnegoToken(String user) throws Exception {
+ // Create a test user principal and generate Kerberos credentials cache
(ccache)
+ String ccacheFilePath =
+ kerberosKdcEnvironment.createUserPrincipalAndCredentialsCache(user);
+ File spngeoTokenFile =
+ new File(kerberosKdcEnvironment.getTestFolderPath() +
"/spngeoToken.bin");
+ // Using ProcessBuilder to generate SPNEGO token becasue apparently some
of the Java
+ // security classes are initialized much earlier and cannot read required
kerberos
+ // config setup by the test.
+ ProcessBuilder pb = new ProcessBuilder(
+ "java", "-cp", System.getProperty("java.class.path"),
+ "-Djava.security.krb5.conf=" +
kerberosKdcEnvironment.getKrb5ConfigPath(),
+ "-Dsun.security.krb5.debug=true",
+
"-Djava.security.debug=gssloginconfig,configfile,configparser,logincontext,JGSS",
+ "-Djavax.security.auth.useSubjectCredsOnly=false",
+ "org.apache.impala.customcluster.SpnegoTokenGenerator",
+ spngeoTokenFile.getCanonicalPath());
+
+ Map<String, String> env = pb.environment();
+ env.put("KRB5CCNAME", "FILE:" + ccacheFilePath);
+
+ pb.inheritIO();
+ Process process = pb.start();
+ int exitCode = process.waitFor();
+ // Non zero exit code indicates token generation failed.
+ assertEquals(0, exitCode);
+
+ byte[] token = readTokenFromFile(spngeoTokenFile.getCanonicalPath());
+ String base64Token = Base64.getEncoder().encodeToString(token);
+ return base64Token;
+ }
+
+ /**
+ * Helper function to read token from token file generated by
SpnegoTokenGenerator.
+ */
+ private static byte[] readTokenFromFile(String path) throws IOException {
+ ByteArrayOutputStream buffer = new ByteArrayOutputStream();
+ InputStream is = new FileInputStream(path);
+ byte[] temp = new byte[4096];
+ int bytesRead;
+ while ((bytesRead = is.read(temp)) != -1) {
+ buffer.write(temp, 0, bytesRead);
+ }
+ is.close();
+ return buffer.toByteArray();
+ }
+
+ /**
+ * Simulates a client opening session and running a number of queries within
a session.
+ */
+ private static void simulateClient(Map<String, String> headers,
+ Map<String, String> config, int clientId, int numQueries) throws
Exception {
+ // Create and open the transport
+ THttpClientWithHeaders transport =
+ new THttpClientWithHeaders("http://localhost:28000");
+ transport.setCustomHeaders(headers);
+ transport.open();
+
+ // Create client stub
+ TCLIService.Iface client = new TCLIService.Client(new
TBinaryProtocol(transport));
+
+ // Open a session
+ TOpenSessionReq openReq = new TOpenSessionReq();
+ openReq.setConfiguration(config);
+ TOpenSessionResp openResp = client.OpenSession(openReq);
+
+ if (openResp.getStatus().getStatusCode() != TStatusCode.SUCCESS_STATUS) {
+ throw new RuntimeException("Failed to open session for client " +
clientId);
+ }
+
+ System.out.println("Client " + clientId + " opened session successfully.");
+
+ // Execute queries
+ for (int i = 0; i < numQueries; i++) {
+ execAndFetch(client, openResp.getSessionHandle(),
+ "select logged_in_user()", "Test1Ldap");
+ int sleepMillis = ThreadLocalRandom.current().nextInt(10, 100);
+ Thread.sleep(sleepMillis);
+ }
+
+ // Close transport
+ transport.close();
+ System.out.println("Client " + clientId + " finished.");
+ }
+
+ /**
+ * Extracts auth mechanism from cookie's value.
+ */
+ private static String extractCookieAuthMech(String cookie) throws Exception {
+ if (cookie == null || cookie.isEmpty()) {
+ return null;
+ }
+ // Expect cookie:
+ //
impala.auth=<base64signature>&<cookie_value>;HttpOnly;Max-Age=86400;Secure.
+ String[] cookieFields = cookie.split(";");
+ if (cookieFields.length == 0) {
+ return null;
+ }
+
+ // We've impala.auth=<base64signature>&<cookie_value> as first token with
+ // cookie_value like
[email protected]&t=549158755&r=1800557187&a=SPNEGO.
+ String[] cookieValueFields = cookieFields[0].trim().split("&");
+ assertEquals(5, cookieValueFields.length);
+ String[] authMech = cookieValueFields[4].trim().split("=");
+ assertEquals(2, authMech.length);
+ assertEquals("a", authMech[0]);
+ return authMech[1];
+ }
+}
diff --git
a/fe/src/test/java/org/apache/impala/customcluster/SpnegoTokenGenerator.java
b/fe/src/test/java/org/apache/impala/customcluster/SpnegoTokenGenerator.java
new file mode 100644
index 000000000..25526d5d4
--- /dev/null
+++ b/fe/src/test/java/org/apache/impala/customcluster/SpnegoTokenGenerator.java
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.impala.customcluster;
+
+import org.ietf.jgss.*;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+public class SpnegoTokenGenerator {
+ public static void main(String[] args) {
+ try {
+ if (args.length < 1) {
+ System.err.println("Missing argument: <output-token-file>");
+ System.exit(1);
+ }
+ String outputPath = args[0];
+ // OID for Kerberos V5
+ Oid krb5Oid = new Oid("1.2.840.113554.1.2.2");
+
+ // Create GSSManager
+ GSSManager manager = GSSManager.getInstance();
+
+ // This OID corresponds to NT_KRB5_PRINCIPAL
+ Oid krb5PrincipalOid = new Oid("1.2.840.113554.1.2.2.1");
+
+ // Full service principal with realm
+ String servicePrincipal = "impala/[email protected]";
+
+ // Create GSSName with full principal name
+ GSSName serverName = manager.createName(servicePrincipal,
krb5PrincipalOid);
+
+ // Create security context
+ GSSContext context = manager.createContext(
+ serverName,
+ krb5Oid,
+ null, // use default credentials from ccache
+ GSSContext.DEFAULT_LIFETIME
+ );
+
+ // Initiate the context, which triggers ticket acquisition
+ context.requestMutualAuth(true);
+ context.requestCredDeleg(false);
+
+ byte[] token = context.initSecContext(new byte[0], 0, 0);
+ if (token != null) {
+ try {
+ FileOutputStream fos = new FileOutputStream(outputPath);
+ fos.write(token);
+ System.out.println("Token written to " + outputPath);
+ } catch (IOException e) {
+ System.err.println("Failed to write token to file: " +
e.getMessage());
+ e.printStackTrace();
+ System.exit(1);
+ }
+ } else {
+ System.err.println("Failed to obtain SPNEGO token.");
+ System.exit(1);
+ }
+ context.dispose();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ System.exit(1);
+ }
+ }
+}
diff --git
a/fe/src/test/java/org/apache/impala/customcluster/THttpClientWithHeaders.java
b/fe/src/test/java/org/apache/impala/customcluster/THttpClientWithHeaders.java
new file mode 100644
index 000000000..2666b18e6
--- /dev/null
+++
b/fe/src/test/java/org/apache/impala/customcluster/THttpClientWithHeaders.java
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+// This code is copied from apache/thrift and modified to return
+// HTTP response headers from transport.
+// Original Source:
+//
https://github.com/apache/thrift/blob/v0.16.0/lib/java/src/org/apache/thrift/transport/
+// THttpClient.java
+
+package org.apache.impala.customcluster;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
+import java.io.IOException;
+
+import java.net.URL;
+import java.net.HttpURLConnection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.http.HttpEntity;
+import org.apache.http.HttpHost;
+import org.apache.http.HttpResponse;
+import org.apache.http.HttpStatus;
+import org.apache.http.client.HttpClient;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.apache.http.params.CoreConnectionPNames;
+import org.apache.thrift.TConfiguration;
+import org.apache.thrift.transport.TEndpointTransport;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.TTransportFactory;
+
+/**
+ * HTTP implementation of the TTransport interface. Used for working with a
+ * Thrift web services implementation (using for example TServlet).
+ *
+ * This class offers two implementations of the HTTP transport.
+ * One uses HttpURLConnection instances, the other HttpClient from Apache
+ * Http Components.
+ * The chosen implementation depends on the constructor used to
+ * create the THttpClient instance.
+ * Using the THttpClient(String url) constructor or passing null as the
+ * HttpClient to THttpClient(String url, HttpClient client) will create an
+ * instance which will use HttpURLConnection.
+ *
+ * When using HttpClient, the following configuration leads to 5-15%
+ * better performance than the HttpURLConnection implementation:
+ *
+ * http.protocol.version=HttpVersion.HTTP_1_1
+ * http.protocol.content-charset=UTF-8
+ * http.protocol.expect-continue=false
+ * http.connection.stalecheck=false
+ *
+ * Also note that under high load, the HttpURLConnection implementation
+ * may exhaust the open file descriptor limit.
+ *
+ * @see <a
href="https://issues.apache.org/jira/browse/THRIFT-970">THRIFT-970</a>
+ */
+
+public class THttpClientWithHeaders extends TEndpointTransport {
+
+ private URL url_ = null;
+
+ private final ByteArrayOutputStream requestBuffer_ = new
ByteArrayOutputStream();
+
+ private InputStream inputStream_ = null;
+
+ private int connectTimeout_ = 0;
+
+ private int readTimeout_ = 0;
+
+ private Map<String,String> customHeaders_ = null;
+
+ // Used for storing response headers. This is not in the
+ // THttpClient.java class in the apache/thrift repository.
+ private Map<String,List<String>> responseHeaders_ = null;
+
+ private final HttpHost host;
+
+ private final HttpClient client;
+
+ /* Not compatible with thrift 0.11.0 which is used when
USE_APACHE_COMPONENTS=true.
+ public static class Factory extends TTransportFactory {
+
+ private final String url;
+ private final HttpClient client;
+
+ public Factory(String url) {
+ this.url = url;
+ this.client = null;
+ }
+
+ public Factory(String url, HttpClient client) {
+ this.url = url;
+ this.client = client;
+ }
+
+ @Override
+ public TTransport getTransport(TTransport trans) {
+ try {
+ if (null != client) {
+ return new THttpClientWithHeaders(trans.getConfiguration(), url,
client);
+ } else {
+ return new THttpClientWithHeaders(trans.getConfiguration(), url);
+ }
+ } catch (TTransportException tte) {
+ return null;
+ }
+ }
+ }*/
+
+ public THttpClientWithHeaders(TConfiguration config, String url)
+ throws TTransportException {
+ super(config);
+ try {
+ url_ = new URL(url);
+ this.client = null;
+ this.host = null;
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public THttpClientWithHeaders(String url) throws TTransportException {
+ super(new TConfiguration());
+ try {
+ url_ = new URL(url);
+ this.client = null;
+ this.host = null;
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public THttpClientWithHeaders(TConfiguration config, String url, HttpClient
client)
+ throws TTransportException {
+ super(config);
+ try {
+ url_ = new URL(url);
+ this.client = client;
+ this.host = new HttpHost(url_.getHost(), -1 == url_.getPort()
+ ? url_.getDefaultPort()
+ : url_.getPort(), url_.getProtocol());
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public THttpClientWithHeaders(String url, HttpClient client)
+ throws TTransportException {
+ super(new TConfiguration());
+ try {
+ url_ = new URL(url);
+ this.client = client;
+ this.host = new HttpHost(url_.getHost(), -1 == url_.getPort()
+ ? url_.getDefaultPort()
+ : url_.getPort(), url_.getProtocol());
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public void setConnectTimeout(int timeout) {
+ connectTimeout_ = timeout;
+ if (null != this.client) {
+ // WARNING, this modifies the HttpClient params, this might have an
impact elsewhere
+ // if the same HttpClient is used for something else.
+ client.getParams().setParameter(
+ CoreConnectionPNames.CONNECTION_TIMEOUT, connectTimeout_);
+ }
+ }
+
+ public void setReadTimeout(int timeout) {
+ readTimeout_ = timeout;
+ if (null != this.client) {
+ // WARNING, this modifies the HttpClient params, this might have an
impact elsewhere
+ // if the same HttpClient is used for something else.
+ client.getParams().setParameter(CoreConnectionPNames.SO_TIMEOUT,
readTimeout_);
+ }
+ }
+
+ public void setCustomHeaders(Map<String,String> headers) {
+ customHeaders_ = headers;
+ }
+
+ public void setCustomHeader(String key, String value) {
+ if (customHeaders_ == null) {
+ customHeaders_ = new HashMap<String, String>();
+ }
+ customHeaders_.put(key, value);
+ }
+
+ public void open() {}
+
+ public void close() {
+ if (null != inputStream_) {
+ try {
+ inputStream_.close();
+ } catch (IOException ioe) {
+ }
+ inputStream_ = null;
+ }
+ }
+
+ public boolean isOpen() {
+ return true;
+ }
+
+ public int read(byte[] buf, int off, int len) throws TTransportException {
+ if (inputStream_ == null) {
+ throw new TTransportException("Response buffer is empty, no request.");
+ }
+
+ checkReadBytesAvailable(len);
+
+ try {
+ int ret = inputStream_.read(buf, off, len);
+ if (ret == -1) {
+ throw new TTransportException("No more data available.");
+ }
+ countConsumedMessageBytes(ret);
+
+ return ret;
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ }
+ }
+
+ public void write(byte[] buf, int off, int len) {
+ requestBuffer_.write(buf, off, len);
+ }
+
+ /**
+ * copy from org.apache.http.util.EntityUtils#consume. Android has it's own
httpcore
+ * that doesn't have a consume.
+ */
+ private static void consume(final HttpEntity entity) throws IOException {
+ if (entity == null) {
+ return;
+ }
+ if (entity.isStreaming()) {
+ InputStream instream = entity.getContent();
+ if (instream != null) {
+ instream.close();
+ }
+ }
+ }
+
+ private void flushUsingHttpClient() throws TTransportException {
+
+ if (null == this.client) {
+ throw new TTransportException("Null HttpClient, aborting.");
+ }
+
+ // Extract request and reset buffer
+ byte[] data = requestBuffer_.toByteArray();
+ requestBuffer_.reset();
+
+ HttpPost post = null;
+
+ InputStream is = null;
+
+ try {
+ // Set request to path + query string
+ post = new HttpPost(this.url_.getFile());
+
+ //
+ // Headers are added to the HttpPost instance, not
+ // to HttpClient.
+ //
+
+ post.setHeader("Content-Type", "application/x-thrift");
+ post.setHeader("Accept", "application/x-thrift");
+ post.setHeader("User-Agent", "Java/THttpClient/HC");
+
+ if (null != customHeaders_) {
+ for (Map.Entry<String, String> header : customHeaders_.entrySet()) {
+ post.setHeader(header.getKey(), header.getValue());
+ }
+ }
+
+ post.setEntity(new ByteArrayEntity(data));
+
+ HttpResponse response = this.client.execute(this.host, post);
+ int responseCode = response.getStatusLine().getStatusCode();
+
+ //
+ // Retrieve the inputstream BEFORE checking the status code so
+ // resources get freed in the finally clause.
+ //
+
+ is = response.getEntity().getContent();
+
+ if (responseCode != HttpStatus.SC_OK) {
+ throw new TTransportException("HTTP Response code: " + responseCode);
+ }
+
+ // Read the responses into a byte array so we can release the connection
+ // early. This implies that the whole content will have to be read in
+ // memory, and that momentarily we might use up twice the memory (while
the
+ // thrift struct is being read up the chain).
+ // Proceeding differently might lead to exhaustion of connections and
thus
+ // to app failure.
+
+ byte[] buf = new byte[1024];
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+
+ int len = 0;
+ do {
+ len = is.read(buf);
+ if (len > 0) {
+ baos.write(buf, 0, len);
+ }
+ } while (-1 != len);
+
+ try {
+ // Indicate we're done with the content.
+ consume(response.getEntity());
+ } catch (IOException ioe) {
+ // We ignore this exception, it might only mean the server has no
+ // keep-alive capability.
+ }
+
+ inputStream_ = new ByteArrayInputStream(baos.toByteArray());
+ } catch (IOException ioe) {
+ // Abort method so the connection gets released back to the connection
manager
+ if (null != post) {
+ post.abort();
+ }
+ throw new TTransportException(ioe);
+ } finally {
+ resetConsumedMessageSize(-1);
+ if (null != is) {
+ // Close the entity's input stream, this will release the underlying
connection
+ try {
+ is.close();
+ } catch (IOException ioe) {
+ throw new TTransportException(ioe);
+ }
+ }
+ if (post != null) {
+ post.releaseConnection();
+ }
+ }
+ }
+
+ public void flush() throws TTransportException {
+
+ if (null != this.client) {
+ flushUsingHttpClient();
+ return;
+ }
+
+ // Extract request and reset buffer
+ byte[] data = requestBuffer_.toByteArray();
+ requestBuffer_.reset();
+
+ try {
+ // Create connection object
+ HttpURLConnection connection = (HttpURLConnection)url_.openConnection();
+
+ // Timeouts, only if explicitly set
+ if (connectTimeout_ > 0) {
+ connection.setConnectTimeout(connectTimeout_);
+ }
+ if (readTimeout_ > 0) {
+ connection.setReadTimeout(readTimeout_);
+ }
+
+ // Make the request
+ connection.setRequestMethod("POST");
+ connection.setRequestProperty("Content-Type", "application/x-thrift");
+ connection.setRequestProperty("Accept", "application/x-thrift");
+ connection.setRequestProperty("User-Agent", "Java/THttpClient");
+ if (customHeaders_ != null) {
+ for (Map.Entry<String, String> header : customHeaders_.entrySet()) {
+ connection.setRequestProperty(header.getKey(), header.getValue());
+ }
+ }
+ connection.setDoOutput(true);
+ connection.connect();
+ connection.getOutputStream().write(data);
+
+ int responseCode = connection.getResponseCode();
+ if (responseCode != HttpURLConnection.HTTP_OK) {
+ throw new TTransportException("HTTP Response code: " + responseCode);
+ }
+
+ // Read the responses
+ inputStream_ = connection.getInputStream();
+ // Capture the response headers.
+ // This is not in the THttpClient.java class in the apache/thrift
repository.
+ responseHeaders_ = connection.getHeaderFields();
+
+ } catch (IOException iox) {
+ throw new TTransportException(iox);
+ } finally {
+ resetConsumedMessageSize(-1);
+ }
+ }
+
+ // Getter function for HTTP response headers. This is not in the
+ // THttpClient.java class in the apache/thrift repository.
+ public Map<String, List<String>> getResponseHeaders() {
+ return responseHeaders_;
+ }
+}