dianfu commented on a change in pull request #11718: [FLINK-17118][python] 
Support Primitive DataTypes in Cython
URL: https://github.com/apache/flink/pull/11718#discussion_r407917041
 
 

 ##########
 File path: flink-python/pyflink/fn_execution/fast_coder_impl.pyx
 ##########
 @@ -0,0 +1,558 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+# cython: language_level = 3
+# cython: infer_types = True
+# cython: profile=True
+# cython: boundscheck=False, wraparound=False, initializedcheck=False, 
cdivision=True
+
+cimport libc.stdlib
+from libc.string cimport strlen
+
+import datetime
+
+cdef class WrapperFuncInputElement:
+    def __cinit__(self, func, wrapper_input_element):
+        self.func = func
+        self.wrapper_input_element = wrapper_input_element
+
+cdef class WrapperInputElement:
+    def __cinit__(self, input_stream):
+        self.input_stream = input_stream
+
+cdef class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
+    def __cinit__(self, value_coder):
+        self._value_coder = value_coder
+
+    cpdef encode_to_stream(self, value, OutputStream out_stream, bint nested):
+        self._value_coder.encode_to_stream(value, out_stream, False)
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        return self._value_coder.decode_from_stream(in_stream, nested)
+
+    cpdef get_estimated_size_and_observables(self, value, bint nested=False):
+        return 0, []
+
+cdef class TableFunctionRowCoderImpl(StreamCoderImpl):
+    def __cinit__(self, flatten_row_coder):
+        self._flatten_row_coder = flatten_row_coder
+
+    cpdef encode_to_stream(self, value, OutputStream out_stream, bint nested):
+        self._flatten_row_coder.encode_table_row_result(value, out_stream)
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        return self._flatten_row_coder.decode_from_stream(in_stream, nested)
+
+cdef class FlattenRowCoderImpl(StreamCoderImpl):
+    def __cinit__(self, field_coders):
+        self._output_field_coders = field_coders
+        self._output_field_count = len(self._output_field_coders)
+        self._output_field_type = <TypeName*> libc.stdlib.malloc(
+            self._output_field_count * sizeof(TypeName))
+        self._output_coder_type = <CoderType*> libc.stdlib.malloc(
+            self._output_field_count * sizeof(CoderType))
+        self._output_leading_complete_bytes_num = self._output_field_count // 8
+        self._output_remaining_bits_num = self._output_field_count % 8
+        self._output_row_buffer_size = 1024
+        self._output_row_pos = 0
+        self._output_row_data = <char*> 
libc.stdlib.malloc(self._output_row_buffer_size)
+        self._null_byte_search_table = <unsigned char*> libc.stdlib.malloc(
+            8 * sizeof(unsigned char))
+        self._init_attribute()
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        cdef WrapperInputElement wrapper_input_element
+        wrapper_input_element = WrapperInputElement(in_stream)
+        self._consume_input_data(wrapper_input_element, in_stream.size())
+        return wrapper_input_element
+
+    cpdef encode_to_stream(self, wrapper_stream, OutputStream out_stream, bint 
nested):
+        self.encode_row_result(wrapper_stream, out_stream)
+
+    cdef encode_row_result(self, WrapperFuncInputElement 
wrapper_func_input_element,
+                           OutputStream out_stream):
+        cdef list result
+        self._before_encode(wrapper_func_input_element, out_stream)
+        while self._input_buffer_size > self._input_pos:
+            self._load_row()
+            result = self.func(self.row)
+            self._write_data(result)
+            self._dump_row()
+        self._after_encode(out_stream)
+
+    cdef encode_table_row_result(self, WrapperFuncInputElement 
wrapper_func_input_element,
+                                 OutputStream out_stream):
+        self._before_encode(wrapper_func_input_element, out_stream)
+        while self._input_buffer_size > self._input_pos:
+            self._load_row()
+            result = self.func(self.row)
+            if result:
+                for value in result:
+                    if self._output_field_count == 1:
+                        value = (value,)
+                    self._write_data(value)
+            self._dump_end_message()
+
+        self._after_encode(out_stream)
+
+    cdef void _init_attribute(self):
+        self._null_byte_search_table[0] = 0x80
+        self._null_byte_search_table[1] = 0x40
+        self._null_byte_search_table[2] = 0x20
+        self._null_byte_search_table[3] = 0x10
+        self._null_byte_search_table[4] = 0x08
+        self._null_byte_search_table[5] = 0x04
+        self._null_byte_search_table[6] = 0x02
+        self._null_byte_search_table[7] = 0x01
+        for i in range(self._output_field_count):
+            self._output_field_type[i] = 
self._output_field_coders[i].type_name()
+            self._output_coder_type[i] = 
self._output_field_coders[i].coder_type()
+
+    cdef void _consume_input_data(self, WrapperInputElement 
wrapper_input_element, size_t size):
+        # wrappers the input field coders and input_stream together
+        # so that it can be transposed to operations
+        wrapper_input_element.input_field_coders = self._output_field_coders
+        wrapper_input_element.input_remaining_bits_num = 
self._output_remaining_bits_num
+        wrapper_input_element.input_leading_complete_bytes_num = \
+            self._output_leading_complete_bytes_num
+        wrapper_input_element.input_field_count = self._output_field_count
+        wrapper_input_element.input_field_type = self._output_field_type
+        wrapper_input_element.input_coder_type = self._output_coder_type
+        wrapper_input_element.input_stream.pos = size
+        wrapper_input_element.input_buffer_size = size
+
+    cdef void _write_data(self, value):
+        cdef libc.stdint.int32_t i
+        self._write_null_mask(value, self._output_leading_complete_bytes_num,
+                              self._output_remaining_bits_num)
+        for i in range(self._output_field_count):
+            item = value[i]
+            if item is not None:
+                if self._output_coder_type[i] == PRIMITIVE:
+                    self._dump_field_primitive(self._output_field_type[i], 
item)
+
+        self._dump_row()
+
+    cdef void _read_null_mask(self, bint*null_mask,
+                              libc.stdint.int32_t 
input_leading_complete_bytes_num,
+                              libc.stdint.int32_t input_remaining_bits_num):
+        cdef libc.stdint.int32_t field_pos, i
+        cdef unsigned char b
+        field_pos = 0
+        for _ in range(input_leading_complete_bytes_num):
+            b = self._input_data[self._input_pos]
+            self._input_pos += 1
+            for i in range(8):
+                null_mask[field_pos] = (b & self._null_byte_search_table[i]) > 0
+                field_pos += 1
+
+        if input_remaining_bits_num:
+            b = self._input_data[self._input_pos]
+            self._input_pos += 1
+            for i in range(input_remaining_bits_num):
+                null_mask[field_pos] = (b & self._null_byte_search_table[i]) > 0
+                field_pos += 1
+
+    cdef void _load_row(self):
+        cdef libc.stdint.int32_t i
+        # skip prefix variable int length
+        while self._input_data[self._input_pos] & 0x80:
+            self._input_pos += 1
+        self._input_pos += 1
+        self._read_null_mask(self._null_mask, 
self._input_leading_complete_bytes_num,
+                             self._input_remaining_bits_num)
+        for i in range(self._input_field_count):
+            if self._null_mask[i]:
+                self.row[i] = None
+            else:
+                if self._input_coder_type[i] == PRIMITIVE:
+                    self.row[i] = 
self._load_field_primitive(self._input_field_type[i])
+
+    cdef object _load_field_primitive(self, TypeName field_type):
+        cdef libc.stdint.int32_t value, minutes, seconds, hours
+        cdef libc.stdint.int64_t milliseconds
+        if field_type == TINYINT:
+            # tinyint
+            return self._load_byte()
+        elif field_type == SMALLINT:
+            # smallint
+            return self._load_smallint()
+        elif field_type == INT:
+            # int
+            return self._load_int()
+        elif field_type == BIGINT:
+            # bigint
+            return self._load_bigint()
+        elif field_type == BOOLEAN:
+            # boolean
+            return not not self._load_byte()
+        elif field_type == FLOAT:
+            # float
+            return self._load_float()
+        elif field_type == DOUBLE:
+            # double
+            return self._load_double()
+        elif field_type == BINARY:
+            # bytes
+            return self._load_bytes()
+        elif field_type == CHAR:
+            # str
+            return self._load_bytes().decode("utf-8")
+        elif field_type == DATE:
+            # Date
+            return datetime.date.fromordinal(self._load_int() + 719163)
+        elif field_type == TIME:
+            # Time
+            value = self._load_int()
+            seconds = value // 1000
+            milliseconds = value % 1000
+            minutes = seconds // 60
+            seconds %= 60
+            hours = minutes // 60
+            minutes %= 60
+            return datetime.time(hours, minutes, seconds, milliseconds * 1000)
+
+    cdef unsigned char _load_byte(self) except? -1:
+        self._input_pos += 1
+        return <unsigned char> self._input_data[self._input_pos - 1]
+
+    cdef libc.stdint.int16_t _load_smallint(self) except? -1:
+        self._input_pos += 2
+        return (<unsigned char> self._input_data[self._input_pos - 1]
+                | <libc.stdint.uint32_t> <unsigned char> 
self._input_data[self._input_pos - 2] << 8)
+
+    cdef libc.stdint.int32_t _load_int(self) except? -1:
+        self._input_pos += 4
+        return (<unsigned char> self._input_data[self._input_pos - 1]
+                | <libc.stdint.uint32_t> <unsigned char> 
self._input_data[self._input_pos - 2] << 8
+                | <libc.stdint.uint32_t> <unsigned char> 
self._input_data[self._input_pos - 3] << 16
+                | <libc.stdint.uint32_t> <unsigned char> self._input_data[
+                    self._input_pos - 4] << 24)
+
+    cdef libc.stdint.int64_t _load_bigint(self) except? -1:
+        self._input_pos += 8
+        return (<unsigned char> self._input_data[self._input_pos - 1]
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 2] << 8
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 3] << 16
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 4] << 24
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 5] << 32
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 6] << 40
+                | <libc.stdint.uint64_t> <unsigned char> 
self._input_data[self._input_pos - 7] << 48
+                | <libc.stdint.uint64_t> <unsigned char> self._input_data[
+                    self._input_pos - 8] << 56)
+
+    cdef float _load_float(self) except? -1:
+        cdef libc.stdint.int32_t as_long = self._load_int()
+        return (<float*> <char*> &as_long)[0]
+
+    cdef double _load_double(self) except? -1:
+        cdef libc.stdint.int64_t as_long = self._load_bigint()
+        return (<double*> <char*> &as_long)[0]
+
+    cdef bytes _load_bytes(self):
+        cdef libc.stdint.int32_t size = self._load_int()
+        self._input_pos += size
+        return self._input_data[self._input_pos - size: self._input_pos]
+
+    cdef void _before_encode(self, WrapperFuncInputElement 
wrapper_func_input_element,
+                             OutputStream out_stream):
+        cdef WrapperInputElement wrapper_input_element
+        # get the data pointer of output_stream
+        self._output_data = out_stream.data
+        self._output_pos = out_stream.pos
+        self._output_buffer_size = out_stream.buffer_size
+        self._output_row_pos = 0
+
+        # get the data pointer of input_stream
+        self._input_data = 
wrapper_func_input_element.wrapper_input_element.input_stream.allc
+        self._input_buffer_size = 
wrapper_func_input_element.wrapper_input_element.input_buffer_size
+
+        # get the infos of input coder which will be used to decode data from 
input_stream
+        wrapper_input_element = 
wrapper_func_input_element.wrapper_input_element
+        self._input_field_count = wrapper_input_element.input_field_count
+        self._input_leading_complete_bytes_num = 
wrapper_input_element.input_leading_complete_bytes_num
+        self._input_remaining_bits_num = 
wrapper_input_element.input_remaining_bits_num
+        self._input_field_type = wrapper_input_element.input_field_type
+        self._input_coder_type = wrapper_input_element.input_coder_type
+        self._input_field_coders = wrapper_input_element.input_field_coders
+        self._null_mask = <bint*> libc.stdlib.malloc(self._input_field_count * 
sizeof(bint))
+        self._input_pos = 0
+
+        # initial the result row and get the Python user-defined function
+        self.row = [None for _ in range(self._input_field_count)]
+        self.func = wrapper_func_input_element.func
+
+    cdef void _after_encode(self, OutputStream out_stream):
+        # map the output_data to the buffer of output_stream
+        out_stream.data = self._output_data
+        out_stream.pos = self._output_pos
+        out_stream.buffer_size = self._output_buffer_size
+
+    cdef void _dump_field_primitive(self, TypeName field_type, item):
+        cdef libc.stdint.int32_t hour, minute, seconds, microsecond, 
milliseconds
+        if field_type == TINYINT:
+            # tinyint
+            self._dump_byte(item)
+        elif field_type == SMALLINT:
+            # smallint
+            self._dump_smallint(item)
+        elif field_type == INT:
+            # int
+            self._dump_int(item)
+        elif field_type == BIGINT:
+            # bigint
+            self._dump_bigint(item)
+        elif field_type == BOOLEAN:
+            # boolean
+            self._dump_byte(item)
+        elif field_type == FLOAT:
+            # float
+            self._dump_float(item)
+        elif field_type == DOUBLE:
+            # double
+            self._dump_double(item)
+        elif field_type == BINARY:
+            # bytes
+            self._dump_bytes(item)
+        elif field_type == CHAR:
+            # str
+            self._dump_bytes(item.encode('utf-8'))
+        elif field_type == DATE:
+            # Date
+            self._dump_int(item.toordinal() - 719163)
+        elif field_type == TIME:
+            # Time
+            hour = item.hour
+            minute = item.minute
+            seconds = item.second
+            microsecond = item.microsecond
+            milliseconds = hour * 3600000 + minute * 60000 + seconds * 1000 + 
microsecond // 1000
+            self._dump_int(milliseconds)
+
+    # write 0x00 as end message of udtf
+    cdef void _dump_end_message(self):
+        if self._output_buffer_size < self._output_pos + 2:
+            self._output_buffer_size *= 2
+            self._output_data = <char*> libc.stdlib.realloc(self._output_data,
+                                                            
self._output_buffer_size)
+        self._output_data[self._output_pos] = 0x01
+        self._output_data[self._output_pos + 1] = 0x00
+        self._output_pos += 2
+
+    cdef void _dump_row(self):
+        cdef size_t size
+        cdef size_t i
+        cdef bint is_realloc
+        cdef char bits
+        # the length of the variable prefix length will be less than 9 bytes
+        if self._output_buffer_size < self._output_pos + self._output_row_pos 
+ 9:
+            self._output_buffer_size += self._output_row_buffer_size + 9
+            self._output_data = <char*> libc.stdlib.realloc(self._output_data,
+                                                            
self._output_buffer_size)
+        size = self._output_row_pos
+        # write variable prefix length
+        while size:
+            bits = size & 0x7F
+            size >>= 7
+            if size:
+                bits |= 0x80
+            self._output_data[self._output_pos] = bits
+            self._output_pos += 1
+        if self._output_row_pos < 8:
+            # This is faster than memcpy when the string is short.
+            for i in range(self._output_row_pos):
+                self._output_data[self._output_pos + i] = 
self._output_row_data[i]
+        else:
+            libc.string.memcpy(self._output_data + self._output_pos, 
self._output_row_data,
+                               self._output_row_pos)
+        self._output_pos += self._output_row_pos
+        self._output_row_pos = 0
+
+    cdef void _dump_byte(self, unsigned char val):
+        if self._output_row_buffer_size < self._output_row_pos + 1:
+            self._output_row_buffer_size *= 2
+            self._output_row_data = <char*> 
libc.stdlib.realloc(self._output_row_data,
+                                                                
self._output_row_buffer_size)
+        self._output_row_data[self._output_row_pos] = val
+        self._output_row_pos += 1
+
+    cdef void _dump_smallint(self, libc.stdint.int16_t v):
+        if self._output_row_buffer_size < self._output_row_pos + 2:
+            self._output_row_buffer_size *= 2
+            self._output_row_data = <char*> 
libc.stdlib.realloc(self._output_row_data,
+                                                                
self._output_row_buffer_size)
+        self._output_row_data[self._output_row_pos] = <unsigned char> (v >> 8)
+        self._output_row_data[self._output_row_pos + 1] = <unsigned char> (v)
 
 Review comment:
   Remove the parentheses

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to