jeffkbkim commented on code in PR #12845:
URL: https://github.com/apache/kafka/pull/12845#discussion_r1027036420


##########
core/src/main/scala/kafka/server/KafkaApis.scala:
##########
@@ -1647,69 +1656,51 @@ class KafkaApis(val requestChannel: RequestChannel,
     }
   }
 
-  def handleJoinGroupRequest(request: RequestChannel.Request, requestLocal: 
RequestLocal): Unit = {
-    val joinGroupRequest = request.body[JoinGroupRequest]
+  private def makeGroupCoordinatorRequestContext(
+    request: RequestChannel.Request,
+    requestLocal: RequestLocal
+  ): GroupCoordinatorRequestContext = {
+    new GroupCoordinatorRequestContext(
+      request.context.header.data.requestApiVersion,
+      request.context.header.data.clientId,
+      request.context.clientAddress,
+      requestLocal.bufferSupplier
+    )
+  }
 
-    // the callback for sending a join-group response
-    def sendResponseCallback(joinResult: JoinGroupResult): Unit = {
-      def createResponse(requestThrottleMs: Int): AbstractResponse = {
-        val responseBody = new JoinGroupResponse(
-          new JoinGroupResponseData()
-            .setThrottleTimeMs(requestThrottleMs)
-            .setErrorCode(joinResult.error.code)
-            .setGenerationId(joinResult.generationId)
-            .setProtocolType(joinResult.protocolType.orNull)
-            .setProtocolName(joinResult.protocolName.orNull)
-            .setLeader(joinResult.leaderId)
-            .setSkipAssignment(joinResult.skipAssignment)
-            .setMemberId(joinResult.memberId)
-            .setMembers(joinResult.members.asJava),
-          request.context.apiVersion
-        )
+  def handleJoinGroupRequest(
+    request: RequestChannel.Request,
+    requestLocal: RequestLocal
+  ): CompletableFuture[Unit] = {
+    val joinGroupRequest = request.body[JoinGroupRequest]
 
-        trace("Sending join group response %s for correlation id %d to client 
%s."
-          .format(responseBody, request.header.correlationId, 
request.header.clientId))
-        responseBody
-      }
-      requestHelper.sendResponseMaybeThrottle(request, createResponse)
+    def sendResponse(response: AbstractResponse): Unit = {
+      trace("Sending join group response %s for correlation id %d to client 
%s."
+        .format(response, request.header.correlationId, 
request.header.clientId))
+      requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => {
+        response.maybeSetThrottleTimeMs(requestThrottleMs)
+        response
+      })
     }
 
     if (joinGroupRequest.data.groupInstanceId != null && 
config.interBrokerProtocolVersion.isLessThan(IBP_2_3_IV0)) {
       // Only enable static membership when IBP >= 2.3, because it is not safe 
for the broker to use the static member logic
       // until we are sure that all brokers support it. If static group being 
loaded by an older coordinator, it will discard
       // the group.instance.id field, so static members could accidentally 
become "dynamic", which leads to wrong states.
-      sendResponseCallback(JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, 
Errors.UNSUPPORTED_VERSION))
+      
sendResponse(joinGroupRequest.getErrorResponse(Errors.UNSUPPORTED_VERSION.exception))
+      CompletableFuture.completedFuture[Unit](())

Review Comment:
   `newGroupCoordinator.joinGroup()` captures exceptions inside the future 
returned to `handleJoinGroupRequest()` which is tested by the unit test. if we 
return `CompletableFuture.completedFuture()` here and in line 1694, would we be 
capturing the exceptions thrown during sendResponse()? not sure if that would 
happen, but thought it was worth mentioning since we seem to be simulating that 
in the unit test.
   
   i guess i'm wondering whether we are handling these cases differently and if 
so, why



##########
core/src/test/scala/unit/kafka/server/KafkaApisTest.scala:
##########
@@ -2524,196 +2530,208 @@ class KafkaApisTest {
     assertEquals(MemoryRecords.EMPTY, 
FetchResponse.recordsOrFail(partitionData))
   }
 
-  @Test
-  def testJoinGroupProtocolsOrder(): Unit = {
-    val protocols = List(
-      ("first", "first".getBytes()),
-      ("second", "second".getBytes())
+  @ParameterizedTest
+  @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP)
+  def testHandleJoinGroupRequest(version: Short): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
+
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build(version))
+
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      version,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
     )
 
-    val groupId = "group"
-    val memberId = "member1"
-    val protocolType = "consumer"
-    val rebalanceTimeoutMs = 10
-    val sessionTimeoutMs = 5
-    val capturedProtocols: ArgumentCaptor[List[(String, Array[Byte])]] = 
ArgumentCaptor.forClass(classOf[List[(String, Array[Byte])]])
+    val expectedJoinGroupRequest = new JoinGroupRequestData()
+      .setGroupId(joinGroupRequest.groupId)
+      .setMemberId(joinGroupRequest.memberId)
+      .setProtocolType(joinGroupRequest.protocolType)
+      .setRebalanceTimeoutMs(if (version >= 1) 
joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs)
+      .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs)
 
