This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 794814b5ba [#5202] feat(client-python): Add column default value
serdes (#7676)
794814b5ba is described below
commit 794814b5ba74165a7d6f2f1bc99de00b5a3829d4
Author: George T. C. Lai <[email protected]>
AuthorDate: Fri Jul 18 10:12:13 2025 +0800
[#5202] feat(client-python): Add column default value serdes (#7676)
### What changes were proposed in this pull request?
This is expected to be the last PR to support Column and its default
value. The following Java classes/methods are implemented:
- JsonUtils.ColumnDefaultValueSerializer
- JsonUtils.ColumnDefaultValueDeserializer
- JsonUtils.readFunctionArg
- JsonUtils.writeFunctionArg
### Why are the changes needed?
We need to support Column and its default value in python client.
#5202
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
---------
Signed-off-by: George T. C. Lai <[email protected]>
---
.../gravitino/api/types/json_serdes/base.py | 5 +-
clients/client-python/gravitino/api/types/types.py | 20 +-
.../client-python/gravitino/dto/rel/column_dto.py | 8 +-
.../json_serdes/_helper/serdes_utils.py | 139 +++++++++++
.../json_serdes/column_default_value_serdes.py | 55 +++++
.../dto/rel/test_column_default_value_serdes.py | 91 ++++++++
.../tests/unittests/dto/rel/test_column_dto.py | 184 +++++++++++++++
.../tests/unittests/dto/rel/test_function_arg.py | 20 +-
.../tests/unittests/dto/rel/test_literal_dto.py | 4 +
.../tests/unittests/dto/rel/test_serdes_utils.py | 255 +++++++++++++++++++++
10 files changed, 765 insertions(+), 16 deletions(-)
diff --git a/clients/client-python/gravitino/api/types/json_serdes/base.py
b/clients/client-python/gravitino/api/types/json_serdes/base.py
index 27c838478b..853eeae66e 100644
--- a/clients/client-python/gravitino/api/types/json_serdes/base.py
+++ b/clients/client-python/gravitino/api/types/json_serdes/base.py
@@ -16,13 +16,14 @@
# under the License.
from abc import ABC, abstractmethod
-from typing import Generic, TypeVar
+from typing import Generic, TypeVar, Union
from dataclasses_json.core import Json
+from gravitino.api.expressions.expression import Expression
from gravitino.api.types.types import Type
-GravitinoTypeT = TypeVar("GravitinoTypeT", bound=Type)
+GravitinoTypeT = TypeVar("GravitinoTypeT", bound=Union[Expression, Type])
class JsonSerializable(ABC, Generic[GravitinoTypeT]):
diff --git a/clients/client-python/gravitino/api/types/types.py
b/clients/client-python/gravitino/api/types/types.py
index f005c158a6..6e82725713 100644
--- a/clients/client-python/gravitino/api/types/types.py
+++ b/clients/client-python/gravitino/api/types/types.py
@@ -16,16 +16,18 @@
# under the License.
# pylint: disable=C0302
from __future__ import annotations
+
from typing import List
+
from .type import (
- Type,
- Name,
- PrimitiveType,
- IntegralType,
- FractionType,
+ ComplexType,
DateTimeType,
+ FractionType,
+ IntegralType,
IntervalType,
- ComplexType,
+ Name,
+ PrimitiveType,
+ Type,
)
@@ -815,7 +817,7 @@ class Types:
)
def __eq__(self, other):
- if not isinstance(other, Types.ListType):
+ if isinstance(other, Types.ListType):
return (
self._element_nullable == other.element_nullable()
and self._element_type == other.element_type()
@@ -972,7 +974,7 @@ class Types:
Returns:
True if both UnionType objects have the same types, False
otherwise.
"""
- if not isinstance(other, Types.UnionType):
+ if isinstance(other, Types.UnionType):
return self._types == other.types()
return False
@@ -1025,7 +1027,7 @@ class Types:
Returns:
True if both unparsed_type objects have the same unparsed type
string, False otherwise.
"""
- if not isinstance(other, Types.UnparsedType):
+ if isinstance(other, Types.UnparsedType):
return self._unparsed_type == other.unparsed_type()
return False
diff --git a/clients/client-python/gravitino/dto/rel/column_dto.py
b/clients/client-python/gravitino/dto/rel/column_dto.py
index 2846ca4cbf..42286fd128 100644
--- a/clients/client-python/gravitino/dto/rel/column_dto.py
+++ b/clients/client-python/gravitino/dto/rel/column_dto.py
@@ -27,6 +27,9 @@ from gravitino.api.expressions.expression import Expression
from gravitino.api.types.json_serdes.type_serdes import TypeSerdes
from gravitino.api.types.type import Type
from gravitino.api.types.types import Types
+from gravitino.dto.rel.expressions.json_serdes.column_default_value_serdes
import (
+ ColumnDefaultValueSerdes,
+)
from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
from gravitino.utils.precondition import Precondition
@@ -50,13 +53,12 @@ class ColumnDTO(Column, DataClassJsonMixin):
_comment: str = field(metadata=config(field_name="comment"))
"""The comment associated with the column."""
- # TODO: We shall specify encoder/decoder in the future PR. They're now
dummy serdes.
_default_value: Optional[Union[Expression, List[Expression]]] = field(
default_factory=lambda: Column.DEFAULT_VALUE_NOT_SET,
metadata=config(
field_name="defaultValue",
- encoder=lambda _: None,
- decoder=lambda _: Column.DEFAULT_VALUE_NOT_SET,
+ encoder=ColumnDefaultValueSerdes.serialize,
+ decoder=ColumnDefaultValueSerdes.deserialize,
exclude=lambda value: value is None
or value is Column.DEFAULT_VALUE_NOT_SET,
),
diff --git
a/clients/client-python/gravitino/dto/rel/expressions/json_serdes/_helper/serdes_utils.py
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/_helper/serdes_utils.py
new file mode 100644
index 0000000000..fc4b7734cd
--- /dev/null
+++
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/_helper/serdes_utils.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any, ClassVar, Dict, cast
+
+from gravitino.api.types.json_serdes._helper.serdes_utils import (
+ SerdesUtils as TypesSerdesUtils,
+)
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.function_arg import FunctionArg
+from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import
UnparsedExpressionDTO
+from gravitino.exceptions.base import IllegalArgumentException
+from gravitino.utils.precondition import Precondition
+
+
+class SerdesUtils:
+ EXPRESSION_TYPE: ClassVar[str] = "type"
+ DATA_TYPE: ClassVar[str] = "dataType"
+ LITERAL_VALUE: ClassVar[str] = "value"
+ FIELD_NAME: ClassVar[str] = "fieldName"
+ FUNCTION_NAME: ClassVar[str] = "funcName"
+ FUNCTION_ARGS: ClassVar[str] = "funcArgs"
+ UNPARSED_EXPRESSION: ClassVar[str] = "unparsedExpression"
+
+ @classmethod
+ def write_function_arg(cls, arg: FunctionArg) -> Dict[str, Any]:
+ arg_type = arg.arg_type()
+ if arg_type not in FunctionArg.ArgType:
+ raise ValueError(f"Unknown function argument type: {arg_type}")
+
+ arg_data = {cls.EXPRESSION_TYPE: arg_type.name.lower()}
+ if arg_type is FunctionArg.ArgType.LITERAL:
+ expression = cast(LiteralDTO, arg)
+ arg_data[cls.DATA_TYPE] = TypesSerdesUtils.write_data_type(
+ data_type=expression.data_type()
+ )
+ arg_data[cls.LITERAL_VALUE] = expression.value()
+
+ if arg_type is FunctionArg.ArgType.FIELD:
+ arg_data[cls.FIELD_NAME] = cast(FieldReferenceDTO,
arg).field_name()
+
+ if arg_type is FunctionArg.ArgType.FUNCTION:
+ expression = cast(FuncExpressionDTO, arg)
+ arg_data[cls.FUNCTION_NAME] = expression.function_name()
+ arg_data[cls.FUNCTION_ARGS] = [
+ cls.write_function_arg(func_arg) for func_arg in
expression.args()
+ ]
+
+ if arg_type is FunctionArg.ArgType.UNPARSED:
+ expression = cast(UnparsedExpressionDTO, arg)
+ arg_data[cls.UNPARSED_EXPRESSION] =
expression.unparsed_expression()
+
+ return arg_data
+
+ @classmethod
+ def read_function_arg(cls, data: Dict[str, Any]) -> FunctionArg:
+ Precondition.check_argument(
+ data is not None and isinstance(data, dict),
+ f"Cannot parse function arg from invalid JSON: {data}",
+ )
+ Precondition.check_argument(
+ data.get(cls.EXPRESSION_TYPE) is not None,
+ f"Cannot parse function arg from missing type: {data}",
+ )
+ try:
+ arg_type = FunctionArg.ArgType(data[cls.EXPRESSION_TYPE].lower())
+ except ValueError:
+ raise IllegalArgumentException(
+ f"Unknown function argument type: {data[cls.EXPRESSION_TYPE]}"
+ )
+
+ if arg_type is FunctionArg.ArgType.LITERAL:
+ Precondition.check_argument(
+ data.get(cls.DATA_TYPE) is not None,
+ f"Cannot parse literal arg from missing data type: {data}",
+ )
+ Precondition.check_argument(
+ data.get(cls.LITERAL_VALUE) is not None,
+ f"Cannot parse literal arg from missing literal value: {data}",
+ )
+ return (
+ LiteralDTO.builder()
+ .with_data_type(
+
data_type=TypesSerdesUtils.read_data_type(data[cls.DATA_TYPE])
+ )
+ .with_value(value=data[cls.LITERAL_VALUE])
+ .build()
+ )
+
+ if arg_type is FunctionArg.ArgType.FIELD:
+ Precondition.check_argument(
+ data.get(cls.FIELD_NAME) is not None,
+ f"Cannot parse field reference arg from missing field name:
{data}",
+ )
+ return (
+ FieldReferenceDTO.builder()
+ .with_field_name(field_name=data[cls.FIELD_NAME])
+ .build()
+ )
+
+ if arg_type is FunctionArg.ArgType.FUNCTION:
+ Precondition.check_argument(
+ data.get(cls.FUNCTION_NAME) is not None,
+ f"Cannot parse function function arg from missing function
name: {data}",
+ )
+ Precondition.check_argument(
+ data.get(cls.FUNCTION_ARGS) is not None,
+ f"Cannot parse function function arg from missing function
args: {data}",
+ )
+ args = [cls.read_function_arg(arg) for arg in
data[cls.FUNCTION_ARGS]]
+ return (
+ FuncExpressionDTO.builder()
+ .with_function_name(function_name=data[cls.FUNCTION_NAME])
+ .with_function_args(function_args=args or
FunctionArg.EMPTY_ARGS)
+ .build()
+ )
+
+ if arg_type is FunctionArg.ArgType.UNPARSED:
+ Precondition.check_argument(
+ isinstance(data.get(cls.UNPARSED_EXPRESSION), str),
+ f"Cannot parse unparsed expression from missing string field
unparsedExpression: {data}",
+ )
+ return UnparsedExpressionDTO(data[cls.UNPARSED_EXPRESSION])
diff --git
a/clients/client-python/gravitino/dto/rel/expressions/json_serdes/column_default_value_serdes.py
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/column_default_value_serdes.py
new file mode 100644
index 0000000000..6e717c4f62
--- /dev/null
+++
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/column_default_value_serdes.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import overload
+
+from dataclasses_json.core import Json
+
+from gravitino.api.column import Column
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.types.json_serdes.base import JsonSerializable
+from gravitino.dto.rel.expressions.json_serdes._helper.serdes_utils import
SerdesUtils
+
+
+class ColumnDefaultValueSerdes(JsonSerializable[Expression]):
+ """Custom JSON serializer/deserializer for Column default value."""
+
+ @classmethod
+ def serialize(cls, value: Expression) -> Json:
+ if cls.is_empty(value):
+ return None
+ return SerdesUtils.write_function_arg(arg=value)
+
+ @classmethod
+ def deserialize(cls, data: Json) -> Expression:
+ if cls.is_empty(data):
+ return Column.DEFAULT_VALUE_NOT_SET
+ return SerdesUtils.read_function_arg(data=data)
+
+ @classmethod
+ @overload
+ def is_empty(cls, value: Expression) -> bool: ...
+
+ @classmethod
+ @overload
+ def is_empty(cls, value: Json) -> bool: ...
+
+ @classmethod
+ def is_empty(cls, value):
+ if isinstance(value, (Expression, list)):
+ return value is None or value is Column.DEFAULT_VALUE_NOT_SET
+ return value is None
diff --git
a/clients/client-python/tests/unittests/dto/rel/test_column_default_value_serdes.py
b/clients/client-python/tests/unittests/dto/rel/test_column_default_value_serdes.py
new file mode 100644
index 0000000000..303dd98f45
--- /dev/null
+++
b/clients/client-python/tests/unittests/dto/rel/test_column_default_value_serdes.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest.mock import patch
+
+from gravitino.api.column import Column
+from gravitino.api.types.types import Types
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.json_serdes._helper.serdes_utils import
SerdesUtils
+from gravitino.dto.rel.expressions.json_serdes.column_default_value_serdes
import (
+ ColumnDefaultValueSerdes,
+)
+from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import
UnparsedExpressionDTO
+from gravitino.exceptions.base import IllegalArgumentException
+
+
+class TestColumnDefaultValueSerdes(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._literal_dto = (
+ LiteralDTO.builder()
+ .with_data_type(data_type=Types.StringType.get())
+ .with_value(value="test_string")
+ .build()
+ )
+ cls._dtos = [
+ cls._literal_dto,
+ FieldReferenceDTO.builder()
+ .with_field_name(field_name=["field_name"])
+ .build(),
+ FuncExpressionDTO.builder()
+ .with_function_name(function_name="simple_func_name")
+ .with_function_args(function_args=[cls._literal_dto])
+ .build(),
+ UnparsedExpressionDTO.builder()
+
.with_unparsed_expression(unparsed_expression="unparsed_expression")
+ .build(),
+ ]
+
+ def test_column_default_serdes_serialize_empty(self):
+ self.assertIsNone(ColumnDefaultValueSerdes.serialize(value=None))
+ self.assertIsNone(
+
ColumnDefaultValueSerdes.serialize(value=Column.DEFAULT_VALUE_NOT_SET)
+ )
+
+ def test_column_default_serdes_deserialize_empty(self):
+ self.assertIs(
+ ColumnDefaultValueSerdes.deserialize(data=None),
+ Column.DEFAULT_VALUE_NOT_SET,
+ )
+
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse function arg from invalid JSON",
+ ColumnDefaultValueSerdes.deserialize,
+ data="None",
+ )
+
+ def test_serialize_dto(self):
+ for dto in self._dtos:
+ with patch.object(
+ SerdesUtils, "write_function_arg"
+ ) as mock_write_function_arg:
+ ColumnDefaultValueSerdes.serialize(value=dto)
+ mock_write_function_arg.assert_called_once_with(arg=dto)
+
+ def test_deserialize_dto(self):
+ for dto in self._dtos:
+ data = ColumnDefaultValueSerdes.serialize(value=dto)
+ with patch.object(
+ SerdesUtils, "read_function_arg"
+ ) as mock_read_function_arg:
+ ColumnDefaultValueSerdes.deserialize(data=data)
+ mock_read_function_arg.assert_called_once_with(data=data)
diff --git a/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
b/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
index 8bdd6d1307..0a9c2b4a3d 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
@@ -17,13 +17,20 @@
import json
import unittest
+from itertools import product
from gravitino.api.column import Column
from gravitino.api.types.json_serdes import TypeSerdes
from gravitino.api.types.json_serdes._helper.serdes_utils import SerdesUtils
from gravitino.api.types.types import Types
from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.json_serdes.column_default_value_serdes
import (
+ ColumnDefaultValueSerdes,
+)
from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import
UnparsedExpressionDTO
from gravitino.exceptions.base import IllegalArgumentException
@@ -75,6 +82,7 @@ class TestColumnDTO(unittest.TestCase):
column_dto_2 = self._string_columns[2]
self.assertNotEqual(column_dto_1, column_dto_2)
self.assertEqual(column_dto_1, column_dto_1)
+ self.assertNotEqual(column_dto_1, "test")
def test_column_dto_hash(self):
column_dto_1 = self._string_columns[1]
@@ -214,3 +222,179 @@ class TestColumnDTO(unittest.TestCase):
deserialized_column_dto.default_value(),
Column.DEFAULT_VALUE_NOT_SET
)
self.assertEqual(serialized_json, deserialized_json)
+
+ def test_column_dto_serdes_with_default_value_literal(self):
+ expected_dict = {
+ "name": "",
+ "type": "",
+ "comment": "",
+ "defaultValue": None,
+ "nullable": False,
+ "autoIncrement": False,
+ }
+ for data_type, default_value_type in product(
+ self._supported_types, self._supported_types
+ ):
+ default_value = (
+ LiteralDTO.builder()
+ .with_data_type(default_value_type)
+ .with_value(default_value_type.simple_string())
+ .build()
+ )
+ column_dto = (
+ ColumnDTO.builder()
+ .with_name(name=str(data_type.name()))
+ .with_data_type(data_type=data_type)
+ .with_default_value(default_value=default_value)
+ .with_comment(comment=data_type.simple_string())
+ .build()
+ )
+ expected_dict["name"] = str(data_type.name())
+ expected_dict["type"] = TypeSerdes.serialize(data_type)
+ expected_dict["comment"] = data_type.simple_string()
+ expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+ value=default_value
+ )
+ expected_dict["nullable"] = True
+ expected_dict["autoIncrement"] = False
+
+ serialized_result = column_dto.to_json()
+ deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+ self.assertDictEqual(json.loads(serialized_result), expected_dict)
+ self.assertEqual(deserialized_dto, column_dto)
+
+ def test_column_dto_serdes_with_default_value_field_ref(self):
+ expected_dict = {
+ "name": "",
+ "type": "",
+ "comment": "",
+ "defaultValue": None,
+ "nullable": False,
+ "autoIncrement": False,
+ }
+ default_value = (
+
FieldReferenceDTO.builder().with_column_name(["field_reference"]).build()
+ )
+ for data_type in self._supported_types:
+ column_dto = (
+ ColumnDTO.builder()
+ .with_name(name=str(data_type.name()))
+ .with_data_type(data_type=data_type)
+ .with_default_value(default_value=default_value)
+ .with_comment(comment=data_type.simple_string())
+ .build()
+ )
+ expected_dict["name"] = str(data_type.name())
+ expected_dict["type"] = TypeSerdes.serialize(data_type)
+ expected_dict["comment"] = data_type.simple_string()
+ expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+ value=default_value
+ )
+ expected_dict["nullable"] = True
+ expected_dict["autoIncrement"] = False
+
+ serialized_result = column_dto.to_json()
+ deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+ self.assertDictEqual(json.loads(serialized_result), expected_dict)
+ self.assertEqual(deserialized_dto, column_dto)
+
+ def test_column_dto_serdes_with_default_value_func_expression(self):
+ expected_dict = {
+ "name": "",
+ "type": "",
+ "comment": "",
+ "defaultValue": None,
+ "nullable": False,
+ "autoIncrement": False,
+ }
+ func_args = [
+ LiteralDTO.builder()
+ .with_data_type(Types.StringType.get())
+ .with_value("year")
+ .build(),
+ FieldReferenceDTO.builder().with_column_name(["birthday"]).build(),
+ FuncExpressionDTO.builder()
+ .with_function_name("randint")
+ .with_function_args(
+ [
+ LiteralDTO.builder()
+ .with_data_type(Types.IntegerType.get())
+ .with_value("1")
+ .build(),
+ LiteralDTO.builder()
+ .with_data_type(Types.IntegerType.get())
+ .with_value("100")
+ .build(),
+ ]
+ )
+ .build(),
+ ]
+ for data_type in self._supported_types:
+ default_value = (
+ FuncExpressionDTO.builder()
+ .with_function_name("test_function")
+ .with_function_args(func_args)
+ .build()
+ )
+ column_dto = (
+ ColumnDTO.builder()
+ .with_name(name=str(data_type.name()))
+ .with_data_type(data_type=data_type)
+ .with_default_value(default_value=default_value)
+ .with_comment(comment=data_type.simple_string())
+ .build()
+ )
+ expected_dict["name"] = str(data_type.name())
+ expected_dict["type"] = TypeSerdes.serialize(data_type)
+ expected_dict["comment"] = data_type.simple_string()
+ expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+ value=default_value
+ )
+ expected_dict["nullable"] = True
+ expected_dict["autoIncrement"] = False
+
+ serialized_result = column_dto.to_json()
+ deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+ self.assertDictEqual(json.loads(serialized_result), expected_dict)
+ self.assertEqual(deserialized_dto, column_dto)
+
+ def test_column_dto_serialize_with_default_value_unparsed(self):
+ expected_dict = {
+ "name": "",
+ "type": "",
+ "comment": "",
+ "defaultValue": None,
+ "nullable": False,
+ "autoIncrement": False,
+ }
+ for data_type in self._supported_types:
+ default_value = (
+ UnparsedExpressionDTO.builder()
+ .with_unparsed_expression("unparsed_expression")
+ .build()
+ )
+ column_dto = (
+ ColumnDTO.builder()
+ .with_name(name=str(data_type.name()))
+ .with_data_type(data_type=data_type)
+ .with_default_value(default_value=default_value)
+ .with_comment(comment=data_type.simple_string())
+ .build()
+ )
+ expected_dict["name"] = str(data_type.name())
+ expected_dict["type"] = TypeSerdes.serialize(data_type)
+ expected_dict["comment"] = data_type.simple_string()
+ expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+ value=default_value
+ )
+ expected_dict["nullable"] = True
+ expected_dict["autoIncrement"] = False
+
+ serialized_result = column_dto.to_json()
+ deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+ self.assertDictEqual(json.loads(serialized_result), expected_dict)
+ self.assertEqual(deserialized_dto, column_dto)
diff --git a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
index 4590bf7cd6..03e4059f8d 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
@@ -19,6 +19,8 @@ import unittest
from gravitino.api.types.types import Types
from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
from gravitino.dto.rel.expressions.function_arg import FunctionArg
from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
@@ -45,7 +47,21 @@ class TestFunctionArg(unittest.TestCase):
self.assertEqual(FunctionArg.EMPTY_ARGS, [])
def test_function_arg_validate(self):
- LiteralDTO(data_type=Types.StringType.get(), value="test").validate(
+ literal_dto = (
+ LiteralDTO.builder()
+ .with_data_type(Types.StringType.get())
+ .with_value("test")
+ .build()
+ )
+ literal_dto.validate(columns=self._columns)
+
+ field_ref_dto = (
+
FieldReferenceDTO.builder().with_column_name(self._column_names).build()
+ )
+ field_ref_dto.validate(columns=self._columns)
+
+ FuncExpressionDTO.builder().with_function_name(
+ "test_function"
+ ).with_function_args([field_ref_dto, literal_dto]).build().validate(
columns=self._columns
)
- # TODO: add unit test for FunctionArg with children
diff --git a/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
b/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
index 62d1d61bd6..d730a49de2 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
@@ -68,3 +68,7 @@ class TestLiteralDTO(unittest.TestCase):
)
self.assertIsInstance(dto, LiteralDTO)
self.assertTrue(dto == self._literal_dto)
+
+ def test_literal_dto_equality(self):
+ self.assertEqual(self._literal_dto, self._literal_dto)
+ self.assertNotEqual(self._literal_dto, "test")
diff --git a/clients/client-python/tests/unittests/dto/rel/test_serdes_utils.py
b/clients/client-python/tests/unittests/dto/rel/test_serdes_utils.py
new file mode 100644
index 0000000000..d8b29659f6
--- /dev/null
+++ b/clients/client-python/tests/unittests/dto/rel/test_serdes_utils.py
@@ -0,0 +1,255 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from enum import Enum
+from unittest.mock import patch
+
+from gravitino.api.types.types import Types
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.function_arg import FunctionArg
+from gravitino.dto.rel.expressions.json_serdes._helper.serdes_utils import
SerdesUtils
+from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import
UnparsedExpressionDTO
+from gravitino.exceptions.base import IllegalArgumentException
+
+
+class MockArgType(str, Enum):
+ INVALID_ARG_TYPE = "invalid_arg_type"
+
+
+class TestExpressionSerdesUtils(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._literal_dto = (
+ LiteralDTO.builder()
+ .with_data_type(data_type=Types.StringType.get())
+ .with_value(value="test_string")
+ .build()
+ )
+ cls._field_reference_dto = (
+ FieldReferenceDTO.builder()
+ .with_field_name(field_name=["field_name"])
+ .build()
+ )
+ cls._naive_func_expression_dto = (
+ FuncExpressionDTO.builder()
+ .with_function_name(function_name="simple_func_name")
+ .with_function_args(function_args=[cls._literal_dto])
+ .build()
+ )
+ cls._func_expression_dto = (
+ FuncExpressionDTO.builder()
+ .with_function_name(function_name="func_name")
+ .with_function_args(
+ function_args=[
+ cls._literal_dto,
+ cls._field_reference_dto,
+ cls._naive_func_expression_dto,
+ ]
+ )
+ .build()
+ )
+ cls._unparsed_expression_dto = (
+ UnparsedExpressionDTO.builder()
+
.with_unparsed_expression(unparsed_expression="unparsed_expression")
+ .build()
+ )
+
+ def test_write_function_arg_invalid_arg_type(self):
+ mock_dto = (
+ LiteralDTO.builder()
+ .with_data_type(data_type=Types.StringType.get())
+ .with_value(value="test")
+ .build()
+ )
+ with patch.object(
+ mock_dto, "arg_type", return_value=MockArgType.INVALID_ARG_TYPE
+ ):
+ self.assertRaises(ValueError, SerdesUtils.write_function_arg,
arg=mock_dto)
+
+ def test_write_function_arg_literal_dto(self):
+ result = SerdesUtils.write_function_arg(arg=self._literal_dto)
+ expected_result = {
+ SerdesUtils.EXPRESSION_TYPE:
self._literal_dto.arg_type().name.lower(),
+ SerdesUtils.DATA_TYPE:
self._literal_dto.data_type().simple_string(),
+ SerdesUtils.LITERAL_VALUE: self._literal_dto.value(),
+ }
+ self.assertDictEqual(result, expected_result)
+
+ def test_write_function_arg_field_reference_dto(self):
+ result = SerdesUtils.write_function_arg(arg=self._field_reference_dto)
+ expected_result = {
+ SerdesUtils.EXPRESSION_TYPE:
self._field_reference_dto.arg_type().name.lower(),
+ SerdesUtils.FIELD_NAME: self._field_reference_dto.field_name(),
+ }
+ self.assertDictEqual(result, expected_result)
+
+ def test_write_function_arg_func_expression_dto(self):
+ result =
SerdesUtils.write_function_arg(arg=self._naive_func_expression_dto)
+ expected_result = {
+ SerdesUtils.EXPRESSION_TYPE:
self._naive_func_expression_dto.arg_type().name.lower(),
+ SerdesUtils.FUNCTION_NAME:
self._naive_func_expression_dto.function_name(),
+ SerdesUtils.FUNCTION_ARGS: [
+ SerdesUtils.write_function_arg(arg=self._literal_dto),
+ ],
+ }
+ self.assertDictEqual(result, expected_result)
+
+ result = SerdesUtils.write_function_arg(arg=self._func_expression_dto)
+ expected_result = {
+ SerdesUtils.EXPRESSION_TYPE:
self._func_expression_dto.arg_type().name.lower(),
+ SerdesUtils.FUNCTION_NAME:
self._func_expression_dto.function_name(),
+ SerdesUtils.FUNCTION_ARGS: [
+ SerdesUtils.write_function_arg(arg=self._literal_dto),
+ SerdesUtils.write_function_arg(arg=self._field_reference_dto),
+
SerdesUtils.write_function_arg(arg=self._naive_func_expression_dto),
+ ],
+ }
+ self.assertDictEqual(result, expected_result)
+
+ def test_write_function_arg_unparsed_expression_dto(self):
+ result =
SerdesUtils.write_function_arg(arg=self._unparsed_expression_dto)
+ expected_result = {
+ SerdesUtils.EXPRESSION_TYPE:
self._unparsed_expression_dto.arg_type().name.lower(),
+ SerdesUtils.UNPARSED_EXPRESSION:
self._unparsed_expression_dto.unparsed_expression(),
+ }
+ self.assertDictEqual(result, expected_result)
+
+ def test_read_function_arg_invalid_data(self):
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse function arg from invalid JSON",
+ SerdesUtils.read_function_arg,
+ data=None,
+ )
+
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse function arg from invalid JSON",
+ SerdesUtils.read_function_arg,
+ data="invalid_data",
+ )
+
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse function arg from missing type",
+ SerdesUtils.read_function_arg,
+ data={},
+ )
+
+ def test_read_function_arg_literal_dto(self):
+ data = {SerdesUtils.EXPRESSION_TYPE:
self._literal_dto.arg_type().name.lower()}
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse literal arg from missing data type",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.DATA_TYPE] =
self._literal_dto.data_type().simple_string()
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse literal arg from missing literal value",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.LITERAL_VALUE] = self._literal_dto.value()
+ result = SerdesUtils.read_function_arg(data=data)
+ self.assertEqual(result, self._literal_dto)
+
+ def test_read_function_arg_field_reference_dto(self):
+ data = {
+ SerdesUtils.EXPRESSION_TYPE:
self._field_reference_dto.arg_type().name.lower()
+ }
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse field reference arg from missing field name",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.FIELD_NAME] = self._field_reference_dto.field_name()
+ result = SerdesUtils.read_function_arg(data=data)
+ self.assertEqual(result, self._field_reference_dto)
+
+ def test_read_function_arg_func_expression_dto(self):
+ data = {
+ SerdesUtils.EXPRESSION_TYPE:
self._naive_func_expression_dto.arg_type().name.lower()
+ }
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse function function arg from missing function name",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.FUNCTION_NAME] = (
+ self._naive_func_expression_dto.function_name()
+ )
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse function function arg from missing function args",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.FUNCTION_ARGS] = [
+ SerdesUtils.write_function_arg(arg=self._literal_dto),
+ ]
+ result = SerdesUtils.read_function_arg(data=data)
+ self.assertEqual(result, self._naive_func_expression_dto)
+
+ data[SerdesUtils.FUNCTION_ARGS] = []
+ result = SerdesUtils.read_function_arg(data=data)
+ self.assertEqual(
+ result,
+ FuncExpressionDTO.builder()
+ .with_function_name(
+ function_name=self._naive_func_expression_dto.function_name()
+ )
+ .with_function_args(function_args=FunctionArg.EMPTY_ARGS)
+ .build(),
+ )
+
+ def test_read_function_arg_unparsed_expression_dto(self):
+ data = {SerdesUtils.EXPRESSION_TYPE: "invalid_expression_type"}
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Unknown function argument type",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.EXPRESSION_TYPE] = (
+ self._unparsed_expression_dto.arg_type().name.lower()
+ )
+ data[SerdesUtils.UNPARSED_EXPRESSION] = {}
+ self.assertRaisesRegex(
+ IllegalArgumentException,
+ "Cannot parse unparsed expression from missing string field
unparsedExpression",
+ SerdesUtils.read_function_arg,
+ data=data,
+ )
+
+ data[SerdesUtils.UNPARSED_EXPRESSION] = (
+ self._unparsed_expression_dto.unparsed_expression()
+ )
+ result = SerdesUtils.read_function_arg(data=data)
+ self.assertEqual(result, self._unparsed_expression_dto)