This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new b181a9f099 [feature](Nereids) support array type in fold constant framework (#23373) b181a9f099 is described below commit b181a9f09905aea17f312cc93a837c0203892903 Author: morrySnow <101034200+morrys...@users.noreply.github.com> AuthorDate: Mon Aug 28 10:47:43 2023 +0800 [feature](Nereids) support array type in fold constant framework (#23373) 1. use legacy planner way to process constant folding result from be 2. support signature with complex type for constant folding on fe --- .../expression/rules/FoldConstantRuleOnBE.java | 62 +++++++++++++++++----- .../trees/expressions/ExpressionEvaluator.java | 46 +++++++++++----- 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java index f381f328fd..6ed045a300 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java @@ -19,12 +19,12 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.ExprId; -import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.PrimitiveType; -import org.apache.doris.catalog.Type; +import org.apache.doris.catalog.ScalarType; import org.apache.doris.common.IdGenerator; import org.apache.doris.common.UserException; +import org.apache.doris.common.util.DebugUtil; import org.apache.doris.common.util.TimeUtils; import org.apache.doris.nereids.glue.translator.ExpressionTranslator; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; @@ -33,9 +33,15 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.CharType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.proto.InternalService; import org.apache.doris.proto.InternalService.PConstantExprResult; +import org.apache.doris.proto.Types.PScalarType; import org.apache.doris.qe.ConnectContext; import org.apache.doris.rpc.BackendServiceProxy; import org.apache.doris.system.Backend; @@ -46,6 +52,7 @@ import org.apache.doris.thrift.TPrimitiveType; import org.apache.doris.thrift.TQueryGlobals; import org.apache.doris.thrift.TQueryOptions; +import com.google.common.base.Preconditions; import com.google.common.collect.Maps; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -69,8 +76,8 @@ public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule { private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - Expression expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx); + public Expression rewrite(Expression expression, ExpressionRewriteContext ctx) { + expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expression, ctx); return foldByBE(expression, ctx); } @@ -175,26 +182,57 @@ public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule { if (result.getStatus().getStatusCode() == 0) { for (Entry<String, InternalService.PExprResultMap> e : result.getExprResultMapMap().entrySet()) { for (Entry<String, InternalService.PExprResult> e1 : e.getValue().getMapMap().entrySet()) { + PScalarType pScalarType = e1.getValue().getType(); + TPrimitiveType tPrimitiveType = TPrimitiveType.findByValue(pScalarType.getType()); + PrimitiveType primitiveType = PrimitiveType.fromThrift(Objects.requireNonNull(tPrimitiveType)); Expression ret; if (e1.getValue().getSuccess()) { - TPrimitiveType type = TPrimitiveType.findByValue(e1.getValue().getType().getType()); - Type t = Type.fromPrimitiveType(PrimitiveType.fromThrift(Objects.requireNonNull(type))); - Expr staleExpr = LiteralExpr.create(e1.getValue().getContent(), Objects.requireNonNull(t)); - // Nereids type - DataType t1 = DataType.convertFromString(staleExpr.getType().getPrimitiveType().toString()); - ret = Literal.of(staleExpr.getStringValue()).castTo(t1); + DataType type; + if (PrimitiveType.ARRAY == primitiveType + || PrimitiveType.MAP == primitiveType + || PrimitiveType.STRUCT == primitiveType + || PrimitiveType.AGG_STATE == primitiveType) { + ret = constMap.get(e1.getKey()); + } else { + if (primitiveType == PrimitiveType.CHAR) { + Preconditions.checkState(pScalarType.hasLen(), + "be return char type without len"); + type = CharType.createCharType(pScalarType.getLen()); + } else if (primitiveType == PrimitiveType.VARCHAR) { + Preconditions.checkState(pScalarType.hasLen(), + "be return varchar type without len"); + type = VarcharType.createVarcharType(pScalarType.getLen()); + } else if (primitiveType == PrimitiveType.DECIMALV2) { + type = DecimalV2Type.createDecimalV2Type( + pScalarType.getPrecision(), pScalarType.getScale()); + } else if (primitiveType == PrimitiveType.DATETIMEV2) { + type = DateTimeV2Type.of(pScalarType.getScale()); + } else if (primitiveType == PrimitiveType.DECIMAL32 + || primitiveType == PrimitiveType.DECIMAL64 + || primitiveType == PrimitiveType.DECIMAL128) { + type = DecimalV3Type.createDecimalV3Type( + pScalarType.getPrecision(), pScalarType.getScale()); + } else { + type = DataType.fromCatalogType(ScalarType.createType( + PrimitiveType.fromThrift(tPrimitiveType))); + } + ret = Literal.of(e1.getValue().getContent()).castTo(type); + } } else { ret = constMap.get(e1.getKey()); } + LOG.debug("Be constant folding convert {} to {}", e1.getKey(), ret); resultMap.put(e1.getKey(), ret); } } } else { - LOG.warn("failed to get const expr value from be: {}", result.getStatus().getErrorMsgsList()); + LOG.warn("query {} failed to get const expr value from be: {}", + DebugUtil.printId(context.queryId()), result.getStatus().getErrorMsgsList()); } } catch (Exception e) { - LOG.warn("failed to get const expr value from be: {}", e.getMessage()); + LOG.warn("query {} failed to get const expr value from be: {}", + DebugUtil.printId(context.queryId()), e.getMessage()); } return resultMap; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java index 0e7ef81cc4..2964a4eeb3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java @@ -30,8 +30,11 @@ import org.apache.doris.nereids.trees.expressions.functions.executable.NumericAr import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.MapType; +import org.apache.doris.nereids.types.StructType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; @@ -106,9 +109,6 @@ public enum ExpressionEvaluator { private FunctionInvoker getFunction(FunctionSignature signature) { Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName()); - if (functionInvokers == null) { - return null; - } for (FunctionInvoker candidate : functionInvokers) { DataType[] candidateTypes = candidate.getSignature().getArgTypes(); DataType[] expectedTypes = signature.getArgTypes(); @@ -134,9 +134,8 @@ public enum ExpressionEvaluator { if (functions != null) { return; } - ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = - new ImmutableMultimap.Builder<String, FunctionInvoker>(); - List<Class> classes = ImmutableList.of( + ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = new ImmutableMultimap.Builder<>(); + List<Class<?>> classes = ImmutableList.of( DateTimeAcquire.class, DateTimeExtractAndTransform.class, ExecutableFunctions.class, @@ -144,7 +143,7 @@ public enum ExpressionEvaluator { DateTimeArithmetic.class, NumericArithmetic.class ); - for (Class cls : classes) { + for (Class<?> cls : classes) { for (Method method : cls.getDeclaredMethods()) { ExecFunctionList annotationList = method.getAnnotation(ExecFunctionList.class); if (annotationList != null) { @@ -165,18 +164,39 @@ public enum ExpressionEvaluator { DataType returnType = DataType.convertFromString(annotation.returnType()); List<DataType> argTypes = new ArrayList<>(); for (String type : annotation.argTypes()) { - if (type.equalsIgnoreCase("DECIMALV3")) { - argTypes.add(DecimalV3Type.WILDCARD); - } else { - argTypes.add(DataType.convertFromString(type)); - } + argTypes.add(replaceDecimalV3WithWildcard(DataType.convertFromString(type))); } FunctionSignature signature = new FunctionSignature(name, - argTypes.toArray(new DataType[argTypes.size()]), returnType); + argTypes.toArray(new DataType[0]), returnType); mapBuilder.put(name, new FunctionInvoker(method, signature)); } } + private DataType replaceDecimalV3WithWildcard(DataType input) { + if (input instanceof ArrayType) { + DataType item = replaceDecimalV3WithWildcard(((ArrayType) input).getItemType()); + if (item == ((ArrayType) input).getItemType()) { + return input; + } + return ArrayType.of(item); + } else if (input instanceof MapType) { + DataType keyType = replaceDecimalV3WithWildcard(((MapType) input).getKeyType()); + DataType valueType = replaceDecimalV3WithWildcard(((MapType) input).getValueType()); + if (keyType == ((MapType) input).getKeyType() && valueType == ((MapType) input).getValueType()) { + return input; + } + return MapType.of(keyType, valueType); + } else if (input instanceof StructType) { + // TODO: support struct type + return input; + } else { + if (input instanceof DecimalV3Type) { + return DecimalV3Type.WILDCARD; + } + return input; + } + } + /** * function invoker. */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org