Skip to content

Commit

Permalink
[SEDONA-495] Raster data source uses shared FileSystem connections wh…
Browse files Browse the repository at this point in the history
…ich lead to race condition (#1236)
  • Loading branch information
jiayuasu committed Apr 28, 2024
1 parent af61dc4 commit 1ae3732
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

package org.apache.spark.sql.sedona_sql.io.raster

import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -29,7 +29,6 @@ import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType

import java.io.IOException
import java.nio.file.Paths
import java.util.UUID

private[spark] class RasterFileFormat extends FileFormat with DataSourceRegister {
Expand Down Expand Up @@ -82,7 +81,7 @@ private class RasterFileWriter(savePath: String,
dataSchema: StructType,
context: TaskAttemptContext) extends OutputWriter {

private val hfs = new Path(savePath).getFileSystem(context.getConfiguration)
private val hfs = FileSystem.newInstance(new Path(savePath).toUri, context.getConfiguration)
private val rasterFieldIndex = if (rasterOptions.rasterField.isEmpty) getRasterFieldIndex else dataSchema.fieldIndex(rasterOptions.rasterField.get)

private def getRasterFieldIndex: Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,6 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {

override def beforeAll(): Unit = {
SedonaContext.create(sparkSession)
// Set up HDFS minicluster
val baseDir = new File("./target/hdfs/").getAbsoluteFile
FileUtil.fullyDelete(baseDir)
val hdfsConf = new HdfsConfiguration
hdfsConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath)
val builder = new MiniDFSCluster.Builder(hdfsConf)
val hdfsCluster = builder.build
hdfsURI = "hdfs://127.0.0.1:" + hdfsCluster.getNameNodePort + "/"
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -186,4 +178,18 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
}).sum
}).sum
}

/**
* Create a mini HDFS cluster and return the HDFS instance and the URI.
* @return (MiniDFSCluster, HDFS URI)
*/
def creatMiniHdfs(): (MiniDFSCluster, String) = {
val baseDir = new File("./target/hdfs/").getAbsoluteFile
FileUtil.fullyDelete(baseDir)
val hdfsConf = new HdfsConfiguration
hdfsConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath)
val builder = new MiniDFSCluster.Builder(hdfsConf)
val hdfsCluster = builder.build
(hdfsCluster, "hdfs://127.0.0.1:" + hdfsCluster.getNameNodePort + "/")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.sedona.sql

import org.apache.commons.io.FileUtils
import org.apache.hadoop.hdfs.MiniDFSCluster
import org.apache.spark.sql.SaveMode
import org.junit.Assert.assertEquals
import org.scalatest.{BeforeAndAfter, GivenWhenThen}
Expand Down Expand Up @@ -149,12 +150,14 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
}

it("should read geotiff using binary source and write geotiff back to hdfs using raster source") {
var rasterDf = sparkSession.read.format("binaryFile").load(rasterdatalocation)
val miniHDFS: (MiniDFSCluster, String) = creatMiniHdfs()
var rasterDf = sparkSession.read.format("binaryFile").load(rasterdatalocation).repartition(3)
val rasterCount = rasterDf.count()
rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(hdfsURI + "/raster-written")
rasterDf = sparkSession.read.format("binaryFile").load(hdfsURI + "/raster-written/*")
rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(miniHDFS._2 + "/raster-written")
rasterDf = sparkSession.read.format("binaryFile").load(miniHDFS._2 + "/raster-written/*")
rasterDf = rasterDf.selectExpr("RS_FromGeoTiff(content)")
assert(rasterDf.count() == rasterCount)
miniHDFS._1.shutdown()
}
}

Expand Down

0 comments on commit 1ae3732

Please sign in to comment.