This is an automated email from the ASF dual-hosted git repository.

jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git


The following commit(s) were added to refs/heads/main by this push:
     new 794814b5ba [#5202] feat(client-python): Add column default value 
serdes (#7676)
794814b5ba is described below

commit 794814b5ba74165a7d6f2f1bc99de00b5a3829d4
Author: George T. C. Lai <[email protected]>
AuthorDate: Fri Jul 18 10:12:13 2025 +0800

    [#5202] feat(client-python): Add column default value serdes (#7676)
    
    ### What changes were proposed in this pull request?
    
    This is expected to be the last PR to support Column and its default
    value. The following Java classes/methods are implemented:
    
    - JsonUtils.ColumnDefaultValueSerializer
    - JsonUtils.ColumnDefaultValueDeserializer
    - JsonUtils.readFunctionArg
    - JsonUtils.writeFunctionArg
    
    ### Why are the changes needed?
    
    We need to support Column and its default value in python client.
    
    #5202
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    ---------
    
    Signed-off-by: George T. C. Lai <[email protected]>
---
 .../gravitino/api/types/json_serdes/base.py        |   5 +-
 clients/client-python/gravitino/api/types/types.py |  20 +-
 .../client-python/gravitino/dto/rel/column_dto.py  |   8 +-
 .../json_serdes/_helper/serdes_utils.py            | 139 +++++++++++
 .../json_serdes/column_default_value_serdes.py     |  55 +++++
 .../dto/rel/test_column_default_value_serdes.py    |  91 ++++++++
 .../tests/unittests/dto/rel/test_column_dto.py     | 184 +++++++++++++++
 .../tests/unittests/dto/rel/test_function_arg.py   |  20 +-
 .../tests/unittests/dto/rel/test_literal_dto.py    |   4 +
 .../tests/unittests/dto/rel/test_serdes_utils.py   | 255 +++++++++++++++++++++
 10 files changed, 765 insertions(+), 16 deletions(-)

diff --git a/clients/client-python/gravitino/api/types/json_serdes/base.py 
b/clients/client-python/gravitino/api/types/json_serdes/base.py
index 27c838478b..853eeae66e 100644
--- a/clients/client-python/gravitino/api/types/json_serdes/base.py
+++ b/clients/client-python/gravitino/api/types/json_serdes/base.py
@@ -16,13 +16,14 @@
 # under the License.
 
 from abc import ABC, abstractmethod
-from typing import Generic, TypeVar
+from typing import Generic, TypeVar, Union
 
 from dataclasses_json.core import Json
 
+from gravitino.api.expressions.expression import Expression
 from gravitino.api.types.types import Type
 
-GravitinoTypeT = TypeVar("GravitinoTypeT", bound=Type)
+GravitinoTypeT = TypeVar("GravitinoTypeT", bound=Union[Expression, Type])
 
 
 class JsonSerializable(ABC, Generic[GravitinoTypeT]):
diff --git a/clients/client-python/gravitino/api/types/types.py 
b/clients/client-python/gravitino/api/types/types.py
index f005c158a6..6e82725713 100644
--- a/clients/client-python/gravitino/api/types/types.py
+++ b/clients/client-python/gravitino/api/types/types.py
@@ -16,16 +16,18 @@
 # under the License.
 # pylint: disable=C0302
 from __future__ import annotations
+
 from typing import List
+
 from .type import (
-    Type,
-    Name,
-    PrimitiveType,
-    IntegralType,
-    FractionType,
+    ComplexType,
     DateTimeType,
+    FractionType,
+    IntegralType,
     IntervalType,
-    ComplexType,
+    Name,
+    PrimitiveType,
+    Type,
 )
 
 
@@ -815,7 +817,7 @@ class Types:
             )
 
         def __eq__(self, other):
-            if not isinstance(other, Types.ListType):
+            if isinstance(other, Types.ListType):
                 return (
                     self._element_nullable == other.element_nullable()
                     and self._element_type == other.element_type()
@@ -972,7 +974,7 @@ class Types:
             Returns:
                 True if both UnionType objects have the same types, False 
otherwise.
             """
-            if not isinstance(other, Types.UnionType):
+            if isinstance(other, Types.UnionType):
                 return self._types == other.types()
             return False
 
@@ -1025,7 +1027,7 @@ class Types:
             Returns:
                 True if both unparsed_type objects have the same unparsed type 
string, False otherwise.
             """
-            if not isinstance(other, Types.UnparsedType):
+            if isinstance(other, Types.UnparsedType):
                 return self._unparsed_type == other.unparsed_type()
             return False
 
diff --git a/clients/client-python/gravitino/dto/rel/column_dto.py 
b/clients/client-python/gravitino/dto/rel/column_dto.py
index 2846ca4cbf..42286fd128 100644
--- a/clients/client-python/gravitino/dto/rel/column_dto.py
+++ b/clients/client-python/gravitino/dto/rel/column_dto.py
@@ -27,6 +27,9 @@ from gravitino.api.expressions.expression import Expression
 from gravitino.api.types.json_serdes.type_serdes import TypeSerdes
 from gravitino.api.types.type import Type
 from gravitino.api.types.types import Types
+from gravitino.dto.rel.expressions.json_serdes.column_default_value_serdes 
import (
+    ColumnDefaultValueSerdes,
+)
 from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
 from gravitino.utils.precondition import Precondition
 
@@ -50,13 +53,12 @@ class ColumnDTO(Column, DataClassJsonMixin):
     _comment: str = field(metadata=config(field_name="comment"))
     """The comment associated with the column."""
 
-    # TODO: We shall specify encoder/decoder in the future PR. They're now 
dummy serdes.
     _default_value: Optional[Union[Expression, List[Expression]]] = field(
         default_factory=lambda: Column.DEFAULT_VALUE_NOT_SET,
         metadata=config(
             field_name="defaultValue",
-            encoder=lambda _: None,
-            decoder=lambda _: Column.DEFAULT_VALUE_NOT_SET,
+            encoder=ColumnDefaultValueSerdes.serialize,
+            decoder=ColumnDefaultValueSerdes.deserialize,
             exclude=lambda value: value is None
             or value is Column.DEFAULT_VALUE_NOT_SET,
         ),
diff --git 
a/clients/client-python/gravitino/dto/rel/expressions/json_serdes/_helper/serdes_utils.py
 
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/_helper/serdes_utils.py
new file mode 100644
index 0000000000..fc4b7734cd
--- /dev/null
+++ 
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/_helper/serdes_utils.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any, ClassVar, Dict, cast
+
+from gravitino.api.types.json_serdes._helper.serdes_utils import (
+    SerdesUtils as TypesSerdesUtils,
+)
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.function_arg import FunctionArg
+from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import 
UnparsedExpressionDTO
+from gravitino.exceptions.base import IllegalArgumentException
+from gravitino.utils.precondition import Precondition
+
+
+class SerdesUtils:
+    EXPRESSION_TYPE: ClassVar[str] = "type"
+    DATA_TYPE: ClassVar[str] = "dataType"
+    LITERAL_VALUE: ClassVar[str] = "value"
+    FIELD_NAME: ClassVar[str] = "fieldName"
+    FUNCTION_NAME: ClassVar[str] = "funcName"
+    FUNCTION_ARGS: ClassVar[str] = "funcArgs"
+    UNPARSED_EXPRESSION: ClassVar[str] = "unparsedExpression"
+
+    @classmethod
+    def write_function_arg(cls, arg: FunctionArg) -> Dict[str, Any]:
+        arg_type = arg.arg_type()
+        if arg_type not in FunctionArg.ArgType:
+            raise ValueError(f"Unknown function argument type: {arg_type}")
+
+        arg_data = {cls.EXPRESSION_TYPE: arg_type.name.lower()}
+        if arg_type is FunctionArg.ArgType.LITERAL:
+            expression = cast(LiteralDTO, arg)
+            arg_data[cls.DATA_TYPE] = TypesSerdesUtils.write_data_type(
+                data_type=expression.data_type()
+            )
+            arg_data[cls.LITERAL_VALUE] = expression.value()
+
+        if arg_type is FunctionArg.ArgType.FIELD:
+            arg_data[cls.FIELD_NAME] = cast(FieldReferenceDTO, 
arg).field_name()
+
+        if arg_type is FunctionArg.ArgType.FUNCTION:
+            expression = cast(FuncExpressionDTO, arg)
+            arg_data[cls.FUNCTION_NAME] = expression.function_name()
+            arg_data[cls.FUNCTION_ARGS] = [
+                cls.write_function_arg(func_arg) for func_arg in 
expression.args()
+            ]
+
+        if arg_type is FunctionArg.ArgType.UNPARSED:
+            expression = cast(UnparsedExpressionDTO, arg)
+            arg_data[cls.UNPARSED_EXPRESSION] = 
expression.unparsed_expression()
+
+        return arg_data
+
+    @classmethod
+    def read_function_arg(cls, data: Dict[str, Any]) -> FunctionArg:
+        Precondition.check_argument(
+            data is not None and isinstance(data, dict),
+            f"Cannot parse function arg from invalid JSON: {data}",
+        )
+        Precondition.check_argument(
+            data.get(cls.EXPRESSION_TYPE) is not None,
+            f"Cannot parse function arg from missing type: {data}",
+        )
+        try:
+            arg_type = FunctionArg.ArgType(data[cls.EXPRESSION_TYPE].lower())
+        except ValueError:
+            raise IllegalArgumentException(
+                f"Unknown function argument type: {data[cls.EXPRESSION_TYPE]}"
+            )
+
+        if arg_type is FunctionArg.ArgType.LITERAL:
+            Precondition.check_argument(
+                data.get(cls.DATA_TYPE) is not None,
+                f"Cannot parse literal arg from missing data type: {data}",
+            )
+            Precondition.check_argument(
+                data.get(cls.LITERAL_VALUE) is not None,
+                f"Cannot parse literal arg from missing literal value: {data}",
+            )
+            return (
+                LiteralDTO.builder()
+                .with_data_type(
+                    
data_type=TypesSerdesUtils.read_data_type(data[cls.DATA_TYPE])
+                )
+                .with_value(value=data[cls.LITERAL_VALUE])
+                .build()
+            )
+
+        if arg_type is FunctionArg.ArgType.FIELD:
+            Precondition.check_argument(
+                data.get(cls.FIELD_NAME) is not None,
+                f"Cannot parse field reference arg from missing field name: 
{data}",
+            )
+            return (
+                FieldReferenceDTO.builder()
+                .with_field_name(field_name=data[cls.FIELD_NAME])
+                .build()
+            )
+
+        if arg_type is FunctionArg.ArgType.FUNCTION:
+            Precondition.check_argument(
+                data.get(cls.FUNCTION_NAME) is not None,
+                f"Cannot parse function function arg from missing function 
name: {data}",
+            )
+            Precondition.check_argument(
+                data.get(cls.FUNCTION_ARGS) is not None,
+                f"Cannot parse function function arg from missing function 
args: {data}",
+            )
+            args = [cls.read_function_arg(arg) for arg in 
data[cls.FUNCTION_ARGS]]
+            return (
+                FuncExpressionDTO.builder()
+                .with_function_name(function_name=data[cls.FUNCTION_NAME])
+                .with_function_args(function_args=args or 
FunctionArg.EMPTY_ARGS)
+                .build()
+            )
+
+        if arg_type is FunctionArg.ArgType.UNPARSED:
+            Precondition.check_argument(
+                isinstance(data.get(cls.UNPARSED_EXPRESSION), str),
+                f"Cannot parse unparsed expression from missing string field 
unparsedExpression: {data}",
+            )
+            return UnparsedExpressionDTO(data[cls.UNPARSED_EXPRESSION])
diff --git 
a/clients/client-python/gravitino/dto/rel/expressions/json_serdes/column_default_value_serdes.py
 
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/column_default_value_serdes.py
new file mode 100644
index 0000000000..6e717c4f62
--- /dev/null
+++ 
b/clients/client-python/gravitino/dto/rel/expressions/json_serdes/column_default_value_serdes.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import overload
+
+from dataclasses_json.core import Json
+
+from gravitino.api.column import Column
+from gravitino.api.expressions.expression import Expression
+from gravitino.api.types.json_serdes.base import JsonSerializable
+from gravitino.dto.rel.expressions.json_serdes._helper.serdes_utils import 
SerdesUtils
+
+
+class ColumnDefaultValueSerdes(JsonSerializable[Expression]):
+    """Custom JSON serializer/deserializer for Column default value."""
+
+    @classmethod
+    def serialize(cls, value: Expression) -> Json:
+        if cls.is_empty(value):
+            return None
+        return SerdesUtils.write_function_arg(arg=value)
+
+    @classmethod
+    def deserialize(cls, data: Json) -> Expression:
+        if cls.is_empty(data):
+            return Column.DEFAULT_VALUE_NOT_SET
+        return SerdesUtils.read_function_arg(data=data)
+
+    @classmethod
+    @overload
+    def is_empty(cls, value: Expression) -> bool: ...
+
+    @classmethod
+    @overload
+    def is_empty(cls, value: Json) -> bool: ...
+
+    @classmethod
+    def is_empty(cls, value):
+        if isinstance(value, (Expression, list)):
+            return value is None or value is Column.DEFAULT_VALUE_NOT_SET
+        return value is None
diff --git 
a/clients/client-python/tests/unittests/dto/rel/test_column_default_value_serdes.py
 
b/clients/client-python/tests/unittests/dto/rel/test_column_default_value_serdes.py
new file mode 100644
index 0000000000..303dd98f45
--- /dev/null
+++ 
b/clients/client-python/tests/unittests/dto/rel/test_column_default_value_serdes.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest.mock import patch
+
+from gravitino.api.column import Column
+from gravitino.api.types.types import Types
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.json_serdes._helper.serdes_utils import 
SerdesUtils
+from gravitino.dto.rel.expressions.json_serdes.column_default_value_serdes 
import (
+    ColumnDefaultValueSerdes,
+)
+from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import 
UnparsedExpressionDTO
+from gravitino.exceptions.base import IllegalArgumentException
+
+
+class TestColumnDefaultValueSerdes(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls._literal_dto = (
+            LiteralDTO.builder()
+            .with_data_type(data_type=Types.StringType.get())
+            .with_value(value="test_string")
+            .build()
+        )
+        cls._dtos = [
+            cls._literal_dto,
+            FieldReferenceDTO.builder()
+            .with_field_name(field_name=["field_name"])
+            .build(),
+            FuncExpressionDTO.builder()
+            .with_function_name(function_name="simple_func_name")
+            .with_function_args(function_args=[cls._literal_dto])
+            .build(),
+            UnparsedExpressionDTO.builder()
+            
.with_unparsed_expression(unparsed_expression="unparsed_expression")
+            .build(),
+        ]
+
+    def test_column_default_serdes_serialize_empty(self):
+        self.assertIsNone(ColumnDefaultValueSerdes.serialize(value=None))
+        self.assertIsNone(
+            
ColumnDefaultValueSerdes.serialize(value=Column.DEFAULT_VALUE_NOT_SET)
+        )
+
+    def test_column_default_serdes_deserialize_empty(self):
+        self.assertIs(
+            ColumnDefaultValueSerdes.deserialize(data=None),
+            Column.DEFAULT_VALUE_NOT_SET,
+        )
+
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse function arg from invalid JSON",
+            ColumnDefaultValueSerdes.deserialize,
+            data="None",
+        )
+
+    def test_serialize_dto(self):
+        for dto in self._dtos:
+            with patch.object(
+                SerdesUtils, "write_function_arg"
+            ) as mock_write_function_arg:
+                ColumnDefaultValueSerdes.serialize(value=dto)
+                mock_write_function_arg.assert_called_once_with(arg=dto)
+
+    def test_deserialize_dto(self):
+        for dto in self._dtos:
+            data = ColumnDefaultValueSerdes.serialize(value=dto)
+            with patch.object(
+                SerdesUtils, "read_function_arg"
+            ) as mock_read_function_arg:
+                ColumnDefaultValueSerdes.deserialize(data=data)
+                mock_read_function_arg.assert_called_once_with(data=data)
diff --git a/clients/client-python/tests/unittests/dto/rel/test_column_dto.py 
b/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
index 8bdd6d1307..0a9c2b4a3d 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_column_dto.py
@@ -17,13 +17,20 @@
 
 import json
 import unittest
+from itertools import product
 
 from gravitino.api.column import Column
 from gravitino.api.types.json_serdes import TypeSerdes
 from gravitino.api.types.json_serdes._helper.serdes_utils import SerdesUtils
 from gravitino.api.types.types import Types
 from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.json_serdes.column_default_value_serdes 
import (
+    ColumnDefaultValueSerdes,
+)
 from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import 
UnparsedExpressionDTO
 from gravitino.exceptions.base import IllegalArgumentException
 
 
@@ -75,6 +82,7 @@ class TestColumnDTO(unittest.TestCase):
         column_dto_2 = self._string_columns[2]
         self.assertNotEqual(column_dto_1, column_dto_2)
         self.assertEqual(column_dto_1, column_dto_1)
+        self.assertNotEqual(column_dto_1, "test")
 
     def test_column_dto_hash(self):
         column_dto_1 = self._string_columns[1]
@@ -214,3 +222,179 @@ class TestColumnDTO(unittest.TestCase):
                 deserialized_column_dto.default_value(), 
Column.DEFAULT_VALUE_NOT_SET
             )
             self.assertEqual(serialized_json, deserialized_json)
+
+    def test_column_dto_serdes_with_default_value_literal(self):
+        expected_dict = {
+            "name": "",
+            "type": "",
+            "comment": "",
+            "defaultValue": None,
+            "nullable": False,
+            "autoIncrement": False,
+        }
+        for data_type, default_value_type in product(
+            self._supported_types, self._supported_types
+        ):
+            default_value = (
+                LiteralDTO.builder()
+                .with_data_type(default_value_type)
+                .with_value(default_value_type.simple_string())
+                .build()
+            )
+            column_dto = (
+                ColumnDTO.builder()
+                .with_name(name=str(data_type.name()))
+                .with_data_type(data_type=data_type)
+                .with_default_value(default_value=default_value)
+                .with_comment(comment=data_type.simple_string())
+                .build()
+            )
+            expected_dict["name"] = str(data_type.name())
+            expected_dict["type"] = TypeSerdes.serialize(data_type)
+            expected_dict["comment"] = data_type.simple_string()
+            expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+                value=default_value
+            )
+            expected_dict["nullable"] = True
+            expected_dict["autoIncrement"] = False
+
+            serialized_result = column_dto.to_json()
+            deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+            self.assertDictEqual(json.loads(serialized_result), expected_dict)
+            self.assertEqual(deserialized_dto, column_dto)
+
+    def test_column_dto_serdes_with_default_value_field_ref(self):
+        expected_dict = {
+            "name": "",
+            "type": "",
+            "comment": "",
+            "defaultValue": None,
+            "nullable": False,
+            "autoIncrement": False,
+        }
+        default_value = (
+            
FieldReferenceDTO.builder().with_column_name(["field_reference"]).build()
+        )
+        for data_type in self._supported_types:
+            column_dto = (
+                ColumnDTO.builder()
+                .with_name(name=str(data_type.name()))
+                .with_data_type(data_type=data_type)
+                .with_default_value(default_value=default_value)
+                .with_comment(comment=data_type.simple_string())
+                .build()
+            )
+            expected_dict["name"] = str(data_type.name())
+            expected_dict["type"] = TypeSerdes.serialize(data_type)
+            expected_dict["comment"] = data_type.simple_string()
+            expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+                value=default_value
+            )
+            expected_dict["nullable"] = True
+            expected_dict["autoIncrement"] = False
+
+            serialized_result = column_dto.to_json()
+            deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+            self.assertDictEqual(json.loads(serialized_result), expected_dict)
+            self.assertEqual(deserialized_dto, column_dto)
+
+    def test_column_dto_serdes_with_default_value_func_expression(self):
+        expected_dict = {
+            "name": "",
+            "type": "",
+            "comment": "",
+            "defaultValue": None,
+            "nullable": False,
+            "autoIncrement": False,
+        }
+        func_args = [
+            LiteralDTO.builder()
+            .with_data_type(Types.StringType.get())
+            .with_value("year")
+            .build(),
+            FieldReferenceDTO.builder().with_column_name(["birthday"]).build(),
+            FuncExpressionDTO.builder()
+            .with_function_name("randint")
+            .with_function_args(
+                [
+                    LiteralDTO.builder()
+                    .with_data_type(Types.IntegerType.get())
+                    .with_value("1")
+                    .build(),
+                    LiteralDTO.builder()
+                    .with_data_type(Types.IntegerType.get())
+                    .with_value("100")
+                    .build(),
+                ]
+            )
+            .build(),
+        ]
+        for data_type in self._supported_types:
+            default_value = (
+                FuncExpressionDTO.builder()
+                .with_function_name("test_function")
+                .with_function_args(func_args)
+                .build()
+            )
+            column_dto = (
+                ColumnDTO.builder()
+                .with_name(name=str(data_type.name()))
+                .with_data_type(data_type=data_type)
+                .with_default_value(default_value=default_value)
+                .with_comment(comment=data_type.simple_string())
+                .build()
+            )
+            expected_dict["name"] = str(data_type.name())
+            expected_dict["type"] = TypeSerdes.serialize(data_type)
+            expected_dict["comment"] = data_type.simple_string()
+            expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+                value=default_value
+            )
+            expected_dict["nullable"] = True
+            expected_dict["autoIncrement"] = False
+
+            serialized_result = column_dto.to_json()
+            deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+            self.assertDictEqual(json.loads(serialized_result), expected_dict)
+            self.assertEqual(deserialized_dto, column_dto)
+
+    def test_column_dto_serialize_with_default_value_unparsed(self):
+        expected_dict = {
+            "name": "",
+            "type": "",
+            "comment": "",
+            "defaultValue": None,
+            "nullable": False,
+            "autoIncrement": False,
+        }
+        for data_type in self._supported_types:
+            default_value = (
+                UnparsedExpressionDTO.builder()
+                .with_unparsed_expression("unparsed_expression")
+                .build()
+            )
+            column_dto = (
+                ColumnDTO.builder()
+                .with_name(name=str(data_type.name()))
+                .with_data_type(data_type=data_type)
+                .with_default_value(default_value=default_value)
+                .with_comment(comment=data_type.simple_string())
+                .build()
+            )
+            expected_dict["name"] = str(data_type.name())
+            expected_dict["type"] = TypeSerdes.serialize(data_type)
+            expected_dict["comment"] = data_type.simple_string()
+            expected_dict["defaultValue"] = ColumnDefaultValueSerdes.serialize(
+                value=default_value
+            )
+            expected_dict["nullable"] = True
+            expected_dict["autoIncrement"] = False
+
+            serialized_result = column_dto.to_json()
+            deserialized_dto = ColumnDTO.from_json(serialized_result)
+
+            self.assertDictEqual(json.loads(serialized_result), expected_dict)
+            self.assertEqual(deserialized_dto, column_dto)
diff --git a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py 
b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
index 4590bf7cd6..03e4059f8d 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_function_arg.py
@@ -19,6 +19,8 @@ import unittest
 
 from gravitino.api.types.types import Types
 from gravitino.dto.rel.column_dto import ColumnDTO
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
 from gravitino.dto.rel.expressions.function_arg import FunctionArg
 from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
 
