This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1af5b93157a Add 'name' and 'group' to public Asset class (#42812)
1af5b93157a is described below

commit 1af5b93157ab8310e8b61c7ef7f923c03a69aa2e
Author: Tzu-ping Chung <uranu...@gmail.com>
AuthorDate: Wed Oct 16 10:23:29 2024 +0800

    Add 'name' and 'group' to public Asset class (#42812)
---
 airflow/assets/__init__.py                     | 73 ++++++++++++++++++++------
 airflow/models/asset.py                        | 20 ++++---
 tests/assets/{tests_asset.py => test_asset.py} | 41 ++++++++++++++-
 tests/models/test_dag.py                       |  6 +--
 tests/serialization/test_serialized_objects.py |  2 +-
 5 files changed, 115 insertions(+), 27 deletions(-)

diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py
index e11b9c49df3..15805418472 100644
--- a/airflow/assets/__init__.py
+++ b/airflow/assets/__init__.py
@@ -21,7 +21,7 @@ import logging
 import os
 import urllib.parse
 import warnings
-from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, 
cast
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, 
cast, overload
 
 import attr
 from sqlalchemy import select
@@ -74,12 +74,6 @@ def _sanitize_uri(uri: str) -> str:
     This checks for URI validity, and normalizes the URI if needed. A fully
     normalized URI is returned.
     """
-    if not uri:
-        raise ValueError("Asset URI cannot be empty")
-    if uri.isspace():
-        raise ValueError("Asset URI cannot be just whitespace")
-    if not uri.isascii():
-        raise ValueError("Asset URI must only consist of ASCII characters")
     parsed = urllib.parse.urlsplit(uri)
     if not parsed.scheme and not parsed.netloc:  # Does not look like a URI.
         return uri
@@ -126,6 +120,24 @@ def _sanitize_uri(uri: str) -> str:
     return urllib.parse.urlunsplit(parsed)
 
 
+def _validate_identifier(instance, attribute, value):
+    if not isinstance(value, str):
+        raise ValueError(f"{type(instance).__name__} {attribute.name} must be 
a string")
+    if len(value) > 1500:
+        raise ValueError(f"{type(instance).__name__} {attribute.name} cannot 
exceed 1500 characters")
+    if value.isspace():
+        raise ValueError(f"{type(instance).__name__} {attribute.name} cannot 
be just whitespace")
+    if not value.isascii():
+        raise ValueError(f"{type(instance).__name__} {attribute.name} must 
only consist of ASCII characters")
+    return value
+
+
+def _validate_non_empty_identifier(instance, attribute, value):
+    if not _validate_identifier(instance, attribute, value):
+        raise ValueError(f"{type(instance).__name__} {attribute.name} cannot 
be empty")
+    return value
+
+
 def extract_event_key(value: str | Asset | AssetAlias) -> str:
     """
     Extract the key of an inlet or an outlet event.
@@ -157,7 +169,7 @@ def expand_alias_to_assets(alias: str | AssetAlias, *, 
session: Session = NEW_SE
         select(AssetAliasModel).where(AssetAliasModel.name == 
alias_name).limit(1)
     )
     if asset_alias_obj:
-        return [Asset(uri=asset.uri, extra=asset.extra) for asset in 
asset_alias_obj.datasets]
+        return [asset.to_public() for asset in asset_alias_obj.datasets]
     return []
 
 
@@ -214,7 +226,7 @@ class BaseAsset:
 class AssetAlias(BaseAsset):
     """A represeation of asset alias which is used to create asset during the 
runtime."""
 
-    name: str
+    name: str = attr.field(validator=_validate_non_empty_identifier)
 
     def iter_assets(self) -> Iterator[tuple[str, Asset]]:
         return iter(())
@@ -256,18 +268,49 @@ def _set_extra_default(extra: dict | None) -> dict:
     return extra
 
 
-@attr.define(unsafe_hash=False)
+@attr.define(init=False, unsafe_hash=False)
 class Asset(os.PathLike, BaseAsset):
     """A representation of data dependencies between workflows."""
 
-    uri: str = attr.field(
-        converter=_sanitize_uri,
-        validator=[attr.validators.min_len(1), attr.validators.max_len(1500)],
-    )
-    extra: dict[str, Any] = attr.field(factory=dict, 
converter=_set_extra_default)
+    name: str = attr.field()
+    uri: str = attr.field()
+    group: str = attr.field()
+    extra: dict[str, Any] = attr.field()
 
     __version__: ClassVar[int] = 1
 
+    @overload
+    def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | 
None = None) -> None:
+        """Canonical; both name and uri are provided."""
+
+    @overload
+    def __init__(self, name: str, *, group: str = "", extra: dict | None = 
None) -> None:
+        """It's possible to only provide the name, either by keyword or as the 
only positional argument."""
+
+    @overload
+    def __init__(self, *, uri: str, group: str = "", extra: dict | None = 
None) -> None:
+        """It's possible to only provide the URI as a keyword argument."""
+
+    def __init__(
+        self,
+        name: str | None = None,
+        uri: str | None = None,
+        *,
+        group: str = "",
+        extra: dict | None = None,
+    ) -> None:
+        if name is None and uri is None:
+            raise TypeError("Asset() requires either 'name' or 'uri'")
+        elif name is None:
+            name = uri
+        elif uri is None:
+            uri = name
+        fields = attr.fields_dict(Asset)
+        self.name = _validate_non_empty_identifier(self, fields["name"], name)
+        self.uri = _sanitize_uri(_validate_non_empty_identifier(self, 
fields["uri"], uri))
+        self.group = _validate_identifier(self, fields["group"], group)
+        self.extra = _set_extra_default(extra)
+
     def __fspath__(self) -> str:
         return self.uri
 
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index d5ca0ea513f..b565c9a100e 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -205,17 +205,23 @@ class AssetModel(Base):
 
     @classmethod
     def from_public(cls, obj: Asset) -> AssetModel:
-        return cls(uri=obj.uri, extra=obj.extra)
-
-    def __init__(self, uri: str, **kwargs):
+        return cls(name=obj.name, uri=obj.uri, group=obj.group, 
extra=obj.extra)
+
+    def __init__(self, name: str = "", uri: str = "", **kwargs):
+        if not name and not uri:
+            raise TypeError("must provide either 'name' or 'uri'")
+        elif not name:
+            name = uri
+        elif not uri:
+            uri = name
         try:
             uri.encode("ascii")
         except UnicodeEncodeError:
-            raise ValueError("URI must be ascii")
+            raise ValueError("URI must be ascii") from None
         parsed = urlsplit(uri)
         if parsed.scheme and parsed.scheme.lower() == "airflow":
-            raise ValueError("Scheme `airflow` is reserved.")
-        super().__init__(name=uri, uri=uri, **kwargs)
+            raise ValueError("Scheme 'airflow' is reserved.")
+        super().__init__(name=name, uri=uri, **kwargs)
 
     def __eq__(self, other):
         if isinstance(other, (self.__class__, Asset)):
@@ -229,7 +235,7 @@ class AssetModel(Base):
         return f"{self.__class__.__name__}(uri={self.uri!r}, 
extra={self.extra!r})"
 
     def to_public(self) -> Asset:
-        return Asset(uri=self.uri, extra=self.extra)
+        return Asset(name=self.name, uri=self.uri, group=self.group, 
extra=self.extra)
 
 
 class AssetActive(Base):
diff --git a/tests/assets/tests_asset.py b/tests/assets/test_asset.py
similarity index 95%
rename from tests/assets/tests_asset.py
rename to tests/assets/test_asset.py
index 0bcfb83e88a..4d3466b90c1 100644
--- a/tests/assets/tests_asset.py
+++ b/tests/assets/test_asset.py
@@ -50,12 +50,26 @@ def clear_assets():
     clear_db_assets()
 
 
+@pytest.mark.parametrize(
+    ["name"],
+    [
+        pytest.param("", id="empty"),
+        pytest.param("\n\t", id="whitespace"),
+        pytest.param("a" * 1501, id="too_long"),
+        pytest.param("😊", id="non-ascii"),
+    ],
+)
+def test_invalid_names(name):
+    with pytest.raises(ValueError):
+        Asset(name=name)
+
+
 @pytest.mark.parametrize(
     ["uri"],
     [
         pytest.param("", id="empty"),
         pytest.param("\n\t", id="whitespace"),
-        pytest.param("a" * 3001, id="too_long"),
+        pytest.param("a" * 1501, id="too_long"),
         pytest.param("airflow://xcom/dag/task", id="reserved_scheme"),
         pytest.param("😊", id="non-ascii"),
     ],
@@ -65,6 +79,31 @@ def test_invalid_uris(uri):
         Asset(uri=uri)
 
 
+def test_only_name():
+    asset = Asset(name="foobar")
+    assert asset.name == "foobar"
+    assert asset.uri == "foobar"
+
+
+def test_only_uri():
+    asset = Asset(uri="s3://bucket/key/path")
+    assert asset.name == "s3://bucket/key/path"
+    assert asset.uri == "s3://bucket/key/path"
+
+
+@pytest.mark.parametrize("arg", ["foobar", "s3://bucket/key/path"])
+def test_only_posarg(arg):
+    asset = Asset(arg)
+    assert asset.name == arg
+    assert asset.uri == arg
+
+
+def test_both_name_and_uri():
+    asset = Asset("foobar", "s3://bucket/key/path")
+    assert asset.name == "foobar"
+    assert asset.uri == "s3://bucket/key/path"
+
+
 @pytest.mark.parametrize(
     "uri, normalized",
     [
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index b439487b016..e6f8042253e 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2715,10 +2715,10 @@ class TestDagModel:
         dag = DAG(
             dag_id="test_dag_asset_expression",
             schedule=AssetAny(
-                Asset("s3://dag1/output_1.txt", {"hi": "bye"}),
+                Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}),
                 AssetAll(
-                    Asset("s3://dag2/output_1.txt", {"hi": "bye"}),
-                    Asset("s3://dag3/output_3.txt", {"hi": "bye"}),
+                    Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}),
+                    Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}),
                 ),
                 AssetAlias(name="test_name"),
             ),
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index 5d35278d89b..56a31d4d38b 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -327,7 +327,7 @@ sample_objects = {
         id=1, filename="test_file", elasticsearch_id="test_id", 
created_at=datetime.now()
     ),
     DagTagPydantic: DagTag(),
-    AssetPydantic: Asset("uri", {}),
+    AssetPydantic: Asset("uri", extra={}),
     AssetEventPydantic: AssetEvent(),
 }
 

Reply via email to