Skip to content

Commit

Permalink
check that pushFilters is called from explain()
Browse files Browse the repository at this point in the history
  • Loading branch information
wengh committed Feb 28, 2025
1 parent 21b6eaa commit 48ab696
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CaseInsensitiveDict,
)
from pyspark.sql.functions import spark_partition_id
from pyspark.sql.session import SparkSession
from pyspark.sql.types import Row, StructType
from pyspark.testing.sqlutils import (
have_pyarrow,
Expand All @@ -44,6 +45,8 @@

@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class BasePythonDataSourceTestsMixin:
spark: SparkSession

def test_basic_data_source_class(self):
class MyDataSource(DataSource):
...
Expand Down Expand Up @@ -318,9 +321,11 @@ def reader(self, schema) -> "DataSourceReader":
self.spark.read.format("test").load().filter("x = 1").show()

def test_filter_pushdown_error(self):
error_str = "dummy error"

class TestDataSourceReader(DataSourceReader):
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
raise Exception("dummy error")
raise Exception(error_str)

def read(self, partition):
yield [1]
Expand All @@ -336,8 +341,8 @@ def reader(self, schema) -> "DataSourceReader":
self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("TestDataSource").load().filter("x = 1 or x is null")
assertDataFrameEqual(df, [Row(x=1)]) # works when not pushing down filters
with self.assertRaisesRegex(Exception, "dummy error"):
df.filter("x = 1").show()
with self.assertRaisesRegex(Exception, error_str):
df.filter("x = 1").explain()

def test_filter_pushdown_disabled(self):
class TestDataSourceReader(DataSourceReader):
Expand Down

0 comments on commit 48ab696

Please sign in to comment.