Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-205] Use BinaryType in GeometryUDT in Sedona Spark. #734

Merged
merged 1 commit into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions python/sedona/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,21 @@

import struct

from pyspark.sql.types import UserDefinedType, ArrayType, ByteType
from pyspark.sql.types import UserDefinedType, BinaryType
from shapely.wkb import dumps, loads


class GeometryType(UserDefinedType):

@classmethod
def sqlType(cls):
return ArrayType(ByteType(), containsNull=False)

def fromInternal(self, obj):
deserialized_obj = None
if obj is not None:
deserialized_obj = self.deserialize(obj)
return deserialized_obj

def toInternal(self, obj):
serialized_obj = None
if obj is not None:
serialized_obj = [el - 256 if el >= 128 else el for el in self.serialize(obj)]
return serialized_obj
return BinaryType()

def serialize(self, obj):
return dumps(obj)

def deserialize(self, datum):
bytes_data = b''.join([struct.pack('b', el) for el in datum])
geom = loads(bytes_data)
return geom
return loads(bytes(datum))

@classmethod
def module(cls):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ object Adapter {
case _: TimestampType => null.asInstanceOf[Timestamp]
case _: BooleanType => null.asInstanceOf[Boolean]
case _: StringType => null.asInstanceOf[String]
case _: BinaryType => null.asInstanceOf[Array[Byte]]
case _: StructType => null.asInstanceOf[StructType]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ object GeometrySerializer {
/**
* Given ArrayData returns Geometry
*
* @param values ArrayData
* @param value ArrayData
* @return JTS geometry
*/
def deserialize(values: ArrayData): Geometry = {
def deserialize(value: Array[Byte]): Geometry = {
val reader = new WKBReader()
reader.read(values.toByteArray())
reader.read(value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,21 @@ import org.locationtech.jts.geom.Geometry


class GeometryUDT extends UserDefinedType[Geometry] {
override def sqlType: DataType = ArrayType(ByteType, containsNull = false)
override def sqlType: DataType = BinaryType

override def pyUDT: String = "sedona.sql.types.GeometryType"

override def userClass: Class[Geometry] = classOf[Geometry]

override def serialize(obj: Geometry): GenericArrayData =
new GenericArrayData(GeometrySerializer.serialize(obj))
override def serialize(obj: Geometry): Array[Byte] = GeometrySerializer.serialize(obj)

override def deserialize(datum: Any): Geometry = {
datum match {
case values: ArrayData =>
GeometrySerializer.deserialize(values)
case value: Array[Byte] => GeometrySerializer.deserialize(value)
}
}


override private[sql] def jsonValue: JValue = {
super.jsonValue mapField {
case ("class", _) => "class" -> this.getClass.getName.stripSuffix("$")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression])
var fileDataSplitter = FileDataSplitter.getFileDataSplitter(geomFormat)
var formatMapper = new FormatMapper(fileDataSplitter, false, GeometryType.POINT)
var geometry = formatMapper.readGeometry(geomString)
new GenericArrayData(GeometrySerializer.serialize(geometry))
GeometrySerializer.serialize(geometry)
}

override def dataType: DataType = GeometryUDT
Expand Down Expand Up @@ -90,7 +90,7 @@ case class ST_PolygonFromText(inputExpressions: Seq[Expression])
var fileDataSplitter = FileDataSplitter.getFileDataSplitter(geomFormat)
var formatMapper = new FormatMapper(fileDataSplitter, false, GeometryType.POLYGON)
var geometry = formatMapper.readGeometry(geomString)
new GenericArrayData(GeometrySerializer.serialize(geometry))
GeometrySerializer.serialize(geometry)
}

override def dataType: DataType = GeometryUDT
Expand Down Expand Up @@ -123,7 +123,7 @@ case class ST_LineFromText(inputExpressions: Seq[Expression])
var formatMapper = new FormatMapper(fileDataSplitter, false)
var geometry = formatMapper.readGeometry(lineString)
if(geometry.getGeometryType.contains("LineString")) {
new GenericArrayData(GeometrySerializer.serialize(geometry))
GeometrySerializer.serialize(geometry)
} else {
null
}
Expand Down Expand Up @@ -160,7 +160,7 @@ case class ST_LineStringFromText(inputExpressions: Seq[Expression])
var formatMapper = new FormatMapper(fileDataSplitter, false, GeometryType.LINESTRING)
var geometry = formatMapper.readGeometry(geomString)

new GenericArrayData(GeometrySerializer.serialize(geometry))
GeometrySerializer.serialize(geometry)
}

override def dataType: DataType = GeometryUDT
Expand Down Expand Up @@ -265,7 +265,7 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
if (inputExpressions.length > 1) {
geometry.setUserData(generateUserData(minInputLength, inputExpressions, inputRow))
}
new GenericArrayData(GeometrySerializer.serialize(geometry))
GeometrySerializer.serialize(geometry)
}

override def dataType: DataType = GeometryUDT
Expand Down Expand Up @@ -296,7 +296,7 @@ case class ST_Point(inputExpressions: Seq[Expression])
}
val geometryFactory = new GeometryFactory()
val geometry = geometryFactory.createPoint(coord)
new GenericArrayData(GeometrySerializer.serialize(geometry))
GeometrySerializer.serialize(geometry)
}

