This is an automated email from the ASF dual-hosted git repository.

karan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c21eb86cd4 Follow up for MSQ deadlock retry issue.  (#18467)
2c21eb86cd4 is described below

commit 2c21eb86cd4614cb86eb49b50fa97cdb06e7a8f2
Author: Karan Kumar <[email protected]>
AuthorDate: Wed Sep 3 12:59:31 2025 +0530

    Follow up for MSQ deadlock retry issue.  (#18467)
    
    * Initial push
    
    * Minor cosmetic changes.
    
    * Minor cosmetic changes.
    
    * Minor cosmetic changes.
    
    * Update 
multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java
    
    Co-authored-by: Copilot <[email protected]>
    
    * Checkstyle things
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 .../msq/dart/controller/DartControllerContext.java |   4 +-
 .../msq/dart/controller/DartWorkerManager.java     |  13 +-
 .../apache/druid/msq/exec/ControllerContext.java   |  10 +-
 .../org/apache/druid/msq/exec/ControllerImpl.java  |  50 ++--
 .../org/apache/druid/msq/exec/WorkerManager.java   |  12 +-
 .../msq/indexing/IndexerControllerContext.java     |   5 +-
 .../druid/msq/indexing/MSQWorkerTaskLauncher.java  | 153 +++++++----
 .../msq/dart/controller/DartWorkerManagerTest.java |  34 ++-
 .../org/apache/druid/msq/exec/MSQTasksTest.java    |   3 +-
 ...ts.java => MSQWorkerTaskLauncherRetryTest.java} | 289 +++++++++------------
 .../msq/indexing/MSQWorkerTaskLauncherTest.java    |  18 +-
 .../druid/msq/test/MSQTestControllerContext.java   |   5 +-
 12 files changed, 335 insertions(+), 261 deletions(-)

diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
 
b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
index 6a1b65c09f5..24637eb54e8 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
@@ -35,6 +35,7 @@ import org.apache.druid.msq.exec.ControllerMemoryParameters;
 import org.apache.druid.msq.exec.MSQMetriceEventBuilder;
 import org.apache.druid.msq.exec.MemoryIntrospector;
 import org.apache.druid.msq.exec.SegmentSource;
+import org.apache.druid.msq.exec.WorkerFailureListener;
 import org.apache.druid.msq.exec.WorkerManager;
 import org.apache.druid.msq.indexing.IndexerControllerContext;
 import org.apache.druid.msq.indexing.MSQSpec;
@@ -203,7 +204,8 @@ public class DartControllerContext implements 
ControllerContext
   public WorkerManager newWorkerManager(
       String queryId,
       MSQSpec querySpec,
-      ControllerQueryKernelConfig queryKernelConfig
+      ControllerQueryKernelConfig queryKernelConfig,
+      WorkerFailureListener workerFailureListener
   )
   {
     // We're ignoring WorkerFailureListener. Dart worker failures are routed 
into the controller by
diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java
 
b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java
index 5ce31286a5f..8d479671a94 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java
@@ -24,6 +24,7 @@ import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
 import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import it.unimi.dsi.fastutil.ints.IntObjectPair;
 import it.unimi.dsi.fastutil.objects.Object2IntMap;
 import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
 import org.apache.druid.common.guava.FutureUtils;
@@ -34,10 +35,10 @@ import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.msq.dart.worker.DartWorkerClient;
 import org.apache.druid.msq.exec.ControllerContext;
 import org.apache.druid.msq.exec.WorkerClient;
-import org.apache.druid.msq.exec.WorkerFailureListener;
 import org.apache.druid.msq.exec.WorkerManager;
 import org.apache.druid.msq.exec.WorkerStats;
 import org.apache.druid.msq.indexing.WorkerCount;
+import org.apache.druid.msq.indexing.error.MSQFault;
 import org.apache.druid.utils.CloseableUtils;
 
 import java.util.ArrayList;
@@ -86,7 +87,7 @@ public class DartWorkerManager implements WorkerManager
   }
 
   @Override
-  public ListenableFuture<?> start(WorkerFailureListener workerFailureListener)
+  public ListenableFuture<?> start()
   {
     if (!state.compareAndSet(State.NEW, State.STARTED)) {
       throw new ISE("Cannot start from state[%s]", state.get());
@@ -96,7 +97,7 @@ public class DartWorkerManager implements WorkerManager
   }
 
   @Override
-  public void launchWorkersIfNeeded(int workerCount)
+  public Set<IntObjectPair<MSQFault>> launchWorkersIfNeeded(int workerCount)
   {
     // Nothing to do, just validate the count.
     if (workerCount > workerIds.size()) {
@@ -106,10 +107,12 @@ public class DartWorkerManager implements WorkerManager
           workerIds.size()
       );
     }
+    // Dart workers are always available, so no failures to report
+    return Collections.emptySet();
   }
 
   @Override
-  public void waitForWorkers(Set<Integer> workerNumbers)
+  public Set<IntObjectPair<MSQFault>> waitForWorkers(Set<Integer> 
workerNumbers)
   {
     // Nothing to wait for, just validate the numbers.
     for (final int workerNumber : workerNumbers) {
@@ -121,6 +124,8 @@ public class DartWorkerManager implements WorkerManager
         );
       }
     }
+    // Dart workers are always available, so no failures to report
+    return Collections.emptySet();
   }
 
   @Override
diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java
 
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java
index 7c0f7c2300a..a25c944cf63 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerContext.java
@@ -99,14 +99,16 @@ public interface ControllerContext
   /**
    * Provides services about workers: starting, canceling, obtaining status.
    *
-   * @param queryId           query ID
-   * @param querySpec         query spec
-   * @param queryKernelConfig config from {@link #queryKernelConfig(MSQSpec)}
+   * @param queryId               query ID
+   * @param querySpec             query spec
+   * @param queryKernelConfig     config from {@link 
#queryKernelConfig(MSQSpec)}
+   * @param workerFailureListener listener that receives callbacks when 
workers fail
    */
   WorkerManager newWorkerManager(
       String queryId,
       MSQSpec querySpec,
-      ControllerQueryKernelConfig queryKernelConfig
+      ControllerQueryKernelConfig queryKernelConfig,
+      WorkerFailureListener workerFailureListener
   );
 
   /**
diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java 
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
index 7bd6d79b557..57b142303a3 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
@@ -33,6 +33,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
 import it.unimi.dsi.fastutil.ints.IntArraySet;
 import it.unimi.dsi.fastutil.ints.IntList;
+import it.unimi.dsi.fastutil.ints.IntObjectPair;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.client.broker.BrokerClient;
 import org.apache.druid.common.guava.FutureUtils;
@@ -213,7 +214,6 @@ import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -301,7 +301,6 @@ public class ControllerImpl implements Controller
   @Nullable
   private MSQSegmentReport segmentReport;
 
-  private final AtomicLong mainThreadId = new AtomicLong();
 
   public ControllerImpl(
       final LegacyMSQSpec querySpec,
@@ -394,7 +393,6 @@ public class ControllerImpl implements Controller
     final TaskState taskStateForReport;
     final MSQErrorReport errorForReport;
 
-    mainThreadId.set(Thread.currentThread().getId());
 
     try {
       // Planning-related: convert the native query from MSQSpec into a 
multi-stage QueryDefinition.
@@ -749,7 +747,8 @@ public class ControllerImpl implements Controller
     workerManager = context.newWorkerManager(
         context.queryId(),
         querySpec,
-        queryKernelConfig
+        queryKernelConfig,
+        getWorkerFailureListener()
     );
 
     if (queryKernelConfig.isFaultTolerant() && !(workerManager instanceof 
RetryCapableWorkerManager)) {
@@ -782,19 +781,14 @@ public class ControllerImpl implements Controller
     return queryDef;
   }
 
-  private WorkerFailureListener getWorkerFailureListener(ControllerQueryKernel 
controllerQueryKernel)
+  private WorkerFailureListener getWorkerFailureListener()
   {
     return (failedTask, fault) -> {
       throwIfNonRetriableFault(fault);
-      if (Thread.currentThread().getId() == mainThreadId.get()) {
-        // this is called from the main controller thread, so we can directly 
access the kernel.
-        addToRetryQueue(controllerQueryKernel, failedTask.getWorkerNumber(), 
fault);
-      } else {
-        // since this is called from the task launcher thread, we need to add 
it to the kernel manipulation queue so that only the controller thread can 
manipulate the kernel.
-        addToKernelManipulationQueue(kernel -> {
-          addToRetryQueue(kernel, failedTask.getWorkerNumber(), fault);
-        });
-      }
+      // since this is called from the task launcher thread, we need to add it 
to the kernel manipulation queue so that only the controller thread can 
manipulate the kernel.
+      addToKernelManipulationQueue(kernel -> {
+        addToRetryQueue(kernel, failedTask.getWorkerNumber(), fault);
+      });
     };
   }
 
@@ -1397,7 +1391,10 @@ public class ControllerImpl implements Controller
     final List<ListenableFuture<Void>> workerFutures = new 
ArrayList<>(workersCopy.size());
 
     try {
-      workerManager.waitForWorkers(workers);
+      Set<IntObjectPair<MSQFault>> workerFaultSet;
+      while (!(workerFaultSet = 
workerManager.waitForWorkers(workers)).isEmpty()) {
+        retryWorkersOrFailJob(queryKernel, workerFaultSet);
+      }
     }
     catch (InterruptedException e) {
       Thread.currentThread().interrupt();
@@ -1435,6 +1432,14 @@ public class ControllerImpl implements Controller
     }
   }
 
+  private void retryWorkersOrFailJob(ControllerQueryKernel queryKernel, 
Set<IntObjectPair<MSQFault>> workerFaultSet)
+  {
+    for (IntObjectPair<MSQFault> workerFault : workerFaultSet) {
+      throwIfNonRetriableFault(workerFault.right());
+      addToRetryQueue(queryKernel, workerFault.firstInt(), 
workerFault.right());
+    }
+  }
+
   private void startWorkForStage(
       final QueryDefinition queryDef,
       final ControllerQueryKernel queryKernel,
@@ -2234,7 +2239,7 @@ public class ControllerImpl implements Controller
     private final ControllerQueryKernel queryKernel;
 
     /**
-     * Return value of {@link WorkerManager#start(WorkerFailureListener)} )}. 
Set by {@link #startTaskLauncher()}.
+     * Return value of {@link WorkerManager#start()} )}. Set by {@link 
#startTaskLauncher()}.
      */
     private ListenableFuture<?> workerTaskLauncherFuture;
 
@@ -2380,7 +2385,10 @@ public class ControllerImpl implements Controller
       }
 
       // wait till the workers identified above are fully ready
