This is an automated email from the ASF dual-hosted git repository.

turaga pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e0b55c21e27 Centralized runtime control of Edge Worker concurrency in 
distributed deployments (#62896)
e0b55c21e27 is described below

commit e0b55c21e27edd0e5a5b7b90f72be5373d01da90
Author: Dheeraj Turaga <[email protected]>
AuthorDate: Sat Mar 7 12:32:00 2026 -0600

    Centralized runtime control of Edge Worker concurrency in distributed 
deployments (#62896)
---
 .../src/airflow/providers/edge3/cli/definition.py  | 15 ++++
 .../airflow/providers/edge3/cli/edge_command.py    | 24 ++++++
 .../src/airflow/providers/edge3/cli/worker.py      |  7 ++
 .../0002_3_2_0_add_concurrency_to_edge_worker.py   | 49 ++++++++++++
 .../edge3/src/airflow/providers/edge3/models/db.py | 28 +++++++
 .../airflow/providers/edge3/models/edge_worker.py  | 21 +++++
 .../providers/edge3/worker_api/datamodels.py       |  7 ++
 .../providers/edge3/worker_api/routes/worker.py    |  5 +-
 .../edge3/worker_api/v2-edge-generated.yaml        |  7 ++
 .../edge3/tests/unit/edge3/cli/test_definition.py  | 18 ++++-
 .../edge3/tests/unit/edge3/cli/test_worker.py      | 27 +++++++
 providers/edge3/tests/unit/edge3/models/test_db.py | 75 +++++++++++++++++-
 .../unit/edge3/worker_api/routes/test_worker.py    | 90 +++++++++++++++++++++-
 13 files changed, 367 insertions(+), 6 deletions(-)

diff --git a/providers/edge3/src/airflow/providers/edge3/cli/definition.py 
b/providers/edge3/src/airflow/providers/edge3/cli/definition.py
index 5cd611cc3a4..a1819e9a9a0 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/definition.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/definition.py
@@ -59,6 +59,12 @@ ARG_QUEUES_MANAGE = Arg(
     help="Comma delimited list of queues to add or remove.",
     required=True,
 )
+ARG_CONCURRENCY_REQUIRED = Arg(
+    ("-c", "--concurrency"),
+    type=int,
+    help="The number of worker processes. Must be a positive integer.",
+    required=True,
+)
 ARG_WAIT_MAINT = Arg(
     ("-w", "--wait"),
     default=False,
@@ -229,6 +235,15 @@ EDGE_COMMANDS: list[ActionCommand] = [
         
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"),
         args=(ARG_YES,),
     ),
+    ActionCommand(
+        name="set-worker-concurrency",
+        help="Set the concurrency of a remote edge worker.",
+        
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.set_remote_worker_concurrency"),
+        args=(
+            ARG_REQUIRED_EDGE_HOSTNAME,
+            ARG_CONCURRENCY_REQUIRED,
+        ),
+    ),
 ]
 
 
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py 
b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
index 864d6851ace..c18fbdc7f90 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
@@ -427,3 +427,27 @@ def remove_worker_queues(args) -> None:
     except TypeError as e:
         logger.error(str(e))
         raise SystemExit
+
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def set_remote_worker_concurrency(args) -> None:
+    """Set the concurrency of a remote edge worker."""
+    _check_valid_db_connection()
+    _check_if_registered_edge_host(hostname=args.edge_hostname)
+    from airflow.providers.edge3.models.edge_worker import 
set_worker_concurrency
+
+    if args.concurrency <= 0:
+        raise SystemExit("Error: Concurrency must be a positive integer.")
+
+    try:
+        set_worker_concurrency(args.edge_hostname, args.concurrency)
+        logger.info(
+            "Concurrency set to %d for Edge Worker host %s by %s.",
+            args.concurrency,
+            args.edge_hostname,
+            getuser(),
+        )
+    except TypeError as e:
+        logger.error(str(e))
+        raise SystemExit
diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py 
b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
index c4aa1d735c5..6ac43e388f8 100644
--- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py
@@ -401,6 +401,13 @@ class EdgeWorker:
                 new_maintenance_comments,
             )
             self.queues = worker_info.queues
+            if worker_info.concurrency is not None and worker_info.concurrency 
!= self.concurrency:
+                logger.info(
+                    "Concurrency updated from %d to %d by remote request.",
+                    self.concurrency,
+                    worker_info.concurrency,
+                )
+                self.concurrency = worker_info.concurrency
             if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST:
                 logger.info("Maintenance mode requested!")
                 self.maintenance_mode = True
diff --git 
a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py
 
b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py
new file mode 100644
index 00000000000..edf6cc7f39c
--- /dev/null
+++ 
b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Add concurrency column to edge_worker table.
+
+Revision ID: b3c4d5e6f7a8
+Revises: 9d34dfc2de06
+Create Date: 2026-03-04 00:00:00.000000
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "b3c4d5e6f7a8"
+down_revision = "9d34dfc2de06"
+branch_labels = None
+depends_on = None
+edge3_version = "3.2.0"
+
+
+def upgrade() -> None:
+    bind = op.get_bind()
+    inspector = sa.inspect(bind)
+    existing_columns = {col["name"] for col in 
inspector.get_columns("edge_worker")}
+    if "concurrency" not in existing_columns:
+        op.add_column("edge_worker", sa.Column("concurrency", sa.Integer(), 
nullable=True))
+
+
+def downgrade() -> None:
+    op.drop_column("edge_worker", "concurrency")
diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py 
b/providers/edge3/src/airflow/providers/edge3/models/db.py
index 1c98cb40235..e98b3cc24b9 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/db.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/db.py
@@ -31,6 +31,7 @@ PACKAGE_DIR = Path(__file__).parents[1]
 
 _REVISION_HEADS_MAP: dict[str, str] = {
     "3.0.0": "9d34dfc2de06",
+    "3.2.0": "b3c4d5e6f7a8",
 }
 
 
@@ -45,6 +46,33 @@ class EdgeDBManager(BaseDBManager):
     supports_table_dropping = True
     revision_heads_map = _REVISION_HEADS_MAP
 
+    def initdb(self):
+        """
+        Initialize the database, handling pre-alembic installations.
+
+        If the edge3 tables already exist but the alembic version table does 
not
+        (e.g. created via create_all before the migration chain was 
introduced),
+        stamp to the first revision and run the incremental upgrade so every
+        migration is applied rather than jumping straight to head.
+        """
+        db_exists = self.get_current_revision()
+        if db_exists:
+            self.upgradedb()
+        else:
+            from airflow import settings
+
+            existing_tables = set(inspect(settings.engine).get_table_names())
+            if any(table in existing_tables for table in self.metadata.tables):
+                script = self.get_script_object()
+                base_revision = next(r.revision for r in 
script.walk_revisions() if r.down_revision is None)
+                config = self.get_alembic_config()
+                from alembic import command
+
+                command.stamp(config, base_revision)
+                self.upgradedb()
+            else:
+                self.create_db_from_orm()
+
     def drop_tables(self, connection):
         """Drop only edge3 tables in reverse dependency order."""
         if not self.supports_table_dropping:
diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py 
b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
index e4b0698e7e0..5b037f2903c 100644
--- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py
@@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
     jobs_success: Mapped[int] = mapped_column(Integer, default=0)
     jobs_failed: Mapped[int] = mapped_column(Integer, default=0)
     sysinfo: Mapped[str | None] = mapped_column(String(256))
+    concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True)
 
     def __init__(
         self,
@@ -392,3 +393,23 @@ def remove_worker_queues(worker_name: str, queues: 
list[str], session: Session =
         logger.error(error_message)
         raise TypeError(error_message)
     worker.remove_queues(queues)
+
+
+@provide_session
+def set_worker_concurrency(worker_name: str, concurrency: int, session: 
Session = NEW_SESSION) -> None:
+    """Set the concurrency of an edge worker."""
+    query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
+    worker: EdgeWorkerModel | None = session.scalar(query)
+    if not worker:
+        raise ValueError(f"Edge Worker {worker_name} not found in list of 
registered workers")
+    if worker.state in (
+        EdgeWorkerState.OFFLINE,
+        EdgeWorkerState.OFFLINE_MAINTENANCE,
+        EdgeWorkerState.UNKNOWN,
+    ):
+        error_message = (
+            f"Cannot set concurrency for edge worker {worker_name} as it is in 
{worker.state} state!"
+        )
+        logger.error(error_message)
+        raise TypeError(error_message)
+    worker.concurrency = concurrency
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
index c25ff53a934..fc780b87662 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py
@@ -200,3 +200,10 @@ class WorkerSetStateReturn(BaseModel):
         str | None,
         Field(description="Comments about the maintenance state of the 
worker."),
     ] = None
+    concurrency: Annotated[
+        int | None,
+        Field(
+            description="Desired concurrency for the worker set by an 
administrator. "
+            "None means no remote override; the worker uses its startup 
value.",
+        ),
+    ] = None
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py 
b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
index 8e3dce56e15..34368c98c4a 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py
@@ -238,7 +238,10 @@ def set_state(
     )
     _assert_version(body.sysinfo)  # Exception only after worker state is in 
the DB
     return WorkerSetStateReturn(
-        state=worker.state, queues=worker.queues, 
maintenance_comments=worker.maintenance_comment
+        state=worker.state,
+        queues=worker.queues,
+        maintenance_comments=worker.maintenance_comment,
+        concurrency=worker.concurrency,
     )
 
 
diff --git 
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml 
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 94feded9ff9..a5daeacb86f 100644
--- 
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++ 
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -1413,6 +1413,13 @@ components:
           - type: 'null'
           title: Maintenance Comments
           description: Comments about the maintenance state of the worker.
+        concurrency:
+          anyOf:
+          - type: integer
+          - type: 'null'
+          title: Concurrency
+          description: Desired concurrency for the worker set by an 
administrator.
+            None means no remote override; the worker uses its startup value.
       type: object
       required:
       - state
diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py 
b/providers/edge3/tests/unit/edge3/cli/test_definition.py
index a99bbc7565e..0c1dac8e6d3 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_definition.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py
@@ -53,8 +53,8 @@ class TestEdgeCliDefinition:
         assert len(commands) == 1
 
     def test_edge_commands_count(self):
-        """Test that EDGE_COMMANDS contains all 13 subcommands."""
-        assert len(EDGE_COMMANDS) == 13
+        """Test that EDGE_COMMANDS contains all 14 subcommands."""
+        assert len(EDGE_COMMANDS) == 14
 
     @pytest.mark.parametrize(
         "command",
@@ -234,3 +234,17 @@ class TestEdgeCliDefinition:
         params = ["edge", "shutdown-all-workers", "--yes"]
         args = self.arg_parser.parse_args(params)
         assert args.yes is True
+
+    def test_set_worker_concurrency_args(self):
+        """Test set-worker-concurrency command with required arguments."""
+        params = [
+            "edge",
+            "set-worker-concurrency",
+            "--edge-hostname",
+            "remote-worker-1",
+            "--concurrency",
+            "16",
+        ]
+        args = self.arg_parser.parse_args(params)
+        assert args.edge_hostname == "remote-worker-1"
+        assert args.concurrency == 16
diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py 
b/providers/edge3/tests/unit/edge3/cli/test_worker.py
index b8fa906dbd3..a9c16d10822 100644
--- a/providers/edge3/tests/unit/edge3/cli/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py
@@ -392,6 +392,33 @@ class TestEdgeWorker:
         assert "queue1" in (queue_list)
         assert "queue2" in (queue_list)
 
+    @patch("airflow.providers.edge3.cli.worker.worker_set_state")
+    async def test_heartbeat_adopts_remote_concurrency(self, mock_set_state, 
worker_with_job: EdgeWorker):
+        EdgeWorker.jobs = []
+        EdgeWorker.drain = False
+        EdgeWorker.maintenance_mode = False
+        mock_set_state.return_value = WorkerSetStateReturn(
+            state=EdgeWorkerState.IDLE, queues=None, concurrency=32
+        )
+        with conf_vars({("edge", "api_url"): 
"https://invalid-api-test-endpoint"}):
+            await worker_with_job.heartbeat()
+        assert worker_with_job.concurrency == 32
+
+    @patch("airflow.providers.edge3.cli.worker.worker_set_state")
+    async def test_heartbeat_no_concurrency_override_keeps_startup_value(
+        self, mock_set_state, worker_with_job: EdgeWorker
+    ):
+        EdgeWorker.jobs = []
+        EdgeWorker.drain = False
+        EdgeWorker.maintenance_mode = False
+        original_concurrency = worker_with_job.concurrency
+        mock_set_state.return_value = WorkerSetStateReturn(
+            state=EdgeWorkerState.IDLE, queues=None, concurrency=None
+        )
+        with conf_vars({("edge", "api_url"): 
"https://invalid-api-test-endpoint"}):
+            await worker_with_job.heartbeat()
+        assert worker_with_job.concurrency == original_concurrency
+
     @patch("airflow.providers.edge3.cli.worker.worker_set_state")
     async def test_version_mismatch(self, mock_set_state, worker_with_job):
         mock_set_state.side_effect = EdgeWorkerVersionException("")
diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py 
b/providers/edge3/tests/unit/edge3/models/test_db.py
index 4424c4b1e59..3bef0a1e563 100644
--- a/providers/edge3/tests/unit/edge3/models/test_db.py
+++ b/providers/edge3/tests/unit/edge3/models/test_db.py
@@ -168,11 +168,13 @@ class TestEdgeDBManager:
         __import__("airflow.providers.edge3.models.db", 
fromlist=["EdgeDBManager"]).EdgeDBManager,
         "get_current_revision",
     )
-    def test_initdb_new_db(self, mock_get_rev, mock_create, mock_upgrade, 
session):
+    @mock.patch("airflow.providers.edge3.models.db.inspect")
+    def test_initdb_new_db(self, mock_inspect, mock_get_rev, mock_create, 
mock_upgrade, session):
         """Test that initdb calls create_db_from_orm for new databases."""
         from airflow.providers.edge3.models.db import EdgeDBManager
 
         mock_get_rev.return_value = None
+        mock_inspect.return_value.get_table_names.return_value = []  # no 
tables exist
 
         manager = EdgeDBManager(session)
         manager.initdb()
@@ -205,11 +207,80 @@ class TestEdgeDBManager:
         mock_create.assert_not_called()
 
     def test_revision_heads_map_populated(self):
-        """Test that _REVISION_HEADS_MAP is populated with the initial 
migration."""
+        """Test that _REVISION_HEADS_MAP is populated with all known 
migrations."""
         from airflow.providers.edge3.models.db import _REVISION_HEADS_MAP
 
         assert "3.0.0" in _REVISION_HEADS_MAP
         assert _REVISION_HEADS_MAP["3.0.0"] == "9d34dfc2de06"
+        assert "3.2.0" in _REVISION_HEADS_MAP
+        assert _REVISION_HEADS_MAP["3.2.0"] == "b3c4d5e6f7a8"
+
+    def 
test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self, 
session):
+        """Test that initdb runs incremental migrations when tables exist but 
alembic version table does not."""
+        from sqlalchemy import inspect, text
+
+        from airflow import settings
+        from airflow.providers.edge3.models.db import EdgeDBManager
+
+        manager = EdgeDBManager(session)
+
+        # Simulate pre-alembic state: tables exist but no version table and no 
concurrency column
+        with settings.engine.begin() as conn:
+            inspector = inspect(conn)
+            if inspector.has_table("alembic_version_edge3"):
+                conn.execute(text("DELETE FROM alembic_version_edge3"))
+            if "concurrency" in {col["name"] for col in 
inspector.get_columns("edge_worker")}:
+                from alembic.migration import MigrationContext
+                from alembic.operations import Operations
+
+                mc = MigrationContext.configure(conn, opts={"render_as_batch": 
True})
+                ops = Operations(mc)
+                ops.drop_column("edge_worker", "concurrency")
+
+        # initdb() should detect tables exist, stamp to base, then upgrade
+        manager.initdb()
+
+        with settings.engine.connect() as conn:
+            version = conn.execute(text("SELECT version_num FROM 
alembic_version_edge3")).scalar()
+            columns = {col["name"] for col in 
inspect(conn).get_columns("edge_worker")}
+
+        assert version == "b3c4d5e6f7a8"
+        assert "concurrency" in columns
+
+    def test_migration_adds_concurrency_column(self, session):
+        """Test that upgrading from 3.0.0 actually adds the concurrency 
column."""
+        from alembic import command
+        from alembic.migration import MigrationContext
+        from alembic.operations import Operations
+        from sqlalchemy import inspect
+
+        from airflow import settings
+        from airflow.providers.edge3.models.db import EdgeDBManager
+
+        manager = EdgeDBManager(session)
+        config = manager.get_alembic_config()
+
+        # DDL must be committed before alembic opens its own connection — use 
engine.begin()
+        # so the DROP is visible to the fresh connection that upgradedb() 
creates internally.
+        with settings.engine.begin() as conn:
+            inspector = inspect(conn)
+            if "concurrency" in {col["name"] for col in 
inspector.get_columns("edge_worker")}:
+                mc = MigrationContext.configure(conn, opts={"render_as_batch": 
True})
+                ops = Operations(mc)
+                ops.drop_column("edge_worker", "concurrency")
+
+        # Stamp to old revision (pre-concurrency) using alembic's own 
connection
+        command.stamp(config, "9d34dfc2de06")
+
+        # Run the upgrade — migration 0002 should detect the missing column 
and add it
+        manager.upgradedb()
+
+        # Verify with a fresh connection (upgradedb also uses its own 
connection)
+        with settings.engine.connect() as conn:
+            inspector = inspect(conn)
+            columns = {col["name"] for col in 
inspector.get_columns("edge_worker")}
+
+        assert "concurrency" in columns, "Migration 0002 should have added the 
concurrency column"
 
     def test_drop_tables_handles_missing_tables(self, session):
         """Test that drop_tables handles missing tables gracefully."""
diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py 
b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
index 57a0bd81021..fc5cd111a1d 100644
--- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
+++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py
@@ -26,7 +26,11 @@ from sqlalchemy import delete, select
 
 from airflow.providers.common.compat.sdk import timezone
 from airflow.providers.edge3.cli.worker import EdgeWorker
-from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, 
EdgeWorkerState
+from airflow.providers.edge3.models.edge_worker import (
+    EdgeWorkerModel,
+    EdgeWorkerState,
+    set_worker_concurrency,
+)
 from airflow.providers.edge3.worker_api.datamodels import 
WorkerQueueUpdateBody, WorkerStateBody
 from airflow.providers.edge3.worker_api.routes.worker import (
     _assert_version,
@@ -248,6 +252,90 @@ class TestWorkerApiRoutes:
         assert worker[0].queues == queues
         assert return_queues == ["default", "default2"]
 
+    def test_set_state_returns_concurrency(self, session: Session, cli_worker: 
EdgeWorker):
+        """set_state includes the DB-stored concurrency override in its 
response."""
+        rwm = EdgeWorkerModel(
+            worker_name="test2_worker",
+            state=EdgeWorkerState.IDLE,
+            queues=["default"],
+            first_online=timezone.utcnow(),
+        )
+        rwm.concurrency = 16
+        session.add(rwm)
+        session.commit()
+
+        body = WorkerStateBody(
+            state=EdgeWorkerState.RUNNING,
+            jobs_active=0,
+            queues=["default"],
+            sysinfo=cli_worker._get_sysinfo(),
+        )
+        result = set_state("test2_worker", body, session)
+        assert result.concurrency == 16
+
+    def test_set_state_returns_none_concurrency_when_not_overridden(
+        self, session: Session, cli_worker: EdgeWorker
+    ):
+        """set_state returns None for concurrency when no override is set."""
+        rwm = EdgeWorkerModel(
+            worker_name="test2_worker",
+            state=EdgeWorkerState.IDLE,
+            queues=["default"],
+            first_online=timezone.utcnow(),
+        )
+        session.add(rwm)
+        session.commit()
+
+        body = WorkerStateBody(
+            state=EdgeWorkerState.RUNNING,
+            jobs_active=0,
+            queues=["default"],
+            sysinfo=cli_worker._get_sysinfo(),
+        )
+        result = set_state("test2_worker", body, session)
+        assert result.concurrency is None
+
+    def test_set_worker_concurrency(self, session: Session):
+        rwm = EdgeWorkerModel(
+            worker_name="test2_worker",
+            state=EdgeWorkerState.IDLE,
+            queues=["default"],
+            first_online=timezone.utcnow(),
+        )
+        session.add(rwm)
+        session.commit()
+
+        set_worker_concurrency("test2_worker", 16, session=session)
+        session.commit()
+
+        worker = session.scalars(select(EdgeWorkerModel)).one()
+        assert worker.concurrency == 16
+
+    @pytest.mark.parametrize(
+        "offline_state",
+        [
+            pytest.param(EdgeWorkerState.OFFLINE, id="offline"),
+            pytest.param(EdgeWorkerState.OFFLINE_MAINTENANCE, 
id="offline-maintenance"),
+            pytest.param(EdgeWorkerState.UNKNOWN, id="unknown"),
+        ],
+    )
+    def test_set_worker_concurrency_rejects_offline(self, session: Session, 
offline_state: EdgeWorkerState):
+        rwm = EdgeWorkerModel(
+            worker_name="test2_worker",
+            state=offline_state,
+            queues=["default"],
+            first_online=timezone.utcnow(),
+        )
+        session.add(rwm)
+        session.commit()
+
+        with pytest.raises(TypeError, match="Cannot set concurrency"):
+            set_worker_concurrency("test2_worker", 8, session=session)
+
+    def test_set_worker_concurrency_raises_for_unknown_worker(self, session: 
Session):
+        with pytest.raises(ValueError, match="not found"):
+            set_worker_concurrency("nonexistent", 8, session=session)
+
     @pytest.mark.parametrize(
         ("add_queues", "remove_queues", "expected_queues"),
         [

Reply via email to