-    createKafkaApis().handleJoinGroupRequest(
-      buildRequest(
-        new JoinGroupRequest.Builder(
-          new JoinGroupRequestData()
-            .setGroupId(groupId)
-            .setMemberId(memberId)
-            .setProtocolType(protocolType)
-            .setRebalanceTimeoutMs(rebalanceTimeoutMs)
-            .setSessionTimeoutMs(sessionTimeoutMs)
-            .setProtocols(new 
JoinGroupRequestData.JoinGroupRequestProtocolCollection(
-              protocols.map { case (name, protocol) => new 
JoinGroupRequestProtocol()
-                .setName(name).setMetadata(protocol)
-              }.iterator.asJava))
-        ).build()
-      ),
-      RequestLocal.withThreadConfinedCaching)
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(expectedJoinGroupRequest)
+    )).thenReturn(future)
 
-    verify(groupCoordinator).handleJoinGroup(
-      ArgumentMatchers.eq(groupId),
-      ArgumentMatchers.eq(memberId),
-      ArgumentMatchers.eq(None),
-      ArgumentMatchers.eq(true),
-      ArgumentMatchers.eq(true),
-      ArgumentMatchers.eq(clientId),
-      ArgumentMatchers.eq(InetAddress.getLocalHost.toString),
-      ArgumentMatchers.eq(rebalanceTimeoutMs),
-      ArgumentMatchers.eq(sessionTimeoutMs),
-      ArgumentMatchers.eq(protocolType),
-      capturedProtocols.capture(),
-      any(),
-      any(),
-      any()
+    createKafkaApis().handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
     )
-    val capturedProtocolsList = capturedProtocols.getValue
-    assertEquals(protocols.size, capturedProtocolsList.size)
-    protocols.zip(capturedProtocolsList).foreach { case ((expectedName, 
expectedBytes), (name, bytes)) =>
-      assertEquals(expectedName, name)
-      assertArrayEquals(expectedBytes, bytes)
-    }
-  }
 
-  @Test
-  def testJoinGroupWhenAnErrorOccurs(): Unit = {
-    for (version <- ApiKeys.JOIN_GROUP.oldestVersion to 
ApiKeys.JOIN_GROUP.latestVersion) {
-      testJoinGroupWhenAnErrorOccurs(version.asInstanceOf[Short])
-    }
-  }
+    val expectedJoinGroupResponse = new JoinGroupResponseData()
+      .setMemberId("member")
+      .setGenerationId(0)
+      .setLeader("leader")
+      .setProtocolType("consumer")
+      .setProtocolName("range")
 
-  def testJoinGroupWhenAnErrorOccurs(version: Short): Unit = {
-    reset(groupCoordinator, clientRequestQuotaManager, requestChannel, 
replicaManager)
+    future.complete(expectedJoinGroupResponse)
+    val capturedResponse = verifyNoThrottling(requestChannelRequest)
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
+    assertEquals(expectedJoinGroupResponse, response.data)
+  }
 
