One other caveat: While writing up this example I realized that we make
SparkPlan private and we are already packaging 1.3-RC3... So you'll need a
custom build of Spark for this to run. We'll fix this in the next release.

On Thu, Mar 5, 2015 at 5:26 PM, Michael Armbrust <mich...@databricks.com>
wrote:

> Currently we have implemented  External Data Source API and are able to
>> push filters and projections.
>>
>> Could you provide some info on how perhaps the joins could be pushed to
>> the original Data Source if both the data sources are from same database
>> *.*
>>
>
> First a disclaimer: This is an experimental API that exposes internals
> that are likely to change in between different Spark releases.  As a
> result, most datasources should be written against the stable public API in
> org.apache.spark.sql.sources
> <https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala>.
> We expose this mostly to get feedback on what optimizations we should add
> to the stable API in order to get the best performance out of data sources.
>
> We'll start with a simple artificial data source that just returns ranges
> of consecutive integers.
>
> /** A data source that returns ranges of consecutive integers in a column 
> named `a`. */case class SimpleRelation(
>     start: Int,
>     end: Int)(
>     @transient val sqlContext: SQLContext)
>   extends BaseRelation with TableScan {
>
>   val schema = StructType('a.int :: Nil)
>   def buildScan() = sqlContext.sparkContext.parallelize(start to 
> end).map(Row(_))
> }
>
>
> Given this we can create tables:
>
> sqlContext.baseRelationToDataFrame(SimpleRelation(1, 
> 1)(sqlContext)).registerTempTable("smallTable")
> sqlContext.baseRelationToDataFrame(SimpleRelation(1, 
> 10000000)(sqlContext)).registerTempTable("bigTable")
>
>
> However, doing a join is pretty slow since we need to shuffle the big
> table around for no reason:
>
> sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a").collect()
> res3: Array[org.apache.spark.sql.Row] = Array([1,1])
>
>
> This takes about 10 seconds on my cluster.  Clearly we can do better.  So
> let's define special physical operators for the case when we are inner
> joining two of these relations using equality. One will handle the case
> when there is no overlap and the other when there is.  Physical operators
> must extend SparkPlan and must return an RDD[Row] containing the answer
> when execute() is called.
>
> import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo}import 
> org.apache.spark.sql.catalyst.plans._import 
> org.apache.spark.sql.catalyst.plans.logical._import 
> org.apache.spark.sql.execution.SparkPlan
> /** A join that just returns the pre-calculated overlap of two ranges of 
> consecutive integers. */case class OverlappingRangeJoin(leftOutput: 
> Attribute, rightOutput: Attribute, start: Int, end: Int) extends SparkPlan {
>   def output: Seq[Attribute] = leftOutput :: rightOutput :: Nil
>
>   def execute(): org.apache.spark.rdd.RDD[Row] = {
>     sqlContext.sparkContext.parallelize(start to end).map(i => Row(i, i))
>   }
>
>   def children: Seq[SparkPlan] = Nil
> }
> /** Used when a join is known to produce no results. */case class 
> EmptyJoin(output: Seq[Attribute]) extends SparkPlan {
>   def execute(): org.apache.spark.rdd.RDD[Row] = {
>     sqlContext.sparkContext.emptyRDD
>   }
>
>   def children: Seq[SparkPlan] = Nil
> }
> /** Finds cases where two sets of consecutive integer ranges are inner joined 
> on equality. */object SmartSimpleJoin extends Strategy with Serializable {
>   def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
>     // Find inner joins between two SimpleRelations where the condition is 
> equality.
>     case Join(l @ LogicalRelation(left: SimpleRelation), r @ 
> LogicalRelation(right: SimpleRelation), Inner, Some(EqualTo(a, b))) =>
>       // Check if the join condition is comparing `a` from each relation.
>       if (a == l.output.head && b == r.output.head || a == r.output.head && b 
> == l.output.head) {
>         if ((left.start <= right.end) && (left.end >= right.start)) {
>           OverlappingRangeJoin(
>             l.output.head,
>             r.output.head,
>             math.max(left.start, right.start),
>             math.min(left.end, right.end)) :: Nil
>         } else {
>           // Ranges don't overlap, join will be empty
>           EmptyJoin(l.output.head :: r.output.head :: Nil) :: Nil
>         }
>       } else {
>         // Join isn't between the the columns output...
>         // Let's just let the query planner handle this.
>         Nil
>       }
>     case _ => Nil // Return an empty list if we don't know how to handle this 
> plan.
>   }
> }
>
>
> We can then add these strategies to the query planner through the
> experimental hook.  Added strategies take precedence over built-in ones.
>
> // Add the strategy to the query planner.
> sqlContext.experimental.extraStrategies = SmartSimpleJoin :: Nil
>
>
> sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a").collect()
> res4: Array[org.apache.spark.sql.Row] = Array([1,1])
>
>
> Now our join returns in < 1 second.  For more advanced matching of joins
> and their conditions you should look at the patterns that are available
> <https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala>,
> and the built-in join strategies
> <https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala>.
> Let me know if you have any questions.
>
> Michael
>

Reply via email to