@@ -45,7 +47,21 @@ class TestFunctionArg(unittest.TestCase):
         self.assertEqual(FunctionArg.EMPTY_ARGS, [])
 
     def test_function_arg_validate(self):
-        LiteralDTO(data_type=Types.StringType.get(), value="test").validate(
+        literal_dto = (
+            LiteralDTO.builder()
+            .with_data_type(Types.StringType.get())
+            .with_value("test")
+            .build()
+        )
+        literal_dto.validate(columns=self._columns)
+
+        field_ref_dto = (
+            
FieldReferenceDTO.builder().with_column_name(self._column_names).build()
+        )
+        field_ref_dto.validate(columns=self._columns)
+
+        FuncExpressionDTO.builder().with_function_name(
+            "test_function"
+        ).with_function_args([field_ref_dto, literal_dto]).build().validate(
             columns=self._columns
         )
-        # TODO: add unit test for FunctionArg with children
diff --git a/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py 
b/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
index 62d1d61bd6..d730a49de2 100644
--- a/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
+++ b/clients/client-python/tests/unittests/dto/rel/test_literal_dto.py
@@ -68,3 +68,7 @@ class TestLiteralDTO(unittest.TestCase):
         )
         self.assertIsInstance(dto, LiteralDTO)
         self.assertTrue(dto == self._literal_dto)
