This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new c0b1bc2 [feature](connector) support datasource v2 pushdown (#250) c0b1bc2 is described below commit c0b1bc2f51c0ccabe8deaadb3ca1b79fcd40d03f Author: gnehil <adamlee...@gmail.com> AuthorDate: Fri Jan 3 17:48:16 2025 +0800 [feature](connector) support datasource v2 pushdown (#250) --- .github/workflows/run-itcase-12.yml | 6 +- .github/workflows/run-itcase-20.yml | 6 +- .../spark-doris-connector-it/pom.xml | 2 +- .../apache/doris/spark/sql/DorisReaderITCase.scala | 26 +++++++ .../{DorisScan.scala => AbstractDorisScan.scala} | 10 ++- .../org/apache/doris/spark/read/DorisScan.scala | 28 +------- .../doris/spark/read/DorisScanBuilderBase.scala | 17 +---- .../apache/doris/spark/read/DorisScanBuilder.scala | 23 ++++++- .../apache/doris/spark/read/DorisScanBuilder.scala | 24 ++++++- .../spark-doris-connector-spark-3.3/pom.xml | 5 ++ .../apache/doris/spark/read/DorisScanBuilder.scala | 26 ++++++- .../org/apache/doris/spark/read/DorisScanV2.scala} | 13 +++- .../read/expression/V2ExpressionBuilder.scala | 79 ++++++++++++++++++++++ .../read/expression/V2ExpressionBuilderTest.scala | 49 ++++++++++++++ .../spark-doris-connector-spark-3.4/pom.xml | 5 ++ .../apache/doris/spark/read/DorisScanBuilder.scala | 26 ++++++- .../org/apache/doris/spark/read/DorisScanV2.scala} | 13 +++- .../read/expression/V2ExpressionBuilder.scala | 79 ++++++++++++++++++++++ .../read/expression/V2ExpressionBuilderTest.scala | 49 ++++++++++++++ .../spark-doris-connector-spark-3.5/pom.xml | 5 ++ .../apache/doris/spark/read/DorisScanBuilder.scala | 26 ++++++- .../org/apache/doris/spark/read/DorisScanV2.scala} | 13 +++- .../read/expression/V2ExpressionBuilder.scala | 79 ++++++++++++++++++++++ .../read/expression/V2ExpressionBuilderTest.scala | 49 ++++++++++++++ 24 files changed, 592 insertions(+), 66 deletions(-) diff --git a/.github/workflows/run-itcase-12.yml b/.github/workflows/run-itcase-12.yml index fddcda7..fd28357 100644 --- a/.github/workflows/run-itcase-12.yml +++ b/.github/workflows/run-itcase-12.yml @@ -42,6 +42,10 @@ jobs: run: | cd spark-doris-connector && mvn clean test -Pspark-2-it,spark-2.4_2.11 -pl spark-doris-connector-it -am -DfailIfNoTests=false -Dtest="*ITCase" -Dimage="adamlee489/doris:1.2.7.1_x86" - - name: Run ITCases for spark 3 + - name: Run ITCases for spark 3.1 run: | cd spark-doris-connector && mvn clean test -Pspark-3-it,spark-3.1 -pl spark-doris-connector-it -am -DfailIfNoTests=false -Dtest="*ITCase" -Dimage="adamlee489/doris:1.2.7.1_x86" + + - name: Run ITCases for spark 3.3 + run: | + cd spark-doris-connector && mvn clean test -Pspark-3-it,spark-3.3 -pl spark-doris-connector-it -am -DfailIfNoTests=false -Dtest="*ITCase" -Dimage="adamlee489/doris:1.2.7.1_x86" diff --git a/.github/workflows/run-itcase-20.yml b/.github/workflows/run-itcase-20.yml index d16d810..b0f31c0 100644 --- a/.github/workflows/run-itcase-20.yml +++ b/.github/workflows/run-itcase-20.yml @@ -42,7 +42,11 @@ jobs: run: | cd spark-doris-connector && mvn clean test -Pspark-2-it,spark-2.4_2.11 -pl spark-doris-connector-it -am -DfailIfNoTests=false -Dtest="*ITCase" -Dimage="adamlee489/doris:2.0.3" - - name: Run ITCases for spark 3 + - name: Run ITCases for spark 3.1 run: | cd spark-doris-connector && mvn clean test -Pspark-3-it,spark-3.1 -pl spark-doris-connector-it -am -DfailIfNoTests=false -Dtest="*ITCase" -Dimage="adamlee489/doris:2.0.3" + + - name: Run ITCases for spark 3.3 + run: | + cd spark-doris-connector && mvn clean test -Pspark-3-it,spark-3.3 -pl spark-doris-connector-it -am -DfailIfNoTests=false -Dtest="*ITCase" -Dimage="adamlee489/doris:2.0.3" \ No newline at end of file diff --git a/spark-doris-connector/spark-doris-connector-it/pom.xml b/spark-doris-connector/spark-doris-connector-it/pom.xml index 9493c03..9797fb6 100644 --- a/spark-doris-connector/spark-doris-connector-it/pom.xml +++ b/spark-doris-connector/spark-doris-connector-it/pom.xml @@ -97,7 +97,7 @@ <dependencies> <dependency> <groupId>org.apache.doris</groupId> - <artifactId>spark-doris-connector-spark-3.1</artifactId> + <artifactId>spark-doris-connector-spark-${spark.major.version}</artifactId> <version>${project.version}</version> <scope>test</scope> </dependency> diff --git a/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala b/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala index a147a7d..2d7930a 100644 --- a/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala +++ b/spark-doris-connector/spark-doris-connector-it/src/test/java/org/apache/doris/spark/sql/DorisReaderITCase.scala @@ -141,5 +141,31 @@ class DorisReaderITCase extends DorisTestBase { } else false } + @Test + @throws[Exception] + def testSQLSourceWithCondition(): Unit = { + initializeTable(TABLE_READ_TBL) + val session = SparkSession.builder().master("local[*]").getOrCreate() + session.sql( + s""" + |CREATE TEMPORARY VIEW test_source + |USING doris + |OPTIONS( + | "table.identifier"="${DATABASE + "." + TABLE_READ_TBL}", + | "fenodes"="${DorisTestBase.getFenodes}", + | "user"="${DorisTestBase.USERNAME}", + | "password"="${DorisTestBase.PASSWORD}" + |) + |""".stripMargin) + + val result = session.sql( + """ + |select name,age from test_source where age = 18 + |""".stripMargin).collect().toList.toString() + session.stop() + + assert("List([doris,18])".equals(result)) + } + } diff --git a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala similarity index 82% copy from spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala copy to spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala index b715766..f1666ad 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala @@ -20,15 +20,13 @@ package org.apache.doris.spark.read import org.apache.doris.spark.client.entity.{Backend, DorisReaderPartition} import org.apache.doris.spark.client.read.ReaderPartitionGenerator import org.apache.doris.spark.config.{DorisConfig, DorisOptions} -import org.apache.doris.spark.util.DorisDialects import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import scala.language.implicitConversions -class DorisScan(config: DorisConfig, schema: StructType, filters: Array[Filter]) extends Scan with Batch with Logging { +abstract class AbstractDorisScan(config: DorisConfig, schema: StructType) extends Scan with Batch with Logging { private val scanMode = ScanMode.valueOf(config.getValue(DorisOptions.READ_MODE).toUpperCase) @@ -37,9 +35,7 @@ class DorisScan(config: DorisConfig, schema: StructType, filters: Array[Filter]) override def toBatch: Batch = this override def planInputPartitions(): Array[InputPartition] = { - val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) - val compiledFilters = filters.map(DorisDialects.compileFilter(_, inValueLengthLimit)).filter(_.isDefined).map(_.get) - ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters).map(toInputPartition) + ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters()).map(toInputPartition) } @@ -50,6 +46,8 @@ class DorisScan(config: DorisConfig, schema: StructType, filters: Array[Filter]) private def toInputPartition(rp: DorisReaderPartition): DorisInputPartition = DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, rp.getFilters) + protected def compiledFilters(): Array[String] + } case class DorisInputPartition(database: String, table: String, backend: Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: Array[String], predicates: Array[String]) extends InputPartition diff --git a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala index b715766..d52a82a 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScan.scala @@ -17,39 +17,17 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.client.entity.{Backend, DorisReaderPartition} -import org.apache.doris.spark.client.read.ReaderPartitionGenerator import org.apache.doris.spark.config.{DorisConfig, DorisOptions} import org.apache.doris.spark.util.DorisDialects import org.apache.spark.internal.Logging -import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import scala.language.implicitConversions -class DorisScan(config: DorisConfig, schema: StructType, filters: Array[Filter]) extends Scan with Batch with Logging { - - private val scanMode = ScanMode.valueOf(config.getValue(DorisOptions.READ_MODE).toUpperCase) - - override def readSchema(): StructType = schema - - override def toBatch: Batch = this - - override def planInputPartitions(): Array[InputPartition] = { +class DorisScan(config: DorisConfig, schema: StructType, filters: Array[Filter]) extends AbstractDorisScan(config, schema) with Logging { + override protected def compiledFilters(): Array[String] = { val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) - val compiledFilters = filters.map(DorisDialects.compileFilter(_, inValueLengthLimit)).filter(_.isDefined).map(_.get) - ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters).map(toInputPartition) - } - - - override def createReaderFactory(): PartitionReaderFactory = { - new DorisPartitionReaderFactory(readSchema(), scanMode, config) + filters.map(DorisDialects.compileFilter(_, inValueLengthLimit)).filter(_.isDefined).map(_.get) } - - private def toInputPartition(rp: DorisReaderPartition): DorisInputPartition = - DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, rp.getFilters) - } - -case class DorisInputPartition(database: String, table: String, backend: Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: Array[String], predicates: Array[String]) extends InputPartition diff --git a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScanBuilderBase.scala b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScanBuilderBase.scala index a6a97dc..cec9890 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScanBuilderBase.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisScanBuilderBase.scala @@ -24,24 +24,9 @@ import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType protected[spark] abstract class DorisScanBuilderBase(config: DorisConfig, schema: StructType) extends ScanBuilder - with SupportsPushDownFilters with SupportsPushDownRequiredColumns { - private var readSchema: StructType = schema - - private var pushDownPredicates: Array[Filter] = Array[Filter]() - - private val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) - - override def build(): Scan = new DorisScan(config, readSchema, pushDownPredicates) - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (pushed, unsupported) = filters.partition(DorisDialects.compileFilter(_, inValueLengthLimit).isDefined) - this.pushDownPredicates = pushed - unsupported - } - - override def pushedFilters(): Array[Filter] = pushDownPredicates + protected var readSchema: StructType = schema override def pruneColumns(requiredSchema: StructType): Unit = { readSchema = StructType(requiredSchema.fields.filter(schema.contains(_))) diff --git a/spark-doris-connector/spark-doris-connector-spark-3.1/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.1/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala index 9e199af..5c8e716 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.1/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.1/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala @@ -17,7 +17,26 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.util.DorisDialects +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) with SupportsPushDownFilters { + + private var pushDownPredicates: Array[Filter] = Array[Filter]() + + private val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) + + override def build(): Scan = new DorisScan(config, readSchema, pushDownPredicates) + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (pushed, unsupported) = filters.partition(DorisDialects.compileFilter(_, inValueLengthLimit).isDefined) + this.pushDownPredicates = pushed + unsupported + } + + override def pushedFilters(): Array[Filter] = pushDownPredicates + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala index 9e199af..68241df 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala @@ -17,7 +17,27 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.util.DorisDialects +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) + with SupportsPushDownFilters { + + private var pushDownPredicates: Array[Filter] = Array[Filter]() + + private val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) + + override def build(): Scan = new DorisScan(config, readSchema, pushDownPredicates) + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (pushed, unsupported) = filters.partition(DorisDialects.compileFilter(_, inValueLengthLimit).isDefined) + this.pushDownPredicates = pushed + unsupported + } + + override def pushedFilters(): Array[Filter] = pushDownPredicates + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/pom.xml b/spark-doris-connector/spark-doris-connector-spark-3.3/pom.xml index ecc71ed..7a046a2 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.3/pom.xml +++ b/spark-doris-connector/spark-doris-connector-spark-3.3/pom.xml @@ -48,6 +48,11 @@ <groupId>org.apache.spark</groupId> <artifactId>spark-sql_${scala.major.version}</artifactId> </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-api</artifactId> + <scope>test</scope> + </dependency> </dependencies> </project> \ No newline at end of file diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala index 9e199af..cc8ddd2 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala @@ -17,7 +17,29 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.read.expression.V2ExpressionBuilder +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters} import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) + with SupportsPushDownV2Filters { + + private var pushDownPredicates: Array[Predicate] = Array[Predicate]() + + private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)) + + override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates) + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val (pushed, unsupported) = predicates.partition(predicate => { + Option(expressionBuilder.build(predicate)).isDefined + }) + this.pushDownPredicates = pushed + unsupported + } + + override def pushedPredicates(): Array[Predicate] = pushDownPredicates + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala similarity index 55% copy from spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala copy to spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala index 9e199af..634257a 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala @@ -17,7 +17,16 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.read.expression.V2ExpressionBuilder +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging { + override protected def compiledFilters(): Array[String] = { + val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) + val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit) + filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get) + } +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala new file mode 100644 index 0000000..f13830c --- /dev/null +++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.read.expression + +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And, Not, Or} +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference} + +class V2ExpressionBuilder(inValueLengthLimit: Int) { + + def build(predicate: Expression): String = { + predicate match { + case and: And => s"(${build(and.left())} AND ${build(and.right())})" + case or: Or => s"(${build(or.left())} OR ${build(or.right())})" + case not: Not => + not.child().name() match { + case "IS_NULL" => build(new GeneralScalarExpression("IS_NOT_NULL", not.children()(0).children())) + case "=" => build(new GeneralScalarExpression("!=", not.children()(0).children())) + case _ => s"NOT (${build(not.child())})" + } + case _: AlwaysTrue => "1=1" + case _: AlwaysFalse => "1=0" + case expr: Expression => + expr match { + case literal: Literal[_] => literal.toString + case namedRef: NamedReference => namedRef.toString + case e: GeneralScalarExpression => e.name() match { + case "IN" => + val expressions = e.children() + if (expressions.nonEmpty && expressions.length <= inValueLengthLimit) { + s"""`${build(expressions(0))}` IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})""" + } else null + case "IS_NULL" => s"`${build(e.children()(0))}` IS NULL" + case "IS_NOT_NULL" => s"`${build(e.children()(0))}` IS NOT NULL" + case "STARTS_WITH" => visitStartWith(build(e.children()(0)), build(e.children()(1))); + case "ENDS_WITH" => visitEndWith(build(e.children()(0)), build(e.children()(1))); + case "CONTAINS" => visitContains(build(e.children()(0)), build(e.children()(1))); + case "=" => s"`${build(e.children()(0))}` = ${build(e.children()(1))}" + case "!=" | "<>" => s"`${build(e.children()(0))}` != ${build(e.children()(1))}" + case "<" => s"`${build(e.children()(0))}` < ${build(e.children()(1))}" + case "<=" => s"`${build(e.children()(0))}` <= ${build(e.children()(1))}" + case ">" => s"`${build(e.children()(0))}` > ${build(e.children()(1))}" + case ">=" => s"`${build(e.children()(0))}` >= ${build(e.children()(1))}" + case _ => null + } + } + } + } + + def visitStartWith(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '$value%'" + } + + def visitEndWith(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '%$value'" + } + + def visitContains(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '%$value%'" + } + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala b/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala new file mode 100644 index 0000000..fc29495 --- /dev/null +++ b/spark-doris-connector/spark-doris-connector-spark-3.3/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala @@ -0,0 +1,49 @@ +package org.apache.doris.spark.read.expression + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.apache.spark.sql.sources._ +import org.junit.jupiter.api.{Assertions, Test} + +class V2ExpressionBuilderTest { + + @Test + def buildTest(): Unit = { + + val builder = new V2ExpressionBuilder(10) + Assertions.assertEquals(builder.build(EqualTo("c0", 1).toV2), "`c0` = 1") + Assertions.assertEquals(builder.build(Not(EqualTo("c1", 2)).toV2), "`c1` != 2") + Assertions.assertEquals(builder.build(GreaterThan("c2", 3.4).toV2), "`c2` > 3.4") + Assertions.assertEquals(builder.build(GreaterThanOrEqual("c3", 5.6).toV2), "`c3` >= 5.6") + Assertions.assertEquals(builder.build(LessThan("c4", 7.8).toV2), "`c4` < 7.8") + Assertions.assertEquals(builder.build(LessThanOrEqual("c5", BigDecimal(9.1011)).toV2), "`c5` <= 9.1011") + Assertions.assertEquals(builder.build(StringStartsWith("c6","a").toV2), "`c6` LIKE 'a%'") + Assertions.assertEquals(builder.build(StringEndsWith("c7", "b").toV2), "`c7` LIKE '%b'") + Assertions.assertEquals(builder.build(StringContains("c8", "c").toV2), "`c8` LIKE '%c%'") + Assertions.assertEquals(builder.build(In("c9", Array(12,13,14)).toV2), "`c9` IN (12,13,14)") + Assertions.assertEquals(builder.build(IsNull("c10").toV2), "`c10` IS NULL") + Assertions.assertEquals(builder.build(Not(IsNull("c11")).toV2), "`c11` IS NOT NULL") + Assertions.assertEquals(builder.build(And(EqualTo("c12", 15), EqualTo("c13", 16)).toV2), "(`c12` = 15 AND `c13` = 16)") + Assertions.assertEquals(builder.build(Or(EqualTo("c14", 17), EqualTo("c15", 18)).toV2), "(`c14` = 17 OR `c15` = 18)") + Assertions.assertEquals(builder.build(AlwaysTrue.toV2), "1=1") + Assertions.assertEquals(builder.build(AlwaysFalse.toV2), "1=0") + Assertions.assertNull(builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2)) + + } + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/pom.xml b/spark-doris-connector/spark-doris-connector-spark-3.4/pom.xml index eeee285..84b84a1 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.4/pom.xml +++ b/spark-doris-connector/spark-doris-connector-spark-3.4/pom.xml @@ -48,6 +48,11 @@ <groupId>org.apache.spark</groupId> <artifactId>spark-sql_${scala.major.version}</artifactId> </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-api</artifactId> + <scope>test</scope> + </dependency> </dependencies> </project> \ No newline at end of file diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala index 9e199af..cc8ddd2 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala @@ -17,7 +17,29 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.read.expression.V2ExpressionBuilder +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters} import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) + with SupportsPushDownV2Filters { + + private var pushDownPredicates: Array[Predicate] = Array[Predicate]() + + private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)) + + override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates) + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val (pushed, unsupported) = predicates.partition(predicate => { + Option(expressionBuilder.build(predicate)).isDefined + }) + this.pushDownPredicates = pushed + unsupported + } + + override def pushedPredicates(): Array[Predicate] = pushDownPredicates + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala similarity index 55% copy from spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala copy to spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala index 9e199af..634257a 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala @@ -17,7 +17,16 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.read.expression.V2ExpressionBuilder +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging { + override protected def compiledFilters(): Array[String] = { + val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) + val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit) + filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get) + } +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala new file mode 100644 index 0000000..f13830c --- /dev/null +++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.read.expression + +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And, Not, Or} +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference} + +class V2ExpressionBuilder(inValueLengthLimit: Int) { + + def build(predicate: Expression): String = { + predicate match { + case and: And => s"(${build(and.left())} AND ${build(and.right())})" + case or: Or => s"(${build(or.left())} OR ${build(or.right())})" + case not: Not => + not.child().name() match { + case "IS_NULL" => build(new GeneralScalarExpression("IS_NOT_NULL", not.children()(0).children())) + case "=" => build(new GeneralScalarExpression("!=", not.children()(0).children())) + case _ => s"NOT (${build(not.child())})" + } + case _: AlwaysTrue => "1=1" + case _: AlwaysFalse => "1=0" + case expr: Expression => + expr match { + case literal: Literal[_] => literal.toString + case namedRef: NamedReference => namedRef.toString + case e: GeneralScalarExpression => e.name() match { + case "IN" => + val expressions = e.children() + if (expressions.nonEmpty && expressions.length <= inValueLengthLimit) { + s"""`${build(expressions(0))}` IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})""" + } else null + case "IS_NULL" => s"`${build(e.children()(0))}` IS NULL" + case "IS_NOT_NULL" => s"`${build(e.children()(0))}` IS NOT NULL" + case "STARTS_WITH" => visitStartWith(build(e.children()(0)), build(e.children()(1))); + case "ENDS_WITH" => visitEndWith(build(e.children()(0)), build(e.children()(1))); + case "CONTAINS" => visitContains(build(e.children()(0)), build(e.children()(1))); + case "=" => s"`${build(e.children()(0))}` = ${build(e.children()(1))}" + case "!=" | "<>" => s"`${build(e.children()(0))}` != ${build(e.children()(1))}" + case "<" => s"`${build(e.children()(0))}` < ${build(e.children()(1))}" + case "<=" => s"`${build(e.children()(0))}` <= ${build(e.children()(1))}" + case ">" => s"`${build(e.children()(0))}` > ${build(e.children()(1))}" + case ">=" => s"`${build(e.children()(0))}` >= ${build(e.children()(1))}" + case _ => null + } + } + } + } + + def visitStartWith(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '$value%'" + } + + def visitEndWith(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '%$value'" + } + + def visitContains(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '%$value%'" + } + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala b/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala new file mode 100644 index 0000000..fc29495 --- /dev/null +++ b/spark-doris-connector/spark-doris-connector-spark-3.4/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala @@ -0,0 +1,49 @@ +package org.apache.doris.spark.read.expression + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.apache.spark.sql.sources._ +import org.junit.jupiter.api.{Assertions, Test} + +class V2ExpressionBuilderTest { + + @Test + def buildTest(): Unit = { + + val builder = new V2ExpressionBuilder(10) + Assertions.assertEquals(builder.build(EqualTo("c0", 1).toV2), "`c0` = 1") + Assertions.assertEquals(builder.build(Not(EqualTo("c1", 2)).toV2), "`c1` != 2") + Assertions.assertEquals(builder.build(GreaterThan("c2", 3.4).toV2), "`c2` > 3.4") + Assertions.assertEquals(builder.build(GreaterThanOrEqual("c3", 5.6).toV2), "`c3` >= 5.6") + Assertions.assertEquals(builder.build(LessThan("c4", 7.8).toV2), "`c4` < 7.8") + Assertions.assertEquals(builder.build(LessThanOrEqual("c5", BigDecimal(9.1011)).toV2), "`c5` <= 9.1011") + Assertions.assertEquals(builder.build(StringStartsWith("c6","a").toV2), "`c6` LIKE 'a%'") + Assertions.assertEquals(builder.build(StringEndsWith("c7", "b").toV2), "`c7` LIKE '%b'") + Assertions.assertEquals(builder.build(StringContains("c8", "c").toV2), "`c8` LIKE '%c%'") + Assertions.assertEquals(builder.build(In("c9", Array(12,13,14)).toV2), "`c9` IN (12,13,14)") + Assertions.assertEquals(builder.build(IsNull("c10").toV2), "`c10` IS NULL") + Assertions.assertEquals(builder.build(Not(IsNull("c11")).toV2), "`c11` IS NOT NULL") + Assertions.assertEquals(builder.build(And(EqualTo("c12", 15), EqualTo("c13", 16)).toV2), "(`c12` = 15 AND `c13` = 16)") + Assertions.assertEquals(builder.build(Or(EqualTo("c14", 17), EqualTo("c15", 18)).toV2), "(`c14` = 17 OR `c15` = 18)") + Assertions.assertEquals(builder.build(AlwaysTrue.toV2), "1=1") + Assertions.assertEquals(builder.build(AlwaysFalse.toV2), "1=0") + Assertions.assertNull(builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2)) + + } + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/pom.xml b/spark-doris-connector/spark-doris-connector-spark-3.5/pom.xml index 2f498b4..ccccc66 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.5/pom.xml +++ b/spark-doris-connector/spark-doris-connector-spark-3.5/pom.xml @@ -48,6 +48,11 @@ <groupId>org.apache.spark</groupId> <artifactId>spark-sql_${scala.major.version}</artifactId> </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-api</artifactId> + <scope>test</scope> + </dependency> </dependencies> </project> \ No newline at end of file diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala index 9e199af..cc8ddd2 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala @@ -17,7 +17,29 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.read.expression.V2ExpressionBuilder +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters} import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) + with SupportsPushDownV2Filters { + + private var pushDownPredicates: Array[Predicate] = Array[Predicate]() + + private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)) + + override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates) + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val (pushed, unsupported) = predicates.partition(predicate => { + Option(expressionBuilder.build(predicate)).isDefined + }) + this.pushDownPredicates = pushed + unsupported + } + + override def pushedPredicates(): Array[Predicate] = pushDownPredicates + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala similarity index 55% copy from spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala copy to spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala index 9e199af..634257a 100644 --- a/spark-doris-connector/spark-doris-connector-spark-3.2/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala +++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala @@ -17,7 +17,16 @@ package org.apache.doris.spark.read -import org.apache.doris.spark.config.DorisConfig +import org.apache.doris.spark.config.{DorisConfig, DorisOptions} +import org.apache.doris.spark.read.expression.V2ExpressionBuilder +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.types.StructType -class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema) {} +class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging { + override protected def compiledFilters(): Array[String] = { + val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT) + val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit) + filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get) + } +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala new file mode 100644 index 0000000..f13830c --- /dev/null +++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilder.scala @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.read.expression + +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And, Not, Or} +import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal, NamedReference} + +class V2ExpressionBuilder(inValueLengthLimit: Int) { + + def build(predicate: Expression): String = { + predicate match { + case and: And => s"(${build(and.left())} AND ${build(and.right())})" + case or: Or => s"(${build(or.left())} OR ${build(or.right())})" + case not: Not => + not.child().name() match { + case "IS_NULL" => build(new GeneralScalarExpression("IS_NOT_NULL", not.children()(0).children())) + case "=" => build(new GeneralScalarExpression("!=", not.children()(0).children())) + case _ => s"NOT (${build(not.child())})" + } + case _: AlwaysTrue => "1=1" + case _: AlwaysFalse => "1=0" + case expr: Expression => + expr match { + case literal: Literal[_] => literal.toString + case namedRef: NamedReference => namedRef.toString + case e: GeneralScalarExpression => e.name() match { + case "IN" => + val expressions = e.children() + if (expressions.nonEmpty && expressions.length <= inValueLengthLimit) { + s"""`${build(expressions(0))}` IN (${expressions.slice(1, expressions.length).map(build).mkString(",")})""" + } else null + case "IS_NULL" => s"`${build(e.children()(0))}` IS NULL" + case "IS_NOT_NULL" => s"`${build(e.children()(0))}` IS NOT NULL" + case "STARTS_WITH" => visitStartWith(build(e.children()(0)), build(e.children()(1))); + case "ENDS_WITH" => visitEndWith(build(e.children()(0)), build(e.children()(1))); + case "CONTAINS" => visitContains(build(e.children()(0)), build(e.children()(1))); + case "=" => s"`${build(e.children()(0))}` = ${build(e.children()(1))}" + case "!=" | "<>" => s"`${build(e.children()(0))}` != ${build(e.children()(1))}" + case "<" => s"`${build(e.children()(0))}` < ${build(e.children()(1))}" + case "<=" => s"`${build(e.children()(0))}` <= ${build(e.children()(1))}" + case ">" => s"`${build(e.children()(0))}` > ${build(e.children()(1))}" + case ">=" => s"`${build(e.children()(0))}` >= ${build(e.children()(1))}" + case _ => null + } + } + } + } + + def visitStartWith(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '$value%'" + } + + def visitEndWith(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '%$value'" + } + + def visitContains(l: String, r: String): String = { + val value = r.substring(1, r.length - 1) + s"`$l` LIKE '%$value%'" + } + +} diff --git a/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala b/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala new file mode 100644 index 0000000..fc29495 --- /dev/null +++ b/spark-doris-connector/spark-doris-connector-spark-3.5/src/test/scala/org/apache/doris/spark/read/expression/V2ExpressionBuilderTest.scala @@ -0,0 +1,49 @@ +package org.apache.doris.spark.read.expression + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.apache.spark.sql.sources._ +import org.junit.jupiter.api.{Assertions, Test} + +class V2ExpressionBuilderTest { + + @Test + def buildTest(): Unit = { + + val builder = new V2ExpressionBuilder(10) + Assertions.assertEquals(builder.build(EqualTo("c0", 1).toV2), "`c0` = 1") + Assertions.assertEquals(builder.build(Not(EqualTo("c1", 2)).toV2), "`c1` != 2") + Assertions.assertEquals(builder.build(GreaterThan("c2", 3.4).toV2), "`c2` > 3.4") + Assertions.assertEquals(builder.build(GreaterThanOrEqual("c3", 5.6).toV2), "`c3` >= 5.6") + Assertions.assertEquals(builder.build(LessThan("c4", 7.8).toV2), "`c4` < 7.8") + Assertions.assertEquals(builder.build(LessThanOrEqual("c5", BigDecimal(9.1011)).toV2), "`c5` <= 9.1011") + Assertions.assertEquals(builder.build(StringStartsWith("c6","a").toV2), "`c6` LIKE 'a%'") + Assertions.assertEquals(builder.build(StringEndsWith("c7", "b").toV2), "`c7` LIKE '%b'") + Assertions.assertEquals(builder.build(StringContains("c8", "c").toV2), "`c8` LIKE '%c%'") + Assertions.assertEquals(builder.build(In("c9", Array(12,13,14)).toV2), "`c9` IN (12,13,14)") + Assertions.assertEquals(builder.build(IsNull("c10").toV2), "`c10` IS NULL") + Assertions.assertEquals(builder.build(Not(IsNull("c11")).toV2), "`c11` IS NOT NULL") + Assertions.assertEquals(builder.build(And(EqualTo("c12", 15), EqualTo("c13", 16)).toV2), "(`c12` = 15 AND `c13` = 16)") + Assertions.assertEquals(builder.build(Or(EqualTo("c14", 17), EqualTo("c15", 18)).toV2), "(`c14` = 17 OR `c15` = 18)") + Assertions.assertEquals(builder.build(AlwaysTrue.toV2), "1=1") + Assertions.assertEquals(builder.build(AlwaysFalse.toV2), "1=0") + Assertions.assertNull(builder.build(In("c19", Array(19,20,21,22,23,24,25,26,27,28,29)).toV2)) + + } + +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org