This is an automated email from the ASF dual-hosted git repository.

manishswaminathan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 48283e9580 Eliminate duplicate cancel attempts in 
PerQueryCPUMemAccountant (#16299)
48283e9580 is described below

commit 48283e9580b4633d68ffabf2dafbf66a9ef5d6d0
Author: Rajat Venkatesh <[email protected]>
AuthorDate: Fri Jul 11 18:04:27 2025 +0530

    Eliminate duplicate cancel attempts in PerQueryCPUMemAccountant (#16299)
    
    * Add basic 1 query tests
    
    * Add more tests
    
    * Add ability to remember cancel queries.
    
    * Clean up if conditions in killMostExpensiveQuery
    
    * Fix test failures.
    
    * Address review comments.
---
 .../PerQueryCPUMemAccountantFactory.java           | 266 ++++++++++----------
 .../PerQueryCPUMemAccountCancelTest.java           | 189 +++++++++++++++
 .../accounting/PerQueryCPUMemAccountantTest.java   | 268 +++++++++++++++++++++
 ...actoryTest.java => QueryMonitorConfigTest.java} |   2 +-
 .../core/accounting/TestResourceAccountant.java    |  95 ++++++++
 5 files changed, 697 insertions(+), 123 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
index 6a0180720f..704301d767 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactory.java
@@ -115,7 +115,7 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
         = ThreadLocal.withInitial(() -> {
           CPUMemThreadLevelAccountingObjects.ThreadEntry ret =
               new CPUMemThreadLevelAccountingObjects.ThreadEntry();
-          _threadEntriesMap.put(Thread.currentThread(), ret);
+          addThreadEntry(Thread.currentThread(), ret);
           LOGGER.debug("Adding thread to _threadLocalEntry: {}", 
Thread.currentThread().getName());
           return ret;
         }
@@ -132,6 +132,8 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
 
     protected final Set<String> _inactiveQuery;
 
+    protected Set<String> _cancelSentQueries;
+
     // the periodical task that aggregates and preempts queries
     protected final WatcherTask _watcherTask;
 
@@ -140,6 +142,20 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
 
     protected final InstanceType _instanceType;
 
+    protected PerQueryCPUMemResourceUsageAccountant(PinotConfiguration config, 
boolean isThreadCPUSamplingEnabled,
+        boolean isThreadMemorySamplingEnabled, boolean 
isThreadSamplingEnabledForMSE, Set<String> inactiveQuery,
+        String instanceId, InstanceType instanceType) {
+      _config = config;
+      _isThreadCPUSamplingEnabled = isThreadCPUSamplingEnabled;
+      _isThreadMemorySamplingEnabled = isThreadMemorySamplingEnabled;
+      _isThreadSamplingEnabledForMSE = isThreadSamplingEnabledForMSE;
+      _inactiveQuery = inactiveQuery;
+      _instanceId = instanceId;
+      _instanceType = instanceType;
+      _cancelSentQueries = new HashSet<>();
+      _watcherTask = createWatcherTask();
+    }
+
     public PerQueryCPUMemResourceUsageAccountant(PinotConfiguration config, 
String instanceId,
         InstanceType instanceType) {
 
@@ -174,8 +190,12 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
 
       // task/query tracking
       _inactiveQuery = new HashSet<>();
+      _cancelSentQueries = new HashSet<>();
+      _watcherTask = createWatcherTask();
+    }
 
-      _watcherTask = new WatcherTask();
+    protected WatcherTask createWatcherTask() {
+      return new WatcherTask();
     }
 
     @Override
@@ -185,26 +205,19 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
 
     /**
      * This function aggregates resource usage from all active threads and 
groups by queryId.
-     * It is inspired by {@link 
PerQueryCPUMemResourceUsageAccountant::aggregate}. The major difference is that
-     * it only reads from thread entries and does not update them.
      * @return A map of query id, QueryResourceTracker.
      */
     @Override
     public Map<String, ? extends QueryResourceTracker> getQueryResources() {
+      return getQueryResourcesImpl();
+    }
+
+    protected Map<String, AggregatedStats> getQueryResourcesImpl() {
       HashMap<String, AggregatedStats> ret = new HashMap<>();
 
       // for each {pqr, pqw}
-      for (Map.Entry<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
entry : _threadEntriesMap.entrySet()) {
-        // sample current usage
-        CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
entry.getValue();
-        long currentCPUSample = _isThreadCPUSamplingEnabled
-            ? threadEntry._currentThreadCPUTimeSampleMS : 0;
-        long currentMemSample = _isThreadMemorySamplingEnabled
-            ? threadEntry._currentThreadMemoryAllocationSampleBytes : 0;
-        // sample current running task status
+      for (CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry : 
_threadEntriesMap.values()) {
         CPUMemThreadLevelAccountingObjects.TaskEntry currentTaskStatus = 
threadEntry.getCurrentThreadTaskStatus();
-        Thread thread = entry.getKey();
-        LOGGER.trace("tid: {}, task: {}", thread.getId(), currentTaskStatus);
 
         // if current thread is not idle
         if (currentTaskStatus != null) {
@@ -213,10 +226,14 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
           if (queryId != null) {
             Thread anchorThread = currentTaskStatus.getAnchorThread();
             boolean isAnchorThread = currentTaskStatus.isAnchorThread();
+            long currentCPUSample = _isThreadCPUSamplingEnabled ? 
threadEntry._currentThreadCPUTimeSampleMS : 0;
+            long currentMemSample =
+                _isThreadMemorySamplingEnabled ? 
threadEntry._currentThreadMemoryAllocationSampleBytes : 0;
             ret.compute(queryId,
                 (k, v) -> v == null ? new AggregatedStats(currentCPUSample, 
currentMemSample, anchorThread,
                     isAnchorThread, threadEntry._errorStatus, queryId)
-                    : v.merge(currentCPUSample, currentMemSample, 
isAnchorThread, threadEntry._errorStatus));
+                    : v.merge(currentCPUSample, currentMemSample, 
isAnchorThread,
+                        threadEntry._errorStatus));
           }
         }
       }
@@ -368,6 +385,10 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
       return _threadLocalEntry.get();
     }
 
+    public void addThreadEntry(Thread thread, 
CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry) {
+      _threadEntriesMap.put(thread, threadEntry);
+    }
+
     /**
      * clears thread accounting info once a runner/worker thread has finished 
a particular run
      */
@@ -406,6 +427,7 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
           _finishedTaskMemStatsAggregator.remove(inactiveQueryId);
           _concurrentTaskMemStatsAggregator.remove(inactiveQueryId);
         }
+        _cancelSentQueries.remove(inactiveQueryId);
       }
       _inactiveQuery.clear();
       if (_isThreadCPUSamplingEnabled) {
@@ -418,29 +440,36 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
       }
     }
 
