This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 1da08a25 Implement all CSV reader options (#1361)
1da08a25 is described below
commit 1da08a259d6844c8dff8c4c8c9ba1874c86f8894
Author: Tim Saucer <[email protected]>
AuthorDate: Thu Feb 5 16:48:04 2026 -0500
Implement all CSV reader options (#1361)
* Implement all CSV options with a builder pattern
* Remove unused clippy warning
* Add additional tests for csv read options
---
docs/source/user-guide/io/csv.rst | 22 +++
examples/csv-read-options.py | 96 +++++++++++++
python/datafusion/__init__.py | 3 +
python/datafusion/context.py | 118 ++++++++++++----
python/datafusion/io.py | 11 +-
python/datafusion/options.py | 284 ++++++++++++++++++++++++++++++++++++++
python/tests/test_context.py | 162 +++++++++++++++++++++-
python/tests/test_sql.py | 2 +-
src/context.rs | 76 ++--------
src/lib.rs | 5 +
src/options.rs | 142 +++++++++++++++++++
11 files changed, 822 insertions(+), 99 deletions(-)
diff --git a/docs/source/user-guide/io/csv.rst
b/docs/source/user-guide/io/csv.rst
index 144b6615..9c23c291 100644
--- a/docs/source/user-guide/io/csv.rst
+++ b/docs/source/user-guide/io/csv.rst
@@ -36,3 +36,25 @@ An alternative is to use
:py:func:`~datafusion.context.SessionContext.register_c
ctx.register_csv("file", "file.csv")
df = ctx.table("file")
+
+If you require additional control over how to read the CSV file, you can use
+:py:class:`~datafusion.options.CsvReadOptions` to set a variety of options.
+
+.. code-block:: python
+
+ from datafusion import CsvReadOptions
+ options = (
+ CsvReadOptions()
+ .with_has_header(True) # File contains a header row
+ .with_delimiter(";") # Use ; as the delimiter instead of ,
+ .with_comment("#") # Skip lines starting with #
+ .with_escape("\\") # Escape character
+ .with_null_regex(r"^(null|NULL|N/A)$") # Treat these as NULL
+ .with_truncated_rows(True) # Allow rows to have incomplete columns
+ .with_file_compression_type("gzip") # Read gzipped CSV
+ .with_file_extension(".gz") # File extension other than .csv
+ )
+ df = ctx.read_csv("data.csv.gz", options=options)
+
+Details for all CSV reading options can be found on the
+`DataFusion documentation site
<https://datafusion.apache.org/library-user-guide/custom-table-providers.html>`_.
diff --git a/examples/csv-read-options.py b/examples/csv-read-options.py
new file mode 100644
index 00000000..a5952d95
--- /dev/null
+++ b/examples/csv-read-options.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example demonstrating CsvReadOptions usage."""
+
+from datafusion import CsvReadOptions, SessionContext
+
+# Create a SessionContext
+ctx = SessionContext()
+
+# Example 1: Using CsvReadOptions with default values
+print("Example 1: Default CsvReadOptions")
+options = CsvReadOptions()
+df = ctx.read_csv("data.csv", options=options)
+
+# Example 2: Using CsvReadOptions with custom parameters
+print("\nExample 2: Custom CsvReadOptions")
+options = CsvReadOptions(
+ has_header=True,
+ delimiter=",",
+ quote='"',
+ schema_infer_max_records=1000,
+ file_extension=".csv",
+)
+df = ctx.read_csv("data.csv", options=options)
+
+# Example 3: Using the builder pattern (recommended for readability)
+print("\nExample 3: Builder pattern")
+options = (
+ CsvReadOptions()
+ .with_has_header(True) # noqa: FBT003
+ .with_delimiter("|")
+ .with_quote("'")
+ .with_schema_infer_max_records(500)
+ .with_truncated_rows(False) # noqa: FBT003
+ .with_newlines_in_values(True) # noqa: FBT003
+)
+df = ctx.read_csv("data.csv", options=options)
+
+# Example 4: Advanced options
+print("\nExample 4: Advanced options")
+options = (
+ CsvReadOptions()
+ .with_has_header(True) # noqa: FBT003
+ .with_delimiter(",")
+ .with_comment("#") # Skip lines starting with #
+ .with_escape("\\") # Escape character
+ .with_null_regex(r"^(null|NULL|N/A)$") # Treat these as NULL
+ .with_truncated_rows(True) # noqa: FBT003
+ .with_file_compression_type("gzip") # Read gzipped CSV
+ .with_file_extension(".gz")
+)
+df = ctx.read_csv("data.csv.gz", options=options)
+
+# Example 5: Register CSV table with options
+print("\nExample 5: Register CSV table")
+options = CsvReadOptions().with_has_header(True).with_delimiter(",") # noqa:
FBT003
+ctx.register_csv("my_table", "data.csv", options=options)
+df = ctx.sql("SELECT * FROM my_table")
+
+# Example 6: Backward compatibility (without options)
+print("\nExample 6: Backward compatibility")
+# Still works the old way!
+df = ctx.read_csv("data.csv", has_header=True, delimiter=",")
+
+print("\nAll examples completed!")
+print("\nFor all available options, see the CsvReadOptions documentation:")
+print(" - has_header: bool")
+print(" - delimiter: str")
+print(" - quote: str")
+print(" - terminator: str | None")
+print(" - escape: str | None")
+print(" - comment: str | None")
+print(" - newlines_in_values: bool")
+print(" - schema: pa.Schema | None")
+print(" - schema_infer_max_records: int")
+print(" - file_extension: str")
+print(" - table_partition_cols: list[tuple[str, pa.DataType]]")
+print(" - file_compression_type: str")
+print(" - file_sort_order: list[list[SortExpr]]")
+print(" - null_regex: str | None")
+print(" - truncated_rows: bool")
diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index 784d4ccc..2e6f8116 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -54,6 +54,7 @@ from .dataframe import (
from .dataframe_formatter import configure_formatter
from .expr import Expr, WindowFrame
from .io import read_avro, read_csv, read_json, read_parquet
+from .options import CsvReadOptions
from .plan import ExecutionPlan, LogicalPlan
from .record_batch import RecordBatch, RecordBatchStream
from .user_defined import (
@@ -75,6 +76,7 @@ __all__ = [
"AggregateUDF",
"Catalog",
"Config",
+ "CsvReadOptions",
"DFSchema",
"DataFrame",
"DataFrameWriteOptions",
@@ -106,6 +108,7 @@ __all__ = [
"lit",
"literal",
"object_store",
+ "options",
"read_avro",
"read_csv",
"read_json",
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index be647fef..7b92c082 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -34,6 +34,11 @@ import pyarrow as pa
from datafusion.catalog import Catalog
from datafusion.dataframe import DataFrame
from datafusion.expr import sort_list_to_raw_sort_list
+from datafusion.options import (
+ DEFAULT_MAX_INFER_SCHEMA,
+ CsvReadOptions,
+ _convert_table_partition_cols,
+)
from datafusion.record_batch import RecordBatchStream
from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
@@ -584,7 +589,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ table_partition_cols =
_convert_table_partition_cols(table_partition_cols)
self.ctx.register_listing_table(
name,
str(path),
@@ -905,7 +910,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ table_partition_cols =
_convert_table_partition_cols(table_partition_cols)
self.ctx.register_parquet(
name,
str(path),
@@ -924,9 +929,10 @@ class SessionContext:
schema: pa.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
- schema_infer_max_records: int = 1000,
+ schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA,
file_extension: str = ".csv",
file_compression_type: str | None = None,
+ options: CsvReadOptions | None = None,
) -> None:
"""Register a CSV file as a table.
@@ -946,18 +952,46 @@ class SessionContext:
file_extension: File extension; only files with this extension are
selected for data input.
file_compression_type: File compression type.
+ options: Set advanced options for CSV reading. This cannot be
+ combined with any of the other options in this method.
"""
- path = [str(p) for p in path] if isinstance(path, list) else str(path)
+ path_arg = [str(p) for p in path] if isinstance(path, list) else
str(path)
+
+ if options is not None and (
+ schema is not None
+ or not has_header
+ or delimiter != ","
+ or schema_infer_max_records != DEFAULT_MAX_INFER_SCHEMA
+ or file_extension != ".csv"
+ or file_compression_type is not None
+ ):
+ message = (
+ "Combining CsvReadOptions parameter with additional options "
+ "is not supported. Use CsvReadOptions to set parameters."
+ )
+ warnings.warn(
+ message,
+ category=UserWarning,
+ stacklevel=2,
+ )
+
+ options = (
+ options
+ if options is not None
+ else CsvReadOptions(
+ schema=schema,
+ has_header=has_header,
+ delimiter=delimiter,
+ schema_infer_max_records=schema_infer_max_records,
+ file_extension=file_extension,
+ file_compression_type=file_compression_type,
+ )
+ )
self.ctx.register_csv(
name,
- path,
- schema,
- has_header,
- delimiter,
- schema_infer_max_records,
- file_extension,
- file_compression_type,
+ path_arg,
+ options.to_inner(),
)
def register_json(
@@ -988,7 +1022,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ table_partition_cols =
_convert_table_partition_cols(table_partition_cols)
self.ctx.register_json(
name,
str(path),
@@ -1021,7 +1055,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ table_partition_cols =
_convert_table_partition_cols(table_partition_cols)
self.ctx.register_avro(
name, str(path), schema, file_extension, table_partition_cols
)
@@ -1101,7 +1135,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ table_partition_cols =
_convert_table_partition_cols(table_partition_cols)
return DataFrame(
self.ctx.read_json(
str(path),
@@ -1119,10 +1153,11 @@ class SessionContext:
schema: pa.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
- schema_infer_max_records: int = 1000,
+ schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA,
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str | pa.DataType]] | None =
None,
file_compression_type: str | None = None,
+ options: CsvReadOptions | None = None,
) -> DataFrame:
"""Read a CSV data source.
@@ -1140,26 +1175,51 @@ class SessionContext:
selected for data input.
table_partition_cols: Partition columns.
file_compression_type: File compression type.
+ options: Set advanced options for CSV reading. This cannot be
+ combined with any of the other options in this method.
Returns:
DataFrame representation of the read CSV files
"""
- if table_partition_cols is None:
- table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ path_arg = [str(p) for p in path] if isinstance(path, list) else
str(path)
+
+ if options is not None and (
+ schema is not None
+ or not has_header
+ or delimiter != ","
+ or schema_infer_max_records != DEFAULT_MAX_INFER_SCHEMA
+ or file_extension != ".csv"
+ or table_partition_cols is not None
+ or file_compression_type is not None
+ ):
+ message = (
+ "Combining CsvReadOptions parameter with additional options "
+ "is not supported. Use CsvReadOptions to set parameters."
+ )
+ warnings.warn(
+ message,
+ category=UserWarning,
+ stacklevel=2,
+ )
- path = [str(p) for p in path] if isinstance(path, list) else str(path)
+ options = (
+ options
+ if options is not None
+ else CsvReadOptions(
+ schema=schema,
+ has_header=has_header,
+ delimiter=delimiter,
+ schema_infer_max_records=schema_infer_max_records,
+ file_extension=file_extension,
+ table_partition_cols=table_partition_cols,
+ file_compression_type=file_compression_type,
+ )
+ )
return DataFrame(
self.ctx.read_csv(
- path,
- schema,
- has_header,
- delimiter,
- schema_infer_max_records,
- file_extension,
- table_partition_cols,
- file_compression_type,
+ path_arg,
+ options.to_inner(),
)
)
@@ -1197,7 +1257,7 @@ class SessionContext:
"""
if table_partition_cols is None:
table_partition_cols = []
- table_partition_cols =
self._convert_table_partition_cols(table_partition_cols)
+ table_partition_cols =
_convert_table_partition_cols(table_partition_cols)
file_sort_order = self._convert_file_sort_order(file_sort_order)
return DataFrame(
self.ctx.read_parquet(
@@ -1231,7 +1291,7 @@ class SessionContext:
"""
if file_partition_cols is None:
file_partition_cols = []
- file_partition_cols =
self._convert_table_partition_cols(file_partition_cols)
+ file_partition_cols =
_convert_table_partition_cols(file_partition_cols)
return DataFrame(
self.ctx.read_avro(str(path), schema, file_partition_cols,
file_extension)
)
diff --git a/python/datafusion/io.py b/python/datafusion/io.py
index 67dbc730..4f9c3c51 100644
--- a/python/datafusion/io.py
+++ b/python/datafusion/io.py
@@ -31,6 +31,8 @@ if TYPE_CHECKING:
from datafusion.dataframe import DataFrame
from datafusion.expr import Expr
+ from .options import CsvReadOptions
+
def read_parquet(
path: str | pathlib.Path,
@@ -126,6 +128,7 @@ def read_csv(
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
file_compression_type: str | None = None,
+ options: CsvReadOptions | None = None,
) -> DataFrame:
"""Read a CSV data source.
@@ -147,15 +150,12 @@ def read_csv(
selected for data input.
table_partition_cols: Partition columns.
file_compression_type: File compression type.
+ options: Set advanced options for CSV reading. This cannot be
+ combined with any of the other options in this method.
Returns:
DataFrame representation of the read CSV files
"""
- if table_partition_cols is None:
- table_partition_cols = []
-
- path = [str(p) for p in path] if isinstance(path, list) else str(path)
-
return SessionContext.global_ctx().read_csv(
path,
schema,
@@ -165,6 +165,7 @@ def read_csv(
file_extension,
table_partition_cols,
file_compression_type,
+ options,
)
diff --git a/python/datafusion/options.py b/python/datafusion/options.py
new file mode 100644
index 00000000..ec19f37d
--- /dev/null
+++ b/python/datafusion/options.py
@@ -0,0 +1,284 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Options for reading various file formats."""
+
+from __future__ import annotations
+
+import warnings
+from typing import TYPE_CHECKING
+
+import pyarrow as pa
+
+from datafusion.expr import sort_list_to_raw_sort_list
+
+if TYPE_CHECKING:
+ from datafusion.expr import SortExpr
+
+from ._internal import options
+
+__all__ = ["CsvReadOptions"]
+
+DEFAULT_MAX_INFER_SCHEMA = 1000
+
+
+class CsvReadOptions:
+ """Options for reading CSV files.
+
+ This class provides a builder pattern for configuring CSV reading options.
+ All methods starting with ``with_`` return ``self`` to allow method
chaining.
+ """
+
+ def __init__(
+ self,
+ *,
+ has_header: bool = True,
+ delimiter: str = ",",
+ quote: str = '"',
+ terminator: str | None = None,
+ escape: str | None = None,
+ comment: str | None = None,
+ newlines_in_values: bool = False,
+ schema: pa.Schema | None = None,
+ schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA,
+ file_extension: str = ".csv",
+ table_partition_cols: list[tuple[str, pa.DataType]] | None = None,
+ file_compression_type: str = "",
+ file_sort_order: list[list[SortExpr]] | None = None,
+ null_regex: str | None = None,
+ truncated_rows: bool = False,
+ ) -> None:
+ """Initialize CsvReadOptions.
+
+ Args:
+ has_header: Does the CSV file have a header row? If schema
inference
+ is run on a file with no headers, default column names are
created.
+ delimiter: Column delimiter character. Must be a single ASCII
character.
+ quote: Quote character for fields containing delimiters or
newlines.
+ Must be a single ASCII character.
+ terminator: Optional line terminator character. If ``None``, uses
CRLF.
+ Must be a single ASCII character.
+ escape: Optional escape character for quotes. Must be a single
ASCII
+ character.
+ comment: If specified, lines beginning with this character are
ignored.
+ Must be a single ASCII character.
+ newlines_in_values: Whether newlines in quoted values are
supported.
+ Parsing newlines in quoted values may be affected by execution
+ behavior such as parallel file scanning. Setting this to
``True``
+ ensures that newlines in values are parsed successfully, which
may
+ reduce performance.
+ schema: Optional PyArrow schema representing the CSV files. If
``None``,
+ the CSV reader will try to infer it based on data in the file.
+ schema_infer_max_records: Maximum number of rows to read from CSV
files
+ for schema inference if needed.
+ file_extension: File extension; only files with this extension are
+ selected for data input.
+ table_partition_cols: Partition columns as a list of tuples of
+ (column_name, data_type).
+ file_compression_type: File compression type. Supported values are
+ ``"gzip"``, ``"bz2"``, ``"xz"``, ``"zstd"``, or empty string
for
+ uncompressed.
+ file_sort_order: Optional sort order of the files as a list of sort
+ expressions per file.
+ null_regex: Optional regex pattern to match null values in the CSV.
+ truncated_rows: Whether to allow truncated rows when parsing. By
default
+ this is ``False`` and will error if the CSV rows have different
+ lengths. When set to ``True``, it will allow records with less
than
+ the expected number of columns and fill the missing columns
with
+ nulls. If the record's schema is not nullable, it will still
return
+ an error.
+ """
+ validate_single_character("delimiter", delimiter)
+ validate_single_character("quote", quote)
+ validate_single_character("terminator", terminator)
+ validate_single_character("escape", escape)
+ validate_single_character("comment", comment)
+
+ self.has_header = has_header
+ self.delimiter = delimiter
+ self.quote = quote
+ self.terminator = terminator
+ self.escape = escape
+ self.comment = comment
+ self.newlines_in_values = newlines_in_values
+ self.schema = schema
+ self.schema_infer_max_records = schema_infer_max_records
+ self.file_extension = file_extension
+ self.table_partition_cols = table_partition_cols or []
+ self.file_compression_type = file_compression_type
+ self.file_sort_order = file_sort_order or []
+ self.null_regex = null_regex
+ self.truncated_rows = truncated_rows
+
+ def with_has_header(self, has_header: bool) -> CsvReadOptions:
+ """Configure whether the CSV has a header row."""
+ self.has_header = has_header
+ return self
+
+ def with_delimiter(self, delimiter: str) -> CsvReadOptions:
+ """Configure the column delimiter."""
+ self.delimiter = delimiter
+ return self
+
+ def with_quote(self, quote: str) -> CsvReadOptions:
+ """Configure the quote character."""
+ self.quote = quote
+ return self
+
+ def with_terminator(self, terminator: str | None) -> CsvReadOptions:
+ """Configure the line terminator character."""
+ self.terminator = terminator
+ return self
+
+ def with_escape(self, escape: str | None) -> CsvReadOptions:
+ """Configure the escape character."""
+ self.escape = escape
+ return self
+
+ def with_comment(self, comment: str | None) -> CsvReadOptions:
+ """Configure the comment character."""
+ self.comment = comment
+ return self
+
+ def with_newlines_in_values(self, newlines_in_values: bool) ->
CsvReadOptions:
+ """Configure whether newlines in values are supported."""
+ self.newlines_in_values = newlines_in_values
+ return self
+
+ def with_schema(self, schema: pa.Schema | None) -> CsvReadOptions:
+ """Configure the schema."""
+ self.schema = schema
+ return self
+
+ def with_schema_infer_max_records(
+ self, schema_infer_max_records: int
+ ) -> CsvReadOptions:
+ """Configure maximum records for schema inference."""
+ self.schema_infer_max_records = schema_infer_max_records
+ return self
+
+ def with_file_extension(self, file_extension: str) -> CsvReadOptions:
+ """Configure the file extension filter."""
+ self.file_extension = file_extension
+ return self
+
+ def with_table_partition_cols(
+ self, table_partition_cols: list[tuple[str, pa.DataType]]
+ ) -> CsvReadOptions:
+ """Configure table partition columns."""
+ self.table_partition_cols = table_partition_cols
+ return self
+
+ def with_file_compression_type(self, file_compression_type: str) ->
CsvReadOptions:
+ """Configure file compression type."""
+ self.file_compression_type = file_compression_type
+ return self
+
+ def with_file_sort_order(
+ self, file_sort_order: list[list[SortExpr]]
+ ) -> CsvReadOptions:
+ """Configure file sort order."""
+ self.file_sort_order = file_sort_order
+ return self
+
+ def with_null_regex(self, null_regex: str | None) -> CsvReadOptions:
+ """Configure null value regex pattern."""
+ self.null_regex = null_regex
+ return self
+
+ def with_truncated_rows(self, truncated_rows: bool) -> CsvReadOptions:
+ """Configure whether to allow truncated rows."""
+ self.truncated_rows = truncated_rows
+ return self
+
+ def to_inner(self) -> options.CsvReadOptions:
+ """Convert this object into the underlying Rust structure.
+
+ This is intended for internal use only.
+ """
+ file_sort_order = (
+ []
+ if self.file_sort_order is None
+ else [
+ sort_list_to_raw_sort_list(sort_list)
+ for sort_list in self.file_sort_order
+ ]
+ )
+
+ return options.CsvReadOptions(
+ has_header=self.has_header,
+ delimiter=ord(self.delimiter[0]) if self.delimiter else ord(","),
+ quote=ord(self.quote[0]) if self.quote else ord('"'),
+ terminator=ord(self.terminator[0]) if self.terminator else None,
+ escape=ord(self.escape[0]) if self.escape else None,
+ comment=ord(self.comment[0]) if self.comment else None,
+ newlines_in_values=self.newlines_in_values,
+ schema=self.schema,
+ schema_infer_max_records=self.schema_infer_max_records,
+ file_extension=self.file_extension,
+ table_partition_cols=_convert_table_partition_cols(
+ self.table_partition_cols
+ ),
+ file_compression_type=self.file_compression_type or "",
+ file_sort_order=file_sort_order,
+ null_regex=self.null_regex,
+ truncated_rows=self.truncated_rows,
+ )
+
+
+def validate_single_character(name: str, value: str | None) -> None:
+ if value is not None and len(value) != 1:
+ message = f"{name} must be a single character"
+ raise ValueError(message)
+
+
+def _convert_table_partition_cols(
+ table_partition_cols: list[tuple[str, str | pa.DataType]],
+) -> list[tuple[str, pa.DataType]]:
+ warn = False
+ converted_table_partition_cols = []
+
+ for col, data_type in table_partition_cols:
+ if isinstance(data_type, str):
+ warn = True
+ if data_type == "string":
+ converted_data_type = pa.string()
+ elif data_type == "int":
+ converted_data_type = pa.int32()
+ else:
+ message = (
+ f"Unsupported literal data type '{data_type}' for
partition "
+ "column. Supported types are 'string' and 'int'"
+ )
+ raise ValueError(message)
+ else:
+ converted_data_type = data_type
+
+ converted_table_partition_cols.append((col, converted_data_type))
+
+ if warn:
+ message = (
+ "using literals for table_partition_cols data types is deprecated,"
+ "use pyarrow types instead"
+ )
+ warnings.warn(
+ message,
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return converted_table_partition_cols
diff --git a/python/tests/test_context.py b/python/tests/test_context.py
index bd65305e..5853f9fe 100644
--- a/python/tests/test_context.py
+++ b/python/tests/test_context.py
@@ -22,6 +22,7 @@ import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion import (
+ CsvReadOptions,
DataFrame,
RuntimeEnvBuilder,
SessionConfig,
@@ -626,6 +627,8 @@ def test_read_csv_list(ctx):
def test_read_csv_compressed(ctx, tmp_path):
test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv")
+ expected = ctx.read_csv(test_data_path).collect()
+
# File compression type
gzip_path = tmp_path / "aggregate_test_100.csv.gz"
@@ -636,7 +639,13 @@ def test_read_csv_compressed(ctx, tmp_path):
gzipped_file.writelines(csv_file)
csv_df = ctx.read_csv(gzip_path, file_extension=".gz",
file_compression_type="gz")
- csv_df.select(column("c1")).show()
+ assert csv_df.collect() == expected
+
+ csv_df = ctx.read_csv(
+ gzip_path,
+ options=CsvReadOptions(file_extension=".gz",
file_compression_type="gz"),
+ )
+ assert csv_df.collect() == expected
def test_read_parquet(ctx):
@@ -710,3 +719,154 @@ def test_create_dataframe_with_global_ctx(batch):
result = df.collect()[0].column(0)
assert result == pa.array([4, 5, 6])
+
+
+def test_csv_read_options_builder_pattern():
+ """Test CsvReadOptions builder pattern."""
+ from datafusion import CsvReadOptions
+
+ options = (
+ CsvReadOptions()
+ .with_has_header(False) # noqa: FBT003
+ .with_delimiter("|")
+ .with_quote("'")
+ .with_schema_infer_max_records(2000)
+ .with_truncated_rows(True) # noqa: FBT003
+ .with_newlines_in_values(True) # noqa: FBT003
+ .with_file_extension(".tsv")
+ )
+ assert options.has_header is False
+ assert options.delimiter == "|"
+ assert options.quote == "'"
+ assert options.schema_infer_max_records == 2000
+ assert options.truncated_rows is True
+ assert options.newlines_in_values is True
+ assert options.file_extension == ".tsv"
+
+
+def read_csv_with_options_inner(
+ tmp_path: pathlib.Path,
+ csv_content: str,
+ options: CsvReadOptions,
+ expected: pa.RecordBatch,
+ as_read: bool,
+ global_ctx: bool,
+) -> None:
+ from datafusion import SessionContext
+
+ # Create a test CSV file
+ group_dir = tmp_path / "group=a"
+ group_dir.mkdir(exist_ok=True)
+
+ csv_path = group_dir / "test.csv"
+ csv_path.write_text(csv_content)
+
+ ctx = SessionContext()
+
+ if as_read:
+ if global_ctx:
+ from datafusion.io import read_csv
+
+ df = read_csv(str(tmp_path), options=options)
+ else:
+ df = ctx.read_csv(str(tmp_path), options=options)
+ else:
+ ctx.register_csv("test_table", str(tmp_path), options=options)
+ df = ctx.sql("SELECT * FROM test_table")
+ df.show()
+
+ # Verify the data
+ result = df.collect()
+ assert len(result) == 1
+ assert result[0] == expected
+
+
[email protected](
+ ("as_read", "global_ctx"),
+ [
+ (True, True),
+ (True, False),
+ (False, False),
+ ],
+)
+def test_read_csv_with_options(tmp_path, as_read, global_ctx):
+ """Test reading CSV with CsvReadOptions."""
+
+ csv_content = "Alice;30;|New York;
NY|\nBob;25\n#Charlie;35;Paris\nPhil;75;Detroit'
MI\nKarin;50;|Stockholm\nSweden|" # noqa: E501
+
+ # Some of the read options are difficult to test in combination
+ # such as schema and schema_infer_max_records so run multiple tests
+ # file_sort_order doesn't impact reading, but included here to ensure
+ # all options parse correctly
+ options = CsvReadOptions(
+ has_header=False,
+ delimiter=";",
+ quote="|",
+ terminator="\n",
+ escape="\\",
+ comment="#",
+ newlines_in_values=True,
+ schema_infer_max_records=1,
+ null_regex="[pP]+aris",
+ truncated_rows=True,
+ file_sort_order=[[column("column_1").sort(), column("column_2")],
["column_3"]],
+ )
+
+ expected = pa.RecordBatch.from_arrays(
+ [
+ pa.array(["Alice", "Bob", "Phil", "Karin"]),
+ pa.array([30, 25, 75, 50]),
+ pa.array(["New York; NY", None, "Detroit' MI",
"Stockholm\nSweden"]),
+ ],
+ names=["column_1", "column_2", "column_3"],
+ )
+
+ read_csv_with_options_inner(
+ tmp_path, csv_content, options, expected, as_read, global_ctx
+ )
+
+ schema = pa.schema(
+ [
+ pa.field("name", pa.string(), nullable=False),
+ pa.field("age", pa.float32(), nullable=False),
+ pa.field("location", pa.string(), nullable=True),
+ ]
+ )
+ options.with_schema(schema)
+
+ expected = pa.RecordBatch.from_arrays(
+ [
+ pa.array(["Alice", "Bob", "Phil", "Karin"]),
+ pa.array([30.0, 25.0, 75.0, 50.0]),
+ pa.array(["New York; NY", None, "Detroit' MI",
"Stockholm\nSweden"]),
+ ],
+ schema=schema,
+ )
+
+ read_csv_with_options_inner(
+ tmp_path, csv_content, options, expected, as_read, global_ctx
+ )
+
+ csv_content = "name,age\nAlice,30\nBob,25\nCharlie,35\nDiego,40\nEmily,15"
+
+ expected = pa.RecordBatch.from_arrays(
+ [
+ pa.array(["Alice", "Bob", "Charlie", "Diego", "Emily"]),
+ pa.array([30, 25, 35, 40, 15]),
+ pa.array(["a", "a", "a", "a", "a"]),
+ ],
+ schema=pa.schema(
+ [
+ pa.field("name", pa.string(), nullable=True),
+ pa.field("age", pa.int64(), nullable=True),
+ pa.field("group", pa.string(), nullable=False),
+ ]
+ ),
+ )
+ options = CsvReadOptions(
+ table_partition_cols=[("group", pa.string())],
+ )
+
+ read_csv_with_options_inner(
+ tmp_path, csv_content, options, expected, as_read, global_ctx
+ )
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 85afd021..48c37466 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -92,7 +92,7 @@ def test_register_csv(ctx, tmp_path):
result = pa.Table.from_batches(result)
assert result.schema == alternative_schema
- with pytest.raises(ValueError, match="Delimiter must be a single
character"):
+ with pytest.raises(ValueError, match="delimiter must be a single
character"):
ctx.register_csv("csv4", path, delimiter="wrong")
with pytest.raises(
diff --git a/src/context.rs b/src/context.rs
index 1cd04ac2..f28c5982 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -64,9 +64,9 @@ use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::expr::sort_expr::PySortExpr;
+use crate::options::PyCsvReadOptions;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
-use crate::sql::exceptions::py_value_err;
use crate::sql::logical::PyLogicalPlan;
use crate::sql::util::replace_placeholders_with_strings;
use crate::store::StorageContexts;
@@ -724,41 +724,20 @@ impl PySessionContext {
Ok(())
}
- #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name,
path,
- schema=None,
- has_header=true,
- delimiter=",",
- schema_infer_max_records=1000,
- file_extension=".csv",
- file_compression_type=None))]
+ options=None))]
pub fn register_csv(
&self,
name: &str,
path: &Bound<'_, PyAny>,
- schema: Option<PyArrowType<Schema>>,
- has_header: bool,
- delimiter: &str,
- schema_infer_max_records: usize,
- file_extension: &str,
- file_compression_type: Option<String>,
+ options: Option<&PyCsvReadOptions>,
py: Python,
) -> PyDataFusionResult<()> {
- let delimiter = delimiter.as_bytes();
- if delimiter.len() != 1 {
- return Err(PyDataFusionError::PythonError(py_value_err(
- "Delimiter must be a single character",
- )));
- }
-
- let mut options = CsvReadOptions::new()
- .has_header(has_header)
- .delimiter(delimiter[0])
- .schema_infer_max_records(schema_infer_max_records)
- .file_extension(file_extension)
-
.file_compression_type(parse_file_compression_type(file_compression_type)?);
- options.schema = schema.as_ref().map(|x| &x.0);
+ let options = options
+ .map(|opts| opts.try_into())
+ .transpose()?
+ .unwrap_or_default();
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
@@ -978,48 +957,19 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}
- #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
path,
- schema=None,
- has_header=true,
- delimiter=",",
- schema_infer_max_records=1000,
- file_extension=".csv",
- table_partition_cols=vec![],
- file_compression_type=None))]
+ options=None))]
pub fn read_csv(
&self,
path: &Bound<'_, PyAny>,
- schema: Option<PyArrowType<Schema>>,
- has_header: bool,
- delimiter: &str,
- schema_infer_max_records: usize,
- file_extension: &str,
- table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
- file_compression_type: Option<String>,
+ options: Option<&PyCsvReadOptions>,
py: Python,
) -> PyDataFusionResult<PyDataFrame> {
- let delimiter = delimiter.as_bytes();
- if delimiter.len() != 1 {
- return Err(PyDataFusionError::PythonError(py_value_err(
- "Delimiter must be a single character",
- )));
- };
-
- let mut options = CsvReadOptions::new()
- .has_header(has_header)
- .delimiter(delimiter[0])
- .schema_infer_max_records(schema_infer_max_records)
- .file_extension(file_extension)
- .table_partition_cols(
- table_partition_cols
- .into_iter()
- .map(|(name, ty)| (name, ty.0))
- .collect::<Vec<(String, DataType)>>(),
- )
-
.file_compression_type(parse_file_compression_type(file_compression_type)?);
- options.schema = schema.as_ref().map(|x| &x.0);
+ let options = options
+ .map(|opts| opts.try_into())
+ .transpose()?
+ .unwrap_or_default();
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
diff --git a/src/lib.rs b/src/lib.rs
index eda50fe1..081366b2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -43,6 +43,7 @@ pub mod errors;
pub mod expr;
#[allow(clippy::borrow_deref_ref)]
mod functions;
+mod options;
pub mod physical_plan;
mod pyarrow_filter_expression;
pub mod pyarrow_util;
@@ -126,6 +127,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) ->
PyResult<()> {
store::init_module(&store)?;
m.add_submodule(&store)?;
+ let options = PyModule::new(py, "options")?;
+ options::init_module(&options)?;
+ m.add_submodule(&options)?;
+
// Register substrait as a submodule
#[cfg(feature = "substrait")]
setup_substrait_module(py, &m)?;
diff --git a/src/options.rs b/src/options.rs
new file mode 100644
index 00000000..a37664b2
--- /dev/null
+++ b/src/options.rs
@@ -0,0 +1,142 @@
+use arrow::datatypes::{DataType, Schema};
+use arrow::pyarrow::PyArrowType;
+use datafusion::prelude::CsvReadOptions;
+use pyo3::prelude::{PyModule, PyModuleMethods};
+use pyo3::{pyclass, pymethods, Bound, PyResult};
+
+use crate::context::parse_file_compression_type;
+use crate::errors::PyDataFusionError;
+use crate::expr::sort_expr::PySortExpr;
+
+/// Options for reading CSV files
+#[pyclass(name = "CsvReadOptions", module = "datafusion.options", frozen)]
+pub struct PyCsvReadOptions {
+ pub has_header: bool,
+ pub delimiter: u8,
+ pub quote: u8,
+ pub terminator: Option<u8>,
+ pub escape: Option<u8>,
+ pub comment: Option<u8>,
+ pub newlines_in_values: bool,
+ pub schema: Option<PyArrowType<Schema>>,
+ pub schema_infer_max_records: usize,
+ pub file_extension: String,
+ pub table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
+ pub file_compression_type: String,
+ pub file_sort_order: Vec<Vec<PySortExpr>>,
+ pub null_regex: Option<String>,
+ pub truncated_rows: bool,
+}
+
+#[pymethods]
+impl PyCsvReadOptions {
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (
+ has_header=true,
+ delimiter=b',',
+ quote=b'"',
+ terminator=None,
+ escape=None,
+ comment=None,
+ newlines_in_values=false,
+ schema=None,
+ schema_infer_max_records=1000,
+ file_extension=".csv".to_string(),
+ table_partition_cols=vec![],
+ file_compression_type="".to_string(),
+ file_sort_order=vec![],
+ null_regex=None,
+ truncated_rows=false
+ ))]
+ #[new]
+ fn new(
+ has_header: bool,
+ delimiter: u8,
+ quote: u8,
+ terminator: Option<u8>,
+ escape: Option<u8>,
+ comment: Option<u8>,
+ newlines_in_values: bool,
+ schema: Option<PyArrowType<Schema>>,
+ schema_infer_max_records: usize,
+ file_extension: String,
+ table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
+ file_compression_type: String,
+ file_sort_order: Vec<Vec<PySortExpr>>,
+ null_regex: Option<String>,
+ truncated_rows: bool,
+ ) -> Self {
+ Self {
+ has_header,
+ delimiter,
+ quote,
+ terminator,
+ escape,
+ comment,
+ newlines_in_values,
+ schema,
+ schema_infer_max_records,
+ file_extension,
+ table_partition_cols,
+ file_compression_type,
+ file_sort_order,
+ null_regex,
+ truncated_rows,
+ }
+ }
+}
+
+impl<'a> TryFrom<&'a PyCsvReadOptions> for CsvReadOptions<'a> {
+ type Error = PyDataFusionError;
+
+ fn try_from(value: &'a PyCsvReadOptions) -> Result<CsvReadOptions<'a>,
Self::Error> {
+ let partition_cols: Vec<(String, DataType)> = value
+ .table_partition_cols
+ .iter()
+ .map(|(name, dtype)| (name.clone(), dtype.0.clone()))
+ .collect();
+
+ let compression =
parse_file_compression_type(Some(value.file_compression_type.clone()))?;
+
+ let sort_order: Vec<Vec<datafusion::logical_expr::SortExpr>> = value
+ .file_sort_order
+ .iter()
+ .map(|inner| {
+ inner
+ .iter()
+ .map(|sort_expr| sort_expr.sort.clone())
+ .collect()
+ })
+ .collect();
+
+ // Explicit struct initialization to catch upstream changes
+ let mut options = CsvReadOptions {
+ has_header: value.has_header,
+ delimiter: value.delimiter,
+ quote: value.quote,
+ terminator: value.terminator,
+ escape: value.escape,
+ comment: value.comment,
+ newlines_in_values: value.newlines_in_values,
+ schema: None, // Will be set separately due to lifetime constraints
+ schema_infer_max_records: value.schema_infer_max_records,
+ file_extension: value.file_extension.as_str(),
+ table_partition_cols: partition_cols,
+ file_compression_type: compression,
+ file_sort_order: sort_order,
+ null_regex: value.null_regex.clone(),
+ truncated_rows: value.truncated_rows,
+ };
+
+ // Set schema separately to handle the lifetime
+ options.schema = value.schema.as_ref().map(|s| &s.0);
+
+ Ok(options)
+ }
+}
+
+pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
+ m.add_class::<PyCsvReadOptions>()?;
+
+ Ok(())
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]