Skip to content

Commit

Permalink
[SEDONA-221] Outer join throws NPE for null geometries. (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
umartin authored Jan 20, 2023
1 parent 43e1d79 commit 9229439
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.sedona_sql.strategy.join

import org.apache.sedona.core.spatialOperator.{SpatialPredicate, SpatialPredicateEvaluators}
import org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator
import org.apache.sedona.sql.utils.GeometrySerializer

import scala.collection.JavaConverters._
import org.apache.spark.broadcast.Broadcast
Expand All @@ -29,12 +30,14 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.prep.{PreparedGeometry, PreparedGeometryFactory}
import org.locationtech.jts.index.SpatialIndex

import java.util.Collections
import scala.collection.mutable

case class BroadcastIndexJoinExec(
Expand Down Expand Up @@ -68,6 +71,10 @@ case class BroadcastIndexJoinExec(
}
}

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))


private val (streamed, broadcast) = indexBuildSide match {
case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
Expand Down Expand Up @@ -115,34 +122,34 @@ case class BroadcastIndexJoinExec(
SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate))
}

private def innerJoin(streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = {
private def innerJoin(streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
streamIter.flatMap { srow =>
joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
index.value.query(srow.getEnvelopeInternal)
streamIter.flatMap { case (geom, row) =>
joinedRow.withLeft(row)
index.value.query(geom.getEnvelopeInternal)
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { factory.create(candidate) }), srow))
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { factory.create(candidate) }), geom))
.map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
.filter(boundCondition)
}
}

private def semiJoin(
streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]
): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
streamIter.flatMap { srow =>
val left = srow.getUserData.asInstanceOf[UnsafeRow]
streamIter.flatMap { case (geom, row) =>
val left = row
joinedRow.withLeft(left)
val anyMatches = index.value.query(srow.getEnvelopeInternal)
val anyMatches = index.value.query(geom.getEnvelopeInternal)
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate)
}), srow))
}), geom))
.map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
.exists(boundCondition)

Expand All @@ -155,19 +162,19 @@ case class BroadcastIndexJoinExec(
}

private def antiJoin(
streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]
): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
streamIter.flatMap { srow =>
val left = srow.getUserData.asInstanceOf[UnsafeRow]
joinedRow.withLeft(left)
val anyMatches = index.value.query(srow.getEnvelopeInternal)
streamIter.flatMap { case (geom, row) =>
val left = row
joinedRow.withLeft(row)
val anyMatches = (if (geom == null) Collections.EMPTY_LIST else index.value.query(geom.getEnvelopeInternal))
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate)
}), srow))
}), geom))
.map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
.exists(boundCondition)

Expand All @@ -180,20 +187,20 @@ case class BroadcastIndexJoinExec(
}

private def outerJoin(
streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]
): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
val nullRow = new GenericInternalRow(broadcast.output.length)

streamIter.flatMap { srow =>
joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
val candidates = index.value.query(srow.getEnvelopeInternal)
streamIter.flatMap { case (geom, row) =>
joinedRow.withLeft(row)
val candidates = (if (geom == null) Collections.EMPTY_LIST else index.value.query(geom.getEnvelopeInternal))
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate)
}), srow))
}), geom))

new RowIterator {
private var found = false
Expand All @@ -218,20 +225,15 @@ case class BroadcastIndexJoinExec(
}

override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val boundStreamShape = BindReferences.bindReference(streamShape, streamed.output)
val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]

val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()

// If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
val streamShapes = distance match {
case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
toExpandedEnvelopeRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(distanceExpression, streamed.output))
case _ =>
toSpatialRDD(streamResultsRaw, boundStreamShape)
}
val streamShapes = createStreamShapes(streamResultsRaw, boundStreamShape)

streamShapes.getRawSpatialRDD.rdd.mapPartitions { streamedIter =>
streamShapes.mapPartitions { streamedIter =>
val joinedIter = joinType match {
case _: InnerLike =>
innerJoin(streamedIter, broadcastIndex)
Expand All @@ -248,11 +250,40 @@ case class BroadcastIndexJoinExec(

val resultProj = createResultProjection()
joinedIter.map { r =>
numOutputRows += 1
resultProj(r)
}
}
}

private def createStreamShapes(streamResultsRaw: RDD[UnsafeRow], boundStreamShape: Expression) = {
// If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
distance match {
case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
streamResultsRaw.map(row => {
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
(null, row)
} else {
val geometry = GeometrySerializer.deserialize(geom)
val radius = BindReferences.bindReference(distanceExpression, streamed.output).eval(row).asInstanceOf[Double]
val envelope = geometry.getEnvelopeInternal
envelope.expandBy(radius)
(geometry.getFactory.toGeometry(envelope), row)
}
})
case _ =>
streamResultsRaw.map(row => {
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
(null, row)
} else {
(GeometrySerializer.deserialize(geom), row)
}
})
}
}

protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = {
copy(left = newLeft, right = newRight)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
}
}

describe("Sedona SQL automatic broadcast") {
describe("Sedona-SQL Automatic broadcast") {
it("Datasets smaller than threshold should be broadcasted") {
val polygonDf = buildPolygonDf.repartition(3).alias("polygon")
val pointDf = buildPointDf.repartition(5).alias("point")
Expand All @@ -1430,4 +1430,32 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(df.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 0)
}
}

describe("Sedona-SQL Broadcast join with null geometries") {
it("Left outer join with nulls on left side") {
import sparkSession.implicits._
val left = Seq(("1", "POINT(1 1)"), ("2", "POINT(1 1)"), ("3", "POINT(1 1)"), ("4", null))
.toDF("seq", "left_geom")
.withColumn("left_geom", expr("ST_GeomFromText(left_geom)"))
val right = Seq("POLYGON((2 0, 2 2, 0 2, 0 0, 2 0))")
.toDF("right_geom")
.withColumn("right_geom", expr("ST_GeomFromText(right_geom)"))
val result = left.join(broadcast(right), expr("ST_Intersects(left_geom, right_geom)"), "left")
assert(result.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
assert(result.count() == 4)
}

it("Left anti join with nulls on left side") {
import sparkSession.implicits._
val left = Seq(("1", "POINT(1 1)"), ("2", "POINT(1 1)"), ("3", "POINT(1 1)"), ("4", null))
.toDF("seq", "left_geom")
.withColumn("left_geom", expr("ST_GeomFromText(left_geom)"))
val right = Seq("POLYGON((2 0, 2 2, 0 2, 0 0, 2 0))")
.toDF("right_geom")
.withColumn("right_geom", expr("ST_GeomFromText(right_geom)"))
val result = left.join(broadcast(right), expr("ST_Intersects(left_geom, right_geom)"), "left_anti")
assert(result.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
assert(result.count() == 1)
}
}
}

0 comments on commit 9229439

Please sign in to comment.