+    public Set<String> getInactiveQueries() {
+      return Collections.unmodifiableSet(_inactiveQuery);
+    }
+
+    public Set<String> getCancelSentQueries() {
+      return Collections.unmodifiableSet(_cancelSentQueries);
+    }
+
     /**
-     * aggregated the stats if the query killing process is triggered
-     * @param isTriggered if the query killing process is triggered
-     * @return aggregated stats of active queries if triggered
+     * This function moves finished tasks through 2 stages.
+     * Initially, task metadata is stored in 
threadEntry._currentThreadTaskStatus etc.
+     * In the first step, it moves this metadata to 
threadEntry._previousThreadTaskStatus etc.
+     *
+     * At the same time, the function moves the information in 
_threadEntry._previousThreadTaskStatus into the second
+     * state. It is aggregated into _finishedTaskCPUStatsAggregator and 
_finishedTaskMemStatsAggregator.
+     *
+     * Finally it cleans up _inactiveQueries AND _threadEntriesMap.
      */
-    public Map<String, AggregatedStats> aggregate(boolean isTriggered) {
-      HashMap<String, AggregatedStats> ret = null;
-      if (isTriggered) {
-        ret = new HashMap<>();
-      }
+    public void reapFinishedTasks() {
+      Set<String> cancellingQueries = new HashSet<>();
 
       // for each {pqr, pqw}
       for (Map.Entry<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
entry : _threadEntriesMap.entrySet()) {
         // sample current usage
         CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
entry.getValue();
-
         long currentCPUSample = _isThreadCPUSamplingEnabled ? 
threadEntry._currentThreadCPUTimeSampleMS : 0;
         long currentMemSample =
             _isThreadMemorySamplingEnabled ? 
threadEntry._currentThreadMemoryAllocationSampleBytes : 0;
         // sample current running task status
         CPUMemThreadLevelAccountingObjects.TaskEntry currentTaskStatus = 
threadEntry.getCurrentThreadTaskStatus();
-        Thread thread = entry.getKey();
-        LOGGER.trace("tid: {}, task: {}", thread.getId(), currentTaskStatus);
 
         // get last task on the thread
         CPUMemThreadLevelAccountingObjects.TaskEntry lastQueryTask = 
threadEntry._previousThreadTaskStatus;
@@ -477,49 +506,33 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
           String queryId = currentTaskStatus.getQueryId();
           // update inactive queries for cleanInactive()
           _inactiveQuery.remove(queryId);
-          // if triggered, accumulate active query task stats
-          if (isTriggered) {
-            Thread anchorThread = currentTaskStatus.getAnchorThread();
-            boolean isAnchorThread = currentTaskStatus.isAnchorThread();
-            ret.compute(queryId, (k, v) -> v == null
-                ? new AggregatedStats(currentCPUSample, currentMemSample, 
anchorThread,
-                isAnchorThread, threadEntry._errorStatus, queryId)
-                : v.merge(currentCPUSample, currentMemSample, isAnchorThread, 
threadEntry._errorStatus));
+          // If query is in cancelling set, retain it.
+          if (_cancelSentQueries.contains(queryId)) {
+            cancellingQueries.add(queryId);
           }
         }
 
+        Thread thread = entry.getKey();
         if (!thread.isAlive()) {
           _threadEntriesMap.remove(thread);
           LOGGER.debug("Removing thread from _threadLocalEntry: {}", 
thread.getName());
         }
       }
-
-      // if triggered, accumulate stats of finished tasks of each active query
-      if (isTriggered) {
-        for (Map.Entry<String, AggregatedStats> queryIdResult : 
ret.entrySet()) {
-          String activeQueryId = queryIdResult.getKey();
-          long accumulatedCPUValue = _isThreadCPUSamplingEnabled
-              ? _finishedTaskCPUStatsAggregator.getOrDefault(activeQueryId, 
0L) : 0;
-          long concurrentCPUValue = _isThreadCPUSamplingEnabled
-              ? _concurrentTaskCPUStatsAggregator.getOrDefault(activeQueryId, 
0L) : 0;
-          long accumulatedMemValue = _isThreadMemorySamplingEnabled
-              ? _finishedTaskMemStatsAggregator.getOrDefault(activeQueryId, 
0L) : 0;
-          long concurrentMemValue = _isThreadMemorySamplingEnabled
-              ? _concurrentTaskMemStatsAggregator.getOrDefault(activeQueryId, 
0L) : 0;
-          queryIdResult.getValue().merge(accumulatedCPUValue + 
concurrentCPUValue,
-              accumulatedMemValue + concurrentMemValue, false, null);
-        }
-      }
-      return ret;
+      _cancelSentQueries = cancellingQueries;
     }
 
-    public void postAggregation(Map<String, AggregatedStats> 
aggregatedUsagePerActiveQuery) {
+    protected void postAggregation(Map<String, AggregatedStats> 
aggregatedUsagePerActiveQuery) {
     }
 
     protected void logQueryResourceUsage(Map<String, ? extends 
QueryResourceTracker> aggregatedUsagePerActiveQuery) {
       LOGGER.warn("Query aggregation results {} for the previous kill.", 
aggregatedUsagePerActiveQuery);
     }
 
+    protected void cancelQuery(AggregatedStats queryResourceTracker) {
+      _cancelSentQueries.add(queryResourceTracker.getQueryId());
+      queryResourceTracker.getAnchorThread().interrupt();
+    }
+
     protected void logTerminatedQuery(QueryResourceTracker 
queryResourceTracker, long totalHeapMemoryUsage) {
       LOGGER.warn("Query {} terminated. Memory Usage: {}. Cpu Usage: {}. Total 
Heap Usage: {}",
           queryResourceTracker.getQueryId(), 
queryResourceTracker.getAllocatedBytes(),
@@ -619,7 +632,7 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
       protected int _sleepTime;
       protected int _numQueriesKilledConsecutively = 0;
       protected Map<String, AggregatedStats> _aggregatedUsagePerActiveQuery;
-      private TriggeringLevel _triggeringLevel;
+      protected TriggeringLevel _triggeringLevel;
 
       // metrics class
       private final AbstractMetrics _metrics;
@@ -715,46 +728,56 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
       @Override
       public void run() {
         while (true) {
-          QueryMonitorConfig config = _queryMonitorConfig.get();
-
-          LOGGER.debug("Running timed task for PerQueryCPUMemAccountant.");
-          _triggeringLevel = TriggeringLevel.Normal;
-          _sleepTime = config.getNormalSleepTime();
-          _aggregatedUsagePerActiveQuery = null;
           try {
-            // Get the metrics used for triggering the kill
-            collectTriggerMetrics();
-            // Prioritize the panic check, kill ALL QUERIES immediately if 
triggered
-            if (outOfMemoryPanicTrigger()) {
-              continue;
-            }
-            // Check for other triggers
-            evalTriggers();
-            // Refresh thread usage and aggregate to per query usage if 
triggered
-            _aggregatedUsagePerActiveQuery = 
aggregate(_triggeringLevel.ordinal() > TriggeringLevel.Normal.ordinal());
-            // post aggregation function
-            postAggregation(_aggregatedUsagePerActiveQuery);
-            // Act on one triggered actions
-            triggeredActions();
-          } catch (Exception e) {
-            LOGGER.error("Caught exception while executing stats aggregation 
and query kill", e);
+            runOnce();
           } finally {
-            LOGGER.debug(_aggregatedUsagePerActiveQuery == null ? 
"_aggregatedUsagePerActiveQuery : null"
-                : _aggregatedUsagePerActiveQuery.toString());
-            LOGGER.debug("_threadEntriesMap size: {}", 
_threadEntriesMap.size());
-
-            // Publish server heap usage metrics
-            if (config.isPublishHeapUsageMetric()) {
-              _metrics.setValueOfGlobalGauge(_memoryUsageGauge, _usedBytes);
-            }
-            // Clean inactive query stats
-            cleanInactive();
             // Sleep for sometime
             reschedule();
           }
         }
       }
 
+      public void runOnce() {
+        QueryMonitorConfig config = _queryMonitorConfig.get();
+
+        LOGGER.debug("Running timed task for PerQueryCPUMemAccountant.");
+        _triggeringLevel = TriggeringLevel.Normal;
+        _sleepTime = config.getNormalSleepTime();
+        _aggregatedUsagePerActiveQuery = null;
+        try {
+          // Get the metrics used for triggering the kill
+          collectTriggerMetrics();
+          // Prioritize the panic check, kill ALL QUERIES immediately if 
triggered
+          if (outOfMemoryPanicTrigger()) {
+            return;
+          }
+          // Check for other triggers
+          evalTriggers();
+          // Refresh thread usage and aggregate to per query usage if triggered
+          reapFinishedTasks();
+          if (_triggeringLevel.ordinal() > TriggeringLevel.Normal.ordinal()) {
+            _aggregatedUsagePerActiveQuery = getQueryResourcesImpl();
+          }
+          // post aggregation function
+          postAggregation(_aggregatedUsagePerActiveQuery);
+          // Act on one triggered actions
+          triggeredActions();
+        } catch (Exception e) {
+          LOGGER.error("Caught exception while executing stats aggregation and 
query kill", e);
+        } finally {
+          LOGGER.debug(_aggregatedUsagePerActiveQuery == null ? 
"_aggregatedUsagePerActiveQuery : null"
+              : _aggregatedUsagePerActiveQuery.toString());
+          LOGGER.debug("_threadEntriesMap size: {}", _threadEntriesMap.size());
+
+          // Publish server heap usage metrics
+          if (config.isPublishHeapUsageMetric()) {
+            _metrics.setValueOfGlobalGauge(_memoryUsageGauge, _usedBytes);
+          }
+          // Clean inactive query stats
+          cleanInactive();
+        }
+      }
+
       private void collectTriggerMetrics() {
         _usedBytes = MEMORY_MX_BEAN.getHeapMemoryUsage().getUsed();
         LOGGER.debug("Heap used bytes {}", _usedBytes);
@@ -774,8 +797,8 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
           _metrics.addMeteredGlobalValue(_heapMemoryPanicExceededMeter, 1);
           LOGGER.error("Heap used bytes {}, greater than _panicLevel {}, 
Killed all queries and triggered gc!",
               _usedBytes, panicLevel);
-          // call aggregate here as will throw exception and
-          aggregate(false);
+          // read finished tasks here as will throw exception and
+          reapFinishedTasks();
           return true;
         }
         return false;
@@ -785,7 +808,7 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
        * Evaluate triggering levels of query preemption
        * Triggers should be mutually exclusive and evaluated following level 
high -> low
        */
-      private void evalTriggers() {
+      protected void evalTriggers() {
         QueryMonitorConfig config = _queryMonitorConfig.get();
 
         if (config.isCpuTimeBasedKillingEnabled()) {
@@ -806,7 +829,7 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
       /**
        * Perform actions at specific triggering levels
        */
-      private void triggeredActions() {
+      protected void triggeredActions() {
         switch (_triggeringLevel) {
           case HeapMemoryCritical:
             LOGGER.warn("Heap used bytes {} exceeds critical level {}", 
_usedBytes,
@@ -868,8 +891,12 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
        * use XX:+ExplicitGCInvokesConcurrent to avoid a full gc when system.gc 
is triggered
        */
       private void killMostExpensiveQuery() {
+        if (!_isThreadMemorySamplingEnabled) {
+          LOGGER.warn("Unable to terminate queries as  memory tracking is not 
enabled");
+          return;
+        }
         QueryMonitorConfig config = _queryMonitorConfig.get();
-        if (!_aggregatedUsagePerActiveQuery.isEmpty()
+        if (_aggregatedUsagePerActiveQuery != null && 
!_aggregatedUsagePerActiveQuery.isEmpty()
             && _numQueriesKilledConsecutively >= config.getGcBackoffCount()) {
           _numQueriesKilledConsecutively = 0;
           System.gc();
@@ -884,37 +911,32 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
           LOGGER.error("After GC, heap used bytes {} still exceeds 
_criticalLevelAfterGC level {}", _usedBytes,
               config.getCriticalLevelAfterGC());
         }
-        if (!(_isThreadMemorySamplingEnabled || _isThreadCPUSamplingEnabled)) {
-          LOGGER.warn("But unable to kill query because neither memory nor cpu 
tracking is enabled");
-          return;
-        }
         // Critical heap memory usage while no queries running
-        if (_aggregatedUsagePerActiveQuery.isEmpty()) {
-          LOGGER.debug("No active queries to kill");
-          return;
-        }
-        AggregatedStats maxUsageTuple;
-        if (_isThreadMemorySamplingEnabled) {
-          maxUsageTuple = 
Collections.max(_aggregatedUsagePerActiveQuery.values(),
-              Comparator.comparing(AggregatedStats::getAllocatedBytes));
-          boolean shouldKill = config.isOomKillQueryEnabled()
-              && maxUsageTuple._allocatedBytes > 
config.getMinMemoryFootprintForKill();
-          if (shouldKill) {
-            maxUsageTuple._exceptionAtomicReference
-                .set(new RuntimeException(String.format(
-                    " Query %s got killed because using %d bytes of memory on 
%s: %s, exceeding the quota",
-                    maxUsageTuple._queryId, maxUsageTuple.getAllocatedBytes(), 
_instanceType, _instanceId)));
-            interruptRunnerThread(maxUsageTuple.getAnchorThread());
-            logTerminatedQuery(maxUsageTuple, _usedBytes);
-          } else if (!config.isOomKillQueryEnabled()) {
-            LOGGER.warn("Query {} got picked because using {} bytes of memory, 
actual kill committed false "
-                    + "because oomKillQueryEnabled is false",
-                maxUsageTuple._queryId, maxUsageTuple._allocatedBytes);
-          } else {
-            LOGGER.warn("But all queries are below quota, no query killed");
+        if (_aggregatedUsagePerActiveQuery != null && 
!_aggregatedUsagePerActiveQuery.isEmpty()) {
+          AggregatedStats maxUsageTuple;
+          maxUsageTuple = _aggregatedUsagePerActiveQuery.values().stream()
+              .filter(stats -> 
!_cancelSentQueries.contains(stats.getQueryId()))
+              
.max(Comparator.comparing(AggregatedStats::getAllocatedBytes)).orElse(null);
+          if (maxUsageTuple != null) {
+            boolean shouldKill =
+                config.isOomKillQueryEnabled() && 
maxUsageTuple._allocatedBytes > config.getMinMemoryFootprintForKill();
+            if (shouldKill) {
+              maxUsageTuple._exceptionAtomicReference.set(new RuntimeException(
+                  String.format(" Query %s got killed because using %d bytes 
of memory on %s: %s, exceeding the quota",
+                      maxUsageTuple._queryId, 
maxUsageTuple.getAllocatedBytes(), _instanceType, _instanceId)));
+              interruptRunnerThread(maxUsageTuple);
+              logTerminatedQuery(maxUsageTuple, _usedBytes);
+            } else if (!config.isOomKillQueryEnabled()) {
+              LOGGER.warn("Query {} got picked because using {} bytes of 
memory, actual kill committed false "
+                  + "because oomKillQueryEnabled is false", 
maxUsageTuple._queryId, maxUsageTuple._allocatedBytes);
+            } else {
+              LOGGER.warn("But all queries are below quota, no query killed");
+            }
           }
+          logQueryResourceUsage(_aggregatedUsagePerActiveQuery);
+        } else {
+          LOGGER.debug("No active queries to kill");
         }
-        logQueryResourceUsage(_aggregatedUsagePerActiveQuery);
       }
 
       private void killCPUTimeExceedQueries() {
@@ -930,15 +952,15 @@ public class PerQueryCPUMemAccountantFactory implements 
ThreadAccountantFactory
                 String.format("Query %s got killed on %s: %s because using %d "
                         + "CPU time exceeding limit of %d ns CPU time", 
value._queryId, _instanceType, _instanceId,
                     value.getCpuTimeNs(), 
config.getCpuTimeBasedKillingThresholdNS())));
-            interruptRunnerThread(value.getAnchorThread());
+            cancelQuery(value);
             logTerminatedQuery(value, _usedBytes);
           }
         }
         logQueryResourceUsage(_aggregatedUsagePerActiveQuery);
       }
 
-      private void interruptRunnerThread(Thread thread) {
-        thread.interrupt();
+      private void interruptRunnerThread(AggregatedStats queryResourceTracker) 
{
+        cancelQuery(queryResourceTracker);
         if (_queryMonitorConfig.get().isQueryKilledMetricEnabled()) {
           _metrics.addMeteredGlobalValue(_queryKilledMeter, 1);
         }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountCancelTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountCancelTest.java
new file mode 100644
index 0000000000..3abd3bf6af
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountCancelTest.java
@@ -0,0 +1,189 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.accounting;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import org.apache.pinot.spi.accounting.QueryResourceTracker;
+import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.utils.CommonConstants;
+import org.apache.pinot.util.TestUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class PerQueryCPUMemAccountCancelTest {
+  static class AlwaysTerminateMostExpensiveQueryAccountant extends 
TestResourceAccountant {
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(AlwaysTerminateMostExpensiveQueryAccountant.class);
+    private final List<String> _cancelLog = new ArrayList<>();
+
+    AlwaysTerminateMostExpensiveQueryAccountant(
+        Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
threadEntries) {
+      super(threadEntries);
+    }
+
+    @Override
+    public WatcherTask createWatcherTask() {
+      return new TerminatingWatcherTask();
+    }
+
+    @Override
+    public void cancelQuery(AggregatedStats queryResourceTracker) {
+      _cancelSentQueries.add(queryResourceTracker.getQueryId());
+      _cancelLog.add(queryResourceTracker.getQueryId());
+    }
+
+    public List<String> getCancelLog() {
+      return _cancelLog;
+    }
+
+    class TerminatingWatcherTask extends WatcherTask {
+      TerminatingWatcherTask() {
+        PinotConfiguration config = new PinotConfiguration();
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_MIN_MEMORY_FOOTPRINT_TO_KILL_RATIO,
 0.01);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_PANIC_LEVEL_HEAP_USAGE_RATIO,
+            CommonConstants.Accounting.DFAULT_PANIC_LEVEL_HEAP_USAGE_RATIO);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_CRITICAL_LEVEL_HEAP_USAGE_RATIO,
+            
CommonConstants.Accounting.DEFAULT_CRITICAL_LEVEL_HEAP_USAGE_RATIO);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_CRITICAL_LEVEL_HEAP_USAGE_RATIO_DELTA_AFTER_GC,
+            
CommonConstants.Accounting.DEFAULT_CONFIG_OF_CRITICAL_LEVEL_HEAP_USAGE_RATIO_DELTA_AFTER_GC);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_GC_BACKOFF_COUNT,
+            CommonConstants.Accounting.DEFAULT_GC_BACKOFF_COUNT);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_ALARMING_LEVEL_HEAP_USAGE_RATIO,
+            
CommonConstants.Accounting.DEFAULT_ALARMING_LEVEL_HEAP_USAGE_RATIO);
+
+        config.setProperty(CommonConstants.Accounting.CONFIG_OF_SLEEP_TIME_MS,
+            CommonConstants.Accounting.DEFAULT_SLEEP_TIME_MS);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_GC_WAIT_TIME_MS,
+            CommonConstants.Accounting.DEFAULT_CONFIG_OF_GC_WAIT_TIME_MS);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_SLEEP_TIME_DENOMINATOR,
+            CommonConstants.Accounting.DEFAULT_SLEEP_TIME_DENOMINATOR);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_OOM_PROTECTION_KILLING_QUERY,
 true);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_PUBLISHING_JVM_USAGE,
+            CommonConstants.Accounting.DEFAULT_PUBLISHING_JVM_USAGE);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_CPU_TIME_BASED_KILLING_ENABLED,
+            CommonConstants.Accounting.DEFAULT_CPU_TIME_BASED_KILLING_ENABLED);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_CPU_TIME_BASED_KILLING_THRESHOLD_MS,
+            
CommonConstants.Accounting.DEFAULT_CPU_TIME_BASED_KILLING_THRESHOLD_MS);
+
+        
config.setProperty(CommonConstants.Accounting.CONFIG_OF_QUERY_KILLED_METRIC_ENABLED,
+            CommonConstants.Accounting.DEFAULT_QUERY_KILLED_METRIC_ENABLED);
+
+        QueryMonitorConfig queryMonitorConfig = new QueryMonitorConfig(config, 
1000);
+        _queryMonitorConfig.set(queryMonitorConfig);
+      }
+
+      @Override
+      public void runOnce() {
+        _aggregatedUsagePerActiveQuery = null;
+        try {
+          evalTriggers();
+          reapFinishedTasks();
+          _aggregatedUsagePerActiveQuery = getQueryResourcesImpl();
+          triggeredActions();
+        } catch (Exception e) {
+          LOGGER.error("Caught exception while executing stats aggregation and 
query kill", e);
+        } finally {
+          // Clean inactive query stats
+          cleanInactive();
+        }
+      }
+
+      @Override
+      public void evalTriggers() {
+        _triggeringLevel = TriggeringLevel.HeapMemoryCritical;
+      }
+    }
+  }
+
+  @Test
+  void testCancelSingleQuery() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregation";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+
+    AlwaysTerminateMostExpensiveQueryAccountant accountant =
+        new AlwaysTerminateMostExpensiveQueryAccountant(threadEntries);
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 5500);
+
+    // Cancel a query.
+    accountant.getWatcherTask().runOnce();
+    assertEquals(accountant.getCancelLog().size(), 1);
+
+    // Try once more. There should still be only one cancel.
+    accountant.getWatcherTask().runOnce();
+    assertEquals(accountant.getCancelLog().size(), 1);
+    threadLatch.countDown();
+    TestUtils.waitForCondition(aVoid -> {
+      accountant.reapFinishedTasks();
+      return accountant.getCancelSentQueries().isEmpty();
+    }, 100L, 1000L, "CancelSentList was not cleared");
+  }
+
+  @Test
+  void testCancelTwoQuery() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryOne";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+    String queryId2 = "testQueryTwo";
+    TestResourceAccountant.getQueryThreadEntries(queryId2, threadLatch, 
threadEntries);
+
+    AlwaysTerminateMostExpensiveQueryAccountant accountant =
+        new AlwaysTerminateMostExpensiveQueryAccountant(threadEntries);
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 2);
+    assertEquals(queryResourceTrackerMap.get(queryId).getAllocatedBytes(), 
5500);
+    assertEquals(queryResourceTrackerMap.get(queryId2).getAllocatedBytes(), 
5500);
+
+    // Cancel a query.
+    accountant.getWatcherTask().runOnce();
+    assertEquals(accountant.getCancelLog().size(), 1);
+
+    accountant.getWatcherTask().runOnce();
+    assertEquals(accountant.getCancelLog().size(), 2);
+    threadLatch.countDown();
+    TestUtils.waitForCondition(aVoid -> {
+      accountant.reapFinishedTasks();
+      return accountant.getCancelSentQueries().isEmpty();
+    }, 100L, 1000L, "CancelSentList was not cleared");
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantTest.java
new file mode 100644
index 0000000000..5cefb3fc37
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantTest.java
@@ -0,0 +1,268 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.accounting;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import org.apache.pinot.spi.accounting.QueryResourceTracker;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.utils.CommonConstants;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+public class PerQueryCPUMemAccountantTest {
+
+  @Test
+  void testQueryAggregation() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregation";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 5500);
+    threadLatch.countDown();
+  }
+
+  /*
+   * 
@link{PerQueryCPUMemAccountantFactory.PerQueryCPUMemResourceUsageAccountant#reapFinishedTask}
 stores the previous
+   * task's status. If it is not called, then the current task info is lost.
+   */
+  @Test
+  void testQueryAggregationCreateNewTask() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregationCreateNewTask";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+
+    TestResourceAccountant.TaskThread anchorThread =
+        accountant.getTaskThread(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID);
+    assertNotNull(anchorThread);
+
+    // Replace task id = 3 (2500 bytes) with a new task id 5 (1500 bytes)
+    TestResourceAccountant.TaskThread workerEntry = 
accountant.getTaskThread(queryId, 3);
+    assertNotNull(workerEntry);
+
+    // New Task
+    CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
workerEntry._threadEntry;
+    threadEntry._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(queryId, 5, 
ThreadExecutionContext.TaskType.SSE,
+            anchorThread._workerThread, 
CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    threadEntry._currentThreadMemoryAllocationSampleBytes = 1500;
+
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 4500);
+    threadLatch.countDown();
+  }
+
+  @Test
+  void testQueryAggregationSetToIdle() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregationSetToIdle";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+
+    TestResourceAccountant.TaskThread anchorThread =
+        accountant.getTaskThread(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID);
+    assertNotNull(anchorThread);
+
+    // Replace task id = 3 (2500 bytes) with a new task id 5 (1500 bytes)
+    TestResourceAccountant.TaskThread workerEntry = 
accountant.getTaskThread(queryId, 3);
+    assertNotNull(workerEntry);
+
+    // New Task
+    CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
workerEntry._threadEntry;
+    threadEntry._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(queryId, 5, 
ThreadExecutionContext.TaskType.SSE,
+            anchorThread._workerThread, 
CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    threadEntry.setToIdle();
+
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 3000);
+    threadLatch.countDown();
+  }
+
+  /*
+   * 
@link{PerQueryCPUMemAccountantFactory.PerQueryCPUMemResourceUsageAccountant#reapFinishedTask}
 stores the previous
+   * task's status. If it is called, then the resources of finished tasks 
should also be provided.
+   */
+  @Test
+  void testQueryAggregationReapAndCreateNewTask() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregationReapAndCreateNewTask";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+    accountant.reapFinishedTasks();
+
+    TestResourceAccountant.TaskThread anchorThread =
+        accountant.getTaskThread(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID);
+    assertNotNull(anchorThread);
+
+    // Replace task id = 3 (2500 bytes) with a new task id 5 (1500 bytes)
+    TestResourceAccountant.TaskThread workerEntry = 
accountant.getTaskThread(queryId, 3);
+    assertNotNull(workerEntry);
+
+    // New Task
+    CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
workerEntry._threadEntry;
+    threadEntry._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(queryId, 5, 
ThreadExecutionContext.TaskType.SSE,
+            anchorThread._workerThread, 
CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    threadEntry._currentThreadMemoryAllocationSampleBytes = 1500;
+
+    accountant.reapFinishedTasks();
+
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 7000);
+    threadLatch.countDown();
+  }
+
+  @Test
+  void testQueryAggregationReapAndSetToIdle() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregationReapAndSetToIdle";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+    accountant.reapFinishedTasks();
+
+    TestResourceAccountant.TaskThread anchorThread =
+        accountant.getTaskThread(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID);
+    assertNotNull(anchorThread);
+
+    // Replace task id = 3 (2500 bytes) with null
+    TestResourceAccountant.TaskThread workerEntry = 
accountant.getTaskThread(queryId, 3);
+    assertNotNull(workerEntry);
+
+    // Set to Idle
+    CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
workerEntry._threadEntry;
+    threadEntry.setToIdle();
+
+    accountant.reapFinishedTasks();
+
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 5500);
+    threadLatch.countDown();
+  }
+
+  @Test
+  void testInActiveQuerySet() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregation";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 1);
+    assertTrue(accountant.getInactiveQueries().isEmpty());
+    accountant.reapFinishedTasks();
+
+    // Pick up a new task. This will add entries to _finishedMemAggregator
+    TestResourceAccountant.TaskThread anchorThread =
+        accountant.getTaskThread(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID);
+    assertNotNull(anchorThread);
+
+    // Replace task id = 3 (2500 bytes) with a new task id 5 (1500 bytes)
+    TestResourceAccountant.TaskThread workerEntry = 
accountant.getTaskThread(queryId, 3);
+    assertNotNull(workerEntry);
+
+    // New Task
+    CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
workerEntry._threadEntry;
+    threadEntry._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(queryId, 5, 
ThreadExecutionContext.TaskType.SSE,
+            anchorThread._workerThread, 
CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    threadEntry._currentThreadMemoryAllocationSampleBytes = 1500;
+
+    accountant.reapFinishedTasks();
+
+    // A call to cleanInactiveQueries surprisingly adds the query id to the 
set.
+    accountant.cleanInactive();
+    assertEquals(accountant.getInactiveQueries().size(), 1);
+    assertTrue(accountant.getInactiveQueries().contains(queryId));
+    // A call to reapFinishedTasks should remove the query id from the 
inactive queries set.
+    accountant.reapFinishedTasks();
+    assertTrue(accountant.getInactiveQueries().isEmpty());
+    threadLatch.countDown();
+  }
+
+  @Test
+  void testQueryAggregationAddNewQueryTask() {
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries 
= new HashMap<>();
+    CountDownLatch threadLatch = new CountDownLatch(1);
+    String queryId = "testQueryAggregationAddNewQueryTask";
+    TestResourceAccountant.getQueryThreadEntries(queryId, threadLatch, 
threadEntries);
+    TestResourceAccountant accountant = new 
TestResourceAccountant(threadEntries);
+    accountant.reapFinishedTasks();
+
+    // Start a new query.
+    CountDownLatch newQueryThreadLatch = new CountDownLatch(1);
+    String newQueryId = "newQuery";
+    Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
newQueryThreadEntries = new HashMap<>();
+    TestResourceAccountant.getQueryThreadEntries(newQueryId, 
newQueryThreadLatch, newQueryThreadEntries);
+    for (Map.Entry<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
entry : newQueryThreadEntries.entrySet()) {
+      accountant.addThreadEntry(entry.getKey(), entry.getValue());
+    }
+
+    // Create a new task for newQuery
+    TestResourceAccountant.TaskThread anchorThread =
+        accountant.getTaskThread(newQueryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID);
+    assertNotNull(anchorThread);
+
+    // Replace task id = 3 (2500 bytes) of first query with a new task id 5 of 
new query (3500 bytes)
+    TestResourceAccountant.TaskThread workerEntry = 
accountant.getTaskThread(queryId, 3);
+    assertNotNull(workerEntry);
+
+    // New Task
+    CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
workerEntry._threadEntry;
+    threadEntry._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(newQueryId, 5, 
ThreadExecutionContext.TaskType.SSE,
+            anchorThread._workerThread, 
CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    threadEntry._currentThreadMemoryAllocationSampleBytes = 3500;
+
+    accountant.reapFinishedTasks();
+
+    Map<String, ? extends QueryResourceTracker> queryResourceTrackerMap = 
accountant.getQueryResources();
+    assertEquals(queryResourceTrackerMap.size(), 2);
+    QueryResourceTracker queryResourceTracker = 
queryResourceTrackerMap.get(queryId);
+    assertEquals(queryResourceTracker.getAllocatedBytes(), 5500);
+    QueryResourceTracker newQueryResourceTracker = 
queryResourceTrackerMap.get(newQueryId);
+    assertEquals(newQueryResourceTracker.getAllocatedBytes(), 9000);
+    threadLatch.countDown();
+    newQueryThreadLatch.countDown();
+  }
+}
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactoryTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/QueryMonitorConfigTest.java
similarity index 99%
rename from 
pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactoryTest.java
rename to 
pinot-core/src/test/java/org/apache/pinot/core/accounting/QueryMonitorConfigTest.java
index 8dc68670c6..f9b66a7220 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/accounting/PerQueryCPUMemAccountantFactoryTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/QueryMonitorConfigTest.java
@@ -32,7 +32,7 @@ import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertTrue;
 
 
-public class PerQueryCPUMemAccountantFactoryTest {
+public class QueryMonitorConfigTest {
   private static final double EXPECTED_MIN_MEMORY_FOOTPRINT_FOR_KILL = 0.05;
   private static final double EXPECTED_PANIC_LEVEL = 0.9f;
   private static final double EXPECTED_CRITICAL_LEVEL = 0.95f;
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/accounting/TestResourceAccountant.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/TestResourceAccountant.java
new file mode 100644
index 0000000000..ba49824c45
--- /dev/null
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/accounting/TestResourceAccountant.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.accounting;
+
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.CountDownLatch;
+import java.util.stream.Collectors;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.config.instance.InstanceType;
+import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.utils.CommonConstants;
+
+
+class TestResourceAccountant extends 
PerQueryCPUMemAccountantFactory.PerQueryCPUMemResourceUsageAccountant {
+  TestResourceAccountant(Map<Thread, 
CPUMemThreadLevelAccountingObjects.ThreadEntry> threadEntries) {
+    super(new PinotConfiguration(), false, true, true, new HashSet<>(), 
"test", InstanceType.SERVER);
+    _threadEntriesMap.putAll(threadEntries);
+  }
+
+  static void getQueryThreadEntries(String queryId, CountDownLatch threadLatch,
+      Map<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
threadEntries) {
+    TaskThread
+        anchorThread = getTaskThread(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID, threadLatch, null);
+    threadEntries.put(anchorThread._workerThread, anchorThread._threadEntry);
+    anchorThread._threadEntry._currentThreadMemoryAllocationSampleBytes = 1000;
+
+    CPUMemThreadLevelAccountingObjects.ThreadEntry anchorEntry = new 
CPUMemThreadLevelAccountingObjects.ThreadEntry();
+    anchorEntry._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID,
+            ThreadExecutionContext.TaskType.SSE, anchorThread._workerThread,
+            CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    anchorEntry._currentThreadMemoryAllocationSampleBytes = 1000;
+    threadEntries.put(anchorThread._workerThread, anchorEntry);
+
+    TaskThread taskThread2 = getTaskThread(queryId, 2, threadLatch, 
anchorThread._workerThread);
+    threadEntries.put(taskThread2._workerThread, taskThread2._threadEntry);
+    taskThread2._threadEntry._currentThreadMemoryAllocationSampleBytes = 2000;
+
+    TaskThread taskThread3 = getTaskThread(queryId, 3, threadLatch, 
anchorThread._workerThread);
+    threadEntries.put(taskThread3._workerThread, taskThread3._threadEntry);
+    taskThread3._threadEntry._currentThreadMemoryAllocationSampleBytes = 2500;
+  }
+
+  private static TaskThread getTaskThread(String queryId, int taskId, 
CountDownLatch threadLatch, Thread anchorThread) {
+    CPUMemThreadLevelAccountingObjects.ThreadEntry worker1 = new 
CPUMemThreadLevelAccountingObjects.ThreadEntry();
+    worker1._currentThreadTaskStatus.set(
+        new CPUMemThreadLevelAccountingObjects.TaskEntry(queryId, taskId, 
ThreadExecutionContext.TaskType.SSE,
+            anchorThread, CommonConstants.Accounting.DEFAULT_WORKLOAD_NAME));
+    Thread workerThread1 = new Thread(() -> {
+      try {
+        threadLatch.await();
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+      }
+    });
+    workerThread1.start();
+    return new TaskThread(worker1, workerThread1);
+  }
+
+  public TaskThread getTaskThread(String queryId, int taskId) {
+    Map.Entry<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
workerEntry =
+        _threadEntriesMap.entrySet().stream().filter(
+            e -> e.getValue()._currentThreadTaskStatus.get().getTaskId() == 3 
&& Objects.equals(
+                e.getValue()._currentThreadTaskStatus.get().getQueryId(), 
queryId)).collect(Collectors.toList()).get(0);
+    return new TaskThread(workerEntry.getValue(), workerEntry.getKey());
+  }
+
+  public static class TaskThread {
+    public final CPUMemThreadLevelAccountingObjects.ThreadEntry _threadEntry;
+    public final Thread _workerThread;
+
+    public TaskThread(CPUMemThreadLevelAccountingObjects.ThreadEntry 
threadEntry, Thread workerThread) {
+      _threadEntry = threadEntry;
+      _workerThread = workerThread;
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to