kaxil commented on code in PR #62754:
URL: https://github.com/apache/airflow/pull/62754#discussion_r2875621488
##########
providers/common/sql/src/airflow/providers/common/sql/datafusion/format_handlers.py:
##########
@@ -16,98 +16,135 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
-from airflow.providers.common.sql.config import FormatType
+from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
+from airflow.providers.common.sql.config import DataSourceConfig, FormatType
from airflow.providers.common.sql.datafusion.base import FormatHandler
-from airflow.providers.common.sql.datafusion.exceptions import
FileFormatRegistrationException
+from airflow.providers.common.sql.datafusion.exceptions import (
+ FileFormatRegistrationException,
+ IcebergRegistrationException,
+)
if TYPE_CHECKING:
from datafusion import SessionContext
class ParquetFormatHandler(FormatHandler):
- """
- Parquet format handler.
-
- :param options: Additional options for the Parquet format.
-
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_parquet
- """
-
- def __init__(self, options: dict[str, Any] | None = None):
- self.options = options or {}
+ """Parquet format handler."""
@property
def get_format(self) -> FormatType:
"""Return the format type."""
return FormatType.PARQUET
- def register_data_source_format(self, ctx: SessionContext, table_name:
str, path: str):
+ def register_data_source_format(self, ctx: SessionContext):
"""Register a data source format."""
try:
- ctx.register_parquet(table_name, path, **self.options)
+ ctx.register_parquet(
+ self.datasource_config.table_name,
+ self.datasource_config.uri,
+ **self.datasource_config.options,
+ )
except Exception as e:
raise FileFormatRegistrationException(f"Failed to register Parquet
data source: {e}")
class CsvFormatHandler(FormatHandler):
- """
- CSV format handler.
-
- :param options: Additional options for the CSV format.
-
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_csv
- """
-
- def __init__(self, options: dict[str, Any] | None = None):
- self.options = options or {}
+ """CSV format handler."""
@property
def get_format(self) -> FormatType:
"""Return the format type."""
return FormatType.CSV
- def register_data_source_format(self, ctx: SessionContext, table_name:
str, path: str):
+ def register_data_source_format(self, ctx: SessionContext):
"""Register a data source format."""
try:
- ctx.register_csv(table_name, path, **self.options)
+ ctx.register_csv(
+ self.datasource_config.table_name,
+ self.datasource_config.uri,
+ **self.datasource_config.options,
+ )
except Exception as e:
raise FileFormatRegistrationException(f"Failed to register csv
data source: {e}")
class AvroFormatHandler(FormatHandler):
- """
- Avro format handler.
-
- :param options: Additional options for the Avro format.
-
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_avro
- """
-
- def __init__(self, options: dict[str, Any] | None = None):
- self.options = options or {}
+ """Avro format handler."""
@property
def get_format(self) -> FormatType:
"""Return the format type."""
return FormatType.AVRO
- def register_data_source_format(self, ctx: SessionContext, table_name:
str, path: str) -> None:
+ def register_data_source_format(self, ctx: SessionContext) -> None:
"""Register a data source format."""
try:
- ctx.register_avro(table_name, path, **self.options)
+ ctx.register_avro(
+ self.datasource_config.table_name,
+ self.datasource_config.uri,
+ **self.datasource_config.options,
+ )
except Exception as e:
raise FileFormatRegistrationException(f"Failed to register Avro
data source: {e}")
-def get_format_handler(format_type: str, options: dict[str, Any] | None =
None) -> FormatHandler:
+class IcebergFormatHandler(FormatHandler):
+ """
+ Iceberg format handler for DataFusion.
+
+ Loads an Iceberg table from a catalog using ``IcebergHook`` and registers
+ it with a DataFusion ``SessionContext`` via ``register_table_provider``.
+ """
+
+ @property
+ def get_format(self) -> FormatType:
+ """Return the format type."""
+ return FormatType.ICEBERG
+
+ def register_data_source_format(self, ctx: SessionContext) -> None:
+ """Register an Iceberg table with the DataFusion session context."""
+ try:
+ from airflow.providers.apache.iceberg.hooks.iceberg import
IcebergHook
+ except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ "Iceberg format requires the
apache-airflow-providers-apache-iceberg package. "
+ "Install it with: pip install
'apache-airflow-providers-apache-iceberg'"
+ )
+
+ try:
+ hook = IcebergHook(iceberg_conn_id=self.datasource_config.conn_id)
+ namespace_table =
f"{self.datasource_config.db_name}.{self.datasource_config.table_name}"
Review Comment:
Should we require `db_name` for Iceberg, or support fully-qualified
`table_name` when `db_name` is omitted? Right now this always prefixes
(`{db_name}.{table_name}`), which can produce `None.<table>` and fail lookup.
I think we should either (1) validate and require `db_name` explicitly for
`format="iceberg"`, or (2) only prefix when `db_name` is provided.
##########
providers/common/sql/tests/unit/common/sql/datafusion/test_format_handlers.py:
##########
@@ -36,47 +40,138 @@ def session_context_mock(self):
return MagicMock()
def test_parquet_handler_success(self, session_context_mock):
- handler = ParquetFormatHandler(options={"key": "value"})
- handler.register_data_source_format(session_context_mock,
"table_name", "path/to/file")
+ datasource_config = DataSourceConfig(
+ table_name="table_name",
+ uri="file://path/to/file",
+ format="parquet",
+ conn_id="conn_id",
+ options={"key": "value"},
+ )
+ handler = ParquetFormatHandler(datasource_config)
+ handler.register_data_source_format(session_context_mock)
session_context_mock.register_parquet.assert_called_once_with(
- "table_name", "path/to/file", key="value"
+ "table_name", "file://path/to/file", key="value"
)
assert handler.get_format == FormatType.PARQUET
def test_parquet_handler_failure(self, session_context_mock):
session_context_mock.register_parquet.side_effect = Exception("Error")
- handler = ParquetFormatHandler()
+ datasource_config = DataSourceConfig(
+ table_name="table_name", uri="file://path/to/file",
format="parquet", conn_id="conn_id"
+ )
+ handler = ParquetFormatHandler(datasource_config)
with pytest.raises(FileFormatRegistrationException, match="Failed to
register Parquet data source"):
- handler.register_data_source_format(session_context_mock,
"table_name", "path/to/file")
+ handler.register_data_source_format(session_context_mock)
def test_csv_handler_success(self, session_context_mock):
- handler = CsvFormatHandler(options={"delimiter": ","})
- handler.register_data_source_format(session_context_mock,
"table_name", "path/to/file")
-
session_context_mock.register_csv.assert_called_once_with("table_name",
"path/to/file", delimiter=",")
+ datasource_config = DataSourceConfig(
+ table_name="table_name",
+ uri="file://path/to/file",
+ format="csv",
+ conn_id="conn_id",
+ options={"delimiter": ","},
+ )
+ handler = CsvFormatHandler(datasource_config)
+ handler.register_data_source_format(session_context_mock)
+ session_context_mock.register_csv.assert_called_once_with(
+ "table_name", "file://path/to/file", delimiter=","
+ )
assert handler.get_format == FormatType.CSV
def test_csv_handler_failure(self, session_context_mock):
session_context_mock.register_csv.side_effect = Exception("Error")
- handler = CsvFormatHandler()
+ datasource_config = DataSourceConfig(
+ table_name="table_name", uri="file://path/to/file", format="csv",
conn_id="conn_id"
+ )
+ handler = CsvFormatHandler(datasource_config)
with pytest.raises(FileFormatRegistrationException, match="Failed to
register csv data source"):
- handler.register_data_source_format(session_context_mock,
"table_name", "path/to/file")
+ handler.register_data_source_format(session_context_mock)
def test_avro_handler_success(self, session_context_mock):
- handler = AvroFormatHandler(options={"key": "value"})
- handler.register_data_source_format(session_context_mock,
"table_name", "path/to/file")
-
session_context_mock.register_avro.assert_called_once_with("table_name",
"path/to/file", key="value")
+ datasource_config = DataSourceConfig(
+ table_name="table_name",
+ uri="file://path/to/file",
+ format="avro",
+ conn_id="conn_id",
+ options={"key": "value"},
+ )
+ handler = AvroFormatHandler(datasource_config)
+ handler.register_data_source_format(session_context_mock)
+ session_context_mock.register_avro.assert_called_once_with(
+ "table_name", "file://path/to/file", key="value"
+ )
assert handler.get_format == FormatType.AVRO
def test_avro_handler_failure(self, session_context_mock):
session_context_mock.register_avro.side_effect = Exception("Error")
- handler = AvroFormatHandler()
+ datasource_config = DataSourceConfig(
+ table_name="table_name", uri="file://path/to/file", format="avro",
conn_id="conn_id"
+ )
+ handler = AvroFormatHandler(datasource_config)
with pytest.raises(FileFormatRegistrationException, match="Failed to
register Avro data source"):
- handler.register_data_source_format(session_context_mock,
"table_name", "path/to/file")
+ handler.register_data_source_format(session_context_mock)
+
+ @patch("airflow.providers.apache.iceberg.hooks.iceberg.IcebergHook")
+ def test_iceberg_handler_success(self, mock_iceberg_hook_cls,
session_context_mock):
+ mock_hook = MagicMock()
+ mock_iceberg_hook_cls.return_value = mock_hook
+ mock_iceberg_table = MagicMock()
+ mock_iceberg_table.io.properties = {}
+ mock_hook.load_table.return_value = mock_iceberg_table
+ datasource_config = DataSourceConfig(
+ table_name="my_table",
+ format="iceberg",
+ conn_id="iceberg_default",
+ db_name="default",
+ )
+ handler = IcebergFormatHandler(datasource_config)
+ handler.register_data_source_format(session_context_mock)
+
+
mock_iceberg_hook_cls.assert_called_once_with(iceberg_conn_id="iceberg_default")
+ mock_hook.load_table.assert_called_once_with("default.my_table")
+
session_context_mock.register_table.assert_called_once_with("my_table",
mock_iceberg_table)
+ assert handler.get_format == FormatType.ICEBERG
+
+ @patch("airflow.providers.apache.iceberg.hooks.iceberg.IcebergHook")
+ def test_iceberg_handler_failure(self, mock_iceberg_hook_cls,
session_context_mock):
+ mock_hook = MagicMock()
+ mock_iceberg_hook_cls.return_value = mock_hook
+ mock_hook.load_table.side_effect = Exception("catalog error")
+ datasource_config = DataSourceConfig(
+ table_name="my_table", format="iceberg",
conn_id="iceberg_default", db_name="default"
+ )
+ handler = IcebergFormatHandler(datasource_config)
+ with pytest.raises(IcebergRegistrationException, match="Failed to
register Iceberg table"):
+ handler.register_data_source_format(session_context_mock)
+
+ def test_iceberg_handler_default_options(self):
+ datasource_config = DataSourceConfig(
+ table_name="my_table", format="iceberg", conn_id="iceberg_default"
+ )
+ handler = IcebergFormatHandler(datasource_config)
+ assert handler.datasource_config.options == {}
+ assert handler.datasource_config.conn_id == "iceberg_default"
+ assert handler.get_format == FormatType.ICEBERG
def test_get_format_handler(self):
- assert isinstance(get_format_handler("parquet"), ParquetFormatHandler)
- assert isinstance(get_format_handler("csv"), CsvFormatHandler)
- assert isinstance(get_format_handler("avro"), AvroFormatHandler)
+ assert isinstance(
+ get_format_handler(
+ DataSourceConfig(table_name="t", format="parquet",
conn_id="c", uri="file://u")
+ ),
+ ParquetFormatHandler,
+ )
+ assert isinstance(
+ get_format_handler(DataSourceConfig(table_name="t", format="csv",
conn_id="c", uri="file://u")),
+ CsvFormatHandler,
+ )
+ assert isinstance(
+ get_format_handler(DataSourceConfig(table_name="t", format="avro",
conn_id="c", uri="file://u")),
+ AvroFormatHandler,
+ )
+ assert isinstance(
+ get_format_handler(DataSourceConfig(table_name="t",
format="iceberg", conn_id="iceberg_default")),
+ IcebergFormatHandler,
+ )
- with pytest.raises(ValueError, match="Unsupported format"):
- get_format_handler("invalid")
+ with pytest.raises(ValueError, match="Unsupported storage type"):
Review Comment:
I think this assertion is now masking the behavior we actually want to test
in `get_format_handler`.
Because `uri` is empty here, `DataSourceConfig` raises `Unsupported storage
type` before `get_format_handler` executes. Could we use a valid URI (for
example `file://u`) and assert `Unsupported format` instead?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]