+
+    def test_literal_dto_equality(self):
+        self.assertEqual(self._literal_dto, self._literal_dto)
+        self.assertNotEqual(self._literal_dto, "test")
diff --git a/clients/client-python/tests/unittests/dto/rel/test_serdes_utils.py 
b/clients/client-python/tests/unittests/dto/rel/test_serdes_utils.py
new file mode 100644
index 0000000000..d8b29659f6
--- /dev/null
+++ b/clients/client-python/tests/unittests/dto/rel/test_serdes_utils.py
@@ -0,0 +1,255 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from enum import Enum
+from unittest.mock import patch
+
+from gravitino.api.types.types import Types
+from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
+from gravitino.dto.rel.expressions.func_expression_dto import FuncExpressionDTO
+from gravitino.dto.rel.expressions.function_arg import FunctionArg
+from gravitino.dto.rel.expressions.json_serdes._helper.serdes_utils import 
SerdesUtils
+from gravitino.dto.rel.expressions.literal_dto import LiteralDTO
+from gravitino.dto.rel.expressions.unparsed_expression_dto import 
UnparsedExpressionDTO
+from gravitino.exceptions.base import IllegalArgumentException
+
+
+class MockArgType(str, Enum):
+    INVALID_ARG_TYPE = "invalid_arg_type"
+
+
+class TestExpressionSerdesUtils(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls) -> None:
+        cls._literal_dto = (
+            LiteralDTO.builder()
+            .with_data_type(data_type=Types.StringType.get())
+            .with_value(value="test_string")
+            .build()
+        )
+        cls._field_reference_dto = (
+            FieldReferenceDTO.builder()
+            .with_field_name(field_name=["field_name"])
+            .build()
+        )
+        cls._naive_func_expression_dto = (
+            FuncExpressionDTO.builder()
+            .with_function_name(function_name="simple_func_name")
+            .with_function_args(function_args=[cls._literal_dto])
+            .build()
+        )
+        cls._func_expression_dto = (
+            FuncExpressionDTO.builder()
+            .with_function_name(function_name="func_name")
+            .with_function_args(
+                function_args=[
+                    cls._literal_dto,
+                    cls._field_reference_dto,
+                    cls._naive_func_expression_dto,
+                ]
+            )
+            .build()
+        )
+        cls._unparsed_expression_dto = (
+            UnparsedExpressionDTO.builder()
+            
.with_unparsed_expression(unparsed_expression="unparsed_expression")
+            .build()
+        )
+
+    def test_write_function_arg_invalid_arg_type(self):
+        mock_dto = (
+            LiteralDTO.builder()
+            .with_data_type(data_type=Types.StringType.get())
+            .with_value(value="test")
+            .build()
+        )
+        with patch.object(
+            mock_dto, "arg_type", return_value=MockArgType.INVALID_ARG_TYPE
+        ):
+            self.assertRaises(ValueError, SerdesUtils.write_function_arg, 
arg=mock_dto)
+
+    def test_write_function_arg_literal_dto(self):
+        result = SerdesUtils.write_function_arg(arg=self._literal_dto)
+        expected_result = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._literal_dto.arg_type().name.lower(),
+            SerdesUtils.DATA_TYPE: 
self._literal_dto.data_type().simple_string(),
+            SerdesUtils.LITERAL_VALUE: self._literal_dto.value(),
+        }
+        self.assertDictEqual(result, expected_result)
+
+    def test_write_function_arg_field_reference_dto(self):
+        result = SerdesUtils.write_function_arg(arg=self._field_reference_dto)
+        expected_result = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._field_reference_dto.arg_type().name.lower(),
+            SerdesUtils.FIELD_NAME: self._field_reference_dto.field_name(),
+        }
+        self.assertDictEqual(result, expected_result)
+
+    def test_write_function_arg_func_expression_dto(self):
+        result = 
SerdesUtils.write_function_arg(arg=self._naive_func_expression_dto)
+        expected_result = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._naive_func_expression_dto.arg_type().name.lower(),
+            SerdesUtils.FUNCTION_NAME: 
self._naive_func_expression_dto.function_name(),
+            SerdesUtils.FUNCTION_ARGS: [
+                SerdesUtils.write_function_arg(arg=self._literal_dto),
+            ],
+        }
+        self.assertDictEqual(result, expected_result)
+
+        result = SerdesUtils.write_function_arg(arg=self._func_expression_dto)
+        expected_result = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._func_expression_dto.arg_type().name.lower(),
+            SerdesUtils.FUNCTION_NAME: 
self._func_expression_dto.function_name(),
+            SerdesUtils.FUNCTION_ARGS: [
+                SerdesUtils.write_function_arg(arg=self._literal_dto),
+                SerdesUtils.write_function_arg(arg=self._field_reference_dto),
+                
SerdesUtils.write_function_arg(arg=self._naive_func_expression_dto),
+            ],
+        }
+        self.assertDictEqual(result, expected_result)
+
+    def test_write_function_arg_unparsed_expression_dto(self):
+        result = 
SerdesUtils.write_function_arg(arg=self._unparsed_expression_dto)
+        expected_result = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._unparsed_expression_dto.arg_type().name.lower(),
+            SerdesUtils.UNPARSED_EXPRESSION: 
self._unparsed_expression_dto.unparsed_expression(),
+        }
+        self.assertDictEqual(result, expected_result)
+
+    def test_read_function_arg_invalid_data(self):
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse function arg from invalid JSON",
+            SerdesUtils.read_function_arg,
+            data=None,
+        )
+
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse function arg from invalid JSON",
+            SerdesUtils.read_function_arg,
+            data="invalid_data",
+        )
+
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse function arg from missing type",
+            SerdesUtils.read_function_arg,
+            data={},
+        )
+
+    def test_read_function_arg_literal_dto(self):
+        data = {SerdesUtils.EXPRESSION_TYPE: 
self._literal_dto.arg_type().name.lower()}
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse literal arg from missing data type",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.DATA_TYPE] = 
self._literal_dto.data_type().simple_string()
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse literal arg from missing literal value",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.LITERAL_VALUE] = self._literal_dto.value()
+        result = SerdesUtils.read_function_arg(data=data)
+        self.assertEqual(result, self._literal_dto)
+
+    def test_read_function_arg_field_reference_dto(self):
+        data = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._field_reference_dto.arg_type().name.lower()
+        }
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse field reference arg from missing field name",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.FIELD_NAME] = self._field_reference_dto.field_name()
+        result = SerdesUtils.read_function_arg(data=data)
+        self.assertEqual(result, self._field_reference_dto)
+
+    def test_read_function_arg_func_expression_dto(self):
+        data = {
+            SerdesUtils.EXPRESSION_TYPE: 
self._naive_func_expression_dto.arg_type().name.lower()
+        }
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse function function arg from missing function name",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.FUNCTION_NAME] = (
+            self._naive_func_expression_dto.function_name()
+        )
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse function function arg from missing function args",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.FUNCTION_ARGS] = [
+            SerdesUtils.write_function_arg(arg=self._literal_dto),
+        ]
+        result = SerdesUtils.read_function_arg(data=data)
+        self.assertEqual(result, self._naive_func_expression_dto)
+
+        data[SerdesUtils.FUNCTION_ARGS] = []
+        result = SerdesUtils.read_function_arg(data=data)
+        self.assertEqual(
+            result,
+            FuncExpressionDTO.builder()
+            .with_function_name(
+                function_name=self._naive_func_expression_dto.function_name()
+            )
+            .with_function_args(function_args=FunctionArg.EMPTY_ARGS)
+            .build(),
+        )
+
+    def test_read_function_arg_unparsed_expression_dto(self):
+        data = {SerdesUtils.EXPRESSION_TYPE: "invalid_expression_type"}
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Unknown function argument type",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.EXPRESSION_TYPE] = (
+            self._unparsed_expression_dto.arg_type().name.lower()
+        )
+        data[SerdesUtils.UNPARSED_EXPRESSION] = {}
+        self.assertRaisesRegex(
+            IllegalArgumentException,
+            "Cannot parse unparsed expression from missing string field 
unparsedExpression",
+            SerdesUtils.read_function_arg,
+            data=data,
+        )
+
+        data[SerdesUtils.UNPARSED_EXPRESSION] = (
+            self._unparsed_expression_dto.unparsed_expression()
+        )
+        result = SerdesUtils.read_function_arg(data=data)
+        self.assertEqual(result, self._unparsed_expression_dto)


Reply via email to