override def dataType: DataType = GeometryUDT
Expand Down Expand Up @@ -335,7 +335,7 @@ case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression])
coordinates(4) = coordinates(0)
val geometryFactory = new GeometryFactory()
val polygon = geometryFactory.createPolygon(coordinates)
new GenericArrayData(GeometrySerializer.serialize(polygon))
GeometrySerializer.serialize(polygon)
}

override def dataType: DataType = GeometryUDT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.collect.Collect
import org.apache.spark.sql.sedona_sql.expressions.implicits._
Expand Down Expand Up @@ -334,20 +334,16 @@ case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
}
}

// TODO: sernetcdf is bundled with an ancient version of apache commons-codec, which
// causes spark sql to throw NoSuchMethodError when folding binary expressions.
case class ST_AsBinary(inputExpressions: Seq[Expression])
extends InferredUnaryExpression(Functions.asWKB) {
extends InferredUnaryExpression(Functions.asWKB) with FoldableExpression {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

// TODO: ST_AsEWKB is similar to ST_AsBinary, which is also affected by the sernetcdf
// problem.
case class ST_AsEWKB(inputExpressions: Seq[Expression])
extends InferredUnaryExpression(Functions.asEWKB) {
extends InferredUnaryExpression(Functions.asEWKB) with FoldableExpression {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down Expand Up @@ -769,7 +765,7 @@ case class ST_MakePolygon(inputExpressions: Seq[Expression])
val numOfElements = possibleHolesRaw.map(_.numElements()).getOrElse(0)

val holes = (0 until numOfElements).map(el => possibleHolesRaw match {
case Some(value) => Some(value.getArray(el))
case Some(value) => Some(value.getBinary(el))
case None => None
}).filter(_.nonEmpty)
.map(el => el.map(_.toGeometry))
Expand Down Expand Up @@ -833,7 +829,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression])
extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {

override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
new GenericArrayData(GeometrySerializer.serialize(leftGeometry.symDifference(rightGeometry)))
leftGeometry.symDifference(rightGeometry).toGenericArrayData
}

override def dataType: DataType = GeometryUDT
Expand All @@ -854,7 +850,7 @@ case class ST_Union(inputExpressions: Seq[Expression])
extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {

override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
new GenericArrayData(GeometrySerializer.serialize(leftGeometry.union(rightGeometry)))
leftGeometry.union(rightGeometry).toGenericArrayData
}

override def dataType: DataType = GeometryUDT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ abstract class ST_Predicate extends Expression
override def children: Seq[Expression] = inputExpressions

override final def eval(inputRow: InternalRow): Any = {
val leftArray = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData]
val leftArray = inputExpressions(0).eval(inputRow).asInstanceOf[Array[Byte]]
if (leftArray == null) {
null
} else {
val rightArray = inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData]
val rightArray = inputExpressions(1).eval(inputRow).asInstanceOf[Array[Byte]]
if (rightArray == null) {
null
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class ST_Collect(inputExpressions: Seq[Expression])
val data = firstElement.eval(input).asInstanceOf[ArrayData]
val numElements = data.numElements()
val geomElements = (0 until numElements)
.map(element => data.getArray(element))
.map(element => data.getBinary(element))
.filter(_ != null)
.map(_.toGeometry)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ case class ST_CollectionExtract(inputExpressions: Seq[Expression]) extends Expre

override def nullable: Boolean = true

def nullSafeEval(geometry: Geometry, geomType: GeomTypeVal): GenericArrayData = {
def nullSafeEval(geometry: Geometry, geomType: GeomTypeVal): Array[Byte] = {
val geometries : util.ArrayList[Geometry] = new util.ArrayList[Geometry]()
filterGeometry(geometries, geometry, geomType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ object implicits {

implicit class InputExpressionEnhancer(inputExpression: Expression) {
def toGeometry(input: InternalRow): Geometry = {
inputExpression.eval(input).asInstanceOf[ArrayData] match {
case arrData: ArrayData => GeometrySerializer.deserialize(arrData)
inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
case binary: Array[Byte] => GeometrySerializer.deserialize(binary)
case _ => null
}
}
Expand Down Expand Up @@ -63,10 +63,10 @@ object implicits {
}
}

implicit class ArrayDataEnhancer(arrayData: ArrayData) {
implicit class ArrayDataEnhancer(arrayData: Array[Byte]) {
def toGeometry: Geometry = {
arrayData match {
case arrData: ArrayData => GeometrySerializer.deserialize(arrData)
case binary: Array[Byte] => GeometrySerializer.deserialize(binary)
case _ => null
}
}
Expand All @@ -75,8 +75,7 @@ object implicits {
implicit class GeometryEnhancer(geom: Geometry) {
private val geometryFactory = new GeometryFactory()

def toGenericArrayData: GenericArrayData =
new GenericArrayData(GeometrySerializer.serialize(geom))
def toGenericArrayData: Array[Byte] = GeometrySerializer.serialize(geom)

def getPoints: Array[Point] =
geom.getCoordinates.map(coordinate => geometryFactory.createPoint(coordinate))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private class GeotiffFileWriter(savePath: String,
val envGeom = wktReader.read(tiffGeometry.toString).asInstanceOf[Polygon]
coordinateList = envGeom.getCoordinates()
} else {
val envGeom = GeometrySerializer.deserialize(tiffGeometry.asInstanceOf[ArrayData])
val envGeom = GeometrySerializer.deserialize(tiffGeometry.asInstanceOf[Array[Byte]])
coordinateList = envGeom.getCoordinates()
}
val referencedEnvelope = new ReferencedEnvelope(coordinateList(0).x, coordinateList(2).x, coordinateList(0).y, coordinateList(2).y, crs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait TraitJoinQueryBase {
spatialRdd.setRawSpatialRDD(
rdd
.map { x => {
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
//logInfo(shape.toString)
shape.setUserData(x.copy)
shape
Expand All @@ -56,7 +56,7 @@ trait TraitJoinQueryBase {
spatialRdd.setRawSpatialRDD(
rdd
.map { x => {
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
val envelope = shape.getEnvelopeInternal.copy()
envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ case class ST_Pixelize(inputExpressions: Seq[Expression])
override def toString: String = s" **${ST_Pixelize.getClass.getName}** "

override def eval(input: InternalRow): Any = {
val inputGeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
val inputGeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[Array[Byte]])
val resolutionX = inputExpressions(1).eval(input).asInstanceOf[Integer]
val resolutionY = inputExpressions(2).eval(input).asInstanceOf[Integer]
val boundary = GeometrySerializer.deserialize(inputExpressions(3).eval(input).asInstanceOf[ArrayData]).getEnvelopeInternal
val boundary = GeometrySerializer.deserialize(inputExpressions(3).eval(input).asInstanceOf[Array[Byte]]).getEnvelopeInternal
val reverseCoordinate = false
val pixels = inputGeometry match {
case geometry: LineString => {
Expand Down