-    val groupId = "group"
-    val memberId = "member1"
-    val protocolType = "consumer"
-    val rebalanceTimeoutMs = 10
-    val sessionTimeoutMs = 5
+  @ParameterizedTest
+  @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP)
+  def testJoinGroupProtocolNameBackwardCompatibility(version: Short): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
+
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build(version))
+
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      version,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
+    )
 
-    val capturedCallback: ArgumentCaptor[JoinGroupCallback] = 
ArgumentCaptor.forClass(classOf[JoinGroupCallback])
+    val expectedJoinGroupRequest = new JoinGroupRequestData()
+      .setGroupId(joinGroupRequest.groupId)
+      .setMemberId(joinGroupRequest.memberId)
+      .setProtocolType(joinGroupRequest.protocolType)
+      .setRebalanceTimeoutMs(if (version >= 1) 
joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs)
+      .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs)
 
-    val joinGroupRequest = new JoinGroupRequest.Builder(
-      new JoinGroupRequestData()
-        .setGroupId(groupId)
-        .setMemberId(memberId)
-        .setProtocolType(protocolType)
-        .setRebalanceTimeoutMs(rebalanceTimeoutMs)
-        .setSessionTimeoutMs(sessionTimeoutMs)
-    ).build(version)
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(expectedJoinGroupRequest)
+    )).thenReturn(future)
 
-    val requestChannelRequest = buildRequest(joinGroupRequest)
+    createKafkaApis().handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
+    )
 
-    createKafkaApis().handleJoinGroupRequest(requestChannelRequest, 
RequestLocal.withThreadConfinedCaching)
+    val joinGroupResponse = new JoinGroupResponseData()
+      .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code)
+      .setMemberId("member")
+      .setProtocolName(null)
 
-    verify(groupCoordinator).handleJoinGroup(
-      ArgumentMatchers.eq(groupId),
-      ArgumentMatchers.eq(memberId),
-      ArgumentMatchers.eq(None),
-      ArgumentMatchers.eq(if (version >= 4) true else false),
-      ArgumentMatchers.eq(if (version >= 9) true else false),
-      ArgumentMatchers.eq(clientId),
-      ArgumentMatchers.eq(InetAddress.getLocalHost.toString),
-      ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else 
sessionTimeoutMs),
-      ArgumentMatchers.eq(sessionTimeoutMs),
-      ArgumentMatchers.eq(protocolType),
-      ArgumentMatchers.eq(List.empty),
-      capturedCallback.capture(),
-      any(),
-      any()
-    )
-    capturedCallback.getValue.apply(JoinGroupResult(memberId, 
Errors.INCONSISTENT_GROUP_PROTOCOL))
+    val expectedJoinGroupResponse = new JoinGroupResponseData()
+      .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code)
+      .setMemberId("member")
+      .setProtocolName(if (version >= 7) null else GroupCoordinator.NoProtocol)
 
+    future.complete(joinGroupResponse)
     val capturedResponse = verifyNoThrottling(requestChannelRequest)
     val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
-
-    assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error)
-    assertEquals(0, response.data.members.size)
-    assertEquals(memberId, response.data.memberId)
-    assertEquals(GroupCoordinator.NoGeneration, response.data.generationId)
-    assertEquals(GroupCoordinator.NoLeader, response.data.leader)
-    assertNull(response.data.protocolType)
-
-    if (version >= 7) {
-      assertNull(response.data.protocolName)
-    } else {
-      assertEquals(GroupCoordinator.NoProtocol, response.data.protocolName)
-    }
+    assertEquals(expectedJoinGroupResponse, response.data)
   }
 
   @Test
