This is an automated email from the ASF dual-hosted git repository. wanghailin pushed a commit to branch dev in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push: new 5f55e31c29 [Improve][connector-doris] Improved doris source enumerator splits allocation algorithm for subtasks (#9108) 5f55e31c29 is described below commit 5f55e31c294a3638439fd1a05128dddbb3bdb539 Author: JeremyXin <739772...@qq.com> AuthorDate: Wed Apr 9 18:17:42 2025 +0800 [Improve][connector-doris] Improved doris source enumerator splits allocation algorithm for subtasks (#9108) --- pom.xml | 7 + .../connectors/doris/rest/PartitionDefinition.java | 2 +- .../source/split/DorisSourceSplitEnumerator.java | 22 +- .../split/DorisSourceSplitEnumeratorTest.java | 144 +++++++++++++ .../seatunnel-translation-base/pom.xml | 7 + .../translation/source/ParallelSourceTest.java | 226 ++++++++++++++------- 6 files changed, 326 insertions(+), 82 deletions(-) diff --git a/pom.xml b/pom.xml index 7a5d2bd1f9..6c72cdab20 100644 --- a/pom.xml +++ b/pom.xml @@ -568,6 +568,13 @@ <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-inline</artifactId> + <version>${mockito.version}</version> + <scope>test</scope> + </dependency> + <!-- The prometheus simpleclient --> <dependency> <groupId>io.prometheus</groupId> diff --git a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java index 0356eb2708..8742271b1a 100644 --- a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java +++ b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/rest/PartitionDefinition.java @@ -129,7 +129,7 @@ public class PartitionDefinition implements Serializable, Comparable<PartitionDe @Override public String toString() { return "PartitionDefinition{" - + ", database='" + + "database='" + database + '\'' + ", table='" diff --git a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java index 1aa10a88b5..af18ac5629 100644 --- a/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java +++ b/seatunnel-connectors-v2/connector-doris/src/main/java/org/apache/seatunnel/connectors/doris/source/split/DorisSourceSplitEnumerator.java @@ -33,10 +33,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; @Slf4j public class DorisSourceSplitEnumerator @@ -52,6 +55,8 @@ public class DorisSourceSplitEnumerator private final Map<TablePath, DorisSourceTable> dorisSourceTables; private final Object stateLock = new Object(); + private final AtomicInteger assignCount = new AtomicInteger(0); + public DorisSourceSplitEnumerator( Context<DorisSourceSplit> context, DorisSourceConfig dorisSourceConfig, @@ -162,15 +167,24 @@ public class DorisSourceSplitEnumerator private void addPendingSplit(Collection<DorisSourceSplit> splits) { int readerCount = context.currentParallelism(); - for (DorisSourceSplit split : splits) { - int ownerReader = getSplitOwner(split.splitId(), readerCount); + + // sorting the splits to ensure the order + List<DorisSourceSplit> sortedSplits = + splits.stream() + .sorted(Comparator.comparing(DorisSourceSplit::getSplitId)) + .collect(Collectors.toList()); + + // allocate splits in load balancing mode + assignCount.set(0); + for (DorisSourceSplit split : sortedSplits) { + int ownerReader = getSplitOwner(assignCount.getAndIncrement(), readerCount); log.info("Assigning split {} to reader {} .", split.splitId(), ownerReader); pendingSplit.computeIfAbsent(ownerReader, f -> new ArrayList<>()).add(split); } } - private static int getSplitOwner(String tp, int numReaders) { - return (tp.hashCode() & Integer.MAX_VALUE) % numReaders; + private static int getSplitOwner(int assignCount, int numReaders) { + return assignCount % numReaders; } private void assignSplit(Collection<Integer> readers) { diff --git a/seatunnel-connectors-v2/connector-doris/src/test/java/org/apache/seatunnel/connectors/doris/split/DorisSourceSplitEnumeratorTest.java b/seatunnel-connectors-v2/connector-doris/src/test/java/org/apache/seatunnel/connectors/doris/split/DorisSourceSplitEnumeratorTest.java new file mode 100644 index 0000000000..9a2709d465 --- /dev/null +++ b/seatunnel-connectors-v2/connector-doris/src/test/java/org/apache/seatunnel/connectors/doris/split/DorisSourceSplitEnumeratorTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.seatunnel.connectors.doris.split; + +import org.apache.seatunnel.shade.com.google.common.collect.Maps; + +import org.apache.seatunnel.api.source.SourceSplitEnumerator; +import org.apache.seatunnel.api.table.catalog.TablePath; +import org.apache.seatunnel.connectors.doris.config.DorisSourceConfig; +import org.apache.seatunnel.connectors.doris.rest.PartitionDefinition; +import org.apache.seatunnel.connectors.doris.rest.RestService; +import org.apache.seatunnel.connectors.doris.source.DorisSourceTable; +import org.apache.seatunnel.connectors.doris.source.split.DorisSourceSplit; +import org.apache.seatunnel.connectors.doris.source.split.DorisSourceSplitEnumerator; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import lombok.extern.slf4j.Slf4j; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.mockito.ArgumentMatchers.any; + +@Slf4j +public class DorisSourceSplitEnumeratorTest { + + private static final String DATABASE = "default"; + private static final String TABLE = "default_table"; + private static final String BE_ADDRESS_PREFIX = "doris-be-"; + private static final String QUERY_PLAN = "DAABDAACDwABDAAAAAEIAA"; + + private static final int PARALLELISM = 4; + + private static final int PARTITION_NUMS = 10; + + @Test + public void dorisSourceSplitEnumeratorTest() { + DorisSourceConfig dorisSourceConfig = Mockito.mock(DorisSourceConfig.class); + DorisSourceTable dorisSourceTable = Mockito.mock(DorisSourceTable.class); + + SourceSplitEnumerator.Context<DorisSourceSplit> context = + Mockito.mock(SourceSplitEnumerator.Context.class); + + Mockito.when(context.registeredReaders()) + .thenReturn(IntStream.range(0, PARALLELISM).boxed().collect(Collectors.toSet())); + Mockito.when(context.currentParallelism()).thenReturn(PARALLELISM); + + Map<TablePath, DorisSourceTable> dorisSourceTableMap = Maps.newHashMap(); + dorisSourceTableMap.put(new TablePath(DATABASE, null, TABLE), dorisSourceTable); + + DorisSourceSplitEnumerator dorisSourceSplitEnumerator = + new DorisSourceSplitEnumerator(context, dorisSourceConfig, dorisSourceTableMap); + + MockedStatic<RestService> restServiceMockedStatic = Mockito.mockStatic(RestService.class); + + restServiceMockedStatic + .when(() -> RestService.findPartitions(any(), any(), any())) + .thenReturn(buildPartitionDefinitions()); + + dorisSourceSplitEnumerator.run(); + + ArgumentCaptor<Integer> subtaskId = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor<List> split = ArgumentCaptor.forClass(List.class); + + Mockito.verify(context, Mockito.times(PARALLELISM)) + .assignSplit(subtaskId.capture(), split.capture()); + + List<Integer> subTaskAllValues = subtaskId.getAllValues(); + List<List> splitAllValues = split.getAllValues(); + + for (int i = 0; i < PARALLELISM; i++) { + Assertions.assertEquals(i, subTaskAllValues.get(i)); + Assertions.assertEquals( + allocateFiles(i, PARALLELISM, PARTITION_NUMS), splitAllValues.get(i).size()); + } + + // check no duplicate file assigned + Assertions.assertEquals(0, dorisSourceSplitEnumerator.currentUnassignedSplitSize()); + } + + private List<PartitionDefinition> buildPartitionDefinitions() { + + List<PartitionDefinition> partitions = new ArrayList<>(); + + IntStream.range(0, PARTITION_NUMS) + .forEach( + i -> { + PartitionDefinition partitionDefinition = + new PartitionDefinition( + DATABASE, + TABLE, + BE_ADDRESS_PREFIX + i, + new HashSet<>(i), + QUERY_PLAN); + + partitions.add(partitionDefinition); + }); + + return partitions; + } + + /** + * calculate the number of files assigned each time + * + * @param id id + * @param parallelism parallelism + * @param fileSize file size + * @return + */ + public int allocateFiles(int id, int parallelism, int fileSize) { + int filesPerIteration = fileSize / parallelism; + int remainder = fileSize % parallelism; + + if (id < remainder) { + return filesPerIteration + 1; + } else { + return filesPerIteration; + } + } +} diff --git a/seatunnel-translation/seatunnel-translation-base/pom.xml b/seatunnel-translation/seatunnel-translation-base/pom.xml index ac37c22c7c..6c0cf4359e 100644 --- a/seatunnel-translation/seatunnel-translation-base/pom.xml +++ b/seatunnel-translation/seatunnel-translation-base/pom.xml @@ -37,5 +37,12 @@ <version>${project.version}</version> <scope>test</scope> </dependency> + + <dependency> + <groupId>org.apache.seatunnel</groupId> + <artifactId>connector-doris</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> </dependencies> </project> diff --git a/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java b/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java index 2f0b151330..1a2a3b98a7 100644 --- a/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java +++ b/seatunnel-translation/seatunnel-translation-base/src/test/java/org/apache/seatunnel/translation/source/ParallelSourceTest.java @@ -17,106 +17,178 @@ package org.apache.seatunnel.translation.source; -import org.apache.seatunnel.shade.com.typesafe.config.Config; - -import org.apache.seatunnel.api.common.PrepareFailException; -import org.apache.seatunnel.connectors.seatunnel.file.config.FileSystemType; +import org.apache.seatunnel.shade.com.google.common.collect.Maps; + +import org.apache.seatunnel.api.table.catalog.TablePath; +import org.apache.seatunnel.connectors.doris.config.DorisSourceConfig; +import org.apache.seatunnel.connectors.doris.rest.PartitionDefinition; +import org.apache.seatunnel.connectors.doris.rest.RestService; +import org.apache.seatunnel.connectors.doris.source.DorisSource; +import org.apache.seatunnel.connectors.doris.source.DorisSourceTable; +import org.apache.seatunnel.connectors.doris.source.reader.DorisSourceReader; +import org.apache.seatunnel.connectors.doris.source.split.DorisSourceSplit; import org.apache.seatunnel.connectors.seatunnel.file.source.BaseFileSource; -import org.apache.seatunnel.connectors.seatunnel.file.source.BaseFileSourceReader; import org.apache.seatunnel.connectors.seatunnel.file.source.split.FileSourceSplit; +import org.apache.seatunnel.connectors.seatunnel.file.source.split.FileSourceSplitEnumerator; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; import lombok.extern.slf4j.Slf4j; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.mockito.ArgumentMatchers.any; @Slf4j public class ParallelSourceTest { @Test - void testParallelSourceForPollingFileAllocation() throws Exception { + void fileParallelSourceSplitEnumeratorTest() throws Exception { int fileSize = 15; int parallelism = 4; - // create file source - BaseFileSource baseFileSource = - new BaseFileSource() { - @Override - public void prepare(Config pluginConfig) throws PrepareFailException { - filePaths = new ArrayList<>(); - for (int i = 0; i < fileSize; i++) { - filePaths.add("file" + i + ".txt"); - } - } - - @Override - public String getPluginName() { - return FileSystemType.HDFS.getFileSystemPluginName(); - } - }; - - // prepare files - baseFileSource.prepare(null); - - ParallelSource parallelSource = - new ParallelSource(baseFileSource, null, parallelism, "parallel-source-test", 0); - ParallelSource parallelSource2 = - new ParallelSource(baseFileSource, null, parallelism, "parallel-source-test2", 1); - ParallelSource parallelSource3 = - new ParallelSource(baseFileSource, null, parallelism, "parallel-source-test3", 2); - ParallelSource parallelSource4 = - new ParallelSource(baseFileSource, null, parallelism, "parallel-source-test4", 3); - - parallelSource.open(); - parallelSource2.open(); - parallelSource3.open(); - parallelSource4.open(); - - // execute file allocation process - parallelSource.splitEnumerator.run(); - parallelSource2.splitEnumerator.run(); - parallelSource3.splitEnumerator.run(); - parallelSource4.splitEnumerator.run(); - - // Gets the splits assigned for each reader - List<FileSourceSplit> sourceSplits = - ((BaseFileSourceReader) parallelSource.reader).snapshotState(0); - List<FileSourceSplit> sourceSplits2 = - ((BaseFileSourceReader) parallelSource2.reader).snapshotState(0); - List<FileSourceSplit> sourceSplits3 = - ((BaseFileSourceReader) parallelSource3.reader).snapshotState(0); - List<FileSourceSplit> sourceSplits4 = - ((BaseFileSourceReader) parallelSource4.reader).snapshotState(0); - - log.info( - "parallel source1 splits => {}", - sourceSplits.stream().map(FileSourceSplit::splitId).collect(Collectors.toList())); - - log.info( - "parallel source2 splits => {}", - sourceSplits2.stream().map(FileSourceSplit::splitId).collect(Collectors.toList())); - - log.info( - "parallel source3 splits => {}", - sourceSplits3.stream().map(FileSourceSplit::splitId).collect(Collectors.toList())); - - log.info( - "parallel source4 splits => {}", - sourceSplits4.stream().map(FileSourceSplit::splitId).collect(Collectors.toList())); - - // check that there are no duplicate file assignments + List<String> filePaths = new ArrayList<>(); + for (int i = 0; i < fileSize; i++) { + filePaths.add("file" + i + ".txt"); + } + BaseFileSource baseFileSource = Mockito.spy(BaseFileSource.class); + Set<FileSourceSplit> splitSet = new HashSet<>(); - splitSet.addAll(sourceSplits); - splitSet.addAll(sourceSplits2); - splitSet.addAll(sourceSplits3); - splitSet.addAll(sourceSplits4); + for (int i = 0; i < parallelism; i++) { + + ParallelEnumeratorContext<FileSourceSplit> context = + Mockito.mock(ParallelEnumeratorContext.class); + + Mockito.when(context.currentParallelism()).thenReturn(parallelism); + + FileSourceSplitEnumerator fileSourceSplitEnumerator = + new FileSourceSplitEnumerator(context, filePaths); + + Mockito.when(baseFileSource.createEnumerator(any())) + .thenReturn(fileSourceSplitEnumerator); + + ParallelSource parallelSource = + new ParallelSource( + baseFileSource, null, parallelism, "parallel-source-test" + i, i); + + parallelSource.open(); + parallelSource.splitEnumerator.run(); + + ArgumentCaptor<Integer> subtaskId = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor<List> split = ArgumentCaptor.forClass(List.class); + + Mockito.verify(context, Mockito.times(parallelism)) + .assignSplit(subtaskId.capture(), split.capture()); + + List<Integer> subTaskAllValues = subtaskId.getAllValues(); + List<List> splitAllValues = split.getAllValues(); + + Assertions.assertEquals(i, subTaskAllValues.get(i)); + Assertions.assertEquals( + allocateFiles(i, parallelism, fileSize), splitAllValues.get(i).size()); + + splitSet.addAll(splitAllValues.get(i)); + } + // Check that there are no duplicate file assign Assertions.assertEquals(splitSet.size(), fileSize); } + + @Test + public void dorisParallelSourceSplitEnumeratorTest() throws Exception { + int parallelism = 4; + int partitionNums = 30; + + DorisSourceConfig dorisSourceConfig = Mockito.mock(DorisSourceConfig.class); + DorisSourceTable dorisSourceTable = Mockito.mock(DorisSourceTable.class); + + Map<TablePath, DorisSourceTable> dorisSourceTableMap = Maps.newHashMap(); + dorisSourceTableMap.put(new TablePath("default", null, "default_table"), dorisSourceTable); + + DorisSource dorisSource = new DorisSource(dorisSourceConfig, dorisSourceTableMap); + + MockedStatic<RestService> restServiceMockedStatic = Mockito.mockStatic(RestService.class); + restServiceMockedStatic + .when(() -> RestService.findPartitions(any(), any(), any())) + .thenReturn(buildPartitionDefinitions(partitionNums)); + + Set<DorisSourceSplit> splitSet = new HashSet<>(); + for (int i = 0; i < parallelism; i++) { + ParallelSource parallelSource = + new ParallelSource( + dorisSource, null, parallelism, "parallel-doris-source" + i, i); + parallelSource.open(); + + // execute file allocation process + parallelSource.splitEnumerator.run(); + List<DorisSourceSplit> sourceSplits = + ((DorisSourceReader) parallelSource.reader).snapshotState(0); + log.info( + "parallel source{} splits => {}", + i + 1, + sourceSplits.stream() + .map(DorisSourceSplit::splitId) + .collect(Collectors.toList())); + + Assertions.assertEquals( + allocateFiles(i, parallelism, partitionNums), sourceSplits.size()); + + // collect all splits + splitSet.addAll(sourceSplits); + } + + Assertions.assertEquals(splitSet.size(), partitionNums); + } + + private List<PartitionDefinition> buildPartitionDefinitions(int partitionNUms) { + + List<PartitionDefinition> partitions = new ArrayList<>(); + + String beAddressPrefix = "doris-be-"; + + IntStream.range(0, partitionNUms) + .forEach( + i -> { + PartitionDefinition partitionDefinition = + new PartitionDefinition( + "default", + "default_table", + beAddressPrefix + i, + new HashSet<>(i), + "QUERY_PLAN"); + + partitions.add(partitionDefinition); + }); + + return partitions; + } + + /** + * calculate the number of files assigned each time + * + * @param id id + * @param parallelism parallelism + * @param fileSize file size + * @return + */ + public int allocateFiles(int id, int parallelism, int fileSize) { + int filesPerIteration = fileSize / parallelism; + int remainder = fileSize % parallelism; + + if (id < remainder) { + return filesPerIteration + 1; + } else { + return filesPerIteration; + } + } }