dianfu commented on a change in pull request #11718: [FLINK-17118][python] 
Support Primitive DataTypes in Cython
URL: https://github.com/apache/flink/pull/11718#discussion_r407909862
 
 

 ##########
 File path: flink-python/pyflink/fn_execution/fast_coder_impl.pyx
 ##########
 @@ -0,0 +1,558 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+# cython: language_level = 3
+# cython: infer_types = True
+# cython: profile=True
+# cython: boundscheck=False, wraparound=False, initializedcheck=False, 
cdivision=True
+
+cimport libc.stdlib
+from libc.string cimport strlen
+
+import datetime
+
+cdef class WrapperFuncInputElement:
+    def __cinit__(self, func, wrapper_input_element):
+        self.func = func
+        self.wrapper_input_element = wrapper_input_element
+
+cdef class WrapperInputElement:
+    def __cinit__(self, input_stream):
+        self.input_stream = input_stream
+
+cdef class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
+    def __cinit__(self, value_coder):
+        self._value_coder = value_coder
+
+    cpdef encode_to_stream(self, value, OutputStream out_stream, bint nested):
+        self._value_coder.encode_to_stream(value, out_stream, False)
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        return self._value_coder.decode_from_stream(in_stream, nested)
+
+    cpdef get_estimated_size_and_observables(self, value, bint nested=False):
+        return 0, []
+
+cdef class TableFunctionRowCoderImpl(StreamCoderImpl):
+    def __cinit__(self, flatten_row_coder):
+        self._flatten_row_coder = flatten_row_coder
+
+    cpdef encode_to_stream(self, value, OutputStream out_stream, bint nested):
+        self._flatten_row_coder.encode_table_row_result(value, out_stream)
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        return self._flatten_row_coder.decode_from_stream(in_stream, nested)
+
+cdef class FlattenRowCoderImpl(StreamCoderImpl):
+    def __cinit__(self, field_coders):
+        self._output_field_coders = field_coders
+        self._output_field_count = len(self._output_field_coders)
+        self._output_field_type = <TypeName*> libc.stdlib.malloc(
+            self._output_field_count * sizeof(TypeName))
+        self._output_coder_type = <CoderType*> libc.stdlib.malloc(
+            self._output_field_count * sizeof(CoderType))
+        self._output_leading_complete_bytes_num = self._output_field_count // 8
+        self._output_remaining_bits_num = self._output_field_count % 8
+        self._output_row_buffer_size = 1024
+        self._output_row_pos = 0
+        self._output_row_data = <char*> 
libc.stdlib.malloc(self._output_row_buffer_size)
+        self._null_byte_search_table = <unsigned char*> libc.stdlib.malloc(
+            8 * sizeof(unsigned char))
+        self._init_attribute()
+
+    cpdef decode_from_stream(self, InputStream in_stream, bint nested):
+        cdef WrapperInputElement wrapper_input_element
+        wrapper_input_element = WrapperInputElement(in_stream)
+        self._consume_input_data(wrapper_input_element, in_stream.size())
+        return wrapper_input_element
+
+    cpdef encode_to_stream(self, wrapper_stream, OutputStream out_stream, bint 
nested):
+        self.encode_row_result(wrapper_stream, out_stream)
+
+    cdef encode_row_result(self, WrapperFuncInputElement 
wrapper_func_input_element,
+                           OutputStream out_stream):
+        cdef list result
+        self._before_encode(wrapper_func_input_element, out_stream)
+        while self._input_buffer_size > self._input_pos:
+            self._load_row()
+            result = self.func(self.row)
+            self._write_data(result)
+            self._dump_row()
+        self._after_encode(out_stream)
+
+    cdef encode_table_row_result(self, WrapperFuncInputElement 
wrapper_func_input_element,
+                                 OutputStream out_stream):
+        self._before_encode(wrapper_func_input_element, out_stream)
+        while self._input_buffer_size > self._input_pos:
+            self._load_row()
+            result = self.func(self.row)
+            if result:
+                for value in result:
+                    if self._output_field_count == 1:
+                        value = (value,)
+                    self._write_data(value)
+            self._dump_end_message()
+
+        self._after_encode(out_stream)
+
+    cdef void _init_attribute(self):
+        self._null_byte_search_table[0] = 0x80
+        self._null_byte_search_table[1] = 0x40
+        self._null_byte_search_table[2] = 0x20
+        self._null_byte_search_table[3] = 0x10
+        self._null_byte_search_table[4] = 0x08
+        self._null_byte_search_table[5] = 0x04
+        self._null_byte_search_table[6] = 0x02
+        self._null_byte_search_table[7] = 0x01
+        for i in range(self._output_field_count):
+            self._output_field_type[i] = 
self._output_field_coders[i].type_name()
+            self._output_coder_type[i] = 
self._output_field_coders[i].coder_type()
+
+    cdef void _consume_input_data(self, WrapperInputElement 
wrapper_input_element, size_t size):
+        # wrappers the input field coders and input_stream together
+        # so that it can be transposed to operations
+        wrapper_input_element.input_field_coders = self._output_field_coders
+        wrapper_input_element.input_remaining_bits_num = 
self._output_remaining_bits_num
+        wrapper_input_element.input_leading_complete_bytes_num = \
+            self._output_leading_complete_bytes_num
+        wrapper_input_element.input_field_count = self._output_field_count
+        wrapper_input_element.input_field_type = self._output_field_type
+        wrapper_input_element.input_coder_type = self._output_coder_type
+        wrapper_input_element.input_stream.pos = size
+        wrapper_input_element.input_buffer_size = size
+
+    cdef void _write_data(self, value):
+        cdef libc.stdint.int32_t i
+        self._write_null_mask(value, self._output_leading_complete_bytes_num,
+                              self._output_remaining_bits_num)
+        for i in range(self._output_field_count):
+            item = value[i]
+            if item is not None:
+                if self._output_coder_type[i] == PRIMITIVE:
+                    self._dump_field_primitive(self._output_field_type[i], 
item)
+
+        self._dump_row()
+
+    cdef void _read_null_mask(self, bint*null_mask,
+                              libc.stdint.int32_t 
input_leading_complete_bytes_num,
+                              libc.stdint.int32_t input_remaining_bits_num):
+        cdef libc.stdint.int32_t field_pos, i
+        cdef unsigned char b
+        field_pos = 0
+        for _ in range(input_leading_complete_bytes_num):
+            b = self._input_data[self._input_pos]
+            self._input_pos += 1
+            for i in range(8):
+                null_mask[field_pos] = (b & self._null_byte_search_table[i]) > 0
+                field_pos += 1
+
+        if input_remaining_bits_num:
+            b = self._input_data[self._input_pos]
+            self._input_pos += 1
+            for i in range(input_remaining_bits_num):
+                null_mask[field_pos] = (b & self._null_byte_search_table[i]) > 0
+                field_pos += 1
+
+    cdef void _load_row(self):
+        cdef libc.stdint.int32_t i
+        # skip prefix variable int length
+        while self._input_data[self._input_pos] & 0x80:
+            self._input_pos += 1
+        self._input_pos += 1
+        self._read_null_mask(self._null_mask, 
self._input_leading_complete_bytes_num,
+                             self._input_remaining_bits_num)
+        for i in range(self._input_field_count):
+            if self._null_mask[i]:
+                self.row[i] = None
+            else:
+                if self._input_coder_type[i] == PRIMITIVE:
+                    self.row[i] = 
self._load_field_primitive(self._input_field_type[i])
+
+    cdef object _load_field_primitive(self, TypeName field_type):
+        cdef libc.stdint.int32_t value, minutes, seconds, hours
+        cdef libc.stdint.int64_t milliseconds
+        if field_type == TINYINT:
+            # tinyint
+            return self._load_byte()
+        elif field_type == SMALLINT:
+            # smallint
+            return self._load_smallint()
+        elif field_type == INT:
+            # int
+            return self._load_int()
+        elif field_type == BIGINT:
+            # bigint
+            return self._load_bigint()
+        elif field_type == BOOLEAN:
+            # boolean
+            return not not self._load_byte()
+        elif field_type == FLOAT:
+            # float
+            return self._load_float()
+        elif field_type == DOUBLE:
+            # double
+            return self._load_double()
+        elif field_type == BINARY:
+            # bytes
+            return self._load_bytes()
+        elif field_type == CHAR:
+            # str
+            return self._load_bytes().decode("utf-8")
+        elif field_type == DATE:
+            # Date
+            return datetime.date.fromordinal(self._load_int() + 719163)
 
 Review comment:
   Could you add some explain how 719163 is computed?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to