Github user zentol commented on a diff in the pull request: https://github.com/apache/flink/pull/4767#discussion_r142640401 --- Diff: flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java --- @@ -275,4 +301,121 @@ public HttpResponseStatus getHttpResponseStatus() { return httpResponseStatus; } } + + public <M extends MessageHeaders<EmptyRequestBody, WebSocketUpgradeResponseBody, U>, U extends MessageParameters, R extends ResponseBody> CompletableFuture<WebSocket> sendWebSocketRequest(String targetAddress, int targetPort, M messageHeaders, U messageParameters, Class<R> messageClazz, WebSocketListener... listeners) throws IOException { + Preconditions.checkNotNull(targetAddress); + Preconditions.checkArgument(0 <= targetPort && targetPort < 65536, "The target port " + targetPort + " is not in the range (0, 65536]."); + Preconditions.checkNotNull(messageHeaders); + Preconditions.checkNotNull(messageParameters); + Preconditions.checkState(messageParameters.isResolved(), "Message parameters were not resolved."); + + String targetUrl = MessageParameters.resolveUrl(messageHeaders.getTargetRestEndpointURL(), messageParameters); + URI webSocketURL = URI.create("ws://" + targetAddress + ":" + targetPort).resolve(targetUrl); + LOG.debug("Sending WebSocket request to {}", webSocketURL); + + final HttpHeaders headers = new DefaultHttpHeaders() + .add(HttpHeaders.Names.CONTENT_TYPE, RestConstants.REST_CONTENT_TYPE); + + Bootstrap bootstrap1 = bootstrap.clone().handler(new ClientBootstrap() { + @Override + protected void initChannel(SocketChannel channel) throws Exception { + super.initChannel(channel); + channel.pipeline() + .addLast(new WebSocketClientProtocolHandler(webSocketURL, WebSocketVersion.V13, null, false, headers, 65535)) + .addLast(new WsResponseHandler(channel, messageClazz, listeners)); + } + }); + + return CompletableFuture.supplyAsync(() -> bootstrap1.connect(targetAddress, targetPort), executor) + .thenApply((channel) -> { + try { + return channel.sync(); + } catch (InterruptedException e) { + throw new FlinkRuntimeException(e); + } + }) + .thenApply((ChannelFuture::channel)) + .thenCompose(channel -> { + WsResponseHandler handler = channel.pipeline().get(WsResponseHandler.class); + return handler.getWebSocketFuture(); + }); + } + + private static class WsResponseHandler extends SimpleChannelInboundHandler<Object> implements WebSocket { + + private final Channel channel; + private final Class<? extends ResponseBody> messageClazz; + private final List<WebSocketListener> listeners = new CopyOnWriteArrayList<>(); + + private final CompletableFuture<WebSocket> webSocketFuture = new CompletableFuture<>(); + + CompletableFuture<WebSocket> getWebSocketFuture() { + return webSocketFuture; + } + + public WsResponseHandler(Channel channel, Class<? extends ResponseBody> messageClazz, WebSocketListener[] listeners) { + this.channel = channel; + this.messageClazz = messageClazz; + this.listeners.addAll(Arrays.asList(listeners)); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + LOG.warn("WebSocket exception", cause); + webSocketFuture.completeExceptionally(cause); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent) { + WebSocketClientProtocolHandler.ClientHandshakeStateEvent wsevt = (WebSocketClientProtocolHandler.ClientHandshakeStateEvent) evt; + switch(wsevt) { --- End diff -- missing space after switch
---