This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new feda987  ARROW-9333: [Python] Expose more IPC options
feda987 is described below

commit feda9877f8145aebf907c61a24640735a968a230
Author: Antoine Pitrou <[email protected]>
AuthorDate: Mon Jul 13 12:49:07 2020 -0500

    ARROW-9333: [Python] Expose more IPC options
    
    Also make some optional arguments keyword-only.
    
    Closes #7730 from pitrou/ARROW-9333-py-ipc-options
    
    Authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Wes McKinney <[email protected]>
---
 cpp/src/arrow/ipc/options.h          |  7 ++-
 python/pyarrow/_flight.pyx           |  6 +--
 python/pyarrow/includes/libarrow.pxd |  2 +
 python/pyarrow/io.pxi                | 29 +++++++++--
 python/pyarrow/ipc.pxi               | 55 ++++++++++++++++++---
 python/pyarrow/ipc.py                | 15 +++---
 python/pyarrow/tests/test_flight.py  |  6 +++
 python/pyarrow/tests/test_ipc.py     | 95 ++++++++++++++++++++++++------------
 python/pyarrow/tests/util.py         | 16 ++++++
 9 files changed, 174 insertions(+), 57 deletions(-)

diff --git a/cpp/src/arrow/ipc/options.h b/cpp/src/arrow/ipc/options.h
index 69e248c..6bbd7b8 100644
--- a/cpp/src/arrow/ipc/options.h
+++ b/cpp/src/arrow/ipc/options.h
@@ -56,10 +56,9 @@ struct ARROW_EXPORT IpcWriteOptions {
   /// \brief The memory pool to use for allocations made during IPC writing
   MemoryPool* memory_pool = default_memory_pool();
 
-  /// \brief EXPERIMENTAL: Codec to use for compressing and decompressing
-  /// record batch body buffers. This is not part of the Arrow IPC protocol and
-  /// only for internal use (e.g. Feather files). May only be LZ4_FRAME and
-  /// ZSTD
+  /// \brief Compression codec to use for record batch body buffers
+  ///
+  /// May only be UNCOMPRESSED, LZ4_FRAME and ZSTD.
   Compression::type compression = Compression::UNCOMPRESSED;
   int compression_level = Compression::kUseDefaultCompressionLevel;
 
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 7e3c837..7b6b281 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -97,10 +97,8 @@ def _munge_grpc_python_error(message):
 
 
 cdef IpcWriteOptions _get_options(options):
-    cdef IpcWriteOptions write_options = \
-        <IpcWriteOptions> _get_legacy_format_default(
-            use_legacy_format=None, options=options)
-    return write_options
+    return <IpcWriteOptions> _get_legacy_format_default(
+        use_legacy_format=None, options=options)
 
 
 cdef class FlightCallOptions:
diff --git a/python/pyarrow/includes/libarrow.pxd 
b/python/pyarrow/includes/libarrow.pxd
index 76203f0..3e461c4 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1329,6 +1329,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" 
nogil:
         c_bool write_legacy_ipc_format
         CMemoryPool* memory_pool
         CMetadataVersion metadata_version
+        CCompressionType compression
+        c_bool use_threads
 
         @staticmethod
         CIpcWriteOptions Defaults()
diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi
index 76a058d..058b09a 100644
--- a/python/pyarrow/io.pxi
+++ b/python/pyarrow/io.pxi
@@ -1539,24 +1539,43 @@ def _detect_compression(path):
 
 cdef CCompressionType _ensure_compression(str name) except *:
     uppercase = name.upper()
-    if uppercase == 'GZIP':
-        return CCompressionType_GZIP
-    elif uppercase == 'BZ2':
+    if uppercase == 'BZ2':
         return CCompressionType_BZ2
+    elif uppercase == 'GZIP':
+        return CCompressionType_GZIP
     elif uppercase == 'BROTLI':
         return CCompressionType_BROTLI
     elif uppercase == 'LZ4' or uppercase == 'LZ4_FRAME':
         return CCompressionType_LZ4_FRAME
     elif uppercase == 'LZ4_RAW':
         return CCompressionType_LZ4
-    elif uppercase == 'ZSTD':
-        return CCompressionType_ZSTD
     elif uppercase == 'SNAPPY':
         return CCompressionType_SNAPPY
+    elif uppercase == 'ZSTD':
+        return CCompressionType_ZSTD
     else:
         raise ValueError('Invalid value for compression: {!r}'.format(name))
 
 
+cdef str _compression_name(CCompressionType ctype):
+    if ctype == CCompressionType_GZIP:
+        return 'gzip'
+    elif ctype == CCompressionType_BROTLI:
+        return 'brotli'
+    elif ctype == CCompressionType_BZ2:
+        return 'bz2'
+    elif ctype == CCompressionType_LZ4_FRAME:
+        return 'lz4'
+    elif ctype == CCompressionType_LZ4:
+        return 'lz4_raw'
+    elif ctype == CCompressionType_SNAPPY:
+        return 'snappy'
+    elif ctype == CCompressionType_ZSTD:
+        return 'zstd'
+    else:
+        raise RuntimeError('Unexpected CCompressionType value')
+
+
 cdef class Codec:
     """
     Compression codec.
diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi
index 4a27634..99352ca 100644
--- a/python/pyarrow/ipc.pxi
+++ b/python/pyarrow/ipc.pxi
@@ -50,29 +50,70 @@ cdef class IpcWriteOptions:
 
     Parameters
     ----------
+    metadata_version : MetadataVersion, default MetadataVersion.V5
+        The metadata version to write.  V5 is the current and latest,
+        V4 is the pre-1.0 metadata version (with incompatible Union layout).
     use_legacy_format : bool, default False
         Whether to use the pre-Arrow 0.15 IPC format.
-    metadata_version : MetadataVersion, default MetadataVersion.V5
-        The metadata version to write.
+    compression: str or None
+        If not None, compression codec to use for record batch buffers.
+        May only be "lz4", "zstd" or None.
+    use_threads: bool
+        Whether to use the global CPU thread pool to parallelize any
+        computational tasks like compression.
     """
+    __slots__ = ()
 
     # cdef block is in lib.pxd
 
-    def __init__(self, use_legacy_format=False,
-                 metadata_version=MetadataVersion.V5):
+    def __init__(self, *, metadata_version=MetadataVersion.V5,
+                 use_legacy_format=False, compression=None,
+                 bint use_threads=True):
         self.c_options = CIpcWriteOptions.Defaults()
-        self.c_options.write_legacy_ipc_format = use_legacy_format
-        self.c_options.metadata_version = \
-            _unwrap_metadata_version(metadata_version)
+        self.use_legacy_format = use_legacy_format
+        self.metadata_version = metadata_version
+        if compression is not None:
+            self.compression = compression
+        self.use_threads = use_threads
 
     @property
     def use_legacy_format(self):
         return self.c_options.write_legacy_ipc_format
 
+    @use_legacy_format.setter
+    def use_legacy_format(self, bint value):
+        self.c_options.write_legacy_ipc_format = value
+
     @property
     def metadata_version(self):
         return _wrap_metadata_version(self.c_options.metadata_version)
 
+    @metadata_version.setter
+    def metadata_version(self, value):
+        self.c_options.metadata_version = _unwrap_metadata_version(value)
+
+    @property
+    def compression(self):
+        if self.c_options.compression == CCompressionType_UNCOMPRESSED:
+            return None
+        else:
+            return _compression_name(self.c_options.compression)
+
+    @compression.setter
+    def compression(self, value):
+        if value is None:
+            self.c_options.compression = CCompressionType_UNCOMPRESSED
+        else:
+            self.c_options.compression = _ensure_compression(value)
+
+    @property
+    def use_threads(self):
+        return self.c_options.use_threads
+
+    @use_threads.setter
+    def use_threads(self, bint value):
+        self.c_options.use_threads = value
+
 
 cdef class Message:
     """
diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py
index e65bb81..19e80ba 100644
--- a/python/pyarrow/ipc.py
+++ b/python/pyarrow/ipc.py
@@ -92,7 +92,7 @@ class RecordBatchStreamWriter(lib._RecordBatchStreamWriter):
 
 {}""".format(_ipc_writer_class_doc)
 
-    def __init__(self, sink, schema, use_legacy_format=None, options=None):
+    def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
         options = _get_legacy_format_default(use_legacy_format, options)
         self._open(sink, schema, options=options)
 
@@ -120,7 +120,7 @@ class RecordBatchFileWriter(lib._RecordBatchFileWriter):
 
 {}""".format(_ipc_writer_class_doc)
 
-    def __init__(self, sink, schema, use_legacy_format=None, options=None):
+    def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
         options = _get_legacy_format_default(use_legacy_format, options)
         self._open(sink, schema, options=options)
 
@@ -130,6 +130,9 @@ def _get_legacy_format_default(use_legacy_format, options):
         raise ValueError(
             "Can provide at most one of options and use_legacy_format")
     elif options:
+        if not isinstance(options, IpcWriteOptions):
+            raise TypeError("expected IpcWriteOptions, got {}"
+                            .format(type(options)))
         return options
 
     metadata_version = MetadataVersion.V5
@@ -142,7 +145,7 @@ def _get_legacy_format_default(use_legacy_format, options):
                            metadata_version=metadata_version)
 
 
-def new_stream(sink, schema, use_legacy_format=None, options=None):
+def new_stream(sink, schema, *, use_legacy_format=None, options=None):
     return RecordBatchStreamWriter(sink, schema,
                                    use_legacy_format=use_legacy_format,
                                    options=options)
@@ -170,7 +173,7 @@ def open_stream(source):
     return RecordBatchStreamReader(source)
 
 
-def new_file(sink, schema, use_legacy_format=None, options=None):
+def new_file(sink, schema, *, use_legacy_format=None, options=None):
     return RecordBatchFileWriter(sink, schema,
                                  use_legacy_format=use_legacy_format,
                                  options=options)
@@ -201,7 +204,7 @@ def open_file(source, footer_offset=None):
     return RecordBatchFileReader(source, footer_offset=footer_offset)
 
 
-def serialize_pandas(df, nthreads=None, preserve_index=None):
+def serialize_pandas(df, *, nthreads=None, preserve_index=None):
     """
     Serialize a pandas DataFrame into a buffer protocol compatible object.
 
@@ -229,7 +232,7 @@ def serialize_pandas(df, nthreads=None, 
preserve_index=None):
     return sink.getvalue()
 
 
-def deserialize_pandas(buf, use_threads=True):
+def deserialize_pandas(buf, *, use_threads=True):
     """Deserialize a buffer protocol compatible object into a pandas DataFrame.
 
     Parameters
diff --git a/python/pyarrow/tests/test_flight.py 
b/python/pyarrow/tests/test_flight.py
index 5a7fda8..50e993d 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -706,6 +706,12 @@ def test_flight_do_get_ints():
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
 
+    with pytest.raises(flight.FlightServerError,
+                       match="expected IpcWriteOptions, got <class 'int'>"):
+        with ConstantFlightServer(options=42) as server:
+            client = flight.connect(('localhost', server.port))
+            data = client.do_get(flight.Ticket(b'ints')).read_all()
+
 
 @pytest.mark.pandas
 def test_do_get_ints_pandas():
diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py
index 7e53be7..44f8499 100644
--- a/python/pyarrow/tests/test_ipc.py
+++ b/python/pyarrow/tests/test_ipc.py
@@ -24,6 +24,7 @@ import threading
 import numpy as np
 
 import pyarrow as pa
+from pyarrow.tests.util import changed_environ
 
 
 try:
@@ -315,7 +316,45 @@ def test_stream_simple_roundtrip(stream_fixture, 
use_legacy_ipc_format):
         reader.read_next_batch()
 
 
-def test_options_legacy_exclusive(stream_fixture):
+def test_write_options():
+    options = pa.ipc.IpcWriteOptions()
+    assert options.use_legacy_format is False
+    assert options.metadata_version == pa.ipc.MetadataVersion.V5
+
+    options.use_legacy_format = True
+    assert options.use_legacy_format is True
+
+    options.metadata_version = pa.ipc.MetadataVersion.V4
+    assert options.metadata_version == pa.ipc.MetadataVersion.V4
+    for value in ('V5', 42):
+        with pytest.raises((TypeError, ValueError)):
+            options.metadata_version = value
+
+    assert options.compression is None
+    for value in ['lz4', 'zstd']:
+        options.compression = value
+        assert options.compression == value
+        options.compression = value.upper()
+        assert options.compression == value
+    options.compression = None
+    assert options.compression is None
+
+    assert options.use_threads is True
+    options.use_threads = False
+    assert options.use_threads is False
+
+    options = pa.ipc.IpcWriteOptions(
+        metadata_version=pa.ipc.MetadataVersion.V4,
+        use_legacy_format=True,
+        compression='lz4',
+        use_threads=False)
+    assert options.metadata_version == pa.ipc.MetadataVersion.V4
+    assert options.use_legacy_format is True
+    assert options.compression == 'lz4'
+    assert options.use_threads is False
+
+
+def test_write_options_legacy_exclusive(stream_fixture):
     with pytest.raises(
             ValueError,
             match="provide at most one of options and use_legacy_format"):
@@ -365,36 +404,30 @@ def test_envvar_set_legacy_ipc_format():
     assert not writer._use_legacy_format
     assert writer._metadata_version == pa.ipc.MetadataVersion.V5
 
-    import os
-
-    os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
-    writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
-    assert writer._use_legacy_format
-    assert writer._metadata_version == pa.ipc.MetadataVersion.V5
-    writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
-    assert writer._use_legacy_format
-    assert writer._metadata_version == pa.ipc.MetadataVersion.V5
-    del os.environ['ARROW_PRE_0_15_IPC_FORMAT']
-
-    os.environ['ARROW_PRE_1_0_METADATA_VERSION'] = '1'
-    writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
-    assert not writer._use_legacy_format
-    assert writer._metadata_version == pa.ipc.MetadataVersion.V4
-    writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
-    assert not writer._use_legacy_format
-    assert writer._metadata_version == pa.ipc.MetadataVersion.V4
-    del os.environ['ARROW_PRE_1_0_METADATA_VERSION']
-
-    os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
-    os.environ['ARROW_PRE_1_0_METADATA_VERSION'] = '1'
-    writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
-    assert writer._use_legacy_format
-    assert writer._metadata_version == pa.ipc.MetadataVersion.V4
-    writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
-    assert writer._use_legacy_format
-    assert writer._metadata_version == pa.ipc.MetadataVersion.V4
-    del os.environ['ARROW_PRE_0_15_IPC_FORMAT']
-    del os.environ['ARROW_PRE_1_0_METADATA_VERSION']
+    with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+        writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+        assert writer._use_legacy_format
+        assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+        writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+        assert writer._use_legacy_format
+        assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+
+    with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+        writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+        assert not writer._use_legacy_format
+        assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+        writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+        assert not writer._use_legacy_format
+        assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+
+    with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+        with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+            writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+            assert writer._use_legacy_format
+            assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+            writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+            assert writer._use_legacy_format
+            assert writer._metadata_version == pa.ipc.MetadataVersion.V4
 
 
 def test_stream_read_all(stream_fixture):
diff --git a/python/pyarrow/tests/util.py b/python/pyarrow/tests/util.py
index dccf2d0..50844d2 100644
--- a/python/pyarrow/tests/util.py
+++ b/python/pyarrow/tests/util.py
@@ -194,3 +194,19 @@ def invoke_script(script_name, *args):
     cmd.extend(args)
 
     subprocess.check_call(cmd, env=subprocess_env)
+
+
[email protected]
+def changed_environ(name, value):
+    """
+    Temporarily set environment variable *name* to *value*.
+    """
+    orig_value = os.environ.get(name)
+    os.environ[name] = value
+    try:
+        yield
+    finally:
+        if orig_value is None:
+            del os.environ[name]
+        else:
+            os.environ[name] = orig_value

Reply via email to