-      workerManager.waitForWorkers(workersNeedToBeFullyStarted);
+      Set<IntObjectPair<MSQFault>> workerFaultSet;
+      while (!(workerFaultSet = 
workerManager.waitForWorkers(workersNeedToBeFullyStarted)).isEmpty()) {
+        retryWorkersOrFailJob(queryKernel, workerFaultSet);
+      }
 
       for (Map.Entry<StageId, Map<Integer, WorkOrder>> stageWorkOrders : 
stageWorkerOrders.entrySet()) {
         contactWorkersForStage(
@@ -2445,7 +2453,7 @@ public class ControllerImpl implements Controller
       // Start tasks.
       log.debug("Query [%s] starting task launcher.", queryDef.getQueryId());
 
-      workerTaskLauncherFuture = 
workerManager.start(getWorkerFailureListener(queryKernel));
+      workerTaskLauncherFuture = workerManager.start();
       closer.register(() -> workerManager.stop(true));
 
       workerTaskLauncherFuture.addListener(
@@ -2559,7 +2567,11 @@ public class ControllerImpl implements Controller
               stageDef.doesShuffle() ? stageDef.getShuffleSpec().kind() : 
"none"
           );
 
-          workerManager.launchWorkersIfNeeded(workerCount);
+
+          Set<IntObjectPair<MSQFault>> workerFaultSet;
+          while (!(workerFaultSet = 
workerManager.launchWorkersIfNeeded(workerCount)).isEmpty()) {
+            retryWorkersOrFailJob(queryKernel, workerFaultSet);
+          }
           stageRuntimesForLiveReports.put(stageId.getStageNumber(), new 
Interval(DateTimes.nowUtc(), DateTimes.MAX));
           startWorkForStage(queryDef, queryKernel, stageId.getStageNumber(), 
segmentsToGenerate);
         }
diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java 
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java
index db25641fb13..ff76867dd30 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java
@@ -20,7 +20,9 @@
 package org.apache.druid.msq.exec;
 
 import com.google.common.util.concurrent.ListenableFuture;
+import it.unimi.dsi.fastutil.ints.IntObjectPair;
 import org.apache.druid.msq.indexing.WorkerCount;
+import org.apache.druid.msq.indexing.error.MSQFault;
 
 import java.util.List;
 import java.util.Map;
@@ -42,19 +44,23 @@ public interface WorkerManager
    * resolves to an exception if one of the workers fails without being 
explicitly canceled, or if something else
    * goes wrong.
    */
-  ListenableFuture<?> start(WorkerFailureListener workerFailureListener);
+  ListenableFuture<?> start();
 
   /**
    * Launch additional workers, if needed, to bring the number of running 
workers up to {@code workerCount}.
    * Blocks until the requested workers are launched. If enough workers are 
already running, this method does nothing.
+   *
+   * @return Worker numbers and the fault that caused them to fail. An empty 
set means all requested workers were launched successfully.
    */
-  void launchWorkersIfNeeded(int workerCount) throws InterruptedException;
+  Set<IntObjectPair<MSQFault>> launchWorkersIfNeeded(int workerCount) throws 
InterruptedException;
 
   /**
    * Blocks until workers with the provided worker numbers (indexes into 
{@link #getWorkerIds()} are ready to be
    * contacted for work.
+   *
+   * @return Worker numbers and the fault that caused them to fail. An empty 
set means all requested workers were launched successfully.
    */
-  void waitForWorkers(Set<Integer> workerNumbers) throws InterruptedException;
+  Set<IntObjectPair<MSQFault>> waitForWorkers(Set<Integer> workerNumbers) 
throws InterruptedException;
 
 
   /**
diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
 
b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
index 317b7d4602d..2244eb7f56f 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
@@ -38,6 +38,7 @@ import org.apache.druid.msq.exec.MSQMetriceEventBuilder;
 import org.apache.druid.msq.exec.MemoryIntrospector;
 import org.apache.druid.msq.exec.SegmentSource;
 import org.apache.druid.msq.exec.WorkerClient;
+import org.apache.druid.msq.exec.WorkerFailureListener;
 import org.apache.druid.msq.exec.WorkerManager;
 import org.apache.druid.msq.guice.MultiStageQuery;
 import 
org.apache.druid.msq.indexing.MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig;
@@ -217,13 +218,15 @@ public class IndexerControllerContext implements 
ControllerContext
   public WorkerManager newWorkerManager(
       final String queryId,
       final MSQSpec querySpec,
-      final ControllerQueryKernelConfig queryKernelConfig
+      final ControllerQueryKernelConfig queryKernelConfig,
+      final WorkerFailureListener workerFailureListener
   )
   {
     return new MSQWorkerTaskLauncher(
         queryId,
         taskDataSource,
         overlordClient,
+        workerFailureListener,
         makeTaskContext(querySpec, queryKernelConfig, taskContext),
         // 10 minutes +- 2 minutes jitter
         TimeUnit.SECONDS.toMillis(600 + 
ThreadLocalRandom.current().nextInt(-4, 5) * 30L),
diff --git 
a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java
 
b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java
index 5155eb25493..d2babe86101 100644
--- 
a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java
+++ 
b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java
@@ -19,17 +19,18 @@
 
 package org.apache.druid.msq.indexing;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
 import com.google.errorprone.annotations.concurrent.GuardedBy;
 import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import it.unimi.dsi.fastutil.ints.IntObjectPair;
 import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.client.indexing.TaskStatusResponse;
 import org.apache.druid.common.guava.FutureUtils;
-import org.apache.druid.error.DruidException;
 import org.apache.druid.indexer.TaskLocation;
 import org.apache.druid.indexer.TaskState;
 import org.apache.druid.indexer.TaskStatus;
@@ -54,6 +55,7 @@ import org.apache.druid.rpc.indexing.OverlordClient;
 
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -64,6 +66,7 @@ import java.util.OptionalLong;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
@@ -149,37 +152,62 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
   private final Set<Integer> failedInactiveWorkers = 
ConcurrentHashMap.newKeySet();
 
   private final ConcurrentHashMap<Integer, List<String>> workerToTaskIds = new 
ConcurrentHashMap<>();
-  private final AtomicReference<WorkerFailureListener> 
workerFailureListenerRef = new AtomicReference<>();
+  private final WorkerFailureListener workerFailureListener;
 
   private final AtomicLong recentFullyStartedWorkerTimeInMillis = new 
AtomicLong(System.currentTimeMillis());
 
+  private final long taskIdsLockTimeout;
+
   public MSQWorkerTaskLauncher(
       final String controllerTaskId,
       final String dataSource,
       final OverlordClient overlordClient,
+      final WorkerFailureListener workerFailureListener,
       final Map<String, Object> taskContextOverrides,
       final long maxTaskStartDelayMillis,
       final MSQWorkerTaskLauncherConfig config
   )
+  {
+    this(
+        controllerTaskId,
+        dataSource,
+        overlordClient,
+        workerFailureListener,
+        taskContextOverrides,
+        maxTaskStartDelayMillis,
+        config,
+        TimeUnit.SECONDS.toMillis(60)
+    );
+  }
+
+  @VisibleForTesting
+  protected MSQWorkerTaskLauncher(
+      final String controllerTaskId,
+      final String dataSource,
+      final OverlordClient overlordClient,
+      final WorkerFailureListener workerFailureListener,
+      final Map<String, Object> taskContextOverrides,
+      final long maxTaskStartDelayMillis,
+      final MSQWorkerTaskLauncherConfig config,
+      final long taskIdsLockTimeout
+  )
   {
     this.controllerTaskId = controllerTaskId;
     this.dataSource = dataSource;
     this.overlordClient = overlordClient;
+    this.workerFailureListener = workerFailureListener;
     this.taskContextOverrides = taskContextOverrides;
     this.exec = Execs.singleThreaded(
         "multi-stage-query-task-launcher[" + 
StringUtils.encodeForFormat(controllerTaskId) + "]-%s"
     );
     this.maxTaskStartDelayMillis = maxTaskStartDelayMillis;
     this.config = config;
+    this.taskIdsLockTimeout = taskIdsLockTimeout;
   }
 
   @Override
-  public ListenableFuture<?> start(WorkerFailureListener workerFailureListener)
+  public ListenableFuture<?> start()
   {
-    if (!this.workerFailureListenerRef.compareAndSet(null, 
workerFailureListener)) {
-      throw DruidException.defensive("WorkerFailureListener already set for 
MSQWorkerTaskLauncher");
-    }
-
     if (state.compareAndSet(State.NEW, State.STARTED)) {
       exec.submit(() -> {
         try {
@@ -248,9 +276,11 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
   }
 
   @Override
-  public void launchWorkersIfNeeded(final int workerCount)
+  public Set<IntObjectPair<MSQFault>> launchWorkersIfNeeded(final int 
workerCount)
       throws InterruptedException
   {
+    Set<IntObjectPair<MSQFault>> failedWorkers = new HashSet<>();
+    
     synchronized (taskIds) {
       retryInactiveTasksIfNeeded(workerCount);
 
@@ -264,23 +294,25 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
           FutureUtils.getUnchecked(stopFuture, false);
           throw new ISE("Stopped");
         }
-        // add failed tasks to retry the queue
-        if (workerFailureListenerRef.get() != null) {
-          for (TaskTracker taskTracker : taskTrackers.values()) {
-            if (taskTracker.isRetrying()) {
-              invokeFailureListener(
-                  taskTracker,
-                  new WorkerFailedFault(
-                      taskTracker.msqWorkerTask.getId(),
-                      taskTracker.statusRef.get().getErrorMsg()
-                  )
-              );
-            }
+        
+        // Check for failed workers and collect them
+        for (TaskTracker taskTracker : taskTrackers.values()) {
+          if (taskTracker.isRetryCandidate()) {
+            failedWorkers.add(IntObjectPair.of(
+                taskTracker.getWorkerNumber(),
+                generateFailureFault(taskTracker.msqWorkerTask.getId(), 
taskTracker.statusRef.get())
+            ));
           }
         }
-        taskIds.wait();
+        if (!failedWorkers.isEmpty()) {
+          return failedWorkers;
+        }
+        taskIds.wait(taskIdsLockTimeout);
       }
     }
+    
+    // this should always be empty
+    return Collections.emptySet();
   }
 
   public void retryInactiveTasksIfNeeded(int taskCount)
@@ -319,9 +351,11 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
   }
 
   @Override
-  public void waitForWorkers(Set<Integer> workerNumbers)
+  public Set<IntObjectPair<MSQFault>> waitForWorkers(Set<Integer> 
workerNumbers)
       throws InterruptedException
   {
+    Set<IntObjectPair<MSQFault>> failedWorkers = new HashSet<>();
+    
     synchronized (taskIds) {
       while (!fullyStartedTasks.containsAll(workerNumbers)) {
         if (stopFuture.isDone() || stopFuture.isCancelled()) {
@@ -329,22 +363,29 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
           throw new ISE("Stopped");
         }
 
-        if (workerFailureListenerRef.get() != null) {
-          for (TaskTracker taskTracker : taskTrackers.values()) {
-            if (taskTracker.isRetrying() && 
workerNumbers.contains(taskTracker.workerNumber)) {
-              invokeFailureListener(taskTracker,
-                                    new WorkerFailedFault(
-                                        taskTracker.msqWorkerTask.getId(),
-                                        
taskTracker.statusRef.get().getErrorMsg()
-                                    )
-              );
-            }
+        // Check for failed workers in the requested set
+        for (TaskTracker taskTracker : taskTrackers.values()) {
+          if (taskTracker.isRetryCandidate() && 
workerNumbers.contains(taskTracker.getWorkerNumber())) {
+            failedWorkers.add(
+                IntObjectPair.of(
+                    taskTracker.getWorkerNumber(),
+                    generateFailureFault(taskTracker.msqWorkerTask.getId(), 
taskTracker.statusRef.get())
+                )
+            );
           }
         }
 
-        taskIds.wait();
+        // return if we found any failed workers. The caller needs to launch 
them and call waitForWorkers again
+        if (!failedWorkers.isEmpty()) {
+          return failedWorkers;
+        }
+
+        taskIds.wait(taskIdsLockTimeout);
+
       }
     }
+    // this should always be empty
+    return Collections.emptySet();
   }
 
   @Override
@@ -590,37 +631,45 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
     for (Map.Entry<String, TaskTracker> taskEntry : 
taskTrackersByWorkerNumber()) {
       final String taskId = taskEntry.getKey();
       final TaskTracker tracker = taskEntry.getValue();
-      if (tracker.isRetrying()) {
+      if (tracker.isRetryCandidate()) {
         continue;
       }
-
-      if (tracker.statusRef.get() == null) {
-        tracker.enableRetrying();
-        removeWorkerFromFullyStartedWorkers(tracker);
-        final String errorMessage = StringUtils.format("Task [%s] status 
missing", taskId);
-        log.info(errorMessage + ". Trying to relaunch the worker");
-        invokeFailureListener(tracker, UnknownFault.forMessage(errorMessage));
-
-      } else if (tracker.didRunTimeOut(maxTaskStartDelayMillis) && 
!canceledWorkerTasks.contains(taskId)) {
+      if (tracker.statusRef.get() != null
+          && tracker.didRunTimeOut(maxTaskStartDelayMillis)
+          && !canceledWorkerTasks.contains(taskId)) {
         removeWorkerFromFullyStartedWorkers(tracker);
         throw new MSQException(new TaskStartTimeoutFault(
             this.getWorkerCount().getPendingWorkerCount(),
             numTasks + 1,
             maxTaskStartDelayMillis
         ));
-      } else if (tracker.didFail() && !canceledWorkerTasks.contains(taskId)) {
-        tracker.enableRetrying();
-        removeWorkerFromFullyStartedWorkers(tracker);
-        TaskStatus taskStatus = tracker.statusRef.get();
-        log.info("Task[%s] failed because %s. Trying to relaunch the worker", 
taskId, taskStatus.getErrorMsg());
-        invokeFailureListener(tracker, new WorkerFailedFault(taskId, 
taskStatus.getErrorMsg()));
+      } else if (tracker.statusRef.get() == null || (tracker.didFail() && 
!canceledWorkerTasks.contains(taskId))) {
+        startRetryingTasksIfNeeded(tracker, taskId);
       }
     }
   }
 
+  private void startRetryingTasksIfNeeded(TaskTracker tracker, String taskId)
+  {
+    tracker.enableRetry();
+    removeWorkerFromFullyStartedWorkers(tracker);
+    MSQFault msqFault = generateFailureFault(taskId, tracker.statusRef.get());
+    log.info("Task[%s] failed caused of [%s]. Trying to relaunch the worker", 
taskId, msqFault);
+    invokeFailureListener(tracker, msqFault);
+  }
+
+  private MSQFault generateFailureFault(String taskId, TaskStatus taskStatus)
+  {
+    if (taskStatus == null) {
+      final String errorMessage = StringUtils.format("Task [%s] status 
missing", taskId);
+      return UnknownFault.forMessage(errorMessage);
+    } else {
+      return new WorkerFailedFault(taskId, taskStatus.getErrorMsg());
+    }
+  }
+
   private void invokeFailureListener(TaskTracker tracker, MSQFault msqFault)
   {
-    WorkerFailureListener workerFailureListener = 
workerFailureListenerRef.get();
     if (workerFailureListener != null) {
       workerFailureListener.onFailure(
           tracker.msqWorkerTask,
@@ -890,7 +939,7 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
     /**
      * Enables retrying for the task
      */
-    public void enableRetrying()
+    public void enableRetry()
     {
       isRetryingRef.set(true);
     }
@@ -898,7 +947,7 @@ public class MSQWorkerTaskLauncher implements 
RetryCapableWorkerManager
     /**
      * Checks is the task is retrying,
      */
-    public boolean isRetrying()
+    public boolean isRetryCandidate()
     {
       return isRetryingRef.get();
     }
diff --git 
a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java
 
b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java
index 10aca82bf98..4dc928e66b2 100644
--- 
a/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java
+++ 
b/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java
@@ -116,9 +116,10 @@ public class DartWorkerManagerTest
   @Test
   public void test_launchWorkersIfNeeded()
   {
-    workerManager.launchWorkersIfNeeded(0); // Does nothing, less than 
WORKERS.size()
-    workerManager.launchWorkersIfNeeded(1); // Does nothing, less than 
WORKERS.size()
-    workerManager.launchWorkersIfNeeded(2); // Does nothing, equal to 
WORKERS.size()
+    // Test successful launch returns empty set
+    Assertions.assertTrue(workerManager.launchWorkersIfNeeded(0).isEmpty()); 
// Does nothing, less than WORKERS.size()
+    Assertions.assertTrue(workerManager.launchWorkersIfNeeded(1).isEmpty()); 
// Does nothing, less than WORKERS.size()
+    Assertions.assertTrue(workerManager.launchWorkersIfNeeded(2).isEmpty()); 
// Does nothing, equal to WORKERS.size()
     Assert.assertThrows(
         DruidException.class,
         () -> workerManager.launchWorkersIfNeeded(3)
@@ -129,7 +130,26 @@ public class DartWorkerManagerTest
   public void test_waitForWorkers()
   {
     workerManager.launchWorkersIfNeeded(2);
-    workerManager.waitForWorkers(IntSet.of(0, 1)); // Returns immediately
+    // Test successful wait returns empty set
+    Assertions.assertTrue(workerManager.waitForWorkers(IntSet.of(0, 
1)).isEmpty()); // Returns immediately
+  }
+
+  @Test
+  public void test_launchWorkersIfNeeded_returnsEmptySet()
+  {
+    // Dart workers are always available, so should always return empty set
+    var result = workerManager.launchWorkersIfNeeded(2);
+    Assertions.assertTrue(result.isEmpty());
+    Assertions.assertEquals(0, result.size());
+  }
+
+  @Test
+  public void test_waitForWorkers_returnsEmptySet()
+  {
+    // Dart workers are always available, so should always return empty set
+    var result = workerManager.waitForWorkers(IntSet.of(0, 1));
+    Assertions.assertTrue(result.isEmpty());
+    Assertions.assertEquals(0, result.size());
   }
 
   @Test
@@ -140,7 +160,7 @@ public class DartWorkerManagerTest
     Mockito.when(workerClient.stopWorker(WORKERS.get(1)))
            .thenReturn(Futures.immediateFuture(null));
 
-    final ListenableFuture<?> future = workerManager.start(null);
+    final ListenableFuture<?> future = workerManager.start();
     workerManager.stop(false);
 
     // Ensure the future from start() resolves.
@@ -155,7 +175,7 @@ public class DartWorkerManagerTest
     Mockito.when(workerClient.stopWorker(WORKERS.get(1)))
            .thenReturn(Futures.immediateFuture(null));
 
-    final ListenableFuture<?> future = workerManager.start(null);
+    final ListenableFuture<?> future = workerManager.start();
     workerManager.stop(true);
 
     // Ensure the future from start() resolves.
@@ -170,7 +190,7 @@ public class DartWorkerManagerTest
     Mockito.when(workerClient.stopWorker(WORKERS.get(1)))
            .thenReturn(Futures.immediateFuture(null));
 
-    final ListenableFuture<?> future = workerManager.start(null);
+    final ListenableFuture<?> future = workerManager.start();
     workerManager.stop(true);
 
     // Ensure the future from start() resolves.
diff --git 
a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java 
b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java
index 02e11444c2e..6291a9d7f35 100644
--- 
a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java
+++ 
b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQTasksTest.java
@@ -230,13 +230,14 @@ public class MSQTasksTest
         CONTROLLER_ID,
         "foo",
         new TasksTestOverlordClient(numSlots),
+        null, // WorkerFailureListener
         ImmutableMap.of(),
         TimeUnit.SECONDS.toMillis(5),
         new MSQWorkerTaskLauncherConfig()
     );
 
     try {
-      msqWorkerTaskLauncher.start(null);
+      msqWorkerTaskLauncher.start();
       msqWorkerTaskLauncher.launchWorkersIfNeeded(numTasks);
       fail();
     }
diff --git 
a/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherRetryTests.java
 
b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherRetryTest.java
similarity index 60%
rename from 
multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherRetryTests.java
rename to 
multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherRetryTest.java
index f2e2477dd47..cecadc55fb9 100644
--- 
a/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherRetryTests.java
+++ 
b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherRetryTest.java
@@ -20,27 +20,29 @@
 package org.apache.druid.msq.indexing;
 
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import it.unimi.dsi.fastutil.ints.IntObjectPair;
 import org.apache.druid.client.indexing.IndexingTotalWorkerCapacityInfo;
 import org.apache.druid.client.indexing.IndexingWorkerInfo;
 import org.apache.druid.client.indexing.TaskPayloadResponse;
 import org.apache.druid.client.indexing.TaskStatusResponse;
-import org.apache.druid.error.DruidException;
 import org.apache.druid.indexer.RunnerTaskState;
 import org.apache.druid.indexer.TaskLocation;
 import org.apache.druid.indexer.TaskState;
 import org.apache.druid.indexer.TaskStatus;
 import org.apache.druid.indexer.TaskStatusPlus;
 import org.apache.druid.indexer.report.TaskReport;
-import org.apache.druid.indexing.overlord.TaskQueue;
 import org.apache.druid.indexing.overlord.supervisor.SupervisorSpec;
 import org.apache.druid.indexing.overlord.supervisor.SupervisorStatus;
 import org.apache.druid.java.util.common.UOE;
 import org.apache.druid.java.util.common.parsers.CloseableIterator;
 import org.apache.druid.metadata.LockFilterPolicy;
 import org.apache.druid.msq.exec.MSQTasks;
+import org.apache.druid.msq.exec.WorkerFailureListener;
+import org.apache.druid.msq.indexing.error.MSQFault;
+import org.apache.druid.msq.indexing.error.WorkerFailedFault;
 import org.apache.druid.rpc.ServiceRetryPolicy;
 import org.apache.druid.rpc.UpdateResponse;
 import org.apache.druid.rpc.indexing.OverlordClient;
@@ -57,215 +59,178 @@ import org.junit.jupiter.api.Test;
 import javax.annotation.Nullable;
 import java.net.URI;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentSkipListSet;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 
-public class MSQWorkerTaskLauncherRetryTests
+public class MSQWorkerTaskLauncherRetryTest
 {
-
   private static final TaskLocation RUNNING_TASK_LOCATION = new 
TaskLocation("host", 1, 2, null);
 
   @Test
-  public void mainThreadBlockingSimulationTest() throws Exception
+  public void testLaunchWorkersIfNeeded_returnsFailedWorkers() throws 
InterruptedException
   {
-    final ExecutorService executors = Executors.newSingleThreadExecutor(new 
ThreadFactoryBuilder().setDaemon(false)
-                                                                               
                   .setNameFormat(
-                                                                               
                       "Controller-simulator-%d")
-                                                                               
                   .build());
-
-    final TestOverlordClient overlordClient = new TestOverlordClient();
-    final int failedWorkerNumber = 2;
-    final CountDownLatch workerFailedLatch = new CountDownLatch(1);
-    final CountDownLatch workerStartedLatch = new CountDownLatch(1);
-    overlordClient.addFailedWorker(2);
-    overlordClient.addUnknownLocationWorker(1);
-
-    final MSQWorkerTaskLauncher msqWorkerTaskLauncher = new 
MSQWorkerTaskLauncher(
+    TestOverlordClient overlordClient = new TestOverlordClient();
+    overlordClient.addFailedWorker(1);
+
+    AtomicInteger failureCallbackCount = new AtomicInteger(0);
+    WorkerFailureListener workerFailureListener = (task, fault) -> {
+      failureCallbackCount.incrementAndGet();
+      Assertions.assertEquals(1, task.getWorkerNumber());
+      Assertions.assertTrue(fault instanceof WorkerFailedFault);
+    };
+
+    MSQWorkerTaskLauncher launcher = new MSQWorkerTaskLauncher(
         "controller-id",
         "foo",
         overlordClient,
+        workerFailureListener,
         ImmutableMap.of(),
         TimeUnit.SECONDS.toMillis(5),
-        new MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig()
+        new MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig(),
+        2
     );
 
-    try {
-      final long workerThreadId = Thread.currentThread().getId();
+    launcher.start();
 
-      startTaskLauncher(
-          msqWorkerTaskLauncher,
-          failedWorkerNumber,
-          workerFailedLatch,
-          overlordClient,
-          workerThreadId,
-          workerStartedLatch
-      );
+    // Should return failed workers in the set
+    Set<IntObjectPair<MSQFault>> failedWorkers = 
launcher.launchWorkersIfNeeded(2);
 
-      MockConsumer mockConsumer = new MockConsumer(
-          msqWorkerTaskLauncher,
-          3,
-          workerStartedLatch
-      );
-      Future<?> futures = executors.submit(mockConsumer);
-      // hook called but worker not queued for relaunch.
-      workerFailedLatch.await();
-      Assertions.assertEquals(1, workerStartedLatch.getCount());
-      // we would need to call hooks to allow the main thread to proceed since 
we are using an exec service to so the thread id's would not match.
-      enableWorkerRelaunch(overlordClient, failedWorkerNumber, 
msqWorkerTaskLauncher, workerStartedLatch);
-      // future should be completed in 5 seconds else throw an exception.
-      Assertions.assertNull(futures.get(5, TimeUnit.SECONDS));
-    }
-    finally {
-      msqWorkerTaskLauncher.stop(true);
-      executors.shutdownNow();
+    // The method should not invoke the failure listener directly anymore, 
+    // but should return the failed workers
+    Assertions.assertFalse(failedWorkers.isEmpty());
+
+    // Check that the failed worker is in the returned set
+    boolean foundFailedWorker = false;
+    for (IntObjectPair<MSQFault> failedWorker : failedWorkers) {
+      if (failedWorker.leftInt() == 1) {
+        foundFailedWorker = true;
+        Assertions.assertTrue(failedWorker.right() instanceof 
WorkerFailedFault);
+      }
     }
-  }
+    Assertions.assertTrue(foundFailedWorker, "Failed worker should be in the 
returned set");
 
-  private static void enableWorkerRelaunch(
-      TestOverlordClient overlordClient,
-      int failedWorkerNumber,
-      MSQWorkerTaskLauncher msqWorkerTaskLauncher,
-      CountDownLatch workerStartedLatch
-  )
-  {
-    overlordClient.removeUnknownLocationWorker(1);
-    overlordClient.removefailedWorker(failedWorkerNumber);
-    msqWorkerTaskLauncher.submitForRelaunch(failedWorkerNumber);
-    workerStartedLatch.countDown();
+    launcher.stop(true);
   }
 
-  private static void startTaskLauncher(
-      MSQWorkerTaskLauncher msqWorkerTaskLauncher,
-      int failedWorkerNumber,
-      CountDownLatch workerFailedLatch,
-      TestOverlordClient overlordClient,
-      long workerThreadId,
-      CountDownLatch workerStartedLatch
-  )
+  @Test
+  public void testWaitForWorkers_returnsFailedWorkers() throws 
InterruptedException
   {
-    msqWorkerTaskLauncher.start((task, fault) -> {
-      Assertions.assertEquals(failedWorkerNumber, task.getWorkerNumber());
-      workerFailedLatch.countDown();
-      if (workerThreadId == Thread.currentThread().getId()) {
-        // If the worker thread is the same as the main thread, we can 
directly relaunch the worker.
-        enableWorkerRelaunch(overlordClient, failedWorkerNumber, 
msqWorkerTaskLauncher, workerStartedLatch);
-      }
-    });
-  }
+    TestOverlordClient overlordClient = new TestOverlordClient();
+    overlordClient.addFailedWorker(0);
 
+    AtomicInteger failureCallbackCount = new AtomicInteger(0);
+    WorkerFailureListener workerFailureListener = (task, fault) -> {
+      failureCallbackCount.incrementAndGet();
+    };
 
-  @Test
-  public void mainThreadNonBlockingSimulationTest() throws Exception
-  {
-    final TestOverlordClient overlordClient = new TestOverlordClient();
-    final int failedWorkerNumber = 2;
-    final CountDownLatch workerFailedLatch = new CountDownLatch(1);
-    final CountDownLatch workerStartedLatch = new CountDownLatch(1);
-    overlordClient.addFailedWorker(2);
-    overlordClient.addUnknownLocationWorker(1);
-
-    final MSQWorkerTaskLauncher msqWorkerTaskLauncher = new 
MSQWorkerTaskLauncher(
+    MSQWorkerTaskLauncher launcher = new MSQWorkerTaskLauncher(
         "controller-id",
         "foo",
         overlordClient,
+        workerFailureListener,
         ImmutableMap.of(),
         TimeUnit.SECONDS.toMillis(5),
-        new MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig()
+        new MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig(),
+        2
     );
 
-    try {
-      final long workerThreadId = Thread.currentThread().getId();
+    launcher.start();
 
-      startTaskLauncher(
-          msqWorkerTaskLauncher,
-          failedWorkerNumber,
-          workerFailedLatch,
-          overlordClient,
-          workerThreadId,
-          workerStartedLatch
-      );
+    // Launch workers first
+    launcher.launchWorkersIfNeeded(1);
 
+    // Should return failed workers in the set when waiting for specific 
workers
+    Set<IntObjectPair<MSQFault>> failedWorkers = 
launcher.waitForWorkers(ImmutableSet.of(0));
 
-      MockConsumer mockConsumer = new MockConsumer(
-          msqWorkerTaskLauncher,
-          3,
-          workerStartedLatch
-      );
-      mockConsumer.run();
-      // failed latch  called
-      workerFailedLatch.await();
-      // worker started.
-      workerStartedLatch.await();
-    }
-    finally {
-      msqWorkerTaskLauncher.stop(true);
+    // The method should not invoke the failure listener directly anymore,
+    // but should return the failed workers
+    Assertions.assertFalse(failedWorkers.isEmpty());
+
+    // Check that the failed worker is in the returned set
+    boolean foundFailedWorker = false;
+    for (IntObjectPair<MSQFault> failedWorker : failedWorkers) {
+      if (failedWorker.leftInt() == 0) {
+        foundFailedWorker = true;
+        Assertions.assertTrue(failedWorker.right() instanceof 
WorkerFailedFault);
+      }
     }
-  }
+    Assertions.assertTrue(foundFailedWorker, "Failed worker should be in the 
returned set");
 
+    launcher.stop(true);
+  }
 
-  private static class MockConsumer implements Runnable
+  @Test
+  public void testLaunchWorkersIfNeeded_returnsEmptySet_whenNoFailures() 
throws InterruptedException
   {
+    TestOverlordClient overlordClient = new TestOverlordClient();
+    // Don't add any failed workers
 
-    private final MSQWorkerTaskLauncher msqWorkerTaskLauncher;
-    private final int taskCount;
-    private final CountDownLatch workerStartedLatch;
+    WorkerFailureListener workerFailureListener = (task, fault) -> {
+      Assertions.fail("Should not call failure listener when no workers fail");
+    };
 
-    public MockConsumer(
-        MSQWorkerTaskLauncher msqWorkerTaskLauncher,
-        int tasksCount,
-        CountDownLatch workerStartedLatch
-    )
-    {
-      this.msqWorkerTaskLauncher = msqWorkerTaskLauncher;
-      this.taskCount = tasksCount;
-      this.workerStartedLatch = workerStartedLatch;
-    }
+    MSQWorkerTaskLauncher launcher = new MSQWorkerTaskLauncher(
+        "controller-id",
+        "foo",
+        overlordClient,
+        workerFailureListener,
+        ImmutableMap.of(),
+        TimeUnit.SECONDS.toMillis(5),
+        new MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig(),
+        2
+    );
 
+    launcher.start();
 
-    @Override
-    public void run()
-    {
-      // start stages
-      try {
-        msqWorkerTaskLauncher.launchWorkersIfNeeded(taskCount);
-        workerStartedLatch.await();
-      }
-      catch (InterruptedException e) {
-        throw new RuntimeException(e);
-      }
-      Set<Integer> workerNumbers = new HashSet<>();
-      for (int i = 0; i < taskCount; i++) {
-        workerNumbers.add(i);
-      }
+    // Should return empty set when no workers fail
+    Set<IntObjectPair<MSQFault>> failedWorkers = 
launcher.launchWorkersIfNeeded(1);
+    Assertions.assertTrue(failedWorkers.isEmpty());
 
-      // submit work worders
-      try {
-        msqWorkerTaskLauncher.waitForWorkers(workerNumbers);
-      }
-      catch (InterruptedException e) {
-        throw new RuntimeException(e);
-      }
-    }
+    launcher.stop(true);
   }
 
+  @Test
+  public void testWaitForWorkers_returnsEmptySet_whenNoFailures() throws 
InterruptedException
+  {
+    TestOverlordClient overlordClient = new TestOverlordClient();
+    // Don't add any failed workers
+
+    WorkerFailureListener workerFailureListener = (task, fault) -> {
+      Assertions.fail("Should not call failure listener when no workers fail");
+    };
+
+    MSQWorkerTaskLauncher launcher = new MSQWorkerTaskLauncher(
+        "controller-id",
+        "foo",
+        overlordClient,
+        workerFailureListener,
+        ImmutableMap.of(),
+        TimeUnit.SECONDS.toMillis(5),
+        new MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig(),
+        2
+    );
+
+    launcher.start();
+
+    // Launch workers first
+    launcher.launchWorkersIfNeeded(1);
+
+    // Should return empty set when no workers fail
+    Set<IntObjectPair<MSQFault>> failedWorkers = 
launcher.waitForWorkers(ImmutableSet.of(0));
+    Assertions.assertTrue(failedWorkers.isEmpty());
+
+    launcher.stop(true);
+  }
 
   private static class TestOverlordClient implements OverlordClient
   {
     private final ConcurrentSkipListSet<Integer> unknownLocationWorkers = new 
ConcurrentSkipListSet<>();
     private final ConcurrentSkipListSet<Integer> failedWorkers = new 
ConcurrentSkipListSet<>();
 
-    public TestOverlordClient()
-    {
-    }
-
     @Override
     public ListenableFuture<URI> findCurrentLeader()
     {
@@ -281,11 +246,7 @@ public class MSQWorkerTaskLauncherRetryTests
     @Override
     public ListenableFuture<Void> cancelTask(String taskId)
     {
-      if (failedWorkers.contains(MSQTasks.workerFromTaskId(taskId))) {
-        return Futures.immediateFuture(null);
-      } else {
-        throw DruidException.defensive("Task %s should not be cancelled", 
taskId);
-      }
+      return Futures.immediateFuture(null);
     }
 
     @Override
@@ -305,7 +266,7 @@ public class MSQWorkerTaskLauncherRetryTests
       for (String taskId : taskIds) {
         int workerNumber = MSQTasks.workerFromTaskId(taskId);
         if (failedWorkers.contains(workerNumber)) {
-          taskStatusMap.put(taskId, TaskStatus.failure(taskId, 
TaskQueue.FAILED_TO_RUN_TASK_SEE_OVERLORD_MSG));
+          taskStatusMap.put(taskId, TaskStatus.failure(taskId, "Task failed"));
         } else if (unknownLocationWorkers.contains(workerNumber)) {
           taskStatusMap.put(taskId, 
TaskStatus.running(taskId).withLocation(TaskLocation.unknown()));
         } else {
diff --git 
a/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java
 
b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java
index 3cd1b5b98f8..2e5c4859904 100644
--- 
a/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java
+++ 
b/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncherTest.java
@@ -21,7 +21,6 @@ package org.apache.druid.msq.indexing;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
-import org.apache.druid.error.DruidException;
 import org.apache.druid.msq.exec.WorkerFailureListener;
 import 
org.apache.druid.msq.indexing.MSQWorkerTaskLauncher.MSQWorkerTaskLauncherConfig;
 import org.apache.druid.rpc.indexing.OverlordClient;
@@ -44,6 +43,7 @@ public class MSQWorkerTaskLauncherTest
         "controller-id",
         "foo",
         Mockito.mock(OverlordClient.class),
+        getWorkerFailureListener(),
         ImmutableMap.of(),
         TimeUnit.SECONDS.toMillis(5),
         new MSQWorkerTaskLauncherConfig()
@@ -59,11 +59,21 @@ public class MSQWorkerTaskLauncherTest
     Assert.assertEquals(target.getWorkersToRelaunch(), ImmutableSet.of(1));
   }
 
+  
   @Test
-  public void testMultipleWorkerFailureRegistration()
+  public void testLaunchWorkersIfNeeded_returnsEmptySet_whenNoFailures() 
throws InterruptedException
   {
-    target.start(getWorkerFailureListener());
-    Assert.assertThrows(DruidException.class, () -> 
target.start(getWorkerFailureListener()));
+    // Test that launchWorkersIfNeeded returns empty set when no workers fail
+    var result = target.launchWorkersIfNeeded(0);
+    Assert.assertTrue(result.isEmpty());
+  }
+  
+  @Test
+  public void testWaitForWorkers_returnsEmptySet_whenNoFailures() throws 
InterruptedException
+  {
+    // Test that waitForWorkers returns empty set when no workers fail
+    var result = target.waitForWorkers(ImmutableSet.of());
+    Assert.assertTrue(result.isEmpty());
   }
 
   private static WorkerFailureListener getWorkerFailureListener()
diff --git 
a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java
 
b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java
index dd618faafd6..6587703dfef 100644
--- 
a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java
+++ 
b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java
@@ -53,6 +53,7 @@ import org.apache.druid.msq.exec.MSQMetriceEventBuilder;
 import org.apache.druid.msq.exec.SegmentSource;
 import org.apache.druid.msq.exec.Worker;
 import org.apache.druid.msq.exec.WorkerClient;
+import org.apache.druid.msq.exec.WorkerFailureListener;
 import org.apache.druid.msq.exec.WorkerImpl;
 import org.apache.druid.msq.exec.WorkerManager;
 import org.apache.druid.msq.exec.WorkerMemoryParameters;
@@ -346,7 +347,8 @@ public class MSQTestControllerContext implements 
ControllerContext, DartControll
   public WorkerManager newWorkerManager(
       String queryId,
       MSQSpec querySpec,
-      ControllerQueryKernelConfig queryKernelConfig
+      ControllerQueryKernelConfig queryKernelConfig,
+      WorkerFailureListener workerFailureListener
   )
   {
     MSQWorkerTaskLauncherConfig taskLauncherConfig = new 
MSQWorkerTaskLauncherConfig();
@@ -358,6 +360,7 @@ public class MSQTestControllerContext implements 
ControllerContext, DartControll
         controller.queryId(),
         "test-datasource",
         overlordClient,
+        workerFailureListener,
         IndexerControllerContext.makeTaskContext(querySpec, queryKernelConfig, 
ImmutableMap.of()),
         0,
         taskLauncherConfig


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to