[ 
https://issues.apache.org/jira/browse/KAFKA-7169?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16580251#comment-16580251
 ] 

ASF GitHub Bot commented on KAFKA-7169:
---------------------------------------

rajinisivaram closed pull request #5497: KAFKA-7169: Validate SASL extensions 
through callback on server side
URL: https://github.com/apache/kafka/pull/5497
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallback.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallback.java
new file mode 100644
index 00000000000..97ac4d96838
--- /dev/null
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallback.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.oauthbearer;
+
+import org.apache.kafka.common.security.auth.SaslExtensions;
+
+import javax.security.auth.callback.Callback;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.apache.kafka.common.utils.CollectionUtils.subtractMap;
+
+/**
+ * A {@code Callback} for use by the {@code SaslServer} implementation when it
+ * needs to validate the SASL extensions for the OAUTHBEARER mechanism
+ * Callback handlers should use the {@link #valid(String)}
+ * method to communicate valid extensions back to the SASL server.
+ * Callback handlers should use the
+ * {@link #error(String, String)} method to communicate validation errors back 
to
+ * the SASL Server.
+ * As per RFC-7628 (https://tools.ietf.org/html/rfc7628#section-3.1), unknown 
extensions must be ignored by the server.
+ * The callback handler implementation should simply ignore unknown extensions,
+ * not calling {@link #error(String, String)} nor {@link #valid(String)}.
+ * Callback handlers should communicate other problems by raising an {@code 
IOException}.
+ * <p>
+ * The OAuth bearer token is provided in the callback for better context in 
extension validation.
+ * It is very important that token validation is done in its own {@link 
OAuthBearerValidatorCallback}
+ * irregardless of provided extensions, as they are inherently insecure.
+ */
+public class OAuthBearerExtensionsValidatorCallback implements Callback {
+    private final OAuthBearerToken token;
+    private final SaslExtensions inputExtensions;
+    private final Map<String, String> validatedExtensions = new HashMap<>();
+    private final Map<String, String> invalidExtensions = new HashMap<>();
+
+    public OAuthBearerExtensionsValidatorCallback(OAuthBearerToken token, 
SaslExtensions extensions) {
+        this.token = Objects.requireNonNull(token);
+        this.inputExtensions = Objects.requireNonNull(extensions);
+    }
+
+    /**
+     * @return {@link OAuthBearerToken} the OAuth bearer token of the client
+     */
+    public OAuthBearerToken token() {
+        return token;
+    }
+
+    /**
+     * @return {@link SaslExtensions} consisting of the unvalidated extension 
names and values that were sent by the client
+     */
+    public SaslExtensions inputExtensions() {
+        return inputExtensions;
+    }
+
+    /**
+     * @return an unmodifiable {@link Map} consisting of the validated and 
recognized by the server extension names and values
+     */
+    public Map<String, String> validatedExtensions() {
+        return Collections.unmodifiableMap(validatedExtensions);
+    }
+
+    /**
+     * @return An immutable {@link Map} consisting of the name->error messages 
of extensions which failed validation
+     */
+    public Map<String, String> invalidExtensions() {
+        return Collections.unmodifiableMap(invalidExtensions);
+    }
+
+    /**
+     * @return An immutable {@link Map} consisting of the extensions that have 
neither been validated nor invalidated
+     */
+    public Map<String, String> ignoredExtensions() {
+        return 
Collections.unmodifiableMap(subtractMap(subtractMap(inputExtensions.map(), 
invalidExtensions), validatedExtensions));
+    }
+
+    /**
+     * Validates a specific extension in the original {@code inputExtensions} 
map
+     * @param extensionName - the name of the extension which was validated
+     */
+    public void valid(String extensionName) {
+        if (!inputExtensions.map().containsKey(extensionName))
+            throw new IllegalArgumentException(String.format("Extension %s was 
not found in the original extensions", extensionName));
+        validatedExtensions.put(extensionName, 
inputExtensions.map().get(extensionName));
+    }
+    /**
+     * Set the error value for a specific extension key-value pair if 
validation has failed
+     *
+     * @param invalidExtensionName
+     *            the mandatory extension name which caused the validation 
failure
+     * @param errorMessage
+     *            error message describing why the validation failed
+     */
+    public void error(String invalidExtensionName, String errorMessage) {
+        if (Objects.requireNonNull(invalidExtensionName).isEmpty())
+            throw new IllegalArgumentException("extension name must not be 
empty");
+        this.invalidExtensions.put(invalidExtensionName, errorMessage);
+    }
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
index 57fa5d20925..ac273ccbb0c 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
@@ -97,7 +97,8 @@
  * <p>
  * You can also add custom unsecured SASL extensions when using the default, 
builtin {@link AuthenticateCallbackHandler}
  * implementation through using the configurable option {@code 
unsecuredLoginExtension_<extensionname>}. Note that there
- * are validations for the key/values in order to conform to the OAuth 
standard, including the reserved key at
+ * are validations for the key/values in order to conform to the 
SASL/OAUTHBEARER standard
+ * (https://tools.ietf.org/html/rfc7628#section-3.1), including the reserved 
key at
  * {@link 
org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse#AUTH_KEY}.
  * The {@code OAuthBearerLoginModule} instance also asks its configured {@link 
AuthenticateCallbackHandler}
  * implementation to handle an instance of {@link SaslExtensionsCallback} and 
return an instance of {@link SaslExtensions}.
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
index 6573f695307..8f89c147328 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
@@ -33,9 +33,11 @@
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import 
org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
 import 
org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
+import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -82,8 +84,7 @@ public OAuthBearerSaslServer(CallbackHandler callbackHandler) 
{
     @Override
     public byte[] evaluateResponse(byte[] response) throws SaslException, 
SaslAuthenticationException {
         if (response.length == 1 && response[0] == 
OAuthBearerSaslClient.BYTE_CONTROL_A && errorMessage != null) {
-            if (log.isDebugEnabled())
-                log.debug("Received %x01 response from client after it 
received our error");
+            log.debug("Received %x01 response from client after it received 
our error");
             throw new SaslAuthenticationException(errorMessage);
         }
         errorMessage = null;
@@ -152,17 +153,13 @@ public void dispose() throws SaslException {
         try {
             callbackHandler.handle(new Callback[] {callback});
         } catch (IOException | UnsupportedCallbackException e) {
-            String msg = String.format("%s: %s", INTERNAL_ERROR_ON_SERVER, 
e.getMessage());
-            if (log.isDebugEnabled())
-                log.debug(msg, e);
-            throw new SaslException(msg);
+            handleCallbackError(e);
         }
         OAuthBearerToken token = callback.token();
         if (token == null) {
             errorMessage = jsonErrorResponse(callback.errorStatus(), 
callback.errorScope(),
                     callback.errorOpenIDConfiguration());
-            if (log.isDebugEnabled())
-                log.debug(errorMessage);
+            log.debug(errorMessage);
             return errorMessage.getBytes(StandardCharsets.UTF_8);
         }
         /*
@@ -173,14 +170,36 @@ public void dispose() throws SaslException {
             throw new SaslAuthenticationException(String.format(
                     "Authentication failed: Client requested an authorization 
id (%s) that is different from the token's principal name (%s)",
                     authorizationId, token.principalName()));
+
+        Map<String, String> validExtensions = processExtensions(token, 
extensions);
+
         tokenForNegotiatedProperty = token;
-        this.extensions = extensions;
+        this.extensions = new SaslExtensions(validExtensions);
         complete = true;
-        if (log.isDebugEnabled())
-            log.debug("Successfully authenticate User={}", 
token.principalName());
+        log.debug("Successfully authenticate User={}", token.principalName());
         return new byte[0];
     }
 
+    private Map<String, String> processExtensions(OAuthBearerToken token, 
SaslExtensions extensions) throws SaslException {
+        OAuthBearerExtensionsValidatorCallback extensionsCallback = new 
OAuthBearerExtensionsValidatorCallback(token, extensions);
+        try {
+            callbackHandler.handle(new Callback[] {extensionsCallback});
+        } catch (UnsupportedCallbackException e) {
+            // backwards compatibility - no extensions will be added
+        } catch (IOException e) {
+            handleCallbackError(e);
+        }
+        if (!extensionsCallback.invalidExtensions().isEmpty()) {
+            String errorMessage = String.format("Authentication failed: %d 
extensions are invalid! They are: %s",
+                    extensionsCallback.invalidExtensions().size(),
+                    Utils.mkString(extensionsCallback.invalidExtensions(), "", 
"", ": ", "; "));
+            log.debug(errorMessage);
+            throw new SaslAuthenticationException(errorMessage);
+        }
+
+        return extensionsCallback.validatedExtensions();
+    }
+
     private static String jsonErrorResponse(String errorStatus, String 
errorScope, String errorOpenIDConfiguration) {
         String jsonErrorResponse = String.format("{\"status\":\"%s\"", 
errorStatus);
         if (errorScope != null)
@@ -192,6 +211,12 @@ private static String jsonErrorResponse(String 
errorStatus, String errorScope, S
         return jsonErrorResponse;
     }
 
+    private void handleCallbackError(Exception e) throws SaslException {
+        String msg = String.format("%s: %s", INTERNAL_ERROR_ON_SERVER, 
e.getMessage());
+        log.debug(msg, e);
+        throw new SaslException(msg);
+    }
+
     public static String[] mechanismNamesCompatibleWithPolicy(Map<String, ?> 
props) {
         return props != null && 
"true".equals(String.valueOf(props.get(Sasl.POLICY_NOPLAINTEXT))) ? new 
String[] {}
                 : new String[] {OAuthBearerLoginModule.OAUTHBEARER_MECHANISM};
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java
index ac7d0199d5c..2e21cf43f4e 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java
@@ -27,6 +27,7 @@
 import javax.security.auth.login.AppConfigurationEntry;
 
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import 
org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import 
org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
 import org.apache.kafka.common.utils.Time;
@@ -68,11 +69,14 @@
  *      unsecuredValidatorAllowableClockSkewMs="3000";
  * };
  * </pre>
- * 
+ * It also recognizes {@link OAuthBearerExtensionsValidatorCallback} and 
validates every extension passed to it.
+ *
  * This class is the default when the SASL mechanism is OAUTHBEARER and no 
value
  * is explicitly set via the
  * {@code 
listener.name.sasl_[plaintext|ssl].oauthbearer.sasl.server.callback.handler.class}
  * broker configuration property.
+ * It is worth noting that this class is not suitable for production use due 
to the use of unsecured JWT tokens and
+ * validation of every given extension.
  */
 public class OAuthBearerUnsecuredValidatorCallbackHandler implements 
AuthenticateCallbackHandler {
     private static final Logger log = 
LoggerFactory.getLogger(OAuthBearerUnsecuredValidatorCallbackHandler.class);
@@ -134,6 +138,9 @@ public void handle(Callback[] callbacks) throws 
IOException, UnsupportedCallback
                     validationCallback.error(failureScope != null ? 
"insufficient_scope" : "invalid_token",
                             failureScope, failureReason.failureOpenIdConfig());
                 }
+            } else if (callback instanceof 
OAuthBearerExtensionsValidatorCallback) {
+                OAuthBearerExtensionsValidatorCallback extensionsCallback = 
(OAuthBearerExtensionsValidatorCallback) callback;
+                
extensionsCallback.inputExtensions().map().forEach((extensionName, v) -> 
extensionsCallback.valid(extensionName));
             } else
                 throw new UnsupportedCallbackException(callback);
         }
diff --git 
a/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java 
b/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java
index dd4b46a0c63..04fce647063 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java
@@ -22,11 +22,21 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 public final class CollectionUtils {
 
     private CollectionUtils() {}
 
+    /**
+     * Given two maps (A, B), returns all the key-value pairs in A whose keys 
are not contained in B
+     */
+    public static <K, V> Map<K, V> subtractMap(Map<? extends K, ? extends V> 
minuend, Map<? extends K, ? extends V> subtrahend) {
+        return minuend.entrySet().stream()
+                .filter(entry -> !subtrahend.containsKey(entry.getKey()))
+                .collect(Collectors.toMap(Map.Entry::getKey, 
Map.Entry::getValue));
+    }
+
     /**
      * group data by topic
      * @param data Data to be partitioned
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallbackTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallbackTest.java
new file mode 100644
index 00000000000..f65031ff38a
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallbackTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.oauthbearer;
+
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class OAuthBearerExtensionsValidatorCallbackTest {
+    private static final OAuthBearerToken TOKEN = new OAuthBearerTokenMock();
+
+    @Test
+    public void testValidatedExtensionsAreReturned() {
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("hello", "bye");
+
+        OAuthBearerExtensionsValidatorCallback callback = new 
OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions));
+
+        assertTrue(callback.validatedExtensions().isEmpty());
+        assertTrue(callback.invalidExtensions().isEmpty());
+        callback.valid("hello");
+        assertFalse(callback.validatedExtensions().isEmpty());
+        assertEquals("bye", callback.validatedExtensions().get("hello"));
+        assertTrue(callback.invalidExtensions().isEmpty());
+    }
+
+    @Test
+    public void testInvalidExtensionsAndErrorMessagesAreReturned() {
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("hello", "bye");
+
+        OAuthBearerExtensionsValidatorCallback callback = new 
OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions));
+
+        assertTrue(callback.validatedExtensions().isEmpty());
+        assertTrue(callback.invalidExtensions().isEmpty());
+        callback.error("hello", "error");
+        assertFalse(callback.invalidExtensions().isEmpty());
+        assertEquals("error", callback.invalidExtensions().get("hello"));
+        assertTrue(callback.validatedExtensions().isEmpty());
+    }
+
+    /**
+     * Extensions that are neither validated or invalidated must not be 
present in either maps
+     */
+    @Test
+    public void testUnvalidatedExtensionsAreIgnored() {
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("valid", "valid");
+        extensions.put("error", "error");
+        extensions.put("nothing", "nothing");
+
+        OAuthBearerExtensionsValidatorCallback callback = new 
OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions));
+        callback.error("error", "error");
+        callback.valid("valid");
+
+        assertFalse(callback.validatedExtensions().containsKey("nothing"));
+        assertFalse(callback.invalidExtensions().containsKey("nothing"));
+        assertEquals("nothing", callback.ignoredExtensions().get("nothing"));
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testCannotValidateExtensionWhichWasNotGiven() {
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("hello", "bye");
+
+        OAuthBearerExtensionsValidatorCallback callback = new 
OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions));
+
+        callback.valid("???");
+    }
+}
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenMock.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenMock.java
new file mode 100644
index 00000000000..994c923a4c1
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenMock.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.oauthbearer;
+
+import java.util.Set;
+
+public class OAuthBearerTokenMock implements OAuthBearerToken {
+    @Override
+    public String value() {
+        return null;
+    }
+
+    @Override
+    public Set<String> scope() {
+        return null;
+    }
+
+    @Override
+    public long lifetimeMs() {
+        return 0;
+    }
+
+    @Override
+    public String principalName() {
+        return null;
+    }
+
+    @Override
+    public Long startTimeMs() {
+        return null;
+    }
+}
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
index fc96f9f8adc..5df7968dc68 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
@@ -36,9 +36,12 @@
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.SaslExtensions;
+import 
org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
+import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenMock;
+import 
org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
 import 
org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerConfigException;
 import 
org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredLoginCallbackHandler;
 import 
org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredValidatorCallbackHandler;
@@ -62,10 +65,28 @@
                 JaasContext.loadClientContext(CONFIGS).configurationEntries());
     }
     private static final AuthenticateCallbackHandler 
VALIDATOR_CALLBACK_HANDLER;
+    private static final AuthenticateCallbackHandler 
EXTENSIONS_VALIDATOR_CALLBACK_HANDLER;
     static {
         VALIDATOR_CALLBACK_HANDLER = new 
OAuthBearerUnsecuredValidatorCallbackHandler();
         VALIDATOR_CALLBACK_HANDLER.configure(CONFIGS, 
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
                 JaasContext.loadClientContext(CONFIGS).configurationEntries());
+        // only validate extensions "firstKey" and "secondKey"
+        EXTENSIONS_VALIDATOR_CALLBACK_HANDLER = new 
OAuthBearerUnsecuredValidatorCallbackHandler() {
+            @Override
+            public void handle(Callback[] callbacks) throws 
UnsupportedCallbackException {
+                for (Callback callback : callbacks) {
+                    if (callback instanceof OAuthBearerValidatorCallback) {
+                        OAuthBearerValidatorCallback validationCallback = 
(OAuthBearerValidatorCallback) callback;
+                        validationCallback.token(new OAuthBearerTokenMock());
+                    } else if (callback instanceof 
OAuthBearerExtensionsValidatorCallback) {
+                        OAuthBearerExtensionsValidatorCallback 
extensionsCallback = (OAuthBearerExtensionsValidatorCallback) callback;
+                        extensionsCallback.valid("firstKey");
+                        extensionsCallback.valid("secondKey");
+                    } else
+                        throw new UnsupportedCallbackException(callback);
+                }
+            }
+        };
     }
     private OAuthBearerSaslServer saslServer;
 
@@ -78,9 +99,13 @@ public void setUp() {
     public void noAuthorizationIdSpecified() throws Exception {
         byte[] nextChallenge = saslServer
                 .evaluateResponse(clientInitialResponse(null));
+        // also asserts that no authentication error is thrown if 
OAuthBearerExtensionsValidatorCallback is not supported
         assertTrue("Next challenge is not empty", nextChallenge.length == 0);
     }
 
+    /**
+     * SASL Extensions that are validated by the callback handler should be 
accessible through the {@code #getNegotiatedProperty()} method
+     */
     @Test
     public void savesCustomExtensionAsNegotiatedProperty() throws Exception {
         Map<String, String> customExtensions = new HashMap<>();
@@ -95,16 +120,53 @@ public void savesCustomExtensionAsNegotiatedProperty() 
throws Exception {
         assertEquals("value2", saslServer.getNegotiatedProperty("secondKey"));
     }
 
+    /**
+     * SASL Extensions that were not recognized (neither validated nor 
invalidated)
+     * by the callback handler must not be accessible through the {@code 
#getNegotiatedProperty()} method
+     */
     @Test
-    public void returnsNullForNonExistentProperty() throws Exception {
+    public void unrecognizedExtensionsAreNotSaved() throws Exception {
+        saslServer = new 
OAuthBearerSaslServer(EXTENSIONS_VALIDATOR_CALLBACK_HANDLER);
         Map<String, String> customExtensions = new HashMap<>();
         customExtensions.put("firstKey", "value1");
+        customExtensions.put("secondKey", "value1");
+        customExtensions.put("thirdKey", "value1");
 
         byte[] nextChallenge = saslServer
                 .evaluateResponse(clientInitialResponse(null, false, 
customExtensions));
 
         assertTrue("Next challenge is not empty", nextChallenge.length == 0);
-        assertNull(saslServer.getNegotiatedProperty("secondKey"));
+        assertNull("Extensions not recognized by the server must be ignored", 
saslServer.getNegotiatedProperty("thirdKey"));
+    }
+
+    /**
+     * If the callback handler handles the 
`OAuthBearerExtensionsValidatorCallback`
+     *  and finds an invalid extension, SaslServer should throw an 
authentication exception
+     */
+    @Test(expected = SaslAuthenticationException.class)
+    public void throwsAuthenticationExceptionOnInvalidExtensions() throws 
Exception {
+        OAuthBearerUnsecuredValidatorCallbackHandler invalidHandler = new 
OAuthBearerUnsecuredValidatorCallbackHandler() {
+            @Override
+            public void handle(Callback[] callbacks) throws 
UnsupportedCallbackException {
+                for (Callback callback : callbacks) {
+                    if (callback instanceof OAuthBearerValidatorCallback) {
+                        OAuthBearerValidatorCallback validationCallback = 
(OAuthBearerValidatorCallback) callback;
+                        validationCallback.token(new OAuthBearerTokenMock());
+                    } else if (callback instanceof 
OAuthBearerExtensionsValidatorCallback) {
+                        OAuthBearerExtensionsValidatorCallback 
extensionsCallback = (OAuthBearerExtensionsValidatorCallback) callback;
+                        extensionsCallback.error("firstKey", "is not valid");
+                        extensionsCallback.error("secondKey", "is not valid 
either");
+                    } else
+                        throw new UnsupportedCallbackException(callback);
+                }
+            }
+        };
+        saslServer = new OAuthBearerSaslServer(invalidHandler);
+        Map<String, String> customExtensions = new HashMap<>();
+        customExtensions.put("firstKey", "value");
+        customExtensions.put("secondKey", "value");
+
+        saslServer.evaluateResponse(clientInitialResponse(null, false, 
customExtensions));
     }
 
     @Test
diff --git 
a/clients/src/test/java/org/apache/kafka/common/utils/CollectionUtilsTest.java 
b/clients/src/test/java/org/apache/kafka/common/utils/CollectionUtilsTest.java
new file mode 100644
index 00000000000..7abf08a4e76
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/utils/CollectionUtilsTest.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.kafka.common.utils.CollectionUtils.subtractMap;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertNotSame;
+
+public class CollectionUtilsTest {
+
+    @Test
+    public void testSubtractMapRemovesSecondMapsKeys() {
+        Map<String, String> mainMap = new HashMap<>();
+        mainMap.put("one", "1");
+        mainMap.put("two", "2");
+        mainMap.put("three", "3");
+        Map<String, String> secondaryMap = new HashMap<>();
+        secondaryMap.put("one", "4");
+        secondaryMap.put("two", "5");
+
+        Map<String, String> newMap = subtractMap(mainMap, secondaryMap);
+
+        assertEquals(3, mainMap.size());  // original map should not be 
modified
+        assertEquals(1, newMap.size());
+        assertTrue(newMap.containsKey("three"));
+        assertEquals("3", newMap.get("three"));
+    }
+
+    @Test
+    public void testSubtractMapDoesntRemoveAnythingWhenEmptyMap() {
+        Map<String, String> mainMap = new HashMap<>();
+        mainMap.put("one", "1");
+        mainMap.put("two", "2");
+        mainMap.put("three", "3");
+        Map<String, String> secondaryMap = new HashMap<>();
+
+        Map<String, String> newMap = subtractMap(mainMap, secondaryMap);
+
+        assertEquals(3, newMap.size());
+        assertEquals("1", newMap.get("one"));
+        assertEquals("2", newMap.get("two"));
+        assertEquals("3", newMap.get("three"));
+        assertNotSame(newMap, mainMap);
+    }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Add support for Custom SASL extensions in OAuth authentication
> --------------------------------------------------------------
>
>                 Key: KAFKA-7169
>                 URL: https://issues.apache.org/jira/browse/KAFKA-7169
>             Project: Kafka
>          Issue Type: Improvement
>            Reporter: Stanislav Kozlovski
>            Assignee: Stanislav Kozlovski
>            Priority: Minor
>             Fix For: 2.1.0
>
>
> KIP: 
> [here|https://cwiki.apache.org/confluence/display/KAFKA/KIP-342%3A+Add+support+for+Custom+SASL+extensions+in+OAuth+authentication]
> Kafka currently supports non-configurable SASL extensions in its SCRAM 
> authentication protocol for delegation token validation. It would be useful 
> to provide configurable SASL extensions for the OAuthBearer authentication 
> mechanism as well, such that clients could attach arbitrary data for the 
> principal authenticating into Kafka. This way, a custom principal can hold 
> information derived from the authentication mechanism, which could prove 
> useful for better tracing and troubleshooting, for example. This can be done 
> in a way which allows for easier extendability in future SASL mechanisms.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to