-  def testJoinGroupProtocolType(): Unit = {
-    for (version <- ApiKeys.JOIN_GROUP.oldestVersion to 
ApiKeys.JOIN_GROUP.latestVersion) {
-      testJoinGroupProtocolType(version.asInstanceOf[Short])
-    }
-  }
+  def testHandleJoinGroupRequestFutureFailed(): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
 
-  def testJoinGroupProtocolType(version: Short): Unit = {
-    reset(groupCoordinator, clientRequestQuotaManager, requestChannel, 
replicaManager)
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build())
 
-    val groupId = "group"
-    val memberId = "member1"
-    val protocolType = "consumer"
-    val protocolName = "range"
-    val rebalanceTimeoutMs = 10
-    val sessionTimeoutMs = 5
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      ApiKeys.JOIN_GROUP.latestVersion,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
+    )
 
-    val capturedCallback: ArgumentCaptor[JoinGroupCallback] = 
ArgumentCaptor.forClass(classOf[JoinGroupCallback])
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(joinGroupRequest)
+    )).thenReturn(future)
 
-    val joinGroupRequest = new JoinGroupRequest.Builder(
-      new JoinGroupRequestData()
-        .setGroupId(groupId)
-        .setMemberId(memberId)
-        .setProtocolType(protocolType)
-        .setRebalanceTimeoutMs(rebalanceTimeoutMs)
-        .setSessionTimeoutMs(sessionTimeoutMs)
-    ).build(version)
+    createKafkaApis().handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
+    )
 
-    val requestChannelRequest = buildRequest(joinGroupRequest)
+    future.completeExceptionally(Errors.REQUEST_TIMED_OUT.exception)
+    val capturedResponse = verifyNoThrottling(requestChannelRequest)
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
+    assertEquals(Errors.REQUEST_TIMED_OUT, response.error)
+  }
 
-    createKafkaApis().handleJoinGroupRequest(requestChannelRequest, 
RequestLocal.withThreadConfinedCaching)
+  @Test
+  def testHandleJoinGroupRequestAuthorizationFailed(): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
 
-    verify(groupCoordinator).handleJoinGroup(
-      ArgumentMatchers.eq(groupId),
-      ArgumentMatchers.eq(memberId),
-      ArgumentMatchers.eq(None),
-      ArgumentMatchers.eq(if (version >= 4) true else false),
-      ArgumentMatchers.eq(if (version >= 9) true else false),
-      ArgumentMatchers.eq(clientId),
-      ArgumentMatchers.eq(InetAddress.getLocalHost.toString),
-      ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else 
sessionTimeoutMs),
-      ArgumentMatchers.eq(sessionTimeoutMs),
-      ArgumentMatchers.eq(protocolType),
-      ArgumentMatchers.eq(List.empty),
-      capturedCallback.capture(),
-      any(),
-      any()
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build())
+
+    val authorizer: Authorizer = mock(classOf[Authorizer])
+    when(authorizer.authorize(any[RequestContext], any[util.List[Action]]))
+      .thenReturn(Seq(AuthorizationResult.DENIED).asJava)
+
+    createKafkaApis(authorizer = Some(authorizer)).handleJoinGroupRequest(
+      requestChannelRequest,
+      RequestLocal.NoCaching
     )
-    capturedCallback.getValue.apply(JoinGroupResult(
-      members = List.empty,
-      memberId = memberId,
-      generationId = 0,
-      protocolType = Some(protocolType),
-      protocolName = Some(protocolName),
-      leaderId = memberId,
-      skipAssignment = true,
-      error = Errors.NONE
-    ))
+
     val capturedResponse = verifyNoThrottling(requestChannelRequest)
     val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
+    assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, response.error)
+  }
 
