AHeise commented on a change in pull request #11725:
URL: https://github.com/apache/flink/pull/11725#discussion_r421296506



##########
File path: 
flink-core/src/test/java/org/apache/flink/util/PropertiesUtilTest.java
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.util.PropertiesUtil.flatten;
+
+/**
+ * Tests for the {@link PropertiesUtil}.
+ */
+public class PropertiesUtilTest {
+
+       @Test
+       public void testFlatten() {
+               // default Properties is null
+               Properties prop1 = new Properties();
+               prop1.put("key1", "value1");
+
+               // default Properties is prop1
+               Properties prop2 = new Properties(prop1);
+               prop2.put("key2", "value2");
+
+               // default Properties is prop2
+               Properties prop3 = new Properties(prop2);
+               prop3.put("key3", "value3");
+
+               Properties flattened = flatten(prop3);
+               Assert.assertEquals(flattened.get("key1"), 
prop3.getProperty("key1"));
+               Assert.assertEquals(flattened.get("key2"), 
prop3.getProperty("key2"));
+               Assert.assertEquals(flattened.get("key3"), 
prop3.getProperty("key3"));
+               Assert.assertNotEquals(flattened.get("key1"), 
prop3.get("key1"));
+               Assert.assertNotEquals(flattened.get("key2"), 
prop3.get("key2"));
+               Assert.assertEquals(flattened.get("key3"), prop3.get("key3"));

Review comment:
       I'd remove these assertions and only keep the last 3.

##########
File path: flink-core/src/main/java/org/apache/flink/util/PropertiesUtil.java
##########
@@ -108,6 +109,27 @@ public static boolean getBoolean(Properties config, String 
key, boolean defaultV
                }
        }
 
+       /**
+        * Flatten a recursive {@link Properties} to a first level property map.
+        * In some cases, {KafkaProducer#propsToMap} for example, Properties is 
used purely as a HashMap
+        * without considering its default properties.
+        *
+        * @param config Properties to be flatten
+        * @return Properties without defaults; all properties are put in the 
first-level
+        */
+       public static Properties flatten(Properties config) {

Review comment:
       > The flattened properties are actually used in the Kafka client lib, 
not that easy to fix.
   
   Does that mean that Kafka is actually not processing the recursive 
Properties correctly? We should probably file a bug report then.

##########
File path: 
flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
##########
@@ -166,6 +167,13 @@ protected TXN currentTransaction() {
         */
        protected abstract void invoke(TXN transaction, IN value, Context 
context) throws Exception;
 
+       /**
+        * Handle watermark within a transaction.
+        */
+       protected void invoke(TXN transaction, Watermark watermark) throws 
Exception {
+               throw new UnsupportedOperationException("invokeWithWatermark 
should not be invoked");
+       }
+

Review comment:
       Not necessary. `KafkaShuffleProducer` can use the protected 
`currentTransaction()`.

##########
File path: 
flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
##########
@@ -235,6 +243,10 @@ public final void invoke(
                invoke(currentTransactionHolder.handle, value, context);
        }
 
+       public final void invoke(Watermark watermark) throws Exception {
+               invoke(currentTransactionHolder.handle, watermark);
+       }
+

Review comment:
       Not necessary. `KafkaShuffleProducer` can use the protected 
`currentTransaction()`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling 
elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> 
{
+       private final KafkaSerializer<IN> kafkaSerializer;
+       private final KeySelector<IN, KEY> keySelector;
+       private final int numberOfPartitions;
+
+       FlinkKafkaShuffleProducer(
+               String defaultTopicId,

Review comment:
       Nit: Method parameters must be double-indented. Cannot be automatically 
done with IntelliJ unfortunately :(.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling 
elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> 
{
+       private final KafkaSerializer<IN> kafkaSerializer;
+       private final KeySelector<IN, KEY> keySelector;
+       private final int numberOfPartitions;
+
+       FlinkKafkaShuffleProducer(
+               String defaultTopicId,
+               TypeInformationSerializationSchema<IN> schema,
+               Properties props,
+               KeySelector<IN, KEY> keySelector,
+               Semantic semantic,
+               int kafkaProducersPoolSize) {
+               super(defaultTopicId, (element, timestamp) -> null, props, 
semantic, kafkaProducersPoolSize);
+
+               this.kafkaSerializer = new 
KafkaSerializer<>(schema.getSerializer());
+               this.keySelector = keySelector;
+
+               Preconditions.checkArgument(
+                       props.getProperty(PARTITION_NUMBER) != null,
+                       "Missing partition number for Kafka Shuffle");
+               numberOfPartitions = PropertiesUtil.getInt(props, 
PARTITION_NUMBER, Integer.MIN_VALUE);
+       }
+
+       /**
+        * This is the function invoked to handle each element.
+        * @param transaction transaction state;
+        *                    elements are written to Kafka in transactions to 
guarantee different level of data consistency
+        * @param next element to handle
+        * @param context context needed to handle the element
+        * @throws FlinkKafkaException for kafka error
+        */
+       @Override
+       public void invoke(KafkaTransactionState transaction, IN next, Context 
context) throws FlinkKafkaException {
+               checkErroneous();
+
+               // write timestamp to Kafka if timestamp is available
+               Long timestamp = context.timestamp();
+
+               int[] partitions = getPartitions(transaction);
+               int partitionIndex;
+               try {
+                       partitionIndex = KeyGroupRangeAssignment
+                               
.assignKeyToParallelOperator(keySelector.getKey(next), partitions.length, 
partitions.length);
+               } catch (Exception e) {
+                       throw new RuntimeException("Fail to assign a partition 
number to record");
+               }
+
+               ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(
+                       defaultTopicId, partitionIndex, timestamp, null, 
kafkaSerializer.serializeRecord(next, timestamp));
+               pendingRecords.incrementAndGet();
+               transaction.getProducer().send(record, callback);
+       }
+
+       /**
+        * This is the function invoked to handle each watermark.
+        * @param transaction transaction state;
+        *                    watermark are written to Kafka (if needed) in 
transactions
+        * @param watermark watermark to handle
+        * @throws FlinkKafkaException for kafka error
+        */
+       @Override
+       public void invoke(KafkaTransactionState transaction, Watermark 
watermark) throws FlinkKafkaException {
+               checkErroneous();
+
+               int[] partitions = getPartitions(transaction);
+               int subtask = getRuntimeContext().getIndexOfThisSubtask();
+
+               // broadcast watermark
+               long timestamp = watermark.getTimestamp();
+               for (int partition : partitions) {
+                       ProducerRecord<byte[], byte[]> record = new 
ProducerRecord<>(
+                               defaultTopicId, partition, timestamp, null, 
kafkaSerializer.serializeWatermark(watermark, subtask));
+                       pendingRecords.incrementAndGet();
+                       transaction.getProducer().send(record, callback);
+               }
+       }
+
+       private int[] getPartitions(KafkaTransactionState transaction) {
+               int[] partitions = topicPartitionsMap.get(defaultTopicId);
+               if (partitions == null) {
+                       partitions = getPartitionsByTopic(defaultTopicId, 
transaction.getProducer());
+                       topicPartitionsMap.put(defaultTopicId, partitions);
+               }
+
+               Preconditions.checkArgument(partitions.length == 
numberOfPartitions);
+
+               return partitions;
+       }
+
+       /**
+        * Flink Kafka Shuffle Serializer.
+        */
+       public static final class KafkaSerializer<IN> implements Serializable {
+               public static final int TAG_REC_WITH_TIMESTAMP = 0;
+               public static final int TAG_REC_WITHOUT_TIMESTAMP = 1;
+               public static final int TAG_WATERMARK = 2;

Review comment:
       Is this supposed to be public or not? It should probably be 
package-private.
   
   I was also thinking of pulling it as top-level class, which then also 
incorporates the deserializing stuff of the next commit.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+               SourceFunction.SourceContext<T> sourceContext,
+               Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+               SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+               SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+               ProcessingTimeService processingTimeProvider,
+               long autoWatermarkInterval,
+               ClassLoader userCodeClassLoader,
+               String taskNameWithSubtasks,
+               TypeSerializer<T> serializer,
+               Properties kafkaProperties,
+               long pollTimeout,
+               MetricGroup subtaskMetricGroup,
+               MetricGroup consumerMetricGroup,
+               boolean useMetrics,
+               int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?

Review comment:
       I'd assume so, or else bounded inputs won't work well.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+               SourceFunction.SourceContext<T> sourceContext,
+               Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+               SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+               SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+               ProcessingTimeService processingTimeProvider,
+               long autoWatermarkInterval,
+               ClassLoader userCodeClassLoader,
+               String taskNameWithSubtasks,
+               TypeSerializer<T> serializer,
+               Properties kafkaProperties,
+               long pollTimeout,
+               MetricGroup subtaskMetricGroup,
+               MetricGroup consumerMetricGroup,
+               boolean useMetrics,
+               int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+               Map<KafkaTopicPartition, Long> offsets,
+               @Nonnull KafkaCommitCallback commitCallback) throws Exception {
+
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       private abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       private static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {
+               final int subtask;
+               final long watermark;
+
+               KafkaShuffleWatermark(int subtask, long watermark) {
+                       this.subtask = subtask;
+                       this.watermark = watermark;
+               }
+       }
+
+       private static class KafkaShuffleRecord<T> extends 
KafkaShuffleElement<T> {
+               final T value;
+               final Long timestamp;
+
+               KafkaShuffleRecord(T value) {
+                       this.value = value;
+                       this.timestamp = null;
+               }
+
+               KafkaShuffleRecord(long timestamp, T value) {
+                       this.value = value;
+                       this.timestamp = timestamp;
+               }
+       }
+

Review comment:
       `KafkaShuffleElement` seems over-engineered. I guess having a holder for 
timestamp + object is enough and then simply use `instanceof` checks.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+               SourceFunction.SourceContext<T> sourceContext,

Review comment:
       nit: indent.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling 
elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> 
{
+       private final KafkaSerializer<IN> kafkaSerializer;
+       private final KeySelector<IN, KEY> keySelector;
+       private final int numberOfPartitions;
+
+       FlinkKafkaShuffleProducer(
+               String defaultTopicId,
+               TypeInformationSerializationSchema<IN> schema,
+               Properties props,
+               KeySelector<IN, KEY> keySelector,
+               Semantic semantic,
+               int kafkaProducersPoolSize) {
+               super(defaultTopicId, (element, timestamp) -> null, props, 
semantic, kafkaProducersPoolSize);
+
+               this.kafkaSerializer = new 
KafkaSerializer<>(schema.getSerializer());
+               this.keySelector = keySelector;
+
+               Preconditions.checkArgument(
+                       props.getProperty(PARTITION_NUMBER) != null,
+                       "Missing partition number for Kafka Shuffle");
+               numberOfPartitions = PropertiesUtil.getInt(props, 
PARTITION_NUMBER, Integer.MIN_VALUE);
+       }
+
+       /**
+        * This is the function invoked to handle each element.
+        * @param transaction transaction state;
+        *                    elements are written to Kafka in transactions to 
guarantee different level of data consistency
+        * @param next element to handle
+        * @param context context needed to handle the element
+        * @throws FlinkKafkaException for kafka error
+        */
+       @Override
+       public void invoke(KafkaTransactionState transaction, IN next, Context 
context) throws FlinkKafkaException {
+               checkErroneous();
+
+               // write timestamp to Kafka if timestamp is available
+               Long timestamp = context.timestamp();
+
+               int[] partitions = getPartitions(transaction);
+               int partitionIndex;
+               try {
+                       partitionIndex = KeyGroupRangeAssignment
+                               
.assignKeyToParallelOperator(keySelector.getKey(next), partitions.length, 
partitions.length);
+               } catch (Exception e) {
+                       throw new RuntimeException("Fail to assign a partition 
number to record");
+               }
+
+               ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(

Review comment:
       Can you explain me once again, why we store timestamp directly in 
`ProducerRecord` and still also serialize it? Seems redundant.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Write to and read from a kafka shuffle with the partition decided by 
keys.
+        * Consumers should read partitions equal to the key group indices they 
are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream input stream to the kafka
+        * @param topic kafka topic
+        * @param producerParallelism parallelism of producer
+        * @param numberOfPartitions number of partitions
+        * @param properties Kafka properties
+        * @param fields key positions from inputStream
+        * @param <T> input type
+        */
+       static <T> KeyedStream<T, Tuple> persistentKeyBy(

Review comment:
       `public`

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+               SourceFunction.SourceContext<T> sourceContext,
+               Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+               SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+               SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+               ProcessingTimeService processingTimeProvider,
+               long autoWatermarkInterval,
+               ClassLoader userCodeClassLoader,
+               String taskNameWithSubtasks,
+               TypeSerializer<T> serializer,
+               Properties kafkaProperties,
+               long pollTimeout,
+               MetricGroup subtaskMetricGroup,
+               MetricGroup consumerMetricGroup,
+               boolean useMetrics,
+               int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);

Review comment:
       Perform under checkpoint lock?

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+class FlinkKafkaShuffle {

Review comment:
       If this is API, I guess it should be `public` and either 
`@PublicEvolving` or `@Experimental`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+               SourceFunction.SourceContext<T> sourceContext,
+               Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+               SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+               SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+               ProcessingTimeService processingTimeProvider,
+               long autoWatermarkInterval,
+               ClassLoader userCodeClassLoader,
+               String taskNameWithSubtasks,
+               TypeSerializer<T> serializer,
+               Properties kafkaProperties,
+               long pollTimeout,
+               MetricGroup subtaskMetricGroup,
+               MetricGroup consumerMetricGroup,
+               boolean useMetrics,
+               int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+               Map<KafkaTopicPartition, Long> offsets,
+               @Nonnull KafkaCommitCallback commitCallback) throws Exception {
+
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       private abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       private static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {
+               final int subtask;
+               final long watermark;
+
+               KafkaShuffleWatermark(int subtask, long watermark) {
+                       this.subtask = subtask;
+                       this.watermark = watermark;
+               }
+       }
+
+       private static class KafkaShuffleRecord<T> extends 
KafkaShuffleElement<T> {
+               final T value;
+               final Long timestamp;
+
+               KafkaShuffleRecord(T value) {
+                       this.value = value;
+                       this.timestamp = null;
+               }
+
+               KafkaShuffleRecord(long timestamp, T value) {
+                       this.value = value;
+                       this.timestamp = timestamp;
+               }
+       }
+
+       private static class KafkaShuffleElementDeserializer<T> implements 
Serializable {
+               private transient DataInputDeserializer dis;
+
+               KafkaShuffleElementDeserializer() {
+                       this.dis = new DataInputDeserializer();
+               }
+
+               KafkaShuffleElement<T> deserialize(TypeSerializer<T> 
serializer, ConsumerRecord<byte[], byte[]> record)
+                       throws Exception {
+                       byte[] value = record.value();
+                       dis.setBuffer(value);
+                       int tag = IntSerializer.INSTANCE.deserialize(dis);
+
+                       if (tag == TAG_REC_WITHOUT_TIMESTAMP) {
+                               return new 
KafkaShuffleRecord<>(serializer.deserialize(dis));
+                       } else if (tag == TAG_REC_WITH_TIMESTAMP) {
+                               return new 
KafkaShuffleRecord<>(LongSerializer.INSTANCE.deserialize(dis), 
serializer.deserialize(dis));

Review comment:
       Again, why do we serialize timestamp in the payload and not take it from 
`ConsumerRecord`?

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Write to and read from a kafka shuffle with the partition decided by 
keys.
+        * Consumers should read partitions equal to the key group indices they 
are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream input stream to the kafka
+        * @param topic kafka topic
+        * @param producerParallelism parallelism of producer
+        * @param numberOfPartitions number of partitions

Review comment:
       Shouldn't that be the same?

##########
File path: 
flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
##########
@@ -264,7 +264,7 @@ public FlinkKafkaConsumerBase(
         * @param properties - Kafka configuration properties to be adjusted
         * @param offsetCommitMode offset commit mode
         */
-       static void adjustAutoCommitConfig(Properties properties, 
OffsetCommitMode offsetCommitMode) {
+       public static void adjustAutoCommitConfig(Properties properties, 
OffsetCommitMode offsetCommitMode) {

Review comment:
       protected?

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+       @BeforeClass
+       public static void prepare() throws Exception {
+               KafkaProducerTestBase.prepare();
+               ((KafkaTestEnvironmentImpl) 
kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with the default 
time characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleProcessingTime() throws Exception {
+               simpleEndToEndTest("test_simple_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleIngestionTime() throws Exception {
+               simpleEndToEndTest("test_simple_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleEventTime() throws Exception {
+               simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionProcessingTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionIngestionTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionEventTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_event_time", 100000, 
EventTime);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1
+        * To test no data is lost or duplicated end-2-end
+        */
+       private void simpleEndToEndTest(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)
+               throws Exception {
+               final int numberOfPartitions = 1;
+               final int producerParallelism = 1;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(timeCharacteristic);
+
+               DataStream<Tuple3<Integer, Long, String>> source =
+                       env.addSource(new 
KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+               DataStream<Tuple3<Integer, Long, String>> input = 
(timeCharacteristic == EventTime) ?
+                       source.assignTimestampsAndWatermarks(new 
PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+               Properties properties = kafkaServer.getStandardProperties();
+               FlinkKafkaShuffle
+                       .persistentKeyBy(input, topic, producerParallelism, 
numberOfPartitions, properties, 0)
+                       .map(new ElementCountNoMoreThanValidator(elementCount * 
producerParallelism)).setParallelism(1)
+                       .map(new ElementCountNoLessThanValidator(elementCount * 
producerParallelism)).setParallelism(1);
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3
+        * To test data is partitioned to the right partition
+        */
+       private void testAssignedToPartition(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)
+               throws Exception {
+               final int numberOfPartitions = 3;
+               final int producerParallelism = 2;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(EventTime);
+
+               DataStream<Tuple3<Integer, Long, String>> source =
+                       env.addSource(new 
KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+               DataStream<Tuple3<Integer, Long, String>> input = 
(timeCharacteristic == EventTime) ?
+                       source.assignTimestampsAndWatermarks(new 
PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+               // ------- Write data to Kafka partition basesd on 
FlinkKafkaPartitioner ------
+               Properties properties = kafkaServer.getStandardProperties();
+
+               KeyedStream<Tuple3<Integer, Long, String>, Tuple> keyedStream = 
FlinkKafkaShuffle
+                       .persistentKeyBy(input, topic, producerParallelism, 
numberOfPartitions, properties, 0);
+               keyedStream
+                       .process(new 
PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+                       .setParallelism(numberOfPartitions)
+                       .map(new ElementCountNoMoreThanValidator(elementCount * 
producerParallelism)).setParallelism(1)
+                       .map(new ElementCountNoLessThanValidator(elementCount * 
producerParallelism)).setParallelism(1);
+
+               tryExecute(env, "KafkaShuffle partition assignment test");
+
+               deleteTestTopic(topic);

Review comment:
       Extract to avoid duplicate code with `simpleEndToEndTest`.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+       @BeforeClass
+       public static void prepare() throws Exception {
+               KafkaProducerTestBase.prepare();
+               ((KafkaTestEnvironmentImpl) 
kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with the default 
time characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)

Review comment:
       Instead of setting timeout to all methods, I'd go with a JUnit rule:
   ```
        @Rule
        public final Timeout timeout = Timeout.builder()
                        .withTimeout(30, TimeUnit.SECONDS)
                        .build();
   ```
   
   and then only use `@Test` on the tests. That's easier to maintain when we 
need to increase the timeout on azure.

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import 
org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair 
together.
+ */
+class FlinkKafkaShuffle {
+       static final String PRODUCER_PARALLELISM = "producer parallelism";
+       static final String PARTITION_NUMBER = "partition number";
+
+       /**
+        * Write to and read from a kafka shuffle with the partition decided by 
keys.
+        * Consumers should read partitions equal to the key group indices they 
are assigned.
+        * The number of partitions is the maximum parallelism of the receiving 
operator.
+        * This version only supports numberOfPartitions = consumerParallelism.
+        *
+        * @param inputStream input stream to the kafka
+        * @param topic kafka topic
+        * @param producerParallelism parallelism of producer
+        * @param numberOfPartitions number of partitions
+        * @param properties Kafka properties
+        * @param fields key positions from inputStream
+        * @param <T> input type
+        */
+       static <T> KeyedStream<T, Tuple> persistentKeyBy(
+               DataStream<T> inputStream,

Review comment:
       nit: indent.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {

Review comment:
       The implemented tests are really good. I miss two cases though:
   * Out of order events (add randomness to source timestamp)
   * Any failure and recovery tests. See 
https://github.com/apache/flink/blob/f239d680e9b8f3f5ace621b7806e0bb7e14d3fdd/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
 for a possible approach.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+       @BeforeClass
+       public static void prepare() throws Exception {
+               KafkaProducerTestBase.prepare();
+               ((KafkaTestEnvironmentImpl) 
kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with the default 
time characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleProcessingTime() throws Exception {
+               simpleEndToEndTest("test_simple_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleIngestionTime() throws Exception {
+               simpleEndToEndTest("test_simple_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleEventTime() throws Exception {
+               simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionProcessingTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionIngestionTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionEventTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_event_time", 100000, 
EventTime);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1
+        * To test no data is lost or duplicated end-2-end
+        */
+       private void simpleEndToEndTest(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)
+               throws Exception {
+               final int numberOfPartitions = 1;
+               final int producerParallelism = 1;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(timeCharacteristic);
+
+               DataStream<Tuple3<Integer, Long, String>> source =
+                       env.addSource(new 
KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+               DataStream<Tuple3<Integer, Long, String>> input = 
(timeCharacteristic == EventTime) ?
+                       source.assignTimestampsAndWatermarks(new 
PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+               Properties properties = kafkaServer.getStandardProperties();
+               FlinkKafkaShuffle
+                       .persistentKeyBy(input, topic, producerParallelism, 
numberOfPartitions, properties, 0)
+                       .map(new ElementCountNoMoreThanValidator(elementCount * 
producerParallelism)).setParallelism(1)
+                       .map(new ElementCountNoLessThanValidator(elementCount * 
producerParallelism)).setParallelism(1);
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3
+        * To test data is partitioned to the right partition
+        */
+       private void testAssignedToPartition(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)
+               throws Exception {
+               final int numberOfPartitions = 3;
+               final int producerParallelism = 2;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(EventTime);
+
+               DataStream<Tuple3<Integer, Long, String>> source =
+                       env.addSource(new 
KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+               DataStream<Tuple3<Integer, Long, String>> input = 
(timeCharacteristic == EventTime) ?
+                       source.assignTimestampsAndWatermarks(new 
PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+               // ------- Write data to Kafka partition basesd on 
FlinkKafkaPartitioner ------
+               Properties properties = kafkaServer.getStandardProperties();
+
+               KeyedStream<Tuple3<Integer, Long, String>, Tuple> keyedStream = 
FlinkKafkaShuffle
+                       .persistentKeyBy(input, topic, producerParallelism, 
numberOfPartitions, properties, 0);
+               keyedStream
+                       .process(new 
PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+                       .setParallelism(numberOfPartitions)
+                       .map(new ElementCountNoMoreThanValidator(elementCount * 
producerParallelism)).setParallelism(1)
+                       .map(new ElementCountNoLessThanValidator(elementCount * 
producerParallelism)).setParallelism(1);
+
+               tryExecute(env, "KafkaShuffle partition assignment test");
+
+               deleteTestTopic(topic);
+       }
+
+       private static class PunctuatedExtractor implements 
AssignerWithPunctuatedWatermarks<Tuple3<Integer, Long, String>> {
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public long extractTimestamp(Tuple3<Integer, Long, String> 
element, long previousTimestamp) {
+                       return element.f1;
+               }
+
+               @Override
+               public Watermark checkAndGetNextWatermark(Tuple3<Integer, Long, 
String> lastElement, long extractedTimestamp) {
+                       return new Watermark(extractedTimestamp);
+               }
+       }
+
+       private static class KafkaSourceFunction extends 
RichParallelSourceFunction<Tuple3<Integer, Long, String>> {
+               private volatile boolean running = true;
+               private int elementCount;
+
+               KafkaSourceFunction(int elementCount) {
+                       this.elementCount = elementCount;
+               }
+
+               @Override
+               public void run(SourceContext<Tuple3<Integer, Long, String>> 
ctx) {
+                       long timestamp = 1584349939799L;
+                       int instanceId = 
getRuntimeContext().getIndexOfThisSubtask();
+                       for (int i = 0; i < elementCount && running; i++) {
+                               ctx.collect(new Tuple3<>(i, timestamp++, 
"source-instance-" + instanceId));
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       running = false;
+               }
+       }
+
+       private static class ElementCountNoMoreThanValidator
+               implements MapFunction<Tuple3<Integer, Long, String>, 
Tuple3<Integer, Long, String>> {

Review comment:
       nit: also double-indent.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+       @BeforeClass
+       public static void prepare() throws Exception {
+               KafkaProducerTestBase.prepare();
+               ((KafkaTestEnvironmentImpl) 
kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with the default 
time characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleProcessingTime() throws Exception {
+               simpleEndToEndTest("test_simple_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleIngestionTime() throws Exception {
+               simpleEndToEndTest("test_simple_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleEventTime() throws Exception {
+               simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionProcessingTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionIngestionTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionEventTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_event_time", 100000, 
EventTime);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1
+        * To test no data is lost or duplicated end-2-end
+        */
+       private void simpleEndToEndTest(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)

Review comment:
       We use end2end in a different context, where we use a complete Flink 
distribution to execute the test. 
   
   I'd simply call it `testKafkaShuffle` to avoid any misunderstanding.

##########
File path: 
flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+       @BeforeClass
+       public static void prepare() throws Exception {
+               KafkaProducerTestBase.prepare();
+               ((KafkaTestEnvironmentImpl) 
kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with the default 
time characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleProcessingTime() throws Exception {
+               simpleEndToEndTest("test_simple_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleIngestionTime() throws Exception {
+               simpleEndToEndTest("test_simple_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1.
+        * To test no data is lost or duplicated end-2-end with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testSimpleEventTime() throws Exception {
+               simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: ProcessingTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionProcessingTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, 
ProcessingTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: IngestionTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionIngestionTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, 
IngestionTime);
+       }
+
+       /**
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3.
+        * To test data is partitioned to the right partition with time 
characteristic: EventTime
+        */
+       @Test(timeout = 30000L)
+       public void testAssignedToPartitionEventTime() throws Exception {
+               
testAssignedToPartition("test_assigned_to_partition_event_time", 100000, 
EventTime);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 1; Kafka Partition # = 1; Consumer 
Parallelism = 1
+        * To test no data is lost or duplicated end-2-end
+        */
+       private void simpleEndToEndTest(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)
+               throws Exception {
+               final int numberOfPartitions = 1;
+               final int producerParallelism = 1;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(timeCharacteristic);
+
+               DataStream<Tuple3<Integer, Long, String>> source =
+                       env.addSource(new 
KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+               DataStream<Tuple3<Integer, Long, String>> input = 
(timeCharacteristic == EventTime) ?
+                       source.assignTimestampsAndWatermarks(new 
PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+               Properties properties = kafkaServer.getStandardProperties();
+               FlinkKafkaShuffle
+                       .persistentKeyBy(input, topic, producerParallelism, 
numberOfPartitions, properties, 0)
+                       .map(new ElementCountNoMoreThanValidator(elementCount * 
producerParallelism)).setParallelism(1)
+                       .map(new ElementCountNoLessThanValidator(elementCount * 
producerParallelism)).setParallelism(1);
+
+               tryExecute(env, topic);
+
+               deleteTestTopic(topic);
+       }
+
+       /**
+        * Schema: (key, timestamp, source instance Id).
+        * Producer Parallelism = 2; Kafka Partition # = 3; Consumer 
Parallelism = 3
+        * To test data is partitioned to the right partition
+        */
+       private void testAssignedToPartition(String topic, int elementCount, 
TimeCharacteristic timeCharacteristic)
+               throws Exception {
+               final int numberOfPartitions = 3;
+               final int producerParallelism = 2;
+
+               createTestTopic(topic, numberOfPartitions, 1);
+
+               final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(producerParallelism);
+               env.setRestartStrategy(RestartStrategies.noRestart());
+               env.setStreamTimeCharacteristic(EventTime);
+
+               DataStream<Tuple3<Integer, Long, String>> source =
+                       env.addSource(new 
KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+               DataStream<Tuple3<Integer, Long, String>> input = 
(timeCharacteristic == EventTime) ?
+                       source.assignTimestampsAndWatermarks(new 
PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+               // ------- Write data to Kafka partition basesd on 
FlinkKafkaPartitioner ------
+               Properties properties = kafkaServer.getStandardProperties();
+
+               KeyedStream<Tuple3<Integer, Long, String>, Tuple> keyedStream = 
FlinkKafkaShuffle
+                       .persistentKeyBy(input, topic, producerParallelism, 
numberOfPartitions, properties, 0);
+               keyedStream
+                       .process(new 
PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+                       .setParallelism(numberOfPartitions)
+                       .map(new ElementCountNoMoreThanValidator(elementCount * 
producerParallelism)).setParallelism(1)
+                       .map(new ElementCountNoLessThanValidator(elementCount * 
producerParallelism)).setParallelism(1);
+
+               tryExecute(env, "KafkaShuffle partition assignment test");
+
+               deleteTestTopic(topic);
+       }
+
+       private static class PunctuatedExtractor implements 
AssignerWithPunctuatedWatermarks<Tuple3<Integer, Long, String>> {
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public long extractTimestamp(Tuple3<Integer, Long, String> 
element, long previousTimestamp) {
+                       return element.f1;
+               }
+
+               @Override
+               public Watermark checkAndGetNextWatermark(Tuple3<Integer, Long, 
String> lastElement, long extractedTimestamp) {
+                       return new Watermark(extractedTimestamp);
+               }
+       }
+
+       private static class KafkaSourceFunction extends 
RichParallelSourceFunction<Tuple3<Integer, Long, String>> {
+               private volatile boolean running = true;
+               private int elementCount;
+
+               KafkaSourceFunction(int elementCount) {
+                       this.elementCount = elementCount;
+               }
+
+               @Override
+               public void run(SourceContext<Tuple3<Integer, Long, String>> 
ctx) {
+                       long timestamp = 1584349939799L;
+                       int instanceId = 
getRuntimeContext().getIndexOfThisSubtask();
+                       for (int i = 0; i < elementCount && running; i++) {
+                               ctx.collect(new Tuple3<>(i, timestamp++, 
"source-instance-" + instanceId));
+                       }
+               }
+
+               @Override
+               public void cancel() {
+                       running = false;
+               }
+       }
+
+       private static class ElementCountNoMoreThanValidator
+               implements MapFunction<Tuple3<Integer, Long, String>, 
Tuple3<Integer, Long, String>> {
+               private final int totalCount;
+               private int counter = 0;
+
+               ElementCountNoMoreThanValidator(int totalCount) {
+                       this.totalCount = totalCount;
+               }
+
+               @Override
+               public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, 
String> element) throws Exception {
+                       counter++;
+
+                       if (counter > totalCount) {
+                               throw new Exception("Error: number of elements 
more than expected");
+                       }
+
+                       return element;
+               }
+       }
+
+       private static class ElementCountNoLessThanValidator
+               implements MapFunction<Tuple3<Integer, Long, String>, 
Tuple3<Integer, Long, String>> {
+               private final int totalCount;
+               private int counter = 0;
+
+               ElementCountNoLessThanValidator(int totalCount) {
+                       this.totalCount = totalCount;
+               }
+
+               @Override
+               public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, 
String> element) throws Exception {
+                       counter++;
+
+                       if (counter == totalCount) {
+                               throw new SuccessException();
+                       }
+
+                       return element;
+               }
+       }
+
+       private static class PartitionValidator
+               extends KeyedProcessFunction<Tuple, Tuple3<Integer, Long, 
String>, Tuple3<Integer, Long, String>> {
+
+               private final KeySelector<Tuple3<Integer, Long, String>, Tuple> 
keySelector;
+               private final int numberOfPartitions;
+               private final String topic;
+
+               private int previousPartition;
+
+               PartitionValidator(
+                       KeySelector<Tuple3<Integer, Long, String>, Tuple> 
keySelector, int numberOfPartitions, String topic) {
+                       this.keySelector = keySelector;
+                       this.numberOfPartitions = numberOfPartitions;
+                       this.topic = topic;
+                       this.previousPartition = -1;
+               }
+
+               @Override
+               public void processElement(
+                       Tuple3<Integer, Long, String> in, Context ctx, 
Collector<Tuple3<Integer, Long, String>> out)

Review comment:
       nit: chop args

##########
File path: 
flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import 
org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import 
org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static 
org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> 
{
+       private static final Logger LOG = 
LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+       private final WatermarkHandler watermarkHandler;
+       // 
------------------------------------------------------------------------
+
+       /** The schema to convert between Kafka's byte messages, and Flink's 
objects. */
+       private final KafkaShuffleElementDeserializer<T> deserializer;
+
+       /** Serializer to serialize record. */
+       private final TypeSerializer<T> serializer;
+
+       /** The handover of data and exceptions between the consumer thread and 
the task thread. */
+       private final Handover handover;
+
+       /** The thread that runs the actual KafkaConsumer and hand the record 
batches to this fetcher. */
+       private final KafkaConsumerThread consumerThread;
+
+       /** Flag to mark the main work loop as alive. */
+       private volatile boolean running = true;
+
+       public KafkaShuffleFetcher(
+               SourceFunction.SourceContext<T> sourceContext,
+               Map<KafkaTopicPartition, Long> 
assignedPartitionsWithInitialOffsets,
+               SerializedValue<AssignerWithPeriodicWatermarks<T>> 
watermarksPeriodic,
+               SerializedValue<AssignerWithPunctuatedWatermarks<T>> 
watermarksPunctuated,
+               ProcessingTimeService processingTimeProvider,
+               long autoWatermarkInterval,
+               ClassLoader userCodeClassLoader,
+               String taskNameWithSubtasks,
+               TypeSerializer<T> serializer,
+               Properties kafkaProperties,
+               long pollTimeout,
+               MetricGroup subtaskMetricGroup,
+               MetricGroup consumerMetricGroup,
+               boolean useMetrics,
+               int producerParallelism) throws Exception {
+               super(
+                       sourceContext,
+                       assignedPartitionsWithInitialOffsets,
+                       watermarksPeriodic,
+                       watermarksPunctuated,
+                       processingTimeProvider,
+                       autoWatermarkInterval,
+                       userCodeClassLoader,
+                       consumerMetricGroup,
+                       useMetrics);
+
+               this.deserializer = new KafkaShuffleElementDeserializer<>();
+               this.serializer = serializer;
+               this.handover = new Handover();
+               this.consumerThread = new KafkaConsumerThread(
+                       LOG,
+                       handover,
+                       kafkaProperties,
+                       unassignedPartitionsQueue,
+                       getFetcherName() + " for " + taskNameWithSubtasks,
+                       pollTimeout,
+                       useMetrics,
+                       consumerMetricGroup,
+                       subtaskMetricGroup);
+               this.watermarkHandler = new 
WatermarkHandler(producerParallelism);
+       }
+
+       // 
------------------------------------------------------------------------
+       //  Fetcher work methods
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void runFetchLoop() throws Exception {
+               try {
+                       final Handover handover = this.handover;
+
+                       // kick off the actual Kafka consumer
+                       consumerThread.start();
+
+                       while (running) {
+                               // this blocks until we get the next records
+                               // it automatically re-throws exceptions 
encountered in the consumer thread
+                               final ConsumerRecords<byte[], byte[]> records = 
handover.pollNext();
+
+                               // get the records for each topic partition
+                               for (KafkaTopicPartitionState<TopicPartition> 
partition : subscribedPartitionStates()) {
+                                       List<ConsumerRecord<byte[], byte[]>> 
partitionRecords =
+                                               
records.records(partition.getKafkaPartitionHandle());
+
+                                       for (ConsumerRecord<byte[], byte[]> 
record : partitionRecords) {
+                                               final KafkaShuffleElement<T> 
element = deserializer.deserialize(serializer, record);
+
+                                               // TODO: do we need to check 
the end of stream if reaching the end watermark?
+
+                                               if (element.isRecord()) {
+                                                       // timestamp is 
inherent from upstream
+                                                       // If using 
ProcessTime, timestamp is going to be ignored (upstream does not include 
timestamp as well)
+                                                       // If using 
IngestionTime, timestamp is going to be overwritten
+                                                       // If using EventTime, 
timestamp is going to be used
+                                                       synchronized 
(checkpointLock) {
+                                                               
KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+                                                               
sourceContext.collectWithTimestamp(
+                                                                       
elementAsRecord.value,
+                                                                       
elementAsRecord.timestamp == null ? record.timestamp() : 
elementAsRecord.timestamp);
+                                                               
partition.setOffset(record.offset());
+                                                       }
+                                               } else if 
(element.isWatermark()) {
+                                                       final 
KafkaShuffleWatermark watermark = element.asWatermark();
+                                                       Optional<Watermark> 
newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+                                                       
newWatermark.ifPresent(sourceContext::emitWatermark);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               finally {
+                       // this signals the consumer thread that no more work 
is to be done
+                       consumerThread.shutdown();
+               }
+
+               // on a clean exit, wait for the runner thread
+               try {
+                       consumerThread.join();
+               }
+               catch (InterruptedException e) {
+                       // may be the result of a wake-up interruption after an 
exception.
+                       // we ignore this here and only restore the 
interruption state
+                       Thread.currentThread().interrupt();
+               }
+       }
+
+       @Override
+       public void cancel() {
+               // flag the main thread to exit. A thread interrupt will come 
anyways.
+               running = false;
+               handover.close();
+               consumerThread.shutdown();
+       }
+
+       @Override
+       protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition 
partition) {
+               return new TopicPartition(partition.getTopic(), 
partition.getPartition());
+       }
+
+       @Override
+       protected void doCommitInternalOffsetsToKafka(
+               Map<KafkaTopicPartition, Long> offsets,
+               @Nonnull KafkaCommitCallback commitCallback) throws Exception {
+
+               @SuppressWarnings("unchecked")
+               List<KafkaTopicPartitionState<TopicPartition>> partitions = 
subscribedPartitionStates();
+
+               Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>(partitions.size());
+
+               for (KafkaTopicPartitionState<TopicPartition> partition : 
partitions) {
+                       Long lastProcessedOffset = 
offsets.get(partition.getKafkaTopicPartition());
+                       if (lastProcessedOffset != null) {
+                               checkState(lastProcessedOffset >= 0, "Illegal 
offset value to commit");
+
+                               // committed offsets through the KafkaConsumer 
need to be 1 more than the last processed offset.
+                               // This does not affect Flink's 
checkpoints/saved state.
+                               long offsetToCommit = lastProcessedOffset + 1;
+
+                               
offsetsToCommit.put(partition.getKafkaPartitionHandle(), new 
OffsetAndMetadata(offsetToCommit));
+                               partition.setCommittedOffset(offsetToCommit);
+                       }
+               }
+
+               // record the work to be committed by the main consumer thread 
and make sure the consumer notices that
+               consumerThread.setOffsetsToCommit(offsetsToCommit, 
commitCallback);
+       }
+
+       private String getFetcherName() {
+               return "Kafka Shuffle Fetcher";
+       }
+
+       private abstract static class KafkaShuffleElement<T> {
+
+               public boolean isRecord() {
+                       return getClass() == KafkaShuffleRecord.class;
+               }
+
+               boolean isWatermark() {
+                       return getClass() == KafkaShuffleWatermark.class;
+               }
+
+               KafkaShuffleRecord<T> asRecord() {
+                       return (KafkaShuffleRecord<T>) this;
+               }
+
+               KafkaShuffleWatermark asWatermark() {
+                       return (KafkaShuffleWatermark) this;
+               }
+       }
+
+       private static class KafkaShuffleWatermark<T> extends 
KafkaShuffleElement<T> {
+               final int subtask;
+               final long watermark;
+
+               KafkaShuffleWatermark(int subtask, long watermark) {
+                       this.subtask = subtask;
+                       this.watermark = watermark;
+               }
+       }
+
+       private static class KafkaShuffleRecord<T> extends 
KafkaShuffleElement<T> {
+               final T value;
+               final Long timestamp;
+
+               KafkaShuffleRecord(T value) {
+                       this.value = value;
+                       this.timestamp = null;
+               }
+
+               KafkaShuffleRecord(long timestamp, T value) {
+                       this.value = value;
+                       this.timestamp = timestamp;
+               }
+       }
+
+       private static class KafkaShuffleElementDeserializer<T> implements 
Serializable {
+               private transient DataInputDeserializer dis;
+
+               KafkaShuffleElementDeserializer() {
+                       this.dis = new DataInputDeserializer();
+               }
+
+               KafkaShuffleElement<T> deserialize(TypeSerializer<T> 
serializer, ConsumerRecord<byte[], byte[]> record)
+                       throws Exception {
+                       byte[] value = record.value();
+                       dis.setBuffer(value);
+                       int tag = IntSerializer.INSTANCE.deserialize(dis);
+
+                       if (tag == TAG_REC_WITHOUT_TIMESTAMP) {
+                               return new 
KafkaShuffleRecord<>(serializer.deserialize(dis));
+                       } else if (tag == TAG_REC_WITH_TIMESTAMP) {
+                               return new 
KafkaShuffleRecord<>(LongSerializer.INSTANCE.deserialize(dis), 
serializer.deserialize(dis));
+                       } else if (tag == TAG_WATERMARK) {
+                               return new KafkaShuffleWatermark<>(
+                                       
IntSerializer.INSTANCE.deserialize(dis), 
LongSerializer.INSTANCE.deserialize(dis));
+                       }
+
+                       throw new UnsupportedOperationException("Unsupported 
tag format");
+               }
+       }
+
+       /**
+        * WatermarkHandler to generate watermarks.
+        */
+       private static class WatermarkHandler {
+               private final int producerParallelism;
+               private final Map<Integer, Long> subtaskWatermark;
+
+               private long currentMinWatermark = Long.MIN_VALUE;
+
+               WatermarkHandler(int numberOfSubtask) {
+                       this.producerParallelism = numberOfSubtask;
+                       this.subtaskWatermark = new HashMap<>(numberOfSubtask);
+               }
+
+               public Optional<Watermark> 
checkAndGetNewWatermark(KafkaShuffleWatermark newWatermark) {
+                       // watermarks is incremental for the same partition and 
PRODUCER subtask
+                       Long currentSubTaskWatermark = 
subtaskWatermark.get(newWatermark.subtask);
+
+                       Preconditions.checkState(
+                               (currentSubTaskWatermark == null) || 
(currentSubTaskWatermark <= newWatermark.watermark),
+                               "Watermark should always increase");
+
+                       subtaskWatermark.put(newWatermark.subtask, 
newWatermark.watermark);
+
+                       if (subtaskWatermark.values().size() < 
producerParallelism) {
+                               return Optional.empty();
+                       }

Review comment:
       What happens if one partition has ended and we receive no watermarks 
anymore? Are the watermarks of the other partitions still propagated properly? 
Almost feels like using `StatusWatermarkValve` would be handy.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to