Skip to content

Commit

Permalink
Don't swap the left and right relations of spatial joins, simply inve…
Browse files Browse the repository at this point in the history
…rse the spatial predicate instead.

Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join.
  • Loading branch information
Kontinuation committed Mar 12, 2023
1 parent 6a5651f commit 892039a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,39 @@ import org.locationtech.jts.geom.Geometry
* @param right right side of the join
* @param leftShape expression for the first argument of spatialPredicate
* @param rightShape expression for the second argument of spatialPredicate
* @param swappedLeftAndRight boolean indicating whether left and right plans were swapped
* @param distance - ST_Distance(left, right) <= distance. Distance can be literal or a computation over 'left'.
* @param distance - ST_Distance(left, right) <= distance. Distance can be literal or a computation over 'left' or 'right'.
* @param distanceBoundToLeft whether distance expression references attributes from left relation or right relation
* @param spatialPredicate spatial predicate as join condition
* @param extraCondition extra join condition other than spatialPredicate
*/
case class DistanceJoinExec(left: SparkPlan,
right: SparkPlan,
leftShape: Expression,
rightShape: Expression,
swappedLeftAndRight: Boolean,
distance: Expression,
distanceBoundToLeft: Boolean,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None)
extends SedonaBinaryExecNode
with TraitJoinQueryExec
with Logging {

private val boundRadius = BindReferences.bindReference(distance, left.output)
private val boundRadius = if (distanceBoundToLeft) {
BindReferences.bindReference(distance, left.output)
} else {
BindReferences.bindReference(distance, right.output)
}

override def toSpatialRddPair(
buildRdd: RDD[UnsafeRow],
buildExpr: Expression,
streamedRdd: RDD[UnsafeRow],
streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
(toExpandedEnvelopeRDD(buildRdd, buildExpr, boundRadius), toSpatialRDD(streamedRdd, streamedExpr))
override def toSpatialRddPair(leftRdd: RDD[UnsafeRow],
leftShapeExpr: Expression,
rightRdd: RDD[UnsafeRow],
rightShapeExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
if (distanceBoundToLeft) {
(toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius), toSpatialRDD(rightRdd, rightShapeExpr))
} else {
(toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, rightShapeExpr, boundRadius))
}
}

protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = {
copy(left = newLeft, right = newRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
val relationship = s"ST_$spatialPredicate"

matchExpressionsToPlans(a, b, left, right) match {
case Some((planA, planB, swappedLeftAndRight)) =>
case Some((_, _, false)) =>
logInfo(s"Planning spatial join for $relationship relationship")
RangeJoinExec(planLater(planA), planLater(planB), a, b, swappedLeftAndRight, spatialPredicate, extraCondition) :: Nil
RangeJoinExec(planLater(left), planLater(right), a, b, spatialPredicate, extraCondition) :: Nil
case Some((_, _, true)) =>
logInfo(s"Planning spatial join for $relationship relationship with swapped left and right shapes")
val invSpatialPredicate = SpatialPredicate.inverse(spatialPredicate)
RangeJoinExec(planLater(left), planLater(right), b, a, invSpatialPredicate, extraCondition) :: Nil
case None =>
logInfo(
s"Spatial join for $relationship with arguments not aligned " +
Expand All @@ -231,18 +235,22 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
val b = children.tail.head

matchExpressionsToPlans(a, b, left, right) match {
case Some((planA, planB, swappedLeftAndRight)) =>
if (distance.references.isEmpty || matches(distance, planA)) {
logInfo("Planning spatial distance join")
DistanceJoinExec(planLater(planA), planLater(planB), a, b, swappedLeftAndRight, distance, spatialPredicate, extraCondition) :: Nil
} else if (matches(distance, planB)) {
logInfo("Planning spatial distance join")
DistanceJoinExec(planLater(planB), planLater(planA), b, a, swappedLeftAndRight, distance, spatialPredicate, extraCondition) :: Nil
} else {
logInfo(
"Spatial distance join for ST_Distance with non-scalar distance " +
"that is not a computation over just one side of the join is not supported")
Nil
case Some((_, _, swappedLeftAndRight)) =>
val (leftShape, rightShape) = if (swappedLeftAndRight) (b, a) else (a, b)
matchDistanceExpressionToJoinSide(distance, left, right) match {
case Some(LeftSide) =>
logInfo("Planning spatial distance join, distance bound to left relation")
DistanceJoinExec(planLater(left), planLater(right), leftShape, rightShape, distance, distanceBoundToLeft = true,
spatialPredicate, extraCondition) :: Nil
case Some(RightSide) =>
logInfo("Planning spatial distance join, distance bound to right relation")
DistanceJoinExec(planLater(left), planLater(right), leftShape, rightShape, distance, distanceBoundToLeft = false,
spatialPredicate, extraCondition) :: Nil
case _ =>
logInfo(
"Spatial distance join for ST_Distance with non-scalar distance " +
"that is not a computation over just one side of the join is not supported")
Nil
}
case None =>
logInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
* @param right right side of the join
* @param leftShape expression for the first argument of spatialPredicate
* @param rightShape expression for the second argument of spatialPredicate
* @param swappedLeftAndRight boolean indicating whether left and right plans were swapped
* @param spatialPredicate spatial predicate as join condition
* @param extraCondition extra join condition other than spatialPredicate
*/
case class RangeJoinExec(left: SparkPlan,
right: SparkPlan,
leftShape: Expression,
rightShape: Expression,
swappedLeftAndRight: Boolean,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None)
extends SedonaBinaryExecNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,27 @@ import org.apache.sedona.core.utils.SedonaConf
import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.execution.SparkPlan
import org.locationtech.jts.geom.Geometry

trait TraitJoinQueryBase {
self: SparkPlan =>

def toSpatialRddPair(buildRdd: RDD[UnsafeRow],
buildExpr: Expression,
streamedRdd: RDD[UnsafeRow],
streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
(toSpatialRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
def toSpatialRddPair(leftRdd: RDD[UnsafeRow],
leftShapeExpr: Expression,
rightRdd: RDD[UnsafeRow],
rightShapeExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
(toSpatialRDD(leftRdd, leftShapeExpr), toSpatialRDD(rightRdd, rightShapeExpr))

def toSpatialRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(
rdd
.map { x => {
.map { x =>
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
//logInfo(shape.toString)
shape.setUserData(x.copy)
shape
}
}
.toJavaRDD())
spatialRdd
}
Expand All @@ -55,7 +52,7 @@ trait TraitJoinQueryBase {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(
rdd
.map { x => {
.map { x =>
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
val envelope = shape.getEnvelopeInternal.copy()
envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])
Expand All @@ -64,7 +61,6 @@ trait TraitJoinQueryBase {
expandedEnvelope.setUserData(x.copy)
expandedEnvelope
}
}
.toJavaRDD())
spatialRdd
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
val right: SparkPlan
val leftShape: Expression
val rightShape: Expression
val swappedLeftAndRight: Boolean
val spatialPredicate: SpatialPredicate
val extraCondition: Option[Expression]

override def output: Seq[Attribute] = {
if (!swappedLeftAndRight) left.output ++ right.output else right.output ++ left.output
}
override def output: Seq[Attribute] = left.output ++ right.output

override protected def doExecute(): RDD[InternalRow] = {
val boundLeftShape = BindReferences.bindReference(leftShape, left.output)
Expand Down Expand Up @@ -125,12 +122,9 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
logDebug(s"Join result has ${matchesRDD.count()} rows")

matchesRDD.mapPartitions { iter =>
val joinRow = if (!swappedLeftAndRight) {
val joinRow = {
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
(l: UnsafeRow, r: UnsafeRow) => joiner.join(l, r)
} else {
val joiner = GenerateUnsafeRowJoiner.create(right.schema, left.schema)
(l: UnsafeRow, r: UnsafeRow) => joiner.join(r, l)
}

val joined = iter.map { case (l, r) =>
Expand Down

0 comments on commit 892039a

Please sign in to comment.