harshmotw-db commented on code in PR #49450: URL: https://github.com/apache/spark/pull/49450#discussion_r1911903839
########## python/pyspark/sql/variant_utils.py: ########## @@ -496,3 +525,297 @@ def _handle_array(cls, value: bytes, pos: int, func: Callable[[List[int]], Any]) element_pos = data_start + offset value_pos_list.append(element_pos) return func(value_pos_list) + + +class FieldEntry(NamedTuple): + """ + Info about an object field + """ + + key: str + id: int + offset: int + + +class VariantBuilder: + """ + A utility class for building VariantVal. + """ + + DEFAULT_SIZE_LIMIT = 16 * 1024 * 1024 + + def __init__(self, size_limit=DEFAULT_SIZE_LIMIT): + self.value = bytearray() + self.dictionary = dict[str, int]() + self.dictionary_keys = list[bytes]() + self.size_limit = size_limit + + def build(self, json_str: str) -> (bytes, bytes): + parsed = json.loads(json_str, parse_float=self._handle_float) + self._process_parsed_json(parsed) + + num_keys = len(self.dictionary_keys) + dictionary_string_size = sum(len(key) for key in self.dictionary_keys) + + # Determine the number of bytes required per offset entry. + # The largest offset is the one-past-the-end value, which is total string size. It's very + # unlikely that the number of keys could be larger, but incorporate that into the + # calculation in case of pathological data. + max_size = max(dictionary_string_size, num_keys) + if max_size > self.size_limit: + raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", messageParameters={}) + offset_size = self._get_integer_size(max_size) + + offset_start = 1 + offset_size + string_start = offset_start + (num_keys + 1) * offset_size + metadata_size = string_start + dictionary_string_size + if metadata_size > self.size_limit: + raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", messageParameters={}) + + metadata = bytearray() + header_byte = VariantUtils.VERSION | ((offset_size - 1) << 6) + metadata.extend(header_byte.to_bytes(1, byteorder="little")) + metadata.extend(num_keys.to_bytes(offset_size, byteorder="little")) + # write offsets + current_offset = 0 + for key in self.dictionary_keys: + metadata.extend(current_offset.to_bytes(offset_size, byteorder="little")) + current_offset += len(key) + metadata.extend(current_offset.to_bytes(offset_size, byteorder="little")) + # write key data + for key in self.dictionary_keys: + metadata.extend(key) + return (bytes(self.value), bytes(metadata)) + + def _process_parsed_json(self, parsed: Any) -> None: + if type(parsed) is dict: + fields = list[FieldEntry]() + start = len(self.value) + for key, value in parsed.items(): + id = self._add_key(key) + fields.append(FieldEntry(key, id, len(self.value) - start)) + self._process_parsed_json(value) + self._finish_writing_object(start, fields) + elif type(parsed) is list: + offsets = [] + start = len(self.value) + for elem in parsed: + offsets.append(len(self.value) - start) + self._process_parsed_json(elem) + self._finish_writing_array(start, offsets) + elif type(parsed) is str: + self._append_string(parsed) + elif type(parsed) is int: + if not self._append_int(parsed): + self._process_parsed_json(self._handle_float(str(parsed))) + elif type(parsed) is float: + self._append_float(parsed) + elif type(parsed) is decimal.Decimal: + self._append_decimal(parsed) + elif type(parsed) is bool: + self._append_boolean(parsed) + elif parsed is None: + self._append_null() + else: + raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={}) + + # Choose the smallest unsigned integer type that can store `value`. It must be within + # [0, U24_MAX]. + def _get_integer_size(self, value: int) -> int: + if value <= VariantUtils.U8_MAX: + return 1 + if value <= VariantUtils.U16_MAX: + return 2 + return VariantUtils.U24_SIZE + + def _check_capacity(self, additional: int) -> None: + required = len(self.value) + additional + if required > self.size_limit: + raise PySparkValueError(errorClass="VARIANT_SIZE_LIMIT_EXCEEDED", messageParameters={}) + + def _primitive_header(self, type: int) -> bytes: + return bytes([(type << 2) | VariantUtils.PRIMITIVE]) + + def _short_string_header(self, size: int) -> bytes: + return bytes([size << 2 | VariantUtils.SHORT_STR]) + + def _array_header(self, large_size: bool, offset_size: int) -> bytes: + return bytes( + [ + ( + (large_size << (VariantUtils.BASIC_TYPE_BITS + 2)) + | ((offset_size - 1) << VariantUtils.BASIC_TYPE_BITS) + | VariantUtils.ARRAY + ) + ] + ) + + def _object_header(self, large_size: bool, id_size: int, offset_size: int) -> bytes: + return bytes( + [ + ( + (large_size << (VariantUtils.BASIC_TYPE_BITS + 4)) + | ((id_size - 1) << (VariantUtils.BASIC_TYPE_BITS + 2)) + | ((offset_size - 1) << VariantUtils.BASIC_TYPE_BITS) + | VariantUtils.OBJECT + ) + ] + ) + + # Add a key to the variant dictionary. If the key already exists, the dictionary is + # not modified. In either case, return the id of the key. + def _add_key(self, key: str) -> int: + if key in self.dictionary: + return self.dictionary[key] + id = len(self.dictionary_keys) + self.dictionary[key] = id + self.dictionary_keys.append(key.encode("utf-8")) + return id + + def _handle_float(self, num_str): + # a float can be a decimal if it only contains digits, '-', or '-'. + if all([ch.isdecimal() or ch == "-" or ch == "." for ch in num_str]): + dec = decimal.Decimal(num_str) + precision = len(dec.as_tuple().digits) + scale = -dec.as_tuple().exponent + + if ( + scale <= VariantUtils.MAX_DECIMAL16_PRECISION + and precision <= VariantUtils.MAX_DECIMAL16_PRECISION + ): + return dec + return float(num_str) + + def _append_boolean(self, b: bool) -> None: + self._check_capacity(1) + self.value.extend(self._primitive_header(VariantUtils.TRUE if b else VariantUtils.FALSE)) + + def _append_null(self) -> None: + self._check_capacity(1) + self.value.extend(self._primitive_header(VariantUtils.NULL)) + + def _append_string(self, s: str) -> None: + text = s.encode("utf-8") + long_str = len(text) > VariantUtils.MAX_SHORT_STR_SIZE + additional = (1 + VariantUtils.U32_SIZE) if long_str else 1 + self._check_capacity(additional + len(text)) + if long_str: + self.value.extend(self._primitive_header(VariantUtils.LONG_STR)) + self.value.extend(len(text).to_bytes(VariantUtils.U32_SIZE, byteorder="little")) + else: + self.value.extend(self._short_string_header(len(text))) + self.value.extend(text) + + def _append_int(self, i: int) -> bool: + self._check_capacity(1 + 8) + if i >= VariantUtils.I8_MIN and i <= VariantUtils.I8_MAX: + self.value.extend(self._primitive_header(VariantUtils.INT1)) + self.value.extend(i.to_bytes(1, byteorder="little", signed=True)) + if i >= VariantUtils.I16_MIN and i <= VariantUtils.I16_MAX: Review Comment: Shouldn't these be `elif`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org