villebro commented on code in PR #36368: URL: https://github.com/apache/superset/pull/36368#discussion_r2777729527
########## superset/tasks/manager.py: ########## @@ -0,0 +1,762 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task manager for the Global Task Framework (GTF)""" + +from __future__ import annotations + +import logging +import threading +import time +from typing import Any, Callable, TYPE_CHECKING + +import redis +from superset_core.api.tasks import TaskProperties, TaskScope + +from superset.async_events.cache_backend import ( + RedisCacheBackend, + RedisSentinelCacheBackend, +) +from superset.extensions import cache_manager +from superset.tasks.constants import ABORT_STATES, TERMINAL_STATES +from superset.tasks.utils import generate_random_task_key + +if TYPE_CHECKING: + from flask import Flask + + from superset.models.tasks import Task + +logger = logging.getLogger(__name__) + + +class AbortListener: + """ + Handle for a background abort listener. + + Returned by TaskManager.listen_for_abort() to allow stopping the listener. + """ + + def __init__( + self, + task_uuid: str, + thread: threading.Thread, + stop_event: threading.Event, + pubsub: redis.client.PubSub | None = None, + ) -> None: + self._task_uuid = task_uuid + self._thread = thread + self._stop_event = stop_event + self._pubsub = pubsub + + def stop(self) -> None: + """Stop the abort listener.""" + self._stop_event.set() + + # Close pub/sub subscription if active + if self._pubsub is not None: + try: + self._pubsub.unsubscribe() + self._pubsub.close() + except Exception as ex: + logger.debug("Error closing pub/sub during stop: %s", ex) + + # Wait for thread to finish (with timeout to avoid blocking indefinitely) + if self._thread.is_alive(): + self._thread.join(timeout=2.0) + + # Check if thread is still running after timeout + if self._thread.is_alive(): + # Thread is a daemon, so it will be killed when process exits. + # Log warning but continue - cleanup will still proceed. + logger.warning( + "Abort listener thread for task %s did not terminate within " + "2 seconds. Thread will be terminated when process exits.", + self._task_uuid, + ) + else: + logger.debug("Stopped abort listener for task %s", self._task_uuid) + else: + logger.debug("Stopped abort listener for task %s", self._task_uuid) + + +class TaskManager: + """ + Handles task creation, scheduling, and abort notifications. + + The TaskManager is responsible for: + 1. Creating task entries in the metastore (Task model) + 2. Scheduling task execution via Celery + 3. Handling deduplication (returning existing active task if duplicate) + 4. Managing real-time abort notifications (optional) + + Redis pub/sub is opt-in via SIGNAL_CACHE_CONFIG configuration. When not + configured, tasks use database polling for abort detection. + """ + + # Class-level state (initialized once via init_app) + _channel_prefix: str = "gtf:abort:" + _completion_channel_prefix: str = "gtf:complete:" + _initialized: bool = False + + # Backward compatibility alias - prefer importing from superset.tasks.constants + TERMINAL_STATES = TERMINAL_STATES + + @classmethod + def init_app(cls, app: Flask) -> None: + """ + Initialize the TaskManager with Flask app config. + + Redis connection is managed by CacheManager - this just reads channel prefixes. + + :param app: Flask application instance + """ + if cls._initialized: + return + + cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", "gtf:abort:") + cls._completion_channel_prefix = app.config.get( + "TASKS_COMPLETION_CHANNEL_PREFIX", "gtf:complete:" + ) + + cls._initialized = True + + @classmethod + def _get_cache(cls) -> RedisCacheBackend | RedisSentinelCacheBackend | None: + """ + Get the signal cache backend. + + :returns: The signal cache backend, or None if not configured + """ + return cache_manager.signal_cache + + @classmethod + def is_pubsub_available(cls) -> bool: + """ + Check if Redis pub/sub backend is configured and available. + + :returns: True if Redis is available for pub/sub, False otherwise + """ + return cls._get_cache() is not None + + @classmethod + def get_abort_channel(cls, task_uuid: str) -> str: + """ + Get the abort channel name for a task. + + :param task_uuid: UUID of the task + :returns: Channel name for the task's abort notifications + """ + return f"{cls._channel_prefix}{task_uuid}" + + @classmethod + def publish_abort(cls, task_uuid: str) -> bool: + """ + Publish an abort message to the task's channel. + + :param task_uuid: UUID of the task to abort + :returns: True if message was published, False if Redis unavailable + """ + cache = cls._get_cache() + if not cache: + return False + + try: + channel = cls.get_abort_channel(task_uuid) + subscriber_count = cache.publish(channel, "abort") + logger.debug( + "Published abort to channel %s (%d subscribers)", + channel, + subscriber_count, + ) + return True + except redis.RedisError as ex: + logger.error("Failed to publish abort for task %s: %s", task_uuid, ex) + return False + + @classmethod + def get_completion_channel(cls, task_uuid: str) -> str: + """ + Get the completion channel name for a task. + + :param task_uuid: UUID of the task + :returns: Channel name for the task's completion notifications + """ + return f"{cls._completion_channel_prefix}{task_uuid}" + + @classmethod + def publish_completion(cls, task_uuid: str, status: str) -> bool: + """ + Publish a completion message to the task's channel. + + Called when task reaches terminal state (SUCCESS, FAILURE, ABORTED, TIMED_OUT). + This notifies any waiters (e.g., sync callers waiting for an existing task). + + :param task_uuid: UUID of the completed task + :param status: Final status of the task + :returns: True if message was published, False if Redis unavailable + """ + cache = cls._get_cache() + if not cache: + return False + + try: + channel = cls.get_completion_channel(task_uuid) + subscriber_count = cache.publish(channel, status) + logger.debug( + "Published completion to channel %s (status=%s, %d subscribers)", + channel, + status, + subscriber_count, + ) + return True + except redis.RedisError as ex: + logger.error("Failed to publish completion for task %s: %s", task_uuid, ex) + return False + + @classmethod + def wait_for_completion( + cls, + task_uuid: str, + timeout: float | None = None, + poll_interval: float = 1.0, + app: Any = None, + ) -> "Task": + """ + Block until task reaches terminal state. + + Uses Redis pub/sub if configured for low-latency, low-CPU waiting. + Uses database polling if Redis is not configured. + + :param task_uuid: UUID of the task to wait for + :param timeout: Maximum time to wait in seconds (None = no limit) + :param poll_interval: Interval for database polling (seconds) + :param app: Flask app for database access + :returns: Task in terminal state + :raises TimeoutError: If timeout expires before task completes + :raises ValueError: If task not found + """ + from superset.daos.tasks import TaskDAO + + start_time = time.monotonic() + + def time_remaining() -> float | None: + if timeout is None: + return None + elapsed = time.monotonic() - start_time + remaining = timeout - elapsed + return remaining if remaining > 0 else 0 + + def get_task() -> "Task | None": + if app: + with app.app_context(): + return TaskDAO.find_one_or_none(uuid=task_uuid) + return TaskDAO.find_one_or_none(uuid=task_uuid) + + # Check current state first + task = get_task() + if not task: + raise ValueError(f"Task {task_uuid} not found") + + if task.status in cls.TERMINAL_STATES: + return task + + logger.info( Review Comment: I changed the abort listener/poller logs to `debug`, but I kept the others, as they seem relevant for debugging purposes -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
