This is an automated email from the ASF dual-hosted git repository.
turaga pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e0b55c21e27 Centralized runtime control of Edge Worker concurrency in
distributed deployments (#62896)
e0b55c21e27 is described below
commit e0b55c21e27edd0e5a5b7b90f72be5373d01da90
Author: Dheeraj Turaga <[email protected]>
AuthorDate: Sat Mar 7 12:32:00 2026 -0600
Centralized runtime control of Edge Worker concurrency in distributed
deployments (#62896)
---
.../src/airflow/providers/edge3/cli/definition.py | 15 ++++
.../airflow/providers/edge3/cli/edge_command.py | 24 ++++++
.../src/airflow/providers/edge3/cli/worker.py | 7 ++
.../0002_3_2_0_add_concurrency_to_edge_worker.py | 49 ++++++++++++
.../edge3/src/airflow/providers/edge3/models/db.py | 28 +++++++
.../airflow/providers/edge3/models/edge_worker.py | 21 +++++
.../providers/edge3/worker_api/datamodels.py | 7 ++
.../providers/edge3/worker_api/routes/worker.py | 5 +-
.../edge3/worker_api/v2-edge-generated.yaml | 7 ++
.../edge3/tests/unit/edge3/cli/test_definition.py | 18 ++++-
.../edge3/tests/unit/edge3/cli/test_worker.py | 27 +++++++
providers/edge3/tests/unit/edge3/models/test_db.py | 75 +++++++++++++++++-
.../unit/edge3/worker_api/routes/test_worker.py | 90 +++++++++++++++++++++-
13 files changed, 367 insertions(+), 6 deletions(-)
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/definition.py
b/providers/edge3/src/airflow/providers/edge3/cli/definition.py
index 5cd611cc3a4..a1819e9a9a0 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/definition.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/definition.py
@@ -59,6 +59,12 @@ ARG_QUEUES_MANAGE = Arg(
help="Comma delimited list of queues to add or remove.",
required=True,
)
+ARG_CONCURRENCY_REQUIRED = Arg(
+ ("-c", "--concurrency"),
+ type=int,
+ help="The number of worker processes. Must be a positive integer.",
+ required=True,
+)
ARG_WAIT_MAINT = Arg(
("-w", "--wait"),
default=False,
@@ -229,6 +235,15 @@ EDGE_COMMANDS: list[ActionCommand] = [
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"),
args=(ARG_YES,),
),
+ ActionCommand(
+ name="set-worker-concurrency",
+ help="Set the concurrency of a remote edge worker.",
+
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.set_remote_worker_concurrency"),
+ args=(
+ ARG_REQUIRED_EDGE_HOSTNAME,
+ ARG_CONCURRENCY_REQUIRED,
+ ),
+ ),
]
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
index 864d6851ace..c18fbdc7f90 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
@@ -427,3 +427,27 @@ def remove_worker_queues(args) -> None:
except TypeError as e:
logger.error(str(e))
raise SystemExit
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def set_remote_worker_concurrency(args) -> None:
+ """Set the concurrency of a remote edge worker."""
+ _check_valid_db_connection()
+ _check_if_registered_edge_host(hostname=args.edge_hostname)
+ from airflow.providers.edge3.models.edge_worker import
set_worker_concurrency
+
+ if args.concurrency <= 0:
+ raise SystemExit("Error: Concurrency must be a positive integer.")
+
+ try:
+ set_worker_concurrency(args.edge_hostname, args.concurrency)
+ logger.info(
+ "Concurrency set to %d for Edge Worker host %s by %s.",
+ args.concurrency,
+ args.edge_hostname,
+ getuser(),
+ )
+ except TypeError as e:
+ logger.error(str(e))
+ raise SystemExit
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
index c4aa1d735c5..6ac43e388f8 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
@@ -401,6 +401,13 @@ class EdgeWorker:
new_maintenance_comments,
)
self.queues = worker_info.queues
+ if worker_info.concurrency is not None and worker_info.concurrency
!= self.concurrency:
+ logger.info(
+ "Concurrency updated from %d to %d by remote request.",
+ self.concurrency,
+ worker_info.concurrency,
+ )
+ self.concurrency = worker_info.concurrency
if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST:
logger.info("Maintenance mode requested!")
self.maintenance_mode = True
diff --git
a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py
b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py
new file mode 100644
index 00000000000..edf6cc7f39c
--- /dev/null
+++
b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Add concurrency column to edge_worker table.
+
+Revision ID: b3c4d5e6f7a8
+Revises: 9d34dfc2de06
+Create Date: 2026-03-04 00:00:00.000000
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "b3c4d5e6f7a8"
+down_revision = "9d34dfc2de06"
+branch_labels = None
+depends_on = None
+edge3_version = "3.2.0"
+
+
+def upgrade() -> None:
+ bind = op.get_bind()
+ inspector = sa.inspect(bind)
+ existing_columns = {col["name"] for col in
inspector.get_columns("edge_worker")}
+ if "concurrency" not in existing_columns:
+ op.add_column("edge_worker", sa.Column("concurrency", sa.Integer(),
nullable=True))
+
+
+def downgrade() -> None:
+ op.drop_column("edge_worker", "concurrency")
diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py
b/providers/edge3/src/airflow/providers/edge3/models/db.py
index 1c98cb40235..e98b3cc24b9 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/db.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/db.py
@@ -31,6 +31,7 @@ PACKAGE_DIR = Path(__file__).parents[1]
_REVISION_HEADS_MAP: dict[str, str] = {
"3.0.0": "9d34dfc2de06",
+ "3.2.0": "b3c4d5e6f7a8",
}
@@ -45,6 +46,33 @@ class EdgeDBManager(BaseDBManager):
supports_table_dropping = True
revision_heads_map = _REVISION_HEADS_MAP
+ def initdb(self):
+ """
+ Initialize the database, handling pre-alembic installations.
+
+ If the edge3 tables already exist but the alembic version table does
not
+ (e.g. created via create_all before the migration chain was
introduced),
+ stamp to the first revision and run the incremental upgrade so every
+ migration is applied rather than jumping straight to head.
+ """
+ db_exists = self.get_current_revision()
+ if db_exists:
+ self.upgradedb()
+ else:
+ from airflow import settings
+
+ existing_tables = set(inspect(settings.engine).get_table_names())
+ if any(table in existing_tables for table in self.metadata.tables):
+ script = self.get_script_object()
+ base_revision = next(r.revision for r in
script.walk_revisions() if r.down_revision is None)
+ config = self.get_alembic_config()
+ from alembic import command
+
+ command.stamp(config, base_revision)
+ self.upgradedb()
+ else:
+ self.create_db_from_orm()
+
def drop_tables(self, connection):
"""Drop only edge3 tables in reverse dependency order."""
if not self.supports_table_dropping:
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
index e4b0698e7e0..5b037f2903c 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
@@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
jobs_success: Mapped[int] = mapped_column(Integer, default=0)
jobs_failed: Mapped[int] = mapped_column(Integer, default=0)
sysinfo: Mapped[str | None] = mapped_column(String(256))
+ concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True)
def __init__(
self,
@@ -392,3 +393,23 @@ def remove_worker_queues(worker_name: str, queues:
list[str], session: Session =
logger.error(error_message)
raise TypeError(error_message)
worker.remove_queues(queues)
+
+
+@provide_session
+def set_worker_concurrency(worker_name: str, concurrency: int, session:
Session = NEW_SESSION) -> None:
+ """Set the concurrency of an edge worker."""
+ query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ worker: EdgeWorkerModel | None = session.scalar(query)
+ if not worker:
+ raise ValueError(f"Edge Worker {worker_name} not found in list of
registered workers")
+ if worker.state in (
+ EdgeWorkerState.OFFLINE,
+ EdgeWorkerState.OFFLINE_MAINTENANCE,
+ EdgeWorkerState.UNKNOWN,
+ ):
+ error_message = (
+ f"Cannot set concurrency for edge worker {worker_name} as it is in
{worker.state} state!"
+ )
+ logger.error(error_message)
+ raise TypeError(error_message)
+ worker.concurrency = concurrency
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
index c25ff53a934..fc780b87662 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
@@ -200,3 +200,10 @@ class WorkerSetStateReturn(BaseModel):
str | None,
Field(description="Comments about the maintenance state of the
worker."),
] = None
+ concurrency: Annotated[
+ int | None,
+ Field(
+ description="Desired concurrency for the worker set by an
administrator. "
+ "None means no remote override; the worker uses its startup
value.",
+ ),
+ ] = None
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
index 8e3dce56e15..34368c98c4a 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
@@ -238,7 +238,10 @@ def set_state(
)
_assert_version(body.sysinfo) # Exception only after worker state is in
the DB
return WorkerSetStateReturn(
- state=worker.state, queues=worker.queues,
maintenance_comments=worker.maintenance_comment
+ state=worker.state,
+ queues=worker.queues,
+ maintenance_comments=worker.maintenance_comment,
+ concurrency=worker.concurrency,
)
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 94feded9ff9..a5daeacb86f 100644
---
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -1413,6 +1413,13 @@ components:
- type: 'null'
title: Maintenance Comments
description: Comments about the maintenance state of the worker.
+ concurrency:
+ anyOf:
+ - type: integer
+ - type: 'null'
+ title: Concurrency
+ description: Desired concurrency for the worker set by an
administrator.
+ None means no remote override; the worker uses its startup value.
type: object
required:
- state
diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py
b/providers/edge3/tests/unit/edge3/cli/test_definition.py
index a99bbc7565e..0c1dac8e6d3 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_definition.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py
@@ -53,8 +53,8 @@ class TestEdgeCliDefinition:
assert len(commands) == 1
def test_edge_commands_count(self):
- """Test that EDGE_COMMANDS contains all 13 subcommands."""
- assert len(EDGE_COMMANDS) == 13
+ """Test that EDGE_COMMANDS contains all 14 subcommands."""
+ assert len(EDGE_COMMANDS) == 14
@pytest.mark.parametrize(
"command",
@@ -234,3 +234,17 @@ class TestEdgeCliDefinition:
params = ["edge", "shutdown-all-workers", "--yes"]
args = self.arg_parser.parse_args(params)
assert args.yes is True
+
+ def test_set_worker_concurrency_args(self):
+ """Test set-worker-concurrency command with required arguments."""
+ params = [
+ "edge",
+ "set-worker-concurrency",
+ "--edge-hostname",
+ "remote-worker-1",
+ "--concurrency",
+ "16",
+ ]
+ args = self.arg_parser.parse_args(params)
+ assert args.edge_hostname == "remote-worker-1"
+ assert args.concurrency == 16
diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py
b/providers/edge3/tests/unit/edge3/cli/test_worker.py
index b8fa906dbd3..a9c16d10822 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py
@@ -392,6 +392,33 @@ class TestEdgeWorker:
assert "queue1" in (queue_list)
assert "queue2" in (queue_list)
+ @patch("airflow.providers.edge3.cli.worker.worker_set_state")
+ async def test_heartbeat_adopts_remote_concurrency(self, mock_set_state,
worker_with_job: EdgeWorker):
+ EdgeWorker.jobs = []
+ EdgeWorker.drain = False
+ EdgeWorker.maintenance_mode = False
+ mock_set_state.return_value = WorkerSetStateReturn(
+ state=EdgeWorkerState.IDLE, queues=None, concurrency=32
+ )
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
+ await worker_with_job.heartbeat()
+ assert worker_with_job.concurrency == 32
+
+ @patch("airflow.providers.edge3.cli.worker.worker_set_state")
+ async def test_heartbeat_no_concurrency_override_keeps_startup_value(
+ self, mock_set_state, worker_with_job: EdgeWorker
+ ):
+ EdgeWorker.jobs = []
+ EdgeWorker.drain = False
+ EdgeWorker.maintenance_mode = False
+ original_concurrency = worker_with_job.concurrency
+ mock_set_state.return_value = WorkerSetStateReturn(
+ state=EdgeWorkerState.IDLE, queues=None, concurrency=None
+ )
+ with conf_vars({("edge", "api_url"):
"https://invalid-api-test-endpoint"}):
+ await worker_with_job.heartbeat()
+ assert worker_with_job.concurrency == original_concurrency
+
@patch("airflow.providers.edge3.cli.worker.worker_set_state")
async def test_version_mismatch(self, mock_set_state, worker_with_job):
mock_set_state.side_effect = EdgeWorkerVersionException("")
diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py
b/providers/edge3/tests/unit/edge3/models/test_db.py
index 4424c4b1e59..3bef0a1e563 100644
--- a/providers/edge3/tests/unit/edge3/models/test_db.py
+++ b/providers/edge3/tests/unit/edge3/models/test_db.py
@@ -168,11 +168,13 @@ class TestEdgeDBManager:
__import__("airflow.providers.edge3.models.db",
fromlist=["EdgeDBManager"]).EdgeDBManager,
"get_current_revision",
)
- def test_initdb_new_db(self, mock_get_rev, mock_create, mock_upgrade,
session):
+ @mock.patch("airflow.providers.edge3.models.db.inspect")
+ def test_initdb_new_db(self, mock_inspect, mock_get_rev, mock_create,
mock_upgrade, session):
"""Test that initdb calls create_db_from_orm for new databases."""
from airflow.providers.edge3.models.db import EdgeDBManager
mock_get_rev.return_value = None
+ mock_inspect.return_value.get_table_names.return_value = [] # no
tables exist
manager = EdgeDBManager(session)
manager.initdb()
@@ -205,11 +207,80 @@ class TestEdgeDBManager:
mock_create.assert_not_called()
def test_revision_heads_map_populated(self):
- """Test that _REVISION_HEADS_MAP is populated with the initial
migration."""
+ """Test that _REVISION_HEADS_MAP is populated with all known
migrations."""
from airflow.providers.edge3.models.db import _REVISION_HEADS_MAP
assert "3.0.0" in _REVISION_HEADS_MAP
assert _REVISION_HEADS_MAP["3.0.0"] == "9d34dfc2de06"
+ assert "3.2.0" in _REVISION_HEADS_MAP
+ assert _REVISION_HEADS_MAP["3.2.0"] == "b3c4d5e6f7a8"
+
+ def
test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self,
session):
+ """Test that initdb runs incremental migrations when tables exist but
alembic version table does not."""
+ from sqlalchemy import inspect, text
+
+ from airflow import settings
+ from airflow.providers.edge3.models.db import EdgeDBManager
+
+ manager = EdgeDBManager(session)
+
+ # Simulate pre-alembic state: tables exist but no version table and no
concurrency column
+ with settings.engine.begin() as conn:
+ inspector = inspect(conn)
+ if inspector.has_table("alembic_version_edge3"):
+ conn.execute(text("DELETE FROM alembic_version_edge3"))
+ if "concurrency" in {col["name"] for col in
inspector.get_columns("edge_worker")}:
+ from alembic.migration import MigrationContext
+ from alembic.operations import Operations
+
+ mc = MigrationContext.configure(conn, opts={"render_as_batch":
True})
+ ops = Operations(mc)
+ ops.drop_column("edge_worker", "concurrency")
+
+ # initdb() should detect tables exist, stamp to base, then upgrade
+ manager.initdb()
+
+ with settings.engine.connect() as conn:
+ version = conn.execute(text("SELECT version_num FROM
alembic_version_edge3")).scalar()
+ columns = {col["name"] for col in
inspect(conn).get_columns("edge_worker")}
+
+ assert version == "b3c4d5e6f7a8"
+ assert "concurrency" in columns
+
+ def test_migration_adds_concurrency_column(self, session):
+ """Test that upgrading from 3.0.0 actually adds the concurrency
column."""
+ from alembic import command
+ from alembic.migration import MigrationContext
+ from alembic.operations import Operations
+ from sqlalchemy import inspect
+
+ from airflow import settings
+ from airflow.providers.edge3.models.db import EdgeDBManager
+
+ manager = EdgeDBManager(session)
+ config = manager.get_alembic_config()
+
+ # DDL must be committed before alembic opens its own connection — use
engine.begin()
+ # so the DROP is visible to the fresh connection that upgradedb()
creates internally.
+ with settings.engine.begin() as conn:
+ inspector = inspect(conn)
+ if "concurrency" in {col["name"] for col in
inspector.get_columns("edge_worker")}:
+ mc = MigrationContext.configure(conn, opts={"render_as_batch":
True})
+ ops = Operations(mc)
+ ops.drop_column("edge_worker", "concurrency")
+
+ # Stamp to old revision (pre-concurrency) using alembic's own
connection
+ command.stamp(config, "9d34dfc2de06")
+
+ # Run the upgrade — migration 0002 should detect the missing column
and add it
+ manager.upgradedb()
+
+ # Verify with a fresh connection (upgradedb also uses its own
connection)
+ with settings.engine.connect() as conn:
+ inspector = inspect(conn)
+ columns = {col["name"] for col in
inspector.get_columns("edge_worker")}
+
+ assert "concurrency" in columns, "Migration 0002 should have added the
concurrency column"
def test_drop_tables_handles_missing_tables(self, session):
"""Test that drop_tables handles missing tables gracefully."""
diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
index 57a0bd81021..fc5cd111a1d 100644
--- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
@@ -26,7 +26,11 @@ from sqlalchemy import delete, select
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.edge3.cli.worker import EdgeWorker
-from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel,
EdgeWorkerState
+from airflow.providers.edge3.models.edge_worker import (
+ EdgeWorkerModel,
+ EdgeWorkerState,
+ set_worker_concurrency,
+)
from airflow.providers.edge3.worker_api.datamodels import
WorkerQueueUpdateBody, WorkerStateBody
from airflow.providers.edge3.worker_api.routes.worker import (
_assert_version,
@@ -248,6 +252,90 @@ class TestWorkerApiRoutes:
assert worker[0].queues == queues
assert return_queues == ["default", "default2"]
+ def test_set_state_returns_concurrency(self, session: Session, cli_worker:
EdgeWorker):
+ """set_state includes the DB-stored concurrency override in its
response."""
+ rwm = EdgeWorkerModel(
+ worker_name="test2_worker",
+ state=EdgeWorkerState.IDLE,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ )
+ rwm.concurrency = 16
+ session.add(rwm)
+ session.commit()
+
+ body = WorkerStateBody(
+ state=EdgeWorkerState.RUNNING,
+ jobs_active=0,
+ queues=["default"],
+ sysinfo=cli_worker._get_sysinfo(),
+ )
+ result = set_state("test2_worker", body, session)
+ assert result.concurrency == 16
+
+ def test_set_state_returns_none_concurrency_when_not_overridden(
+ self, session: Session, cli_worker: EdgeWorker
+ ):
+ """set_state returns None for concurrency when no override is set."""
+ rwm = EdgeWorkerModel(
+ worker_name="test2_worker",
+ state=EdgeWorkerState.IDLE,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ )
+ session.add(rwm)
+ session.commit()
+
+ body = WorkerStateBody(
+ state=EdgeWorkerState.RUNNING,
+ jobs_active=0,
+ queues=["default"],
+ sysinfo=cli_worker._get_sysinfo(),
+ )
+ result = set_state("test2_worker", body, session)
+ assert result.concurrency is None
+
+ def test_set_worker_concurrency(self, session: Session):
+ rwm = EdgeWorkerModel(
+ worker_name="test2_worker",
+ state=EdgeWorkerState.IDLE,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ )
+ session.add(rwm)
+ session.commit()
+
+ set_worker_concurrency("test2_worker", 16, session=session)
+ session.commit()
+
+ worker = session.scalars(select(EdgeWorkerModel)).one()
+ assert worker.concurrency == 16
+
+ @pytest.mark.parametrize(
+ "offline_state",
+ [
+ pytest.param(EdgeWorkerState.OFFLINE, id="offline"),
+ pytest.param(EdgeWorkerState.OFFLINE_MAINTENANCE,
id="offline-maintenance"),
+ pytest.param(EdgeWorkerState.UNKNOWN, id="unknown"),
+ ],
+ )
+ def test_set_worker_concurrency_rejects_offline(self, session: Session,
offline_state: EdgeWorkerState):
+ rwm = EdgeWorkerModel(
+ worker_name="test2_worker",
+ state=offline_state,
+ queues=["default"],
+ first_online=timezone.utcnow(),
+ )
+ session.add(rwm)
+ session.commit()
+
+ with pytest.raises(TypeError, match="Cannot set concurrency"):
+ set_worker_concurrency("test2_worker", 8, session=session)
+
+ def test_set_worker_concurrency_raises_for_unknown_worker(self, session:
Session):
+ with pytest.raises(ValueError, match="not found"):
+ set_worker_concurrency("nonexistent", 8, session=session)
+
@pytest.mark.parametrize(
("add_queues", "remove_queues", "expected_queues"),
[