[ 
https://issues.apache.org/jira/browse/KAFKA-7169?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16570428#comment-16570428
 ] 

ASF GitHub Bot commented on KAFKA-7169:
---------------------------------------

rajinisivaram closed pull request #5379: KAFKA-7169: Custom SASL extensions for 
OAuthBearer authentication mechanism
URL: https://github.com/apache/kafka/pull/5379
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java
 
b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java
new file mode 100644
index 00000000000..75cac0533ea
--- /dev/null
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.auth;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A simple immutable value object class holding customizable SASL extensions
+ */
+public class SaslExtensions {
+    private final Map<String, String> extensionsMap;
+
+    public SaslExtensions(Map<String, String> extensionsMap) {
+        this.extensionsMap = Collections.unmodifiableMap(new 
HashMap<>(extensionsMap));
+    }
+
+    /**
+     * Returns an <strong>immutable</strong> map of the extension names and 
their values
+     */
+    public Map<String, String> map() {
+        return extensionsMap;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return extensionsMap.equals(((SaslExtensions) o).extensionsMap);
+    }
+
+    @Override
+    public String toString() {
+        return extensionsMap.toString();
+    }
+
+    @Override
+    public int hashCode() {
+        return extensionsMap.hashCode();
+    }
+
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java
 
b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java
new file mode 100644
index 00000000000..d07be320625
--- /dev/null
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.auth;
+
+import javax.security.auth.callback.Callback;
+
+/**
+ * Optional callback used for SASL mechanisms if any extensions need to be set
+ * in the SASL exchange.
+ */
+public class SaslExtensionsCallback implements Callback {
+    private SaslExtensions extensions;
+
+    /**
+     * Returns a {@link SaslExtensions} consisting of the extension names and 
values that are sent by the client to
+     * the server in the initial client SASL authentication message.
+     */
+    public SaslExtensions extensions() {
+        return extensions;
+    }
+
+    /**
+     * Sets the SASL extensions on this callback.
+     */
+    public void extensions(SaslExtensions extensions) {
+        this.extensions = extensions;
+    }
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
 
b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
index 5b2a28181cd..8b830c0c888 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
@@ -30,14 +30,19 @@
 import javax.security.sasl.RealmCallback;
 
 import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.scram.internals.ScramMechanism;
 
 /**
  * Default callback handler for Sasl clients. The callbacks required for the 
SASL mechanism
  * configured for the client should be supported by this callback handler. See
  * <a 
href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html";>Java
 SASL API</a>
  * for the list of SASL callback handlers required for each SASL mechanism.
+ *
+ * For adding custom SASL extensions, a {@link SaslExtensions} may be added to 
the subject's public credentials
  */
 public class SaslClientCallbackHandler implements AuthenticateCallbackHandler {
 
@@ -78,9 +83,15 @@ public void handle(Callback[] callbacks) throws 
UnsupportedCallbackException {
                 if (ac.isAuthorized())
                     ac.setAuthorizedID(authzId);
             } else if (callback instanceof ScramExtensionsCallback) {
-                ScramExtensionsCallback sc = (ScramExtensionsCallback) 
callback;
-                if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) && subject 
!= null && !subject.getPublicCredentials(Map.class).isEmpty()) {
-                    sc.extensions((Map<String, String>) 
subject.getPublicCredentials(Map.class).iterator().next());
+                if (ScramMechanism.isScram(mechanism) && subject != null && 
!subject.getPublicCredentials(Map.class).isEmpty()) {
+                    Map<String, String> extensions = (Map<String, String>) 
subject.getPublicCredentials(Map.class).iterator().next();
+                    ((ScramExtensionsCallback) 
callback).extensions(extensions);
+                }
+            } else if (callback instanceof SaslExtensionsCallback) {
+                if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) &&
+                        subject != null && 
!subject.getPublicCredentials(SaslExtensions.class).isEmpty()) {
+                    SaslExtensions extensions = 
subject.getPublicCredentials(SaslExtensions.class).iterator().next();
+                    ((SaslExtensionsCallback) callback).extensions(extensions);
                 }
             }  else {
                 throw new UnsupportedCallbackException(callback, "Unrecognized 
SASL ClientCallback");
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
index 07382a86e3f..57fa5d20925 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.common.security.oauthbearer;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Objects;
@@ -31,6 +32,8 @@
 import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.Login;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import 
org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClientProvider;
 import 
org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslServerProvider;
 import org.slf4j.Logger;
@@ -91,6 +94,16 @@
  * </tr>
  * </table>
  * <p>
+ * <p>
+ * You can also add custom unsecured SASL extensions when using the default, 
builtin {@link AuthenticateCallbackHandler}
+ * implementation through using the configurable option {@code 
unsecuredLoginExtension_<extensionname>}. Note that there
+ * are validations for the key/values in order to conform to the OAuth 
standard, including the reserved key at
+ * {@link 
org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse#AUTH_KEY}.
+ * The {@code OAuthBearerLoginModule} instance also asks its configured {@link 
AuthenticateCallbackHandler}
+ * implementation to handle an instance of {@link SaslExtensionsCallback} and 
return an instance of {@link SaslExtensions}.
+ * The configured callback handler does not need to handle this callback, 
though -- any {@code UnsupportedCallbackException}
+ * that is thrown is ignored, and no SASL extensions will be associated with 
the login.
+ * <p>
  * Production use cases will require writing an implementation of
  * {@link AuthenticateCallbackHandler} that can handle an instance of
  * {@link OAuthBearerTokenCallback} and declaring it via either the
@@ -227,10 +240,13 @@
      */
     public static final String OAUTHBEARER_MECHANISM = "OAUTHBEARER";
     private static final Logger log = 
LoggerFactory.getLogger(OAuthBearerLoginModule.class);
+    private static final SaslExtensions EMPTY_EXTENSIONS = new 
SaslExtensions(Collections.emptyMap());
     private Subject subject = null;
     private AuthenticateCallbackHandler callbackHandler = null;
     private OAuthBearerToken tokenRequiringCommit = null;
     private OAuthBearerToken myCommittedToken = null;
+    private SaslExtensions extensionsRequiringCommit = null;
+    private SaslExtensions myCommittedExtensions = null;
 
     static {
         OAuthBearerSaslClientProvider.initialize(); // not part of public API
@@ -256,22 +272,51 @@ public boolean login() throws LoginException {
             throw new IllegalStateException(String.format(
                     "Already have a committed token with private credential 
token count=%d; must login on another login context or logout here first before 
reusing the same login context",
                     committedTokenCount()));
-        OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
+
+        identifyToken();
+        identifyExtensions();
+
+        log.info("Login succeeded; invoke commit() to commit it; current 
committed token count={}",
+                committedTokenCount());
+        return true;
+    }
+
+    private void identifyToken() throws LoginException {
+        OAuthBearerTokenCallback tokenCallback = new 
OAuthBearerTokenCallback();
         try {
-            callbackHandler.handle(new Callback[] {callback});
+            callbackHandler.handle(new Callback[] {tokenCallback});
         } catch (IOException | UnsupportedCallbackException e) {
             log.error(e.getMessage(), e);
-            throw new LoginException("An internal error occurred");
+            throw new LoginException("An internal error occurred while 
retrieving token from callback handler");
         }
-        tokenRequiringCommit = callback.token();
+
+        tokenRequiringCommit = tokenCallback.token();
         if (tokenRequiringCommit == null) {
-            log.info(String.format("Login failed: %s : %s (URI=%s)", 
callback.errorCode(), callback.errorDescription(),
-                    callback.errorUri()));
-            throw new LoginException(callback.errorDescription());
+            log.info("Login failed: {} : {} (URI={})", 
tokenCallback.errorCode(), tokenCallback.errorDescription(),
+                    tokenCallback.errorUri());
+            throw new LoginException(tokenCallback.errorDescription());
+        }
+    }
+
+    /**
+     * Attaches SASL extensions to the Subject
+     */
+    private void identifyExtensions() throws LoginException {
+        SaslExtensionsCallback extensionsCallback = new 
SaslExtensionsCallback();
+        try {
+            callbackHandler.handle(new Callback[] {extensionsCallback});
+            extensionsRequiringCommit = extensionsCallback.extensions();
+        } catch (IOException e) {
+            log.error(e.getMessage(), e);
+            throw new LoginException("An internal error occurred while 
retrieving SASL extensions from callback handler");
+        } catch (UnsupportedCallbackException e) {
+            extensionsRequiringCommit = EMPTY_EXTENSIONS;
+            log.info("CallbackHandler {} does not support SASL extensions. No 
extensions will be added", callbackHandler.getClass().getName());
+        }
+        if (extensionsRequiringCommit ==  null) {
+            log.error("SASL Extensions cannot be null. Check whether your 
callback handler is explicitly setting them as null.");
+            throw new LoginException("Extensions cannot be null.");
         }
-        log.info("Login succeeded; invoke commit() to commit it; current 
committed token count={}",
-                committedTokenCount());
-        return true;
     }
 
     @Override
@@ -294,6 +339,12 @@ public boolean logout() {
             }
         }
         log.info("Done logging out my token; committed token count is now {}", 
committedTokenCount());
+
+        log.info("Logging out my extensions");
+        if (subject.getPublicCredentials().removeIf(e -> myCommittedExtensions 
== e))
+            myCommittedExtensions = null;
+        log.info("Done logging out my extensions");
+
         return true;
     }
 
@@ -304,11 +355,17 @@ public boolean commit() throws LoginException {
                 log.debug("Nothing here to commit");
             return false;
         }
+
         log.info("Committing my token; current committed token count = {}", 
committedTokenCount());
         subject.getPrivateCredentials().add(tokenRequiringCommit);
         myCommittedToken = tokenRequiringCommit;
         tokenRequiringCommit = null;
         log.info("Done committing my token; committed token count is now {}", 
committedTokenCount());
+
+        subject.getPublicCredentials().add(extensionsRequiringCommit);
+        myCommittedExtensions = extensionsRequiringCommit;
+        extensionsRequiringCommit = null;
+
         return true;
     }
 
@@ -317,6 +374,7 @@ public boolean abort() throws LoginException {
         if (tokenRequiringCommit != null) {
             log.info("Login aborted");
             tokenRequiringCommit = null;
+            extensionsRequiringCommit = null;
             return true;
         }
         if (log.isDebugEnabled())
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
index 8d4b18aede6..ef16ea237d4 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java
@@ -16,11 +16,11 @@
  */
 package org.apache.kafka.common.security.oauthbearer.internals;
 
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.utils.Utils;
 
 import javax.security.sasl.SaslException;
 import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -31,15 +31,19 @@
     private static final String SASLNAME = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+";
     private static final String KEY = "[A-Za-z]+";
     private static final String VALUE = "[\\x21-\\x7E \t\r\n]+";
+
     private static final String KVPAIRS = String.format("(%s=%s%s)*", KEY, 
VALUE, SEPARATOR);
     private static final Pattern AUTH_PATTERN = 
Pattern.compile("(?<scheme>[\\w]+)[ ]+(?<token>[-_\\.a-zA-Z0-9]+)");
     private static final Pattern CLIENT_INITIAL_RESPONSE_PATTERN = 
Pattern.compile(
             String.format("n,(a=(?<authzid>%s))?,%s(?<kvpairs>%s)%s", 
SASLNAME, SEPARATOR, KVPAIRS, SEPARATOR));
-    private static final String AUTH_KEY = "auth";
+    public static final String AUTH_KEY = "auth";
 
     private final String tokenValue;
     private final String authorizationId;
-    private final Map<String, String> properties;
+    private SaslExtensions saslExtensions;
+
+    public static final Pattern EXTENSION_KEY_PATTERN = Pattern.compile(KEY);
+    public static final Pattern EXTENSION_VALUE_PATTERN = 
Pattern.compile(VALUE);
 
     public OAuthBearerClientInitialResponse(byte[] response) throws 
SaslException {
         String responseMsg = new String(response, StandardCharsets.UTF_8);
@@ -49,10 +53,12 @@ public OAuthBearerClientInitialResponse(byte[] response) 
throws SaslException {
         String authzid = matcher.group("authzid");
         this.authorizationId = authzid == null ? "" : authzid;
         String kvPairs = matcher.group("kvpairs");
-        this.properties = Utils.parseMap(kvPairs, "=", SEPARATOR);
+        Map<String, String> properties = Utils.parseMap(kvPairs, "=", 
SEPARATOR);
         String auth = properties.get(AUTH_KEY);
         if (auth == null)
             throw new SaslException("Invalid OAUTHBEARER client first message: 
'auth' not specified");
+        properties.remove(AUTH_KEY);
+        this.saslExtensions = validateExtensions(new 
SaslExtensions(properties));
 
         Matcher authMatcher = AUTH_PATTERN.matcher(auth);
         if (!authMatcher.matches())
@@ -65,20 +71,29 @@ public OAuthBearerClientInitialResponse(byte[] response) 
throws SaslException {
         this.tokenValue = authMatcher.group("token");
     }
 
-    public OAuthBearerClientInitialResponse(String tokenValue) {
-        this(tokenValue, "", new HashMap<>());
+    public OAuthBearerClientInitialResponse(String tokenValue, SaslExtensions 
extensions) throws SaslException {
+        this(tokenValue, "", extensions);
     }
 
-    public OAuthBearerClientInitialResponse(String tokenValue, String 
authorizationId, Map<String, String> props) {
+    public OAuthBearerClientInitialResponse(String tokenValue, String 
authorizationId, SaslExtensions extensions) throws SaslException {
         this.tokenValue = tokenValue;
         this.authorizationId = authorizationId == null ? "" : authorizationId;
-        this.properties = new HashMap<>(props);
+        this.saslExtensions = validateExtensions(extensions);
+    }
+
+    public SaslExtensions extensions() {
+        return saslExtensions;
     }
 
     public byte[] toBytes() {
         String authzid = authorizationId.isEmpty() ? "" : "a=" + 
authorizationId;
-        String message = String.format("n,%s,%sauth=Bearer %s%s%s", authzid,
-                SEPARATOR, tokenValue, SEPARATOR, SEPARATOR);
+        String extensions = extensionsMessage();
+        if (extensions.length() > 0)
+            extensions = SEPARATOR + extensions;
+
+        String message = String.format("n,%s,%sauth=Bearer %s%s%s%s", authzid,
+                SEPARATOR, tokenValue, extensions, SEPARATOR, SEPARATOR);
+
         return message.getBytes(StandardCharsets.UTF_8);
     }
 
@@ -90,7 +105,32 @@ public String authorizationId() {
         return authorizationId;
     }
 
-    public String propertyValue(String name) {
-        return properties.get(name);
+    /**
+     * Validates that the given extensions conform to the standard. They 
should also not contain the reserve key name {@link 
OAuthBearerClientInitialResponse#AUTH_KEY}
+     *
+     * @see <a href="https://tools.ietf.org/html/rfc7628#section-3.1";>RFC 7628,
+     *  Section 3.1</a>
+     */
+    public static SaslExtensions validateExtensions(SaslExtensions extensions) 
throws SaslException {
+        if 
(extensions.map().containsKey(OAuthBearerClientInitialResponse.AUTH_KEY))
+            throw new SaslException("Extension name " + 
OAuthBearerClientInitialResponse.AUTH_KEY + " is invalid");
+
+        for (Map.Entry<String, String> entry : extensions.map().entrySet()) {
+            String extensionName = entry.getKey();
+            String extensionValue = entry.getValue();
+
+            if (!EXTENSION_KEY_PATTERN.matcher(extensionName).matches())
+                throw new SaslException("Extension name " + extensionName + " 
is invalid");
+            if (!EXTENSION_VALUE_PATTERN.matcher(extensionValue).matches())
+                throw new SaslException("Extension value (" + extensionValue + 
") for extension " + extensionName + " is invalid");
+        }
+        return extensions;
+    }
+
+    /**
+     * Converts the SASLExtensions to an OAuth protocol-friendly string
+     */
+    private String extensionsMessage() {
+        return Utils.mkString(saslExtensions.map(), "", "", "=", SEPARATOR);
     }
 }
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
index 4d4ee57b3a8..16db3c8b382 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java
@@ -30,6 +30,8 @@
 import javax.security.sasl.SaslException;
 
 import org.apache.kafka.common.errors.IllegalSaslStateException;
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
@@ -42,7 +44,8 @@
  * implementation requires an instance of {@code AuthenticateCallbackHandler}
  * that can handle an instance of {@link OAuthBearerTokenCallback} and return
  * the {@link OAuthBearerToken} generated by the {@code login()} event on the
- * {@code LoginContext}.
+ * {@code LoginContext}. Said handler can also optionally handle an instance 
of {@link SaslExtensionsCallback}
+ * to return any extensions generated by the {@code login()} event on the 
{@code LoginContext}.
  *
  * @see <a href="https://tools.ietf.org/html/rfc6750#section-2.1";>RFC 6750,
  *      Section 2.1</a>
@@ -87,8 +90,11 @@ public boolean hasInitialResponse() {
                     if (challenge != null && challenge.length != 0)
                         throw new SaslException("Expected empty challenge");
                     callbackHandler().handle(new Callback[] {callback});
+                    SaslExtensions extensions = retrieveCustomExtensions();
+
                     setState(State.RECEIVE_SERVER_FIRST_MESSAGE);
-                    return new 
OAuthBearerClientInitialResponse(callback.token().value()).toBytes();
+
+                    return new 
OAuthBearerClientInitialResponse(callback.token().value(), 
extensions).toBytes();
                 case RECEIVE_SERVER_FIRST_MESSAGE:
                     if (challenge != null && challenge.length != 0) {
                         String jsonErrorResponse = new String(challenge, 
StandardCharsets.UTF_8);
@@ -150,6 +156,20 @@ private void setState(State state) {
         this.state = state;
     }
 
+    private SaslExtensions retrieveCustomExtensions() throws SaslException {
+        SaslExtensionsCallback extensionsCallback = new 
SaslExtensionsCallback();
+        try {
+            callbackHandler().handle(new Callback[] {extensionsCallback});
+        } catch (UnsupportedCallbackException e) {
+            log.debug("Extensions callback is not supported by client callback 
handler {}, no extensions will be added",
+                    callbackHandler());
+        } catch (Exception e) {
+            throw new SaslException("SASL extensions could not be obtained", 
e);
+        }
+
+        return extensionsCallback.extensions();
+    }
+
     public static class OAuthBearerSaslClientFactory implements 
SaslClientFactory {
         @Override
         public SaslClient createSaslClient(String[] mechanisms, String 
authorizationId, String protocol,
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
index 586c5239165..ab2b7163256 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java
@@ -28,7 +28,9 @@
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.login.AppConfigurationEntry;
 
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
@@ -38,7 +40,9 @@
  * {@link OAuthBearerTokenCallback} and retrieves OAuth 2 Bearer Token that was
  * created when the {@code OAuthBearerLoginModule} logged in by looking for an
  * instance of {@link OAuthBearerToken} in the {@code Subject}'s private
- * credentials.
+ * credentials. This class also recognizes {@link SaslExtensionsCallback} and 
retrieves any SASL extensions that were
+ * created when the {@code OAuthBearerLoginModule} logged in by looking for an 
instance of {@link SaslExtensions}
+ * in the {@code Subject}'s public credentials
  * <p>
  * Use of this class is configured automatically and does not need to be
  * explicitly set via the {@code sasl.client.callback.handler.class}
@@ -70,6 +74,8 @@ public void handle(Callback[] callbacks) throws IOException, 
UnsupportedCallback
         for (Callback callback : callbacks) {
             if (callback instanceof OAuthBearerTokenCallback)
                 handleCallback((OAuthBearerTokenCallback) callback);
+            else if (callback instanceof SaslExtensionsCallback)
+                handleCallback((SaslExtensionsCallback) callback, 
Subject.getSubject(AccessController.getContext()));
             else
                 throw new UnsupportedCallbackException(callback);
         }
@@ -93,4 +99,14 @@ private void handleCallback(OAuthBearerTokenCallback 
callback) throws IOExceptio
                             privateCredentials.size()));
         callback.token(privateCredentials.iterator().next());
     }
+
+    /**
+     * Attaches the first {@link SaslExtensions} found in the public 
credentials of the Subject
+     */
+    private static void handleCallback(SaslExtensionsCallback 
extensionsCallback, Subject subject) {
+        if (subject != null && 
!subject.getPublicCredentials(SaslExtensions.class).isEmpty()) {
+            SaslExtensions extensions = 
subject.getPublicCredentials(SaslExtensions.class).iterator().next();
+            extensionsCallback.extensions(extensions);
+        }
+    }
 }
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
index aacc8fa3cbb..6573f695307 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java
@@ -31,6 +31,7 @@
 import javax.security.sasl.SaslServerFactory;
 
 import org.apache.kafka.common.errors.SaslAuthenticationException;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
@@ -46,6 +47,7 @@
  * for example).
  */
 public class OAuthBearerSaslServer implements SaslServer {
+
     private static final Logger log = 
LoggerFactory.getLogger(OAuthBearerSaslServer.class);
     private static final String NEGOTIATED_PROPERTY_KEY_TOKEN = 
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM + ".token";
     private static final String INTERNAL_ERROR_ON_SERVER = "Authentication 
could not be performed due to an internal error on the server";
@@ -55,6 +57,7 @@
     private boolean complete;
     private OAuthBearerToken tokenForNegotiatedProperty = null;
     private String errorMessage = null;
+    private SaslExtensions extensions;
 
     public OAuthBearerSaslServer(CallbackHandler callbackHandler) {
         if (!(Objects.requireNonNull(callbackHandler) instanceof 
AuthenticateCallbackHandler))
@@ -84,6 +87,7 @@ public OAuthBearerSaslServer(CallbackHandler callbackHandler) 
{
             throw new SaslAuthenticationException(errorMessage);
         }
         errorMessage = null;
+
         OAuthBearerClientInitialResponse clientResponse;
         try {
             clientResponse = new OAuthBearerClientInitialResponse(response);
@@ -91,7 +95,8 @@ public OAuthBearerSaslServer(CallbackHandler callbackHandler) 
{
             log.debug(e.getMessage());
             throw e;
         }
-        return process(clientResponse.tokenValue(), 
clientResponse.authorizationId());
+
+        return process(clientResponse.tokenValue(), 
clientResponse.authorizationId(), clientResponse.extensions());
     }
 
     @Override
@@ -110,7 +115,10 @@ public String getMechanismName() {
     public Object getNegotiatedProperty(String propName) {
         if (!complete)
             throw new IllegalStateException("Authentication exchange has not 
completed");
-        return NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName) ? 
tokenForNegotiatedProperty : null;
+        if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName))
+            return tokenForNegotiatedProperty;
+
+        return extensions.map().get(propName);
     }
 
     @Override
@@ -136,9 +144,10 @@ public boolean isComplete() {
     public void dispose() throws SaslException {
         complete = false;
         tokenForNegotiatedProperty = null;
+        extensions = null;
     }
 
-    private byte[] process(String tokenValue, String authorizationId) throws 
SaslException {
+    private byte[] process(String tokenValue, String authorizationId, 
SaslExtensions extensions) throws SaslException {
         OAuthBearerValidatorCallback callback = new 
OAuthBearerValidatorCallback(tokenValue);
         try {
             callbackHandler.handle(new Callback[] {callback});
@@ -165,6 +174,7 @@ public void dispose() throws SaslException {
                     "Authentication failed: Client requested an authorization 
id (%s) that is different from the token's principal name (%s)",
                     authorizationId, token.principalName()));
         tokenForNegotiatedProperty = token;
+        this.extensions = extensions;
         complete = true;
         if (log.isDebugEnabled())
             log.debug("Successfully authenticate User={}", 
token.principalName());
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
index 67a75ae9e1d..88399ace85c 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java
@@ -22,6 +22,7 @@
 import java.util.Base64;
 import java.util.Base64.Encoder;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -31,18 +32,23 @@
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.sasl.SaslException;
 
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
+import 
org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
 import org.apache.kafka.common.utils.Time;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
  * A {@code CallbackHandler} that recognizes {@link OAuthBearerTokenCallback}
- * and returns an unsecured OAuth 2 bearer token.
+ * to return an unsecured OAuth 2 bearer token and {@link 
SaslExtensionsCallback} to return SASL extensions
  * <p>
  * Claims and their values on the returned token can be specified using
  * {@code unsecuredLoginStringClaim_<claimname>},
@@ -52,6 +58,11 @@
  * name and value except '{@code iat}' and '{@code exp}', both of which are
  * calculated automatically.
  * <p>
+ * <p>
+ * You can also add custom unsecured SASL extensions using
+ * {@code unsecuredLoginExtension_<extensionname>}. Extension keys and values 
are subject to regex validation.
+ * The extension key must also not be equal to the reserved key {@link 
OAuthBearerClientInitialResponse#AUTH_KEY}
+ * <p>
  * This implementation also accepts the following options:
  * <ul>
  * <li>{@code unsecuredLoginPrincipalClaimName} set to a custom claim name if
@@ -72,7 +83,8 @@
  *      org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule 
Required
  *      unsecuredLoginStringClaim_sub="thePrincipalName"
  *      unsecuredLoginListClaim_scope="|scopeValue1|scopeValue2"
- *      unsecuredLoginLifetimeSeconds="60";
+ *      unsecuredLoginLifetimeSeconds="60"
+ *      unsecuredLoginExtension_traceId="123";
  * };
  * </pre>
  * 
@@ -96,6 +108,7 @@
     private static final String STRING_CLAIM_PREFIX = OPTION_PREFIX + 
"StringClaim_";
     private static final String NUMBER_CLAIM_PREFIX = OPTION_PREFIX + 
"NumberClaim_";
     private static final String LIST_CLAIM_PREFIX = OPTION_PREFIX + 
"ListClaim_";
+    private static final String EXTENSION_PREFIX = OPTION_PREFIX + 
"Extension_";
     private static final String QUOTE = "\"";
     private Time time = Time.SYSTEM;
     private Map<String, String> moduleOptions = null;
@@ -140,7 +153,13 @@ public void handle(Callback[] callbacks) throws 
IOException, UnsupportedCallback
         for (Callback callback : callbacks) {
             if (callback instanceof OAuthBearerTokenCallback)
                 try {
-                    handleCallback((OAuthBearerTokenCallback) callback);
+                    handleTokenCallback((OAuthBearerTokenCallback) callback);
+                } catch (KafkaException e) {
+                    throw new IOException(e.getMessage(), e);
+                }
+            else if (callback instanceof SaslExtensionsCallback)
+                try {
+                    handleExtensionsCallback((SaslExtensionsCallback) 
callback);
                 } catch (KafkaException e) {
                     throw new IOException(e.getMessage(), e);
                 }
@@ -154,7 +173,7 @@ public void close() {
         // empty
     }
 
-    private void handleCallback(OAuthBearerTokenCallback callback) throws 
IOException {
+    private void handleTokenCallback(OAuthBearerTokenCallback callback) throws 
IOException {
         if (callback.token() != null)
             throw new IllegalArgumentException("Callback had a token already");
         String principalClaimNameValue = 
optionValue(PRINCIPAL_CLAIM_NAME_OPTION);
@@ -190,6 +209,30 @@ private void handleCallback(OAuthBearerTokenCallback 
callback) throws IOExceptio
         }
     }
 
+    /**
+     *  Add and validate all the configured extensions.
+     *  Token keys, apart from passing regex validation, must not be equal to 
the reserved key {@link OAuthBearerClientInitialResponse#AUTH_KEY}
+     */
+    private void handleExtensionsCallback(SaslExtensionsCallback callback) {
+        Map<String, String> extensions = new HashMap<>();
+        for (Map.Entry<String, String> configEntry : 
this.moduleOptions.entrySet()) {
+            String key = configEntry.getKey();
+            if (!key.startsWith(EXTENSION_PREFIX))
+                continue;
+
+            extensions.put(key.substring(EXTENSION_PREFIX.length()), 
configEntry.getValue());
+        }
+
+        SaslExtensions saslExtensions = new SaslExtensions(extensions);
+        try {
+            
OAuthBearerClientInitialResponse.validateExtensions(saslExtensions);
+        } catch (SaslException e) {
+            throw new ConfigException(e.getMessage());
+        }
+
+        callback.extensions(saslExtensions);
+    }
+
     private String commaPrependedStringNumberAndListClaimsJsonText() throws 
OAuthBearerConfigException {
         StringBuilder sb = new StringBuilder();
         for (String key : moduleOptions.keySet()) {
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
 
b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
index debe163e36b..b83c94e04bc 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
@@ -14,13 +14,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.kafka.common.security.scram;
 
 import javax.security.auth.callback.Callback;
 import java.util.Collections;
 import java.util.Map;
 
+
 /**
  * Optional callback used for SCRAM mechanisms if any extensions need to be set
  * in the SASL/SCRAM exchange.
@@ -29,18 +29,18 @@
     private Map<String, String> extensions = Collections.emptyMap();
 
     /**
-     * Returns the extension names and values that are sent by the client to
+     * Returns map of the extension names and values that are sent by the 
client to
      * the server in the initial client SCRAM authentication message.
-     * Default is an empty map.
+     * Default is an empty unmodifiable map.
      */
     public Map<String, String> extensions() {
         return extensions;
     }
 
     /**
-     * Sets the SCRAM extensions on this callback.
+     * Sets the SCRAM extensions on this callback. Maps passed in should be 
unmodifiable
      */
     public void extensions(Map<String, String> extensions) {
         this.extensions = extensions;
     }
-}
\ No newline at end of file
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
index 5028329feb1..7b518908abe 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java
@@ -16,15 +16,14 @@
  */
 package org.apache.kafka.common.security.scram.internals;
 
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
 import org.apache.kafka.common.utils.Utils;
 
 import java.util.Collections;
 import java.util.Map;
-import java.util.Set;
 
-public class ScramExtensions {
-    private final Map<String, String> extensionMap;
+public class ScramExtensions extends SaslExtensions {
 
     public ScramExtensions() {
         this(Collections.<String, String>emptyMap());
@@ -35,23 +34,10 @@ public ScramExtensions(String extensions) {
     }
 
     public ScramExtensions(Map<String, String> extensionMap) {
-        this.extensionMap = extensionMap;
-    }
-
-    public String extensionValue(String name) {
-        return extensionMap.get(name);
-    }
-
-    public Set<String> extensionNames() {
-        return extensionMap.keySet();
+        super(extensionMap);
     }
 
     public boolean tokenAuthenticated() {
-        return 
Boolean.parseBoolean(extensionMap.get(ScramLoginModule.TOKEN_AUTH_CONFIG));
-    }
-
-    @Override
-    public String toString() {
-        return Utils.mkString(extensionMap, "", "", "=", ",");
+        return 
Boolean.parseBoolean(map().get(ScramLoginModule.TOKEN_AUTH_CONFIG));
     }
-}
\ No newline at end of file
+}
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
index b56d7592661..05512962906 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.security.scram.internals;
 
+import org.apache.kafka.common.utils.Utils;
+
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
 import java.util.Map;
@@ -112,7 +114,8 @@ public ScramExtensions extensions() {
         }
 
         public String clientFirstMessageBare() {
-            String extensionStr = extensions.toString();
+            String extensionStr = Utils.mkString(extensions.map(), "", "", 
"=", ",");
+
             if (extensionStr.isEmpty())
                 return String.format("n=%s,r=%s", saslName, nonce);
             else
diff --git 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
index d464e895b97..b11300abc21 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
@@ -98,9 +98,9 @@ public ScramSaslServer(ScramMechanism mechanism, Map<String, 
?> props, CallbackH
                 case RECEIVE_CLIENT_FIRST_MESSAGE:
                     this.clientFirstMessage = new ClientFirstMessage(response);
                     this.scramExtensions = clientFirstMessage.extensions();
-                    if 
(!SUPPORTED_EXTENSIONS.containsAll(scramExtensions.extensionNames())) {
+                    if 
(!SUPPORTED_EXTENSIONS.containsAll(scramExtensions.map().keySet())) {
                         log.debug("Unsupported extensions will be ignored, 
supported {}, provided {}",
-                                SUPPORTED_EXTENSIONS, 
scramExtensions.extensionNames());
+                                SUPPORTED_EXTENSIONS, 
scramExtensions.map().keySet());
                     }
                     String serverNonce = formatter.secureRandomString();
                     try {
@@ -183,7 +183,7 @@ public Object getNegotiatedProperty(String propName) {
             throw new IllegalStateException("Authentication exchange has not 
completed");
 
         if (SUPPORTED_EXTENSIONS.contains(propName))
-            return scramExtensions.extensionValue(propName);
+            return scramExtensions.map().get(propName);
         else
             return null;
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java 
b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 07f91a73f3d..6e0b693cc1b 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -498,6 +498,12 @@ public static String formatBytes(long bytes) {
         return sb.toString();
     }
 
+    /**
+     *  Converts a {@code Map} class into a string, concatenating keys and 
values
+     *  Example:
+     *      {@code mkString({ key: "hello", keyTwo: "hi" }, "|START|", 
"|END|", "=", ",")
+     *          => "|START|key=hello,keyTwo=hi|END|"}
+     */
     public static <K, V> String mkString(Map<K, V> map, String begin, String 
end,
                                          String keyValueSeparator, String 
elementSeparator) {
         StringBuilder bld = new StringBuilder();
@@ -512,6 +518,13 @@ public static String formatBytes(long bytes) {
         return bld.toString();
     }
 
+    /**
+     *  Converts an extensions string into a {@code Map<String, String>}.
+     *
+     *  Example:
+     *      {@code parseMap("key=hey,keyTwo=hi,keyThree=hello", "=", ",") => { 
key: "hey", keyTwo: "hi", keyThree: "hello" }}
+     *
+     */
     public static Map<String, String> parseMap(String mapStr, String 
keyValueSeparator, String elementSeparator) {
         Map<String, String> map = new HashMap<>();
 
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java
new file mode 100644
index 00000000000..77a45235ea5
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security;
+
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertNull;
+
+public class SaslExtensionsTest {
+    Map<String, String> map;
+
+    @Before
+    public void setUp() {
+        this.map = new HashMap<>();
+        this.map.put("what", "42");
+        this.map.put("who", "me");
+    }
+
+    @Test(expected = UnsupportedOperationException.class)
+    public void testReturnedMapIsImmutable() {
+        SaslExtensions extensions = new SaslExtensions(this.map);
+        extensions.map().put("hello", "test");
+    }
+
+    @Test
+    public void testCannotAddValueToMapReferenceAndGetFromExtensions() {
+        SaslExtensions extensions = new SaslExtensions(this.map);
+
+        assertNull(extensions.map().get("hello"));
+        this.map.put("hello", "42");
+        assertNull(extensions.map().get("hello"));
+    }
+}
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
index d883e5e9dcc..a9620fa6937 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java
@@ -19,6 +19,8 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertNotNull;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -36,16 +38,24 @@
 
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.easymock.EasyMock;
 import org.junit.Test;
 
 public class OAuthBearerLoginModuleTest {
-    private static class TestTokenCallbackHandler implements 
AuthenticateCallbackHandler {
+
+    public static final SaslExtensions RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG = 
null;
+
+    private static class TestCallbackHandler implements 
AuthenticateCallbackHandler {
         private final OAuthBearerToken[] tokens;
         private int index = 0;
+        private int extensionsIndex = 0;
+        private final SaslExtensions[] extensions;
 
-        public TestTokenCallbackHandler(OAuthBearerToken[] tokens) {
+        public TestCallbackHandler(OAuthBearerToken[] tokens, SaslExtensions[] 
extensions) {
             this.tokens = Objects.requireNonNull(tokens);
+            this.extensions = extensions;
         }
 
         @Override
@@ -57,7 +67,13 @@ public void handle(Callback[] callbacks) throws IOException, 
UnsupportedCallback
                     } catch (KafkaException e) {
                         throw new IOException(e.getMessage(), e);
                     }
-                else
+                else if (callback instanceof SaslExtensionsCallback) {
+                    try {
+                        handleExtensionsCallback((SaslExtensionsCallback) 
callback);
+                    } catch (KafkaException e) {
+                        throw new IOException(e.getMessage(), e);
+                    }
+                } else
                     throw new UnsupportedCallbackException(callback);
             }
         }
@@ -81,6 +97,19 @@ private void handleCallback(OAuthBearerTokenCallback 
callback) throws IOExceptio
             else
                 throw new IOException("no more tokens");
         }
+
+        private void handleExtensionsCallback(SaslExtensionsCallback callback) 
throws IOException, UnsupportedCallbackException {
+            if (extensions.length > extensionsIndex) {
+                SaslExtensions extension = extensions[extensionsIndex++];
+
+                if (extension == RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG) {
+                    throw new UnsupportedCallbackException(callback);
+                }
+
+                callback.extensions(extension);
+            } else
+                throw new IOException("no more extensions");
+        }
     }
 
     @Test
@@ -92,12 +121,16 @@ public void 
login1Commit1Login2Commit2Logout1Login3Commit3Logout2() throws Login
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] 
{EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class), 
EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] 
{EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class), 
EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1], tokens[2]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new 
TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[2]);
+        TestCallbackHandler testTokenCallbackHandler = new 
TestCallbackHandler(tokens, extensions);
 
         // Create login modules
         OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
@@ -112,47 +145,68 @@ public void 
login1Commit1Login2Commit2Logout1Login3Commit3Logout2() throws Login
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.commit();
-        // Now we should have the first token
+        // Now we should have the first token and extensions
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
 
         // Now login on loginModule2 to get the second token
+        // loginModule2 does not support the extensions callback and will 
raise UnsupportedCallbackException
         loginModule2.login();
-        // Should still have just the first token
+        // Should still have just the first token and extensions
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule2.commit();
         // Should have the first and second tokens at this point
         assertEquals(2, privateCredentials.size());
+        assertEquals(2, publicCredentials.size());
         Iterator<Object> iterator = privateCredentials.iterator();
+        Iterator<Object> publicIterator = publicCredentials.iterator();
         assertNotSame(tokens[2], iterator.next());
         assertNotSame(tokens[2], iterator.next());
+        assertNotSame(extensions[2], publicIterator.next());
+        assertNotSame(extensions[2], publicIterator.next());
         // finally logout() on loginModule1
         loginModule1.logout();
-        // Now we should have just the second token
+        // Now we should have just the second token and extension
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
 
         // Now login on loginModule3 to get the third token
         loginModule3.login();
-        // Should still have just the second token
+        // Should still have just the second token and extensions
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
         loginModule3.commit();
         // Should have the second and third tokens at this point
         assertEquals(2, privateCredentials.size());
+        assertEquals(2, publicCredentials.size());
         iterator = privateCredentials.iterator();
+        publicIterator = publicCredentials.iterator();
         assertNotSame(tokens[0], iterator.next());
         assertNotSame(tokens[0], iterator.next());
+        assertNotSame(extensions[0], publicIterator.next());
+        assertNotSame(extensions[0], publicIterator.next());
         // finally logout() on loginModule2
         loginModule2.logout();
         // Now we should have just the third token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[2], privateCredentials.iterator().next());
+        assertSame(extensions[2], publicCredentials.iterator().next());
     }
 
     @Test
@@ -163,12 +217,16 @@ public void login1Commit1Logout1Login2Commit2Logout2() 
throws LoginException {
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] 
{EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] 
{EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new 
TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[1]);
+        TestCallbackHandler testTokenCallbackHandler = new 
TestCallbackHandler(tokens, extensions);
 
         // Create login modules
         OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
@@ -180,27 +238,36 @@ public void login1Commit1Logout1Login2Commit2Logout2() 
throws LoginException {
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.commit();
         // Now we should have the first token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule1.logout();
         // Should have nothing again
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
 
         loginModule2.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule2.commit();
         // Now we should have the second token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
         loginModule2.logout();
         // Should have nothing again
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
     }
 
     @Test
@@ -210,12 +277,16 @@ public void loginAbortLoginCommitLogout() throws 
LoginException {
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] 
{EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] 
{EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new 
TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[1]);
+        TestCallbackHandler testTokenCallbackHandler = new 
TestCallbackHandler(tokens, extensions);
 
         // Create login module
         OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule();
@@ -224,23 +295,30 @@ public void loginAbortLoginCommitLogout() throws 
LoginException {
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule.abort();
         // Should still have nothing since we aborted
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
 
         loginModule.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule.commit();
         // Now we should have the second token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[1], privateCredentials.iterator().next());
+        assertSame(extensions[1], publicCredentials.iterator().next());
         loginModule.logout();
         // Should have nothing again
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
     }
 
     @Test
@@ -251,12 +329,16 @@ public void 
login1Commit1Login2Abort2Login3Commit3Logout3() throws LoginExceptio
          */
         Subject subject = new Subject();
         Set<Object> privateCredentials = subject.getPrivateCredentials();
+        Set<Object> publicCredentials = subject.getPublicCredentials();
 
         // Create callback handler
         OAuthBearerToken[] tokens = new OAuthBearerToken[] 
{EasyMock.mock(OAuthBearerToken.class),
             EasyMock.mock(OAuthBearerToken.class), 
EasyMock.mock(OAuthBearerToken.class)};
+        SaslExtensions[] extensions = new SaslExtensions[] 
{EasyMock.mock(SaslExtensions.class),
+            EasyMock.mock(SaslExtensions.class), 
EasyMock.mock(SaslExtensions.class)};
         EasyMock.replay(tokens[0], tokens[1], tokens[2]); // expect nothing
-        TestTokenCallbackHandler testTokenCallbackHandler = new 
TestTokenCallbackHandler(tokens);
+        EasyMock.replay(extensions[0], extensions[1], extensions[2]);
+        TestCallbackHandler testTokenCallbackHandler = new 
TestCallbackHandler(tokens, extensions);
 
         // Create login modules
         OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
@@ -271,38 +353,81 @@ public void 
login1Commit1Login2Abort2Login3Commit3Logout3() throws LoginExceptio
 
         // Should start with nothing
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.login();
         // Should still have nothing until commit() is called
         assertEquals(0, privateCredentials.size());
+        assertEquals(0, publicCredentials.size());
         loginModule1.commit();
         // Now we should have the first token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
 
         // Now go get the second token
         loginModule2.login();
         // Should still have first token
         assertEquals(1, privateCredentials.size());
+        assertEquals(1, publicCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule2.abort();
         // Should still have just the first token because we aborted
         assertEquals(1, privateCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertEquals(1, publicCredentials.size());
+        assertSame(extensions[0], publicCredentials.iterator().next());
 
         // Now go get the third token
         loginModule2.login();
         // Should still have first token
         assertEquals(1, privateCredentials.size());
         assertSame(tokens[0], privateCredentials.iterator().next());
+        assertEquals(1, publicCredentials.size());
+        assertSame(extensions[0], publicCredentials.iterator().next());
         loginModule2.commit();
         // Should have first and third tokens at this point
         assertEquals(2, privateCredentials.size());
         Iterator<Object> iterator = privateCredentials.iterator();
         assertNotSame(tokens[1], iterator.next());
         assertNotSame(tokens[1], iterator.next());
+        assertEquals(2, publicCredentials.size());
+        Iterator<Object> publicIterator = publicCredentials.iterator();
+        assertNotSame(extensions[1], publicIterator.next());
+        assertNotSame(extensions[1], publicIterator.next());
         loginModule1.logout();
         // Now we should have just the third token
         assertEquals(1, privateCredentials.size());
         assertSame(tokens[2], privateCredentials.iterator().next());
+        assertEquals(1, publicCredentials.size());
+        assertSame(extensions[2], publicCredentials.iterator().next());
+    }
+
+    /**
+     * 2.1.0 added customizable SASL extensions and a new callback type.
+     * Ensure that old, custom-written callbackHandlers that do not handle the 
callback work
+     */
+    @Test
+    public void commitDoesNotThrowOnUnsupportedExtensionsCallback() throws 
LoginException {
+        Subject subject = new Subject();
+
+        // Create callback handler
+        OAuthBearerToken[] tokens = new OAuthBearerToken[] 
{EasyMock.mock(OAuthBearerToken.class),
+                EasyMock.mock(OAuthBearerToken.class), 
EasyMock.mock(OAuthBearerToken.class)};
+        EasyMock.replay(tokens[0], tokens[1], tokens[2]); // expect nothing
+        TestCallbackHandler testTokenCallbackHandler = new 
TestCallbackHandler(tokens, new SaslExtensions[] 
{RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG});
+
+        // Create login modules
+        OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule();
+        loginModule1.initialize(subject, testTokenCallbackHandler, 
Collections.emptyMap(),
+                Collections.emptyMap());
+
+        loginModule1.login();
+        // Should populate public credentials with SaslExtensions and not 
throw an exception
+        loginModule1.commit();
+        SaslExtensions extensions = 
subject.getPublicCredentials(SaslExtensions.class).iterator().next();
+        assertNotNull(extensions);
+        assertTrue(extensions.map().isEmpty());
     }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
index eccf2dd2ed4..3de6408accd 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java
@@ -18,12 +18,49 @@
 
 import static org.junit.Assert.assertEquals;
 
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.junit.Test;
 
+import javax.security.sasl.SaslException;
 import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
 
 public class OAuthBearerClientInitialResponseTest {
 
+    /*
+        Test how a client would build a response
+     */
+    @Test
+    public void testBuildClientResponseToBytes() throws Exception {
+        String expectedMesssage = "n,,\u0001auth=Bearer 
123.345.567\u0001nineteen=42\u0001\u0001";
+
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("nineteen", "42");
+        OAuthBearerClientInitialResponse response = new 
OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions));
+
+        String message = new String(response.toBytes(), 
StandardCharsets.UTF_8);
+
+        assertEquals(expectedMesssage, message);
+    }
+
+    @Test
+    public void testBuildServerResponseToBytes() throws Exception {
+        String serverMessage = "n,,\u0001auth=Bearer 
123.345.567\u0001nineteen=42\u0001\u0001";
+        OAuthBearerClientInitialResponse response = new 
OAuthBearerClientInitialResponse(serverMessage.getBytes(StandardCharsets.UTF_8));
+
+        String message = new String(response.toBytes(), 
StandardCharsets.UTF_8);
+
+        assertEquals(serverMessage, message);
+    }
+
+    @Test(expected = SaslException.class)
+    public void testThrowsSaslExceptionOnInvalidExtensionKey() throws 
Exception {
+        Map<String, String> extensions = new HashMap<>();
+        extensions.put("19", "42"); // keys can only be a-z
+        new OAuthBearerClientInitialResponse("123.345.567", new 
SaslExtensions(extensions));
+    }
+
     @Test
     public void testToken() throws Exception {
         String message = "n,,\u0001auth=Bearer 123.345.567\u0001\u0001";
@@ -41,13 +78,13 @@ public void testAuthorizationId() throws Exception {
     }
 
     @Test
-    public void testProperties() throws Exception {
+    public void testExtensions() throws Exception {
         String message = "n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 
567\u0001propB=valueB\u0001\u0001";
         OAuthBearerClientInitialResponse response = new 
OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8));
         assertEquals("567", response.tokenValue());
         assertEquals("", response.authorizationId());
-        assertEquals("valueA1, valueA2", response.propertyValue("propA"));
-        assertEquals("valueB", response.propertyValue("propB"));
+        assertEquals("valueA1, valueA2", 
response.extensions().map().get("propA"));
+        assertEquals("valueB", response.extensions().map().get("propB"));
     }
 
     // The example in the RFC uses 
`vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==` as the token
@@ -59,7 +96,7 @@ public void testRfc7688Example() throws Exception {
         OAuthBearerClientInitialResponse response = new 
OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8));
         assertEquals("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", 
response.tokenValue());
         assertEquals("u...@example.com", response.authorizationId());
-        assertEquals("server.example.com", response.propertyValue("host"));
-        assertEquals("143", response.propertyValue("port"));
+        assertEquals("server.example.com", 
response.extensions().map().get("host"));
+        assertEquals("143", response.extensions().map().get("port"));
     }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java
new file mode 100644
index 00000000000..55a86245da4
--- /dev/null
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.oauthbearer.internals;
+
+import org.apache.kafka.common.config.ConfigException;
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
+import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
+import 
org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredJws;
+import org.easymock.EasyMockSupport;
+import org.junit.Test;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.sasl.SaslException;
+import java.nio.charset.StandardCharsets;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class OAuthBearerSaslClientTest extends EasyMockSupport {
+
+    private static final Map<String, String> TEST_PROPERTIES = new 
LinkedHashMap<String, String>() {
+        {
+            put("One", "1");
+            put("Two", "2");
+            put("Three", "3");
+        }
+    };
+    private SaslExtensions testExtensions = new 
SaslExtensions(TEST_PROPERTIES);
+    private final String errorMessage = "Error as expected!";
+
+    public class ExtensionsCallbackHandler implements 
AuthenticateCallbackHandler {
+        private boolean configured = false;
+        private boolean toThrow;
+
+        ExtensionsCallbackHandler(boolean toThrow) {
+            this.toThrow = toThrow;
+        }
+
+        public boolean configured() {
+            return configured;
+        }
+
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, 
List<AppConfigurationEntry> jaasConfigEntries) {
+            configured = true;
+        }
+
+        @Override
+        public void handle(Callback[] callbacks) throws 
UnsupportedCallbackException {
+            for (Callback callback : callbacks) {
+                if (callback instanceof OAuthBearerTokenCallback)
+                    ((OAuthBearerTokenCallback) 
callback).token(createMock(OAuthBearerUnsecuredJws.class));
+                else if (callback instanceof SaslExtensionsCallback) {
+                    if (toThrow)
+                        throw new ConfigException(errorMessage);
+                    else
+                        ((SaslExtensionsCallback) 
callback).extensions(testExtensions);
+                } else
+                    throw new UnsupportedCallbackException(callback);
+            }
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    @Test
+    public void testAttachesExtensionsToFirstClientMessage() throws Exception {
+        String expectedToken = new String(new 
OAuthBearerClientInitialResponse(null, testExtensions).toBytes(), 
StandardCharsets.UTF_8);
+
+        OAuthBearerSaslClient client = new OAuthBearerSaslClient(new 
ExtensionsCallbackHandler(false));
+
+        String message = new String(client.evaluateChallenge("".getBytes()), 
StandardCharsets.UTF_8);
+
+        assertEquals(expectedToken, message);
+    }
+
+    @Test
+    public void testNoExtensionsDoesNotAttachAnythingToFirstClientMessage() 
throws Exception {
+        TEST_PROPERTIES.clear();
+        testExtensions = new SaslExtensions(TEST_PROPERTIES);
+        String expectedToken = new String(new 
OAuthBearerClientInitialResponse(null, new 
SaslExtensions(TEST_PROPERTIES)).toBytes(), StandardCharsets.UTF_8);
+        OAuthBearerSaslClient client = new OAuthBearerSaslClient(new 
ExtensionsCallbackHandler(false));
+
+        String message = new String(client.evaluateChallenge("".getBytes()), 
StandardCharsets.UTF_8);
+
+        assertEquals(expectedToken, message);
+    }
+
+    @Test
+    public void 
testWrapsExtensionsCallbackHandlingErrorInSaslExceptionInFirstClientMessage() {
+        OAuthBearerSaslClient client = new OAuthBearerSaslClient(new 
ExtensionsCallbackHandler(true));
+        try {
+            client.evaluateChallenge("".getBytes());
+            fail("Should have failed with " + SaslException.class.getName());
+        } catch (SaslException e) {
+            // assert it has caught our expected exception
+            assertEquals(ConfigException.class, e.getCause().getClass());
+            assertEquals(errorMessage, e.getCause().getMessage());
+        }
+
+    }
+}
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
index 6b53e963af7..fc96f9f8adc 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java
@@ -18,6 +18,7 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertNull;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
@@ -34,6 +35,7 @@
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.SaslExtensions;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
@@ -68,7 +70,7 @@
     private OAuthBearerSaslServer saslServer;
 
     @Before
-    public void setUp() throws Exception {
+    public void setUp() {
         saslServer = new OAuthBearerSaslServer(VALIDATOR_CALLBACK_HANDLER);
     }
 
@@ -79,6 +81,32 @@ public void noAuthorizationIdSpecified() throws Exception {
         assertTrue("Next challenge is not empty", nextChallenge.length == 0);
     }
 
+    @Test
+    public void savesCustomExtensionAsNegotiatedProperty() throws Exception {
+        Map<String, String> customExtensions = new HashMap<>();
+        customExtensions.put("firstKey", "value1");
+        customExtensions.put("secondKey", "value2");
+
+        byte[] nextChallenge = saslServer
+                .evaluateResponse(clientInitialResponse(null, false, 
customExtensions));
+
+        assertTrue("Next challenge is not empty", nextChallenge.length == 0);
+        assertEquals("value1", saslServer.getNegotiatedProperty("firstKey"));
+        assertEquals("value2", saslServer.getNegotiatedProperty("secondKey"));
+    }
+
+    @Test
+    public void returnsNullForNonExistentProperty() throws Exception {
+        Map<String, String> customExtensions = new HashMap<>();
+        customExtensions.put("firstKey", "value1");
+
+        byte[] nextChallenge = saslServer
+                .evaluateResponse(clientInitialResponse(null, false, 
customExtensions));
+
+        assertTrue("Next challenge is not empty", nextChallenge.length == 0);
+        assertNull(saslServer.getNegotiatedProperty("secondKey"));
+    }
+
     @Test
     public void authorizatonIdEqualsAuthenticationId() throws Exception {
         byte[] nextChallenge = saslServer
@@ -93,7 +121,7 @@ public void authorizatonIdNotEqualsAuthenticationId() throws 
Exception {
 
     @Test
     public void illegalToken() throws Exception {
-        byte[] bytes = saslServer.evaluateResponse(clientInitialResponse(null, 
true));
+        byte[] bytes = saslServer.evaluateResponse(clientInitialResponse(null, 
true, Collections.emptyMap()));
         String challenge = new String(bytes, StandardCharsets.UTF_8);
         assertEquals("{\"status\":\"invalid_token\"}", challenge);
     }
@@ -105,11 +133,17 @@ public void illegalToken() throws Exception {
 
     private byte[] clientInitialResponse(String authorizationId, boolean 
illegalToken)
             throws OAuthBearerConfigException, IOException, 
UnsupportedCallbackException, LoginException {
+        return clientInitialResponse(authorizationId, false, 
Collections.emptyMap());
+    }
+
+    private byte[] clientInitialResponse(String authorizationId, boolean 
illegalToken, Map<String, String> customExtensions)
+            throws OAuthBearerConfigException, IOException, 
UnsupportedCallbackException {
         OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
         LOGIN_CALLBACK_HANDLER.handle(new Callback[] {callback});
         OAuthBearerToken token = callback.token();
         String compactSerialization = token.value();
+
         String tokenValue = compactSerialization + (illegalToken ? "AB" : "");
-        return new OAuthBearerClientInitialResponse(tokenValue, 
authorizationId, Collections.emptyMap()).toBytes();
+        return new OAuthBearerClientInitialResponse(tokenValue, 
authorizationId, new SaslExtensions(customExtensions)).toBytes();
     }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
index a5c216d860c..be01fe3d7cf 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java
@@ -31,6 +31,7 @@
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.login.LoginException;
 
+import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
 import org.apache.kafka.common.security.authenticator.TestJaasConfig;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
 import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
@@ -38,6 +39,39 @@
 import org.junit.Test;
 
 public class OAuthBearerUnsecuredLoginCallbackHandlerTest {
+
+    @Test
+    public void addsExtensions() throws IOException, 
UnsupportedCallbackException {
+        Map<String, String> options = new HashMap<>();
+        options.put("unsecuredLoginExtension_testId", "1");
+        OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = 
createCallbackHandler(options, new MockTime());
+        SaslExtensionsCallback callback = new SaslExtensionsCallback();
+
+        callbackHandler.handle(new Callback[] {callback});
+
+        assertEquals("1", callback.extensions().map().get("testId"));
+    }
+
+    @Test(expected = IOException.class)
+    public void throwsErrorOnInvalidExtensionName() throws IOException, 
UnsupportedCallbackException {
+        Map<String, String> options = new HashMap<>();
+        options.put("unsecuredLoginExtension_test.Id", "1");
+        OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = 
createCallbackHandler(options, new MockTime());
+        SaslExtensionsCallback callback = new SaslExtensionsCallback();
+
+        callbackHandler.handle(new Callback[] {callback});
+    }
+
+    @Test(expected = IOException.class)
+    public void throwsErrorOnInvalidExtensionValue() throws IOException, 
UnsupportedCallbackException {
+        Map<String, String> options = new HashMap<>();
+        options.put("unsecuredLoginExtension_testId", "Çalifornia");
+        OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = 
createCallbackHandler(options, new MockTime());
+        SaslExtensionsCallback callback = new SaslExtensionsCallback();
+
+        callbackHandler.handle(new Callback[] {callback});
+    }
+
     @Test
     public void minimalToken() throws IOException, 
UnsupportedCallbackException {
         Map<String, String> options = new HashMap<>();
diff --git a/docs/security.html b/docs/security.html
index 7f765090f6f..743d673fe80 100644
--- a/docs/security.html
+++ b/docs/security.html
@@ -750,6 +750,13 @@ <h3><a id="security_sasl" href="#security_sasl">7.3 
Authentication using SASL</a
                  automatically generated).</td>
                  </tr>
                  <tr>
+                 
<td><tt>unsecuredLoginExtension_&lt;extensionname&gt;="value"</tt></td>
+                 <td>Creates a <tt>String</tt> extension with the given name 
and value.
+                 For example: <tt>unsecuredLoginExtension_traceId="123"</tt>. 
A valid extension name
+                 is any sequence of lowercase or uppercase alphabet 
characters. In addition, the "auth" extension name is reserved.
+                 A valid extension value is any combination of characters with 
ASCII codes 1-127.
+                 </tr>
+                 <tr>
                  <td><tt>unsecuredLoginPrincipalClaimName</tt></td>
                  <td>Set to a custom claim name if you wish the name of the 
<tt>String</tt>
                  claim holding the principal name to be something other than 
'<tt>sub</tt>'.</td>


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Add support for Custom SASL extensions in OAuth authentication
> --------------------------------------------------------------
>
>                 Key: KAFKA-7169
>                 URL: https://issues.apache.org/jira/browse/KAFKA-7169
>             Project: Kafka
>          Issue Type: Improvement
>            Reporter: Stanislav Kozlovski
>            Assignee: Stanislav Kozlovski
>            Priority: Minor
>
> KIP: 
> [here|https://cwiki.apache.org/confluence/display/KAFKA/KIP-342%3A+Add+support+for+Custom+SASL+extensions+in+OAuth+authentication]
> Kafka currently supports non-configurable SASL extensions in its SCRAM 
> authentication protocol for delegation token validation. It would be useful 
> to provide configurable SASL extensions for the OAuthBearer authentication 
> mechanism as well, such that clients could attach arbitrary data for the 
> principal authenticating into Kafka. This way, a custom principal can hold 
> information derived from the authentication mechanism, which could prove 
> useful for better tracing and troubleshooting, for example. This can be done 
> in a way which allows for easier extendability in future SASL mechanisms.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to