One other caveat: While writing up this example I realized that we make SparkPlan private and we are already packaging 1.3-RC3... So you'll need a custom build of Spark for this to run. We'll fix this in the next release.
On Thu, Mar 5, 2015 at 5:26 PM, Michael Armbrust <mich...@databricks.com> wrote: > Currently we have implemented External Data Source API and are able to >> push filters and projections. >> >> Could you provide some info on how perhaps the joins could be pushed to >> the original Data Source if both the data sources are from same database >> *.* >> > > First a disclaimer: This is an experimental API that exposes internals > that are likely to change in between different Spark releases. As a > result, most datasources should be written against the stable public API in > org.apache.spark.sql.sources > <https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala>. > We expose this mostly to get feedback on what optimizations we should add > to the stable API in order to get the best performance out of data sources. > > We'll start with a simple artificial data source that just returns ranges > of consecutive integers. > > /** A data source that returns ranges of consecutive integers in a column > named `a`. */case class SimpleRelation( > start: Int, > end: Int)( > @transient val sqlContext: SQLContext) > extends BaseRelation with TableScan { > > val schema = StructType('a.int :: Nil) > def buildScan() = sqlContext.sparkContext.parallelize(start to > end).map(Row(_)) > } > > > Given this we can create tables: > > sqlContext.baseRelationToDataFrame(SimpleRelation(1, > 1)(sqlContext)).registerTempTable("smallTable") > sqlContext.baseRelationToDataFrame(SimpleRelation(1, > 10000000)(sqlContext)).registerTempTable("bigTable") > > > However, doing a join is pretty slow since we need to shuffle the big > table around for no reason: > > sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a").collect() > res3: Array[org.apache.spark.sql.Row] = Array([1,1]) > > > This takes about 10 seconds on my cluster. Clearly we can do better. So > let's define special physical operators for the case when we are inner > joining two of these relations using equality. One will handle the case > when there is no overlap and the other when there is. Physical operators > must extend SparkPlan and must return an RDD[Row] containing the answer > when execute() is called. > > import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo}import > org.apache.spark.sql.catalyst.plans._import > org.apache.spark.sql.catalyst.plans.logical._import > org.apache.spark.sql.execution.SparkPlan > /** A join that just returns the pre-calculated overlap of two ranges of > consecutive integers. */case class OverlappingRangeJoin(leftOutput: > Attribute, rightOutput: Attribute, start: Int, end: Int) extends SparkPlan { > def output: Seq[Attribute] = leftOutput :: rightOutput :: Nil > > def execute(): org.apache.spark.rdd.RDD[Row] = { > sqlContext.sparkContext.parallelize(start to end).map(i => Row(i, i)) > } > > def children: Seq[SparkPlan] = Nil > } > /** Used when a join is known to produce no results. */case class > EmptyJoin(output: Seq[Attribute]) extends SparkPlan { > def execute(): org.apache.spark.rdd.RDD[Row] = { > sqlContext.sparkContext.emptyRDD > } > > def children: Seq[SparkPlan] = Nil > } > /** Finds cases where two sets of consecutive integer ranges are inner joined > on equality. */object SmartSimpleJoin extends Strategy with Serializable { > def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { > // Find inner joins between two SimpleRelations where the condition is > equality. > case Join(l @ LogicalRelation(left: SimpleRelation), r @ > LogicalRelation(right: SimpleRelation), Inner, Some(EqualTo(a, b))) => > // Check if the join condition is comparing `a` from each relation. > if (a == l.output.head && b == r.output.head || a == r.output.head && b > == l.output.head) { > if ((left.start <= right.end) && (left.end >= right.start)) { > OverlappingRangeJoin( > l.output.head, > r.output.head, > math.max(left.start, right.start), > math.min(left.end, right.end)) :: Nil > } else { > // Ranges don't overlap, join will be empty > EmptyJoin(l.output.head :: r.output.head :: Nil) :: Nil > } > } else { > // Join isn't between the the columns output... > // Let's just let the query planner handle this. > Nil > } > case _ => Nil // Return an empty list if we don't know how to handle this > plan. > } > } > > > We can then add these strategies to the query planner through the > experimental hook. Added strategies take precedence over built-in ones. > > // Add the strategy to the query planner. > sqlContext.experimental.extraStrategies = SmartSimpleJoin :: Nil > > > sql("SELECT * FROM smallTable s JOIN bigTable b ON s.a = b.a").collect() > res4: Array[org.apache.spark.sql.Row] = Array([1,1]) > > > Now our join returns in < 1 second. For more advanced matching of joins > and their conditions you should look at the patterns that are available > <https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala>, > and the built-in join strategies > <https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala>. > Let me know if you have any questions. > > Michael >