-    assertEquals(Errors.NONE, response.error)
-    assertEquals(0, response.data.members.size)
-    assertEquals(memberId, response.data.memberId)
-    assertEquals(0, response.data.generationId)
-    assertEquals(memberId, response.data.leader)
-    assertEquals(protocolName, response.data.protocolName)
-    assertEquals(protocolType, response.data.protocolType)
-    assertTrue(response.data.skipAssignment)
+  @Test
+  def testHandleJoinGroupRequestUnexpectedException(): Unit = {
+    val joinGroupRequest = new JoinGroupRequestData()
+      .setGroupId("group")
+      .setMemberId("member")
+      .setProtocolType("consumer")
+      .setRebalanceTimeoutMs(1000)
+      .setSessionTimeoutMs(2000)
+
+    val requestChannelRequest = buildRequest(new 
JoinGroupRequest.Builder(joinGroupRequest).build())
+
+    val expectedRequestContext = new GroupCoordinatorRequestContext(
+      ApiKeys.JOIN_GROUP.latestVersion,
+      requestChannelRequest.context.clientId,
+      requestChannelRequest.context.clientAddress,
+      RequestLocal.NoCaching.bufferSupplier
+    )
+
+    val future = new CompletableFuture[JoinGroupResponseData]()
+    when(newGroupCoordinator.joinGroup(
+      ArgumentMatchers.eq(expectedRequestContext),
+      ArgumentMatchers.eq(joinGroupRequest)
+    )).thenReturn(future)
+
+    val response = new AtomicReference[JoinGroupResponse]()
+    when(requestChannel.sendResponse(any(), any(), any())).thenAnswer { _ =>
+      throw new Exception("Something went wrong")
+    }.thenAnswer { invocation =>
+      val resp = invocation.getArgument(1, classOf[JoinGroupResponse])
+      response.set(resp)
+    }
+
+    createKafkaApis().handle(
+      requestChannelRequest,
+      RequestLocal.NoCaching
+    )

Review Comment:
   gotcha. i think i was missing this from the CompletableFuture java docs
   
   > Actions supplied for dependent completions of non-async methods may be 
performed by the thread that completes the current CompletableFuture, or by any 
other caller of a completion method.
   
   so the thread that completes the future will attempt to finish all 
(non-async) dependent chains



##########
core/src/main/scala/kafka/server/KafkaApis.scala:
##########
@@ -1647,69 +1656,51 @@ class KafkaApis(val requestChannel: RequestChannel,
     }
   }
 
-  def handleJoinGroupRequest(request: RequestChannel.Request, requestLocal: 
RequestLocal): Unit = {
-    val joinGroupRequest = request.body[JoinGroupRequest]
+  private def makeGroupCoordinatorRequestContext(
+    request: RequestChannel.Request,
+    requestLocal: RequestLocal
+  ): GroupCoordinatorRequestContext = {
+    new GroupCoordinatorRequestContext(
+      request.context.header.data.requestApiVersion,
+      request.context.header.data.clientId,
+      request.context.clientAddress,
+      requestLocal.bufferSupplier
+    )
+  }
 
-    // the callback for sending a join-group response
-    def sendResponseCallback(joinResult: JoinGroupResult): Unit = {
-      def createResponse(requestThrottleMs: Int): AbstractResponse = {
-        val responseBody = new JoinGroupResponse(
-          new JoinGroupResponseData()
-            .setThrottleTimeMs(requestThrottleMs)
-            .setErrorCode(joinResult.error.code)
-            .setGenerationId(joinResult.generationId)
-            .setProtocolType(joinResult.protocolType.orNull)
-            .setProtocolName(joinResult.protocolName.orNull)
-            .setLeader(joinResult.leaderId)
-            .setSkipAssignment(joinResult.skipAssignment)
-            .setMemberId(joinResult.memberId)
-            .setMembers(joinResult.members.asJava),
-          request.context.apiVersion
-        )
+  def handleJoinGroupRequest(
+    request: RequestChannel.Request,
+    requestLocal: RequestLocal
+  ): CompletableFuture[Unit] = {

Review Comment:
   i am also curious, that is one of my confusions :sweat_smile: 
   
   > All APIs will be async for this reason but not that most of them are 
already async today
   
   can you elaborate on the reason and clarify whether most of the APIs are 
async or not? i was under the impression that a single API request -> response 
is handled synchronously by a request handler thread.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to