dianfu commented on a change in pull request #11718: [FLINK-17118][python] 
Support Primitive DataTypes in Cython
URL: https://github.com/apache/flink/pull/11718#discussion_r408766504
 
 

 ##########
 File path: flink-python/pyflink/fn_execution/fast_coder_impl.pyx
 ##########
 @@ -0,0 +1,559 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+# cython: language_level = 3
+# cython: infer_types = True
+# cython: profile=True
+# cython: boundscheck=False, wraparound=False, initializedcheck=False, 
cdivision=True
+
+cimport libc.stdlib
+from libc.string cimport strlen
+
+import datetime
+
+cdef class InputStreamAndFunctionWrapper:
+    def __cinit__(self, func, input_stream_wrapper):
+        self.func = func
+        self.input_stream_wrapper = input_stream_wrapper
+
+cdef class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
+    def __cinit__(self, value_coder):
+        self._value_coder = value_coder
+
+    cpdef encode_to_stream(self, value, OutputStream out_stream, bint nested):
+        self._value_coder.encode_to_stream(value, out_stream, nested)
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        return self._value_coder.decode_from_stream(in_stream, nested)
+
+    cpdef get_estimated_size_and_observables(self, value, bint nested=False):
+        return 0, []
+
+cdef class TableFunctionRowCoderImpl(FlattenRowCoderImpl):
+    def __init__(self, flatten_row_coder):
+        super(TableFunctionRowCoderImpl, 
self).__init__(flatten_row_coder._output_field_coders)
+
+    cpdef encode_to_stream(self, input_stream_and_function_wrapper, 
OutputStream out_stream,
+                           bint nested):
+        self._prepare_encode(input_stream_and_function_wrapper, out_stream)
+        while self._input_buffer_size > self._input_pos:
+            self._decode_next_row()
+            result = self.func(self.row)
+            if result:
+                for value in result:
+                    if self._output_field_count == 1:
+                        value = (value,)
+                    self._encode_one_row(value)
+                    self._maybe_flush(out_stream)
+            self._encode_end_message()
+
+        self._map_output_data_to_output_stream(out_stream)
+
+cdef class FlattenRowCoderImpl(StreamCoderImpl):
+    def __init__(self, field_coders):
+        self._output_field_coders = field_coders
+        self._output_field_count = len(self._output_field_coders)
+        self._output_field_type = <TypeName*> libc.stdlib.malloc(
+            self._output_field_count * sizeof(TypeName))
+        self._output_coder_type = <CoderType*> libc.stdlib.malloc(
+            self._output_field_count * sizeof(CoderType))
+        self._output_leading_complete_bytes_num = self._output_field_count // 8
+        self._output_remaining_bits_num = self._output_field_count % 8
+        self._output_row_buffer_size = 1024
+        self._output_row_pos = 0
+        self._output_row_data = <char*> 
libc.stdlib.malloc(self._output_row_buffer_size)
+        self._null_byte_search_table = <unsigned char*> libc.stdlib.malloc(
+            8 * sizeof(unsigned char))
+        self._init_attribute()
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        cdef InputStreamWrapper input_stream_wrapper
+        input_stream_wrapper = self._wrap_input_stream(in_stream, 
in_stream.size())
+        return input_stream_wrapper
+
+    cpdef encode_to_stream(self, input_stream_and_function_wrapper, 
OutputStream out_stream,
+                           bint nested):
+        cdef list result
+        self._prepare_encode(input_stream_and_function_wrapper, out_stream)
+        while self._input_buffer_size > self._input_pos:
+            self._decode_next_row()
+            result = self.func(self.row)
+            self._encode_one_row(result)
+            self._maybe_flush(out_stream)
+        self._map_output_data_to_output_stream(out_stream)
+
+    cdef void _init_attribute(self):
+        self._null_byte_search_table[0] = 0x80
+        self._null_byte_search_table[1] = 0x40
+        self._null_byte_search_table[2] = 0x20
+        self._null_byte_search_table[3] = 0x10
+        self._null_byte_search_table[4] = 0x08
+        self._null_byte_search_table[5] = 0x04
+        self._null_byte_search_table[6] = 0x02
+        self._null_byte_search_table[7] = 0x01
+        for i in range(self._output_field_count):
+            self._output_field_type[i] = 
self._output_field_coders[i].type_name()
+            self._output_coder_type[i] = 
self._output_field_coders[i].coder_type()
+
+    cdef InputStreamWrapper _wrap_input_stream(self, InputStream input_stream, 
size_t size):
+        # wrappers the input field coders and input_stream together
+        # so that it can be transposed to operations
+        cdef InputStreamWrapper input_stream_wrapper
+        input_stream_wrapper = InputStreamWrapper()
+        input_stream_wrapper.input_stream = input_stream
+        input_stream_wrapper.input_field_coders = self._output_field_coders
+        input_stream_wrapper.input_remaining_bits_num = 
self._output_remaining_bits_num
+        input_stream_wrapper.input_leading_complete_bytes_num = \
+            self._output_leading_complete_bytes_num
+        input_stream_wrapper.input_field_count = self._output_field_count
+        input_stream_wrapper.input_field_type = self._output_field_type
+        input_stream_wrapper.input_coder_type = self._output_coder_type
+        input_stream_wrapper.input_stream.pos = size
+        input_stream_wrapper.input_buffer_size = size
+        return input_stream_wrapper
+
+    cdef void _encode_one_row(self, value):
+        cdef libc.stdint.int32_t i
+        self._write_null_mask(value, self._output_leading_complete_bytes_num,
+                              self._output_remaining_bits_num)
+        for i in range(self._output_field_count):
+            item = value[i]
+            if item is not None:
+                if self._output_coder_type[i] == SIMPLE:
+                    self._encode_field_simple(self._output_field_type[i], item)
+
+        self._copy_row_buffer_to_output_buffer()
+
+    cdef void _read_null_mask(self, bint*null_mask,
+                              libc.stdint.int32_t 
input_leading_complete_bytes_num,
+                              libc.stdint.int32_t input_remaining_bits_num):
+        cdef libc.stdint.int32_t field_pos, i
+        cdef unsigned char b
+        field_pos = 0
+        for _ in range(input_leading_complete_bytes_num):
+            b = self._input_data[self._input_pos]
+            self._input_pos += 1
+            for i in range(8):
+                null_mask[field_pos] = (b & self._null_byte_search_table[i]) > 0
+                field_pos += 1
+
+        if input_remaining_bits_num:
+            b = self._input_data[self._input_pos]
+            self._input_pos += 1
+            for i in range(input_remaining_bits_num):
+                null_mask[field_pos] = (b & self._null_byte_search_table[i]) > 0
+                field_pos += 1
+
+    cdef void _decode_next_row(self):
+        cdef libc.stdint.int32_t i
+        # skip prefix variable int length
+        while self._input_data[self._input_pos] & 0x80:
+            self._input_pos += 1
+        self._input_pos += 1
+        self._read_null_mask(self._null_mask, 
self._input_leading_complete_bytes_num,
+                             self._input_remaining_bits_num)
+        for i in range(self._input_field_count):
+            if self._null_mask[i]:
+                self.row[i] = None
+            else:
+                if self._input_coder_type[i] == SIMPLE:
+                    self.row[i] = 
self._decode_field_simple(self._input_field_type[i])
+
+    cdef object _decode_field_simple(self, TypeName field_type):
+        cdef libc.stdint.int32_t value, minutes, seconds, hours
+        cdef libc.stdint.int64_t milliseconds
+        if field_type == TINYINT:
+            # tinyint
+            return self._decode_byte()
+        elif field_type == SMALLINT:
+            # smallint
+            return self._decode_smallint()
+        elif field_type == INT:
+            # int
+            return self._decode_int()
+        elif field_type == BIGINT:
+            # bigint
+            return self._decode_bigint()
+        elif field_type == BOOLEAN:
+            # boolean
+            return not not self._decode_byte()
+        elif field_type == FLOAT:
+            # float
+            return self._decode_float()
+        elif field_type == DOUBLE:
+            # double
+            return self._decode_double()
+        elif field_type == BINARY:
+            # bytes
+            return self._decode_bytes()
+        elif field_type == CHAR:
+            # str
+            return self._decode_bytes().decode("utf-8")
+        elif field_type == DATE:
+            # Date
+            # EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+            # The value of EPOCH_ORDINAL is 719163
+            return datetime.date.fromordinal(self._decode_int() + 719163)
+        elif field_type == TIME:
+            # Time
+            value = self._decode_int()
+            seconds = value // 1000
+            milliseconds = value % 1000
+            minutes = seconds // 60
+            seconds %= 60
+            hours = minutes // 60
+            minutes %= 60
+            return datetime.time(hours, minutes, seconds, milliseconds * 1000)
+
+    cdef unsigned char _decode_byte(self) except? -1:
+        self._input_pos += 1
+        return <unsigned char> self._input_data[self._input_pos - 1]
+
+    cdef libc.stdint.int16_t _decode_smallint(self) except? -1:
+        self._input_pos += 2
+        return (<unsigned char> self._input_data[self._input_pos - 1]
+                | <libc.stdint.uint32_t> <unsigned char> 
self._input_data[self._input_pos - 2] << 8)
+
+    cdef libc.stdint.int32_t _decode_int(self) except? -1:
+        self._input_pos += 4
+        return (<unsigned char> self._input_data[self._input_pos - 1]
+                | <libc.stdint.uint32_t> <unsigned char> 
self._input_data[self._input_pos - 2] << 8
+                | <libc.stdint.uint32_t> <unsigned char> 
self._input_data[self._input_pos - 3] << 16
+                | <libc.stdint.uint32_t> <unsigned char> self._input_data[
+                    self._input_pos - 4] << 24)
+
+    cdef libc.stdint.int64_t _decode_bigint(self) except? -1:
+        self._input_pos += 8
+        return (<unsigned char> self._input_data[self._input_pos - 1]
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 2] << 8
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 3] << 16
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 4] << 24
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 5] << 32
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 6] << 40
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 7] << 48
+                | <libc.stdint.uint64_t> <unsigned char> self._input_data[
+                    self._input_pos - 8] << 56)
+
+    cdef float _decode_float(self) except? -1:
+        cdef libc.stdint.int32_t as_long = self._decode_int()
+        return (<float*> <char*> &as_long)[0]
+
+    cdef double _decode_double(self) except? -1:
+        cdef libc.stdint.int64_t as_long = self._decode_bigint()
+        return (<double*> <char*> &as_long)[0]
+
+    cdef bytes _decode_bytes(self):
+        cdef libc.stdint.int32_t size = self._decode_int()
+        self._input_pos += size
+        return self._input_data[self._input_pos - size: self._input_pos]
+
+    cdef void _prepare_encode(self, InputStreamAndFunctionWrapper 
input_stream_and_function_wrapper,
+                              OutputStream out_stream):
+        cdef InputStreamWrapper input_stream_wrapper
+        # get the data pointer of output_stream
+        self._output_data = out_stream.data
+        self._output_pos = out_stream.pos
+        self._output_buffer_size = out_stream.buffer_size
+        self._output_row_pos = 0
+
+        input_stream_wrapper = 
input_stream_and_function_wrapper.input_stream_wrapper
+        # get the data pointer of input_stream
+        self._input_data = input_stream_wrapper.input_stream.allc
+        self._input_buffer_size = input_stream_wrapper.input_buffer_size
+
+        # get the infos of input coder which will be used to decode data from 
input_stream
+        self._input_field_count = input_stream_wrapper.input_field_count
+        self._input_leading_complete_bytes_num = 
input_stream_wrapper.input_leading_complete_bytes_num
+        self._input_remaining_bits_num = 
input_stream_wrapper.input_remaining_bits_num
+        self._input_field_type = input_stream_wrapper.input_field_type
+        self._input_coder_type = input_stream_wrapper.input_coder_type
+        self._input_field_coders = input_stream_wrapper.input_field_coders
+        self._null_mask = <bint*> libc.stdlib.malloc(self._input_field_count * 
sizeof(bint))
+        self._input_pos = 0
+
+        # initial the result row and get the Python user-defined function
+        self.row = [None for _ in range(self._input_field_count)]
+        self.func = input_stream_and_function_wrapper.func
+
+    cdef void _encode_field_simple(self, TypeName field_type, item):
+        cdef libc.stdint.int32_t hour, minute, seconds, microsecond, 
milliseconds
+        if field_type == TINYINT:
+            # tinyint
+            self._encode_byte(item)
+        elif field_type == SMALLINT:
+            # smallint
+            self._encode_smallint(item)
+        elif field_type == INT:
+            # int
+            self._encode_int(item)
+        elif field_type == BIGINT:
+            # bigint
+            self._encode_bigint(item)
+        elif field_type == BOOLEAN:
+            # boolean
+            self._encode_byte(item)
+        elif field_type == FLOAT:
+            # float
+            self._encode_float(item)
+        elif field_type == DOUBLE:
+            # double
+            self._encode_double(item)
+        elif field_type == BINARY:
+            # bytes
+            self._encode_bytes(item)
+        elif field_type == CHAR:
+            # str
+            self._encode_bytes(item.encode('utf-8'))
+        elif field_type == DATE:
+            # Date
+            # EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+            # The value of EPOCH_ORDINAL is 719163
+            self._encode_int(item.toordinal() - 719163)
+        elif field_type == TIME:
+            # Time
+            hour = item.hour
+            minute = item.minute
+            seconds = item.second
+            microsecond = item.microsecond
+            milliseconds = hour * 3600000 + minute * 60000 + seconds * 1000 + 
microsecond // 1000
+            self._encode_int(milliseconds)
+
+    # write 0x00 as end message of udtf
+    cdef void _encode_end_message(self):
 
 Review comment:
   Move this method to TableFunctionRowCoderImpl?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to