This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new feda987 ARROW-9333: [Python] Expose more IPC options
feda987 is described below
commit feda9877f8145aebf907c61a24640735a968a230
Author: Antoine Pitrou <[email protected]>
AuthorDate: Mon Jul 13 12:49:07 2020 -0500
ARROW-9333: [Python] Expose more IPC options
Also make some optional arguments keyword-only.
Closes #7730 from pitrou/ARROW-9333-py-ipc-options
Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Wes McKinney <[email protected]>
---
cpp/src/arrow/ipc/options.h | 7 ++-
python/pyarrow/_flight.pyx | 6 +--
python/pyarrow/includes/libarrow.pxd | 2 +
python/pyarrow/io.pxi | 29 +++++++++--
python/pyarrow/ipc.pxi | 55 ++++++++++++++++++---
python/pyarrow/ipc.py | 15 +++---
python/pyarrow/tests/test_flight.py | 6 +++
python/pyarrow/tests/test_ipc.py | 95 ++++++++++++++++++++++++------------
python/pyarrow/tests/util.py | 16 ++++++
9 files changed, 174 insertions(+), 57 deletions(-)
diff --git a/cpp/src/arrow/ipc/options.h b/cpp/src/arrow/ipc/options.h
index 69e248c..6bbd7b8 100644
--- a/cpp/src/arrow/ipc/options.h
+++ b/cpp/src/arrow/ipc/options.h
@@ -56,10 +56,9 @@ struct ARROW_EXPORT IpcWriteOptions {
/// \brief The memory pool to use for allocations made during IPC writing
MemoryPool* memory_pool = default_memory_pool();
- /// \brief EXPERIMENTAL: Codec to use for compressing and decompressing
- /// record batch body buffers. This is not part of the Arrow IPC protocol and
- /// only for internal use (e.g. Feather files). May only be LZ4_FRAME and
- /// ZSTD
+ /// \brief Compression codec to use for record batch body buffers
+ ///
+ /// May only be UNCOMPRESSED, LZ4_FRAME and ZSTD.
Compression::type compression = Compression::UNCOMPRESSED;
int compression_level = Compression::kUseDefaultCompressionLevel;
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 7e3c837..7b6b281 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -97,10 +97,8 @@ def _munge_grpc_python_error(message):
cdef IpcWriteOptions _get_options(options):
- cdef IpcWriteOptions write_options = \
- <IpcWriteOptions> _get_legacy_format_default(
- use_legacy_format=None, options=options)
- return write_options
+ return <IpcWriteOptions> _get_legacy_format_default(
+ use_legacy_format=None, options=options)
cdef class FlightCallOptions:
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 76203f0..3e461c4 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1329,6 +1329,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc"
nogil:
c_bool write_legacy_ipc_format
CMemoryPool* memory_pool
CMetadataVersion metadata_version
+ CCompressionType compression
+ c_bool use_threads
@staticmethod
CIpcWriteOptions Defaults()
diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi
index 76a058d..058b09a 100644
--- a/python/pyarrow/io.pxi
+++ b/python/pyarrow/io.pxi
@@ -1539,24 +1539,43 @@ def _detect_compression(path):
cdef CCompressionType _ensure_compression(str name) except *:
uppercase = name.upper()
- if uppercase == 'GZIP':
- return CCompressionType_GZIP
- elif uppercase == 'BZ2':
+ if uppercase == 'BZ2':
return CCompressionType_BZ2
+ elif uppercase == 'GZIP':
+ return CCompressionType_GZIP
elif uppercase == 'BROTLI':
return CCompressionType_BROTLI
elif uppercase == 'LZ4' or uppercase == 'LZ4_FRAME':
return CCompressionType_LZ4_FRAME
elif uppercase == 'LZ4_RAW':
return CCompressionType_LZ4
- elif uppercase == 'ZSTD':
- return CCompressionType_ZSTD
elif uppercase == 'SNAPPY':
return CCompressionType_SNAPPY
+ elif uppercase == 'ZSTD':
+ return CCompressionType_ZSTD
else:
raise ValueError('Invalid value for compression: {!r}'.format(name))
+cdef str _compression_name(CCompressionType ctype):
+ if ctype == CCompressionType_GZIP:
+ return 'gzip'
+ elif ctype == CCompressionType_BROTLI:
+ return 'brotli'
+ elif ctype == CCompressionType_BZ2:
+ return 'bz2'
+ elif ctype == CCompressionType_LZ4_FRAME:
+ return 'lz4'
+ elif ctype == CCompressionType_LZ4:
+ return 'lz4_raw'
+ elif ctype == CCompressionType_SNAPPY:
+ return 'snappy'
+ elif ctype == CCompressionType_ZSTD:
+ return 'zstd'
+ else:
+ raise RuntimeError('Unexpected CCompressionType value')
+
+
cdef class Codec:
"""
Compression codec.
diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi
index 4a27634..99352ca 100644
--- a/python/pyarrow/ipc.pxi
+++ b/python/pyarrow/ipc.pxi
@@ -50,29 +50,70 @@ cdef class IpcWriteOptions:
Parameters
----------
+ metadata_version : MetadataVersion, default MetadataVersion.V5
+ The metadata version to write. V5 is the current and latest,
+ V4 is the pre-1.0 metadata version (with incompatible Union layout).
use_legacy_format : bool, default False
Whether to use the pre-Arrow 0.15 IPC format.
- metadata_version : MetadataVersion, default MetadataVersion.V5
- The metadata version to write.
+ compression: str or None
+ If not None, compression codec to use for record batch buffers.
+ May only be "lz4", "zstd" or None.
+ use_threads: bool
+ Whether to use the global CPU thread pool to parallelize any
+ computational tasks like compression.
"""
+ __slots__ = ()
# cdef block is in lib.pxd
- def __init__(self, use_legacy_format=False,
- metadata_version=MetadataVersion.V5):
+ def __init__(self, *, metadata_version=MetadataVersion.V5,
+ use_legacy_format=False, compression=None,
+ bint use_threads=True):
self.c_options = CIpcWriteOptions.Defaults()
- self.c_options.write_legacy_ipc_format = use_legacy_format
- self.c_options.metadata_version = \
- _unwrap_metadata_version(metadata_version)
+ self.use_legacy_format = use_legacy_format
+ self.metadata_version = metadata_version
+ if compression is not None:
+ self.compression = compression
+ self.use_threads = use_threads
@property
def use_legacy_format(self):
return self.c_options.write_legacy_ipc_format
+ @use_legacy_format.setter
+ def use_legacy_format(self, bint value):
+ self.c_options.write_legacy_ipc_format = value
+
@property
def metadata_version(self):
return _wrap_metadata_version(self.c_options.metadata_version)
+ @metadata_version.setter
+ def metadata_version(self, value):
+ self.c_options.metadata_version = _unwrap_metadata_version(value)
+
+ @property
+ def compression(self):
+ if self.c_options.compression == CCompressionType_UNCOMPRESSED:
+ return None
+ else:
+ return _compression_name(self.c_options.compression)
+
+ @compression.setter
+ def compression(self, value):
+ if value is None:
+ self.c_options.compression = CCompressionType_UNCOMPRESSED
+ else:
+ self.c_options.compression = _ensure_compression(value)
+
+ @property
+ def use_threads(self):
+ return self.c_options.use_threads
+
+ @use_threads.setter
+ def use_threads(self, bint value):
+ self.c_options.use_threads = value
+
cdef class Message:
"""
diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py
index e65bb81..19e80ba 100644
--- a/python/pyarrow/ipc.py
+++ b/python/pyarrow/ipc.py
@@ -92,7 +92,7 @@ class RecordBatchStreamWriter(lib._RecordBatchStreamWriter):
{}""".format(_ipc_writer_class_doc)
- def __init__(self, sink, schema, use_legacy_format=None, options=None):
+ def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
options = _get_legacy_format_default(use_legacy_format, options)
self._open(sink, schema, options=options)
@@ -120,7 +120,7 @@ class RecordBatchFileWriter(lib._RecordBatchFileWriter):
{}""".format(_ipc_writer_class_doc)
- def __init__(self, sink, schema, use_legacy_format=None, options=None):
+ def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
options = _get_legacy_format_default(use_legacy_format, options)
self._open(sink, schema, options=options)
@@ -130,6 +130,9 @@ def _get_legacy_format_default(use_legacy_format, options):
raise ValueError(
"Can provide at most one of options and use_legacy_format")
elif options:
+ if not isinstance(options, IpcWriteOptions):
+ raise TypeError("expected IpcWriteOptions, got {}"
+ .format(type(options)))
return options
metadata_version = MetadataVersion.V5
@@ -142,7 +145,7 @@ def _get_legacy_format_default(use_legacy_format, options):
metadata_version=metadata_version)
-def new_stream(sink, schema, use_legacy_format=None, options=None):
+def new_stream(sink, schema, *, use_legacy_format=None, options=None):
return RecordBatchStreamWriter(sink, schema,
use_legacy_format=use_legacy_format,
options=options)
@@ -170,7 +173,7 @@ def open_stream(source):
return RecordBatchStreamReader(source)
-def new_file(sink, schema, use_legacy_format=None, options=None):
+def new_file(sink, schema, *, use_legacy_format=None, options=None):
return RecordBatchFileWriter(sink, schema,
use_legacy_format=use_legacy_format,
options=options)
@@ -201,7 +204,7 @@ def open_file(source, footer_offset=None):
return RecordBatchFileReader(source, footer_offset=footer_offset)
-def serialize_pandas(df, nthreads=None, preserve_index=None):
+def serialize_pandas(df, *, nthreads=None, preserve_index=None):
"""
Serialize a pandas DataFrame into a buffer protocol compatible object.
@@ -229,7 +232,7 @@ def serialize_pandas(df, nthreads=None,
preserve_index=None):
return sink.getvalue()
-def deserialize_pandas(buf, use_threads=True):
+def deserialize_pandas(buf, *, use_threads=True):
"""Deserialize a buffer protocol compatible object into a pandas DataFrame.
Parameters
diff --git a/python/pyarrow/tests/test_flight.py
b/python/pyarrow/tests/test_flight.py
index 5a7fda8..50e993d 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -706,6 +706,12 @@ def test_flight_do_get_ints():
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
+ with pytest.raises(flight.FlightServerError,
+ match="expected IpcWriteOptions, got <class 'int'>"):
+ with ConstantFlightServer(options=42) as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+
@pytest.mark.pandas
def test_do_get_ints_pandas():
diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py
index 7e53be7..44f8499 100644
--- a/python/pyarrow/tests/test_ipc.py
+++ b/python/pyarrow/tests/test_ipc.py
@@ -24,6 +24,7 @@ import threading
import numpy as np
import pyarrow as pa
+from pyarrow.tests.util import changed_environ
try:
@@ -315,7 +316,45 @@ def test_stream_simple_roundtrip(stream_fixture,
use_legacy_ipc_format):
reader.read_next_batch()
-def test_options_legacy_exclusive(stream_fixture):
+def test_write_options():
+ options = pa.ipc.IpcWriteOptions()
+ assert options.use_legacy_format is False
+ assert options.metadata_version == pa.ipc.MetadataVersion.V5
+
+ options.use_legacy_format = True
+ assert options.use_legacy_format is True
+
+ options.metadata_version = pa.ipc.MetadataVersion.V4
+ assert options.metadata_version == pa.ipc.MetadataVersion.V4
+ for value in ('V5', 42):
+ with pytest.raises((TypeError, ValueError)):
+ options.metadata_version = value
+
+ assert options.compression is None
+ for value in ['lz4', 'zstd']:
+ options.compression = value
+ assert options.compression == value
+ options.compression = value.upper()
+ assert options.compression == value
+ options.compression = None
+ assert options.compression is None
+
+ assert options.use_threads is True
+ options.use_threads = False
+ assert options.use_threads is False
+
+ options = pa.ipc.IpcWriteOptions(
+ metadata_version=pa.ipc.MetadataVersion.V4,
+ use_legacy_format=True,
+ compression='lz4',
+ use_threads=False)
+ assert options.metadata_version == pa.ipc.MetadataVersion.V4
+ assert options.use_legacy_format is True
+ assert options.compression == 'lz4'
+ assert options.use_threads is False
+
+
+def test_write_options_legacy_exclusive(stream_fixture):
with pytest.raises(
ValueError,
match="provide at most one of options and use_legacy_format"):
@@ -365,36 +404,30 @@ def test_envvar_set_legacy_ipc_format():
assert not writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
- import os
-
- os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
- writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
- assert writer._use_legacy_format
- assert writer._metadata_version == pa.ipc.MetadataVersion.V5
- writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
- assert writer._use_legacy_format
- assert writer._metadata_version == pa.ipc.MetadataVersion.V5
- del os.environ['ARROW_PRE_0_15_IPC_FORMAT']
-
- os.environ['ARROW_PRE_1_0_METADATA_VERSION'] = '1'
- writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
- assert not writer._use_legacy_format
- assert writer._metadata_version == pa.ipc.MetadataVersion.V4
- writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
- assert not writer._use_legacy_format
- assert writer._metadata_version == pa.ipc.MetadataVersion.V4
- del os.environ['ARROW_PRE_1_0_METADATA_VERSION']
-
- os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
- os.environ['ARROW_PRE_1_0_METADATA_VERSION'] = '1'
- writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
- assert writer._use_legacy_format
- assert writer._metadata_version == pa.ipc.MetadataVersion.V4
- writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
- assert writer._use_legacy_format
- assert writer._metadata_version == pa.ipc.MetadataVersion.V4
- del os.environ['ARROW_PRE_0_15_IPC_FORMAT']
- del os.environ['ARROW_PRE_1_0_METADATA_VERSION']
+ with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+
+ with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+
+ with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+ with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
def test_stream_read_all(stream_fixture):
diff --git a/python/pyarrow/tests/util.py b/python/pyarrow/tests/util.py
index dccf2d0..50844d2 100644
--- a/python/pyarrow/tests/util.py
+++ b/python/pyarrow/tests/util.py
@@ -194,3 +194,19 @@ def invoke_script(script_name, *args):
cmd.extend(args)
subprocess.check_call(cmd, env=subprocess_env)
+
+
[email protected]
+def changed_environ(name, value):
+ """
+ Temporarily set environment variable *name* to *value*.
+ """
+ orig_value = os.environ.get(name)
+ os.environ[name] = value
+ try:
+ yield
+ finally:
+ if orig_value is None:
+ del os.environ[name]
+ else:
+ os.environ[name] = orig_value