This is an automated email from the ASF dual-hosted git repository.

wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 5f55e31c29 [Improve][connector-doris] Improved doris source enumerator 
splits allocation algorithm for subtasks (#9108)
5f55e31c29 is described below

commit 5f55e31c294a3638439fd1a05128dddbb3bdb539
Author: JeremyXin <739772...@qq.com>
AuthorDate: Wed Apr 9 18:17:42 2025 +0800

    [Improve][connector-doris] Improved doris source enumerator splits 
allocation algorithm for subtasks (#9108)
---
 pom.xml                                            |   7 +
 .../connectors/doris/rest/PartitionDefinition.java |   2 +-
 .../source/split/DorisSourceSplitEnumerator.java   |  22 +-
 .../split/DorisSourceSplitEnumeratorTest.java      | 144 +++++++++++++
 .../seatunnel-translation-base/pom.xml             |   7 +
 .../translation/source/ParallelSourceTest.java     | 226 ++++++++++++++-------
 6 files changed, 326 insertions(+), 82 deletions(-)

diff --git a/pom.xml b/pom.xml
index 7a5d2bd1f9..6c72cdab20 100644
--- a/pom.xml
+++ b/pom.xml
@@ -568,6 +568,13 @@
             <scope>test</scope>
         </dependency>
 
+        <dependency>
+            <groupId>org.mockito</groupId>
+            <artifactId>mockito-inline</artifactId>
+            <version>${mockito.version}</version>
+            <scope>test</scope>
+        </dependency>
+
         <!-- The prometheus simpleclient -->
         <dependency>
             <groupId>io.prometheus</groupId>
diff --git 
a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java
 
b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java
index 0356eb2708..8742271b1a 100644
--- 
a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java
+++ 
b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java
@@ -129,7 +129,7 @@ public class PartitionDefinition implements Serializable, 
Comparable<PartitionDe
     @Override
     public String toString() {
         return "PartitionDefinition{"
-                + ", database='"
+                + "database='"
                 + database
                 + '\''
                 + ", table='"
diff --git 
a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java
 
b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java
index 1aa10a88b5..af18ac5629 100644
--- 
a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java
+++ 
b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java
@@ -33,10 +33,13 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
 
 @Slf4j
 public class DorisSourceSplitEnumerator
@@ -52,6 +55,8 @@ public class DorisSourceSplitEnumerator
     private final Map<TablePath, DorisSourceTable> dorisSourceTables;
     private final Object stateLock = new Object();
 
+    private final AtomicInteger assignCount = new AtomicInteger(0);
+
     public DorisSourceSplitEnumerator(
             Context<DorisSourceSplit> context,
             DorisSourceConfig dorisSourceConfig,
@@ -162,15 +167,24 @@ public class DorisSourceSplitEnumerator
 
     private void addPendingSplit(Collection<DorisSourceSplit> splits) {
         int readerCount = context.currentParallelism();
-        for (DorisSourceSplit split : splits) {
-            int ownerReader = getSplitOwner(split.splitId(), readerCount);
+
+        // sorting the splits to ensure the order
+        List<DorisSourceSplit> sortedSplits =
+                splits.stream()
+                        
.sorted(Comparator.comparing(DorisSourceSplit::getSplitId))
+                        .collect(Collectors.toList());
+
+        // allocate splits in load balancing mode
+        assignCount.set(0);
+        for (DorisSourceSplit split : sortedSplits) {
+            int ownerReader = getSplitOwner(assignCount.getAndIncrement(), 
readerCount);
             log.info("Assigning split {} to reader {} .", split.splitId(), 
ownerReader);
             pendingSplit.computeIfAbsent(ownerReader, f -> new 
ArrayList<>()).add(split);
         }
     }
 
-    private static int getSplitOwner(String tp, int numReaders) {
-        return (tp.hashCode() & Integer.MAX_VALUE) % numReaders;
+    private static int getSplitOwner(int assignCount, int numReaders) {
+        return assignCount % numReaders;
     }
 
     private void assignSplit(Collection<Integer> readers) {
diff --git 
a/seatunnel-connectors-v2/connector-doris/src/test/java/org/apache/seatunnel/connectors/doris/split/DorisSourceSplitEnumeratorTest.java
 
b/seatunnel-connectors-v2/connector-doris/src/test/java/org/apache/seatunnel/connectors/doris/split/DorisSourceSplitEnumeratorTest.java
new file mode 100644
index 0000000000..9a2709d465
--- /dev/null
+++ 
b/seatunnel-connectors-v2/connector-doris/src/test/java/org/apache/seatunnel/connectors/doris/split/DorisSourceSplitEnumeratorTest.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.connectors.doris.split;
+
+import org.apache.seatunnel.shade.com.google.common.collect.Maps;
+
+import org.apache.seatunnel.api.source.SourceSplitEnumerator;
+import org.apache.seatunnel.api.table.catalog.TablePath;
+import org.apache.seatunnel.connectors.doris.config.DorisSourceConfig;
+import org.apache.seatunnel.connectors.doris.rest.PartitionDefinition;
+import org.apache.seatunnel.connectors.doris.rest.RestService;
+import org.apache.seatunnel.connectors.doris.source.DorisSourceTable;
+import org.apache.seatunnel.connectors.doris.source.split.DorisSourceSplit;
+import 
org.apache.seatunnel.connectors.doris.source.split.DorisSourceSplitEnumerator;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import lombok.extern.slf4j.Slf4j;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.mockito.ArgumentMatchers.any;
+
+@Slf4j
+public class DorisSourceSplitEnumeratorTest {
+
+    private static final String DATABASE = "default";
+    private static final String TABLE = "default_table";
+    private static final String BE_ADDRESS_PREFIX = "doris-be-";
+    private static final String QUERY_PLAN = "DAABDAACDwABDAAAAAEIAA";
+
+    private static final int PARALLELISM = 4;
+
+    private static final int PARTITION_NUMS = 10;
+
+    @Test
+    public void dorisSourceSplitEnumeratorTest() {
+        DorisSourceConfig dorisSourceConfig = 
Mockito.mock(DorisSourceConfig.class);
+        DorisSourceTable dorisSourceTable = 
Mockito.mock(DorisSourceTable.class);
+
+        SourceSplitEnumerator.Context<DorisSourceSplit> context =
+                Mockito.mock(SourceSplitEnumerator.Context.class);
+
+        Mockito.when(context.registeredReaders())
+                .thenReturn(IntStream.range(0, 
PARALLELISM).boxed().collect(Collectors.toSet()));
+        Mockito.when(context.currentParallelism()).thenReturn(PARALLELISM);
+
+        Map<TablePath, DorisSourceTable> dorisSourceTableMap = 
Maps.newHashMap();
+        dorisSourceTableMap.put(new TablePath(DATABASE, null, TABLE), 
dorisSourceTable);
+
+        DorisSourceSplitEnumerator dorisSourceSplitEnumerator =
+                new DorisSourceSplitEnumerator(context, dorisSourceConfig, 
dorisSourceTableMap);
+
+        MockedStatic<RestService> restServiceMockedStatic = 
Mockito.mockStatic(RestService.class);
+
+        restServiceMockedStatic
+                .when(() -> RestService.findPartitions(any(), any(), any()))
+                .thenReturn(buildPartitionDefinitions());
+
+        dorisSourceSplitEnumerator.run();
+
+        ArgumentCaptor<Integer> subtaskId = 
ArgumentCaptor.forClass(Integer.class);
+        ArgumentCaptor<List> split = ArgumentCaptor.forClass(List.class);
+
+        Mockito.verify(context, Mockito.times(PARALLELISM))
+                .assignSplit(subtaskId.capture(), split.capture());
+
+        List<Integer> subTaskAllValues = subtaskId.getAllValues();
+        List<List> splitAllValues = split.getAllValues();
+
+        for (int i = 0; i < PARALLELISM; i++) {
+            Assertions.assertEquals(i, subTaskAllValues.get(i));
+            Assertions.assertEquals(
+                    allocateFiles(i, PARALLELISM, PARTITION_NUMS), 
splitAllValues.get(i).size());
+        }
+
+        // check no duplicate file assigned
+        Assertions.assertEquals(0, 
dorisSourceSplitEnumerator.currentUnassignedSplitSize());
+    }
+
+    private List<PartitionDefinition> buildPartitionDefinitions() {
+
+        List<PartitionDefinition> partitions = new ArrayList<>();
+
+        IntStream.range(0, PARTITION_NUMS)
+                .forEach(
+                        i -> {
+                            PartitionDefinition partitionDefinition =
+                                    new PartitionDefinition(
+                                            DATABASE,
+                                            TABLE,
+                                            BE_ADDRESS_PREFIX + i,
+                                            new HashSet<>(i),
+                                            QUERY_PLAN);
+
+                            partitions.add(partitionDefinition);
+                        });
+
+        return partitions;
+    }
+
+    /**
+     * calculate the number of files assigned each time
+     *
+     * @param id id
+     * @param parallelism parallelism
+     * @param fileSize file size
+     * @return
+     */
+    public int allocateFiles(int id, int parallelism, int fileSize) {
+        int filesPerIteration = fileSize / parallelism;
+        int remainder = fileSize % parallelism;
+
+        if (id < remainder) {
+            return filesPerIteration + 1;
+        } else {
+            return filesPerIteration;
+        }
+    }
+}
diff --git a/seatunnel-translation/seatunnel-translation-base/pom.xml 
b/seatunnel-translation/seatunnel-translation-base/pom.xml
index ac37c22c7c..6c0cf4359e 100644
--- a/seatunnel-translation/seatunnel-translation-base/pom.xml
+++ b/seatunnel-translation/seatunnel-translation-base/pom.xml
@@ -37,5 +37,12 @@
             <version>${project.version}</version>
             <scope>test</scope>
         </dependency>
+
+        <dependency>
+            <groupId>org.apache.seatunnel</groupId>
+            <artifactId>connector-doris</artifactId>
+            <version>${project.version}</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 </project>
diff --git 
a/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java
 
b/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java
index 2f0b151330..1a2a3b98a7 100644
--- 
a/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java
+++ 
b/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java
@@ -17,106 +17,178 @@
 
 package org.apache.seatunnel.translation.source;
 
-import org.apache.seatunnel.shade.com.typesafe.config.Config;
-
-import org.apache.seatunnel.api.common.PrepareFailException;
-import org.apache.seatunnel.connectors.seatunnel.file.config.FileSystemType;
+import org.apache.seatunnel.shade.com.google.common.collect.Maps;
+
+import org.apache.seatunnel.api.table.catalog.TablePath;
+import org.apache.seatunnel.connectors.doris.config.DorisSourceConfig;
+import org.apache.seatunnel.connectors.doris.rest.PartitionDefinition;
+import org.apache.seatunnel.connectors.doris.rest.RestService;
+import org.apache.seatunnel.connectors.doris.source.DorisSource;
+import org.apache.seatunnel.connectors.doris.source.DorisSourceTable;
+import org.apache.seatunnel.connectors.doris.source.reader.DorisSourceReader;
+import org.apache.seatunnel.connectors.doris.source.split.DorisSourceSplit;
 import org.apache.seatunnel.connectors.seatunnel.file.source.BaseFileSource;
-import 
org.apache.seatunnel.connectors.seatunnel.file.source.BaseFileSourceReader;
 import 
org.apache.seatunnel.connectors.seatunnel.file.source.split.FileSourceSplit;
+import 
org.apache.seatunnel.connectors.seatunnel.file.source.split.FileSourceSplitEnumerator;
 
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
 
 import lombok.extern.slf4j.Slf4j;
 
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.mockito.ArgumentMatchers.any;
 
 @Slf4j
 public class ParallelSourceTest {
 
     @Test
-    void testParallelSourceForPollingFileAllocation() throws Exception {
+    void fileParallelSourceSplitEnumeratorTest() throws Exception {
         int fileSize = 15;
         int parallelism = 4;
 
-        // create file source
-        BaseFileSource baseFileSource =
-                new BaseFileSource() {
-                    @Override
-                    public void prepare(Config pluginConfig) throws 
PrepareFailException {
-                        filePaths = new ArrayList<>();
-                        for (int i = 0; i < fileSize; i++) {
-                            filePaths.add("file" + i + ".txt");
-                        }
-                    }
-
-                    @Override
-                    public String getPluginName() {
-                        return FileSystemType.HDFS.getFileSystemPluginName();
-                    }
-                };
-
-        // prepare files
-        baseFileSource.prepare(null);
-
-        ParallelSource parallelSource =
-                new ParallelSource(baseFileSource, null, parallelism, 
"parallel-source-test", 0);
-        ParallelSource parallelSource2 =
-                new ParallelSource(baseFileSource, null, parallelism, 
"parallel-source-test2", 1);
-        ParallelSource parallelSource3 =
-                new ParallelSource(baseFileSource, null, parallelism, 
"parallel-source-test3", 2);
-        ParallelSource parallelSource4 =
-                new ParallelSource(baseFileSource, null, parallelism, 
"parallel-source-test4", 3);
-
-        parallelSource.open();
-        parallelSource2.open();
-        parallelSource3.open();
-        parallelSource4.open();
-
-        // execute file allocation process
-        parallelSource.splitEnumerator.run();
-        parallelSource2.splitEnumerator.run();
-        parallelSource3.splitEnumerator.run();
-        parallelSource4.splitEnumerator.run();
-
-        // Gets the splits assigned for each reader
-        List<FileSourceSplit> sourceSplits =
-                ((BaseFileSourceReader) 
parallelSource.reader).snapshotState(0);
-        List<FileSourceSplit> sourceSplits2 =
-                ((BaseFileSourceReader) 
parallelSource2.reader).snapshotState(0);
-        List<FileSourceSplit> sourceSplits3 =
-                ((BaseFileSourceReader) 
parallelSource3.reader).snapshotState(0);
-        List<FileSourceSplit> sourceSplits4 =
-                ((BaseFileSourceReader) 
parallelSource4.reader).snapshotState(0);
-
-        log.info(
-                "parallel source1 splits => {}",
-                
sourceSplits.stream().map(FileSourceSplit::splitId).collect(Collectors.toList()));
-
-        log.info(
-                "parallel source2 splits => {}",
-                
sourceSplits2.stream().map(FileSourceSplit::splitId).collect(Collectors.toList()));
-
-        log.info(
-                "parallel source3 splits => {}",
-                
sourceSplits3.stream().map(FileSourceSplit::splitId).collect(Collectors.toList()));
-
-        log.info(
-                "parallel source4 splits => {}",
-                
sourceSplits4.stream().map(FileSourceSplit::splitId).collect(Collectors.toList()));
-
-        // check that there are no duplicate file assignments
+        List<String> filePaths = new ArrayList<>();
+        for (int i = 0; i < fileSize; i++) {
+            filePaths.add("file" + i + ".txt");
+        }
+        BaseFileSource baseFileSource = Mockito.spy(BaseFileSource.class);
+
         Set<FileSourceSplit> splitSet = new HashSet<>();
-        splitSet.addAll(sourceSplits);
-        splitSet.addAll(sourceSplits2);
-        splitSet.addAll(sourceSplits3);
-        splitSet.addAll(sourceSplits4);
+        for (int i = 0; i < parallelism; i++) {
+
+            ParallelEnumeratorContext<FileSourceSplit> context =
+                    Mockito.mock(ParallelEnumeratorContext.class);
+
+            Mockito.when(context.currentParallelism()).thenReturn(parallelism);
+
+            FileSourceSplitEnumerator fileSourceSplitEnumerator =
+                    new FileSourceSplitEnumerator(context, filePaths);
+
+            Mockito.when(baseFileSource.createEnumerator(any()))
+                    .thenReturn(fileSourceSplitEnumerator);
+
+            ParallelSource parallelSource =
+                    new ParallelSource(
+                            baseFileSource, null, parallelism, 
"parallel-source-test" + i, i);
+
+            parallelSource.open();
+            parallelSource.splitEnumerator.run();
+
+            ArgumentCaptor<Integer> subtaskId = 
ArgumentCaptor.forClass(Integer.class);
+            ArgumentCaptor<List> split = ArgumentCaptor.forClass(List.class);
+
+            Mockito.verify(context, Mockito.times(parallelism))
+                    .assignSplit(subtaskId.capture(), split.capture());
+
+            List<Integer> subTaskAllValues = subtaskId.getAllValues();
+            List<List> splitAllValues = split.getAllValues();
+
+            Assertions.assertEquals(i, subTaskAllValues.get(i));
+            Assertions.assertEquals(
+                    allocateFiles(i, parallelism, fileSize), 
splitAllValues.get(i).size());
+
+            splitSet.addAll(splitAllValues.get(i));
+        }
 
+        // Check that there are no duplicate file assign
         Assertions.assertEquals(splitSet.size(), fileSize);
     }
+
+    @Test
+    public void dorisParallelSourceSplitEnumeratorTest() throws Exception {
+        int parallelism = 4;
+        int partitionNums = 30;
+
+        DorisSourceConfig dorisSourceConfig = 
Mockito.mock(DorisSourceConfig.class);
+        DorisSourceTable dorisSourceTable = 
Mockito.mock(DorisSourceTable.class);
+
+        Map<TablePath, DorisSourceTable> dorisSourceTableMap = 
Maps.newHashMap();
+        dorisSourceTableMap.put(new TablePath("default", null, 
"default_table"), dorisSourceTable);
+
+        DorisSource dorisSource = new DorisSource(dorisSourceConfig, 
dorisSourceTableMap);
+
+        MockedStatic<RestService> restServiceMockedStatic = 
Mockito.mockStatic(RestService.class);
+        restServiceMockedStatic
+                .when(() -> RestService.findPartitions(any(), any(), any()))
+                .thenReturn(buildPartitionDefinitions(partitionNums));
+
+        Set<DorisSourceSplit> splitSet = new HashSet<>();
+        for (int i = 0; i < parallelism; i++) {
+            ParallelSource parallelSource =
+                    new ParallelSource(
+                            dorisSource, null, parallelism, 
"parallel-doris-source" + i, i);
+            parallelSource.open();
+
+            // execute file allocation process
+            parallelSource.splitEnumerator.run();
+            List<DorisSourceSplit> sourceSplits =
+                    ((DorisSourceReader) 
parallelSource.reader).snapshotState(0);
+            log.info(
+                    "parallel source{} splits => {}",
+                    i + 1,
+                    sourceSplits.stream()
+                            .map(DorisSourceSplit::splitId)
+                            .collect(Collectors.toList()));
+
+            Assertions.assertEquals(
+                    allocateFiles(i, parallelism, partitionNums), 
sourceSplits.size());
+
+            // collect all splits
+            splitSet.addAll(sourceSplits);
+        }
+
+        Assertions.assertEquals(splitSet.size(), partitionNums);
+    }
+
+    private List<PartitionDefinition> buildPartitionDefinitions(int 
partitionNUms) {
+
+        List<PartitionDefinition> partitions = new ArrayList<>();
+
+        String beAddressPrefix = "doris-be-";
+
+        IntStream.range(0, partitionNUms)
+                .forEach(
+                        i -> {
+                            PartitionDefinition partitionDefinition =
+                                    new PartitionDefinition(
+                                            "default",
+                                            "default_table",
+                                            beAddressPrefix + i,
+                                            new HashSet<>(i),
+                                            "QUERY_PLAN");
+
+                            partitions.add(partitionDefinition);
+                        });
+
+        return partitions;
+    }
+
+    /**
+     * calculate the number of files assigned each time
+     *
+     * @param id id
+     * @param parallelism parallelism
+     * @param fileSize file size
+     * @return
+     */
+    public int allocateFiles(int id, int parallelism, int fileSize) {
+        int filesPerIteration = fileSize / parallelism;
+        int remainder = fileSize % parallelism;
+
+        if (id < remainder) {
+            return filesPerIteration + 1;
+        } else {
+            return filesPerIteration;
+        }
+    }
 }

Reply via email to