This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new b181a9f099 [feature](Nereids) support array type in fold constant 
framework (#23373)
b181a9f099 is described below

commit b181a9f09905aea17f312cc93a837c0203892903
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Mon Aug 28 10:47:43 2023 +0800

    [feature](Nereids) support array type in fold constant framework (#23373)
    
    1. use legacy planner way to process constant folding result from be
    2. support signature with complex type for constant folding on fe
---
 .../expression/rules/FoldConstantRuleOnBE.java     | 62 +++++++++++++++++-----
 .../trees/expressions/ExpressionEvaluator.java     | 46 +++++++++++-----
 2 files changed, 83 insertions(+), 25 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
index f381f328fd..6ed045a300 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java
@@ -19,12 +19,12 @@ package org.apache.doris.nereids.rules.expression.rules;
 
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.ExprId;
-import org.apache.doris.analysis.LiteralExpr;
 import org.apache.doris.catalog.Env;
 import org.apache.doris.catalog.PrimitiveType;
-import org.apache.doris.catalog.Type;
+import org.apache.doris.catalog.ScalarType;
 import org.apache.doris.common.IdGenerator;
 import org.apache.doris.common.UserException;
+import org.apache.doris.common.util.DebugUtil;
 import org.apache.doris.common.util.TimeUtils;
 import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
 import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
@@ -33,9 +33,15 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.CharType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DecimalV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.VarcharType;
 import org.apache.doris.proto.InternalService;
 import org.apache.doris.proto.InternalService.PConstantExprResult;
+import org.apache.doris.proto.Types.PScalarType;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.rpc.BackendServiceProxy;
 import org.apache.doris.system.Backend;
@@ -46,6 +52,7 @@ import org.apache.doris.thrift.TPrimitiveType;
 import org.apache.doris.thrift.TQueryGlobals;
 import org.apache.doris.thrift.TQueryOptions;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Maps;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
@@ -69,8 +76,8 @@ public class FoldConstantRuleOnBE extends 
AbstractExpressionRewriteRule {
     private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator();
 
     @Override
-    public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
-        Expression expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expr, 
ctx);
+    public Expression rewrite(Expression expression, ExpressionRewriteContext 
ctx) {
+        expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expression, ctx);
         return foldByBE(expression, ctx);
     }
 
@@ -175,26 +182,57 @@ public class FoldConstantRuleOnBE extends 
AbstractExpressionRewriteRule {
             if (result.getStatus().getStatusCode() == 0) {
                 for (Entry<String, InternalService.PExprResultMap> e : 
result.getExprResultMapMap().entrySet()) {
                     for (Entry<String, InternalService.PExprResult> e1 : 
e.getValue().getMapMap().entrySet()) {
+                        PScalarType pScalarType = e1.getValue().getType();
+                        TPrimitiveType tPrimitiveType = 
TPrimitiveType.findByValue(pScalarType.getType());
+                        PrimitiveType primitiveType = 
PrimitiveType.fromThrift(Objects.requireNonNull(tPrimitiveType));
                         Expression ret;
                         if (e1.getValue().getSuccess()) {
-                            TPrimitiveType type = 
TPrimitiveType.findByValue(e1.getValue().getType().getType());
-                            Type t = 
Type.fromPrimitiveType(PrimitiveType.fromThrift(Objects.requireNonNull(type)));
-                            Expr staleExpr = 
LiteralExpr.create(e1.getValue().getContent(), Objects.requireNonNull(t));
-                            // Nereids type
-                            DataType t1 = 
DataType.convertFromString(staleExpr.getType().getPrimitiveType().toString());
-                            ret = 
Literal.of(staleExpr.getStringValue()).castTo(t1);
+                            DataType type;
+                            if (PrimitiveType.ARRAY == primitiveType
+                                    || PrimitiveType.MAP == primitiveType
+                                    || PrimitiveType.STRUCT == primitiveType
+                                    || PrimitiveType.AGG_STATE == 
primitiveType) {
+                                ret = constMap.get(e1.getKey());
+                            } else {
+                                if (primitiveType == PrimitiveType.CHAR) {
+                                    
Preconditions.checkState(pScalarType.hasLen(),
+                                            "be return char type without len");
+                                    type = 
CharType.createCharType(pScalarType.getLen());
+                                } else if (primitiveType == 
PrimitiveType.VARCHAR) {
+                                    
Preconditions.checkState(pScalarType.hasLen(),
+                                            "be return varchar type without 
len");
+                                    type = 
VarcharType.createVarcharType(pScalarType.getLen());
+                                } else if (primitiveType == 
PrimitiveType.DECIMALV2) {
+                                    type = DecimalV2Type.createDecimalV2Type(
+                                            pScalarType.getPrecision(), 
pScalarType.getScale());
+                                } else if (primitiveType == 
PrimitiveType.DATETIMEV2) {
+                                    type = 
DateTimeV2Type.of(pScalarType.getScale());
+                                } else if (primitiveType == 
PrimitiveType.DECIMAL32
+                                        || primitiveType == 
PrimitiveType.DECIMAL64
+                                        || primitiveType == 
PrimitiveType.DECIMAL128) {
+                                    type = DecimalV3Type.createDecimalV3Type(
+                                            pScalarType.getPrecision(), 
pScalarType.getScale());
+                                } else {
+                                    type = 
DataType.fromCatalogType(ScalarType.createType(
+                                            
PrimitiveType.fromThrift(tPrimitiveType)));
+                                }
+                                ret = 
Literal.of(e1.getValue().getContent()).castTo(type);
+                            }
                         } else {
                             ret = constMap.get(e1.getKey());
                         }
+                        LOG.debug("Be constant folding convert {} to {}", 
e1.getKey(), ret);
                         resultMap.put(e1.getKey(), ret);
                     }
                 }
 
             } else {
-                LOG.warn("failed to get const expr value from be: {}", 
result.getStatus().getErrorMsgsList());
+                LOG.warn("query {} failed to get const expr value from be: {}",
+                        DebugUtil.printId(context.queryId()), 
result.getStatus().getErrorMsgsList());
             }
         } catch (Exception e) {
-            LOG.warn("failed to get const expr value from be: {}", 
e.getMessage());
+            LOG.warn("query {} failed to get const expr value from be: {}",
+                    DebugUtil.printId(context.queryId()), e.getMessage());
         }
         return resultMap;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
index 0e7ef81cc4..2964a4eeb3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
@@ -30,8 +30,11 @@ import 
org.apache.doris.nereids.trees.expressions.functions.executable.NumericAr
 import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.types.ArrayType;
 import org.apache.doris.nereids.types.DataType;
 import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.MapType;
+import org.apache.doris.nereids.types.StructType;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMultimap;
@@ -106,9 +109,6 @@ public enum ExpressionEvaluator {
 
     private FunctionInvoker getFunction(FunctionSignature signature) {
         Collection<FunctionInvoker> functionInvokers = 
functions.get(signature.getName());
-        if (functionInvokers == null) {
-            return null;
-        }
         for (FunctionInvoker candidate : functionInvokers) {
             DataType[] candidateTypes = candidate.getSignature().getArgTypes();
             DataType[] expectedTypes = signature.getArgTypes();
@@ -134,9 +134,8 @@ public enum ExpressionEvaluator {
         if (functions != null) {
             return;
         }
-        ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder =
-                new ImmutableMultimap.Builder<String, FunctionInvoker>();
-        List<Class> classes = ImmutableList.of(
+        ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = new 
ImmutableMultimap.Builder<>();
+        List<Class<?>> classes = ImmutableList.of(
                 DateTimeAcquire.class,
                 DateTimeExtractAndTransform.class,
                 ExecutableFunctions.class,
@@ -144,7 +143,7 @@ public enum ExpressionEvaluator {
                 DateTimeArithmetic.class,
                 NumericArithmetic.class
         );
-        for (Class cls : classes) {
+        for (Class<?> cls : classes) {
             for (Method method : cls.getDeclaredMethods()) {
                 ExecFunctionList annotationList = 
method.getAnnotation(ExecFunctionList.class);
                 if (annotationList != null) {
@@ -165,18 +164,39 @@ public enum ExpressionEvaluator {
             DataType returnType = 
DataType.convertFromString(annotation.returnType());
             List<DataType> argTypes = new ArrayList<>();
             for (String type : annotation.argTypes()) {
-                if (type.equalsIgnoreCase("DECIMALV3")) {
-                    argTypes.add(DecimalV3Type.WILDCARD);
-                } else {
-                    argTypes.add(DataType.convertFromString(type));
-                }
+                
argTypes.add(replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
             }
             FunctionSignature signature = new FunctionSignature(name,
-                    argTypes.toArray(new DataType[argTypes.size()]), 
returnType);
+                    argTypes.toArray(new DataType[0]), returnType);
             mapBuilder.put(name, new FunctionInvoker(method, signature));
         }
     }
 
+    private DataType replaceDecimalV3WithWildcard(DataType input) {
+        if (input instanceof ArrayType) {
+            DataType item = replaceDecimalV3WithWildcard(((ArrayType) 
input).getItemType());
+            if (item == ((ArrayType) input).getItemType()) {
+                return input;
+            }
+            return ArrayType.of(item);
+        } else if (input instanceof MapType) {
+            DataType keyType = replaceDecimalV3WithWildcard(((MapType) 
input).getKeyType());
+            DataType valueType = replaceDecimalV3WithWildcard(((MapType) 
input).getValueType());
+            if (keyType == ((MapType) input).getKeyType() && valueType == 
((MapType) input).getValueType()) {
+                return input;
+            }
+            return MapType.of(keyType, valueType);
+        } else if (input instanceof StructType) {
+            // TODO: support struct type
+            return input;
+        } else {
+            if (input instanceof DecimalV3Type) {
+                return DecimalV3Type.WILDCARD;
+            }
+            return input;
+        }
+    }
+
     /**
      * function invoker.
      */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to