Skip to content

Commit cde35ab

Browse files
authored
feat: support projection pushdown for datafusion iceberg (apache#594)
* support projection pushdown for datafusion iceberg * support projection pushdown for datafusion iceberg * fix ci * fix field id * remove depencences * remove depencences
1 parent eae9464 commit cde35ab

File tree

4 files changed

+134
-18
lines changed

4 files changed

+134
-18
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ iceberg = { version = "0.3.0", path = "./crates/iceberg" }
6464
iceberg-catalog-rest = { version = "0.3.0", path = "./crates/catalog/rest" }
6565
iceberg-catalog-hms = { version = "0.3.0", path = "./crates/catalog/hms" }
6666
iceberg-catalog-memory = { version = "0.3.0", path = "./crates/catalog/memory" }
67+
iceberg-datafusion = { version = "0.3.0", path = "./crates/integrations/datafusion" }
6768
itertools = "0.13"
6869
log = "0.4"
6970
mockito = "1"

crates/integrations/datafusion/src/physical_plan/scan.rs

+35-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use std::any::Any;
1919
use std::pin::Pin;
2020
use std::sync::Arc;
21+
use std::vec;
2122

2223
use datafusion::arrow::array::RecordBatch;
2324
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
@@ -44,17 +45,25 @@ pub(crate) struct IcebergTableScan {
4445
/// Stores certain, often expensive to compute,
4546
/// plan properties used in query optimization.
4647
plan_properties: PlanProperties,
48+
/// Projection column names, None means all columns
49+
projection: Option<Vec<String>>,
4750
}
4851

4952
impl IcebergTableScan {
5053
/// Creates a new [`IcebergTableScan`] object.
51-
pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self {
54+
pub(crate) fn new(
55+
table: Table,
56+
schema: ArrowSchemaRef,
57+
projection: Option<&Vec<usize>>,
58+
) -> Self {
5259
let plan_properties = Self::compute_properties(schema.clone());
60+
let projection = get_column_names(schema.clone(), projection);
5361

5462
Self {
5563
table,
5664
schema,
5765
plan_properties,
66+
projection,
5867
}
5968
}
6069

@@ -100,7 +109,7 @@ impl ExecutionPlan for IcebergTableScan {
100109
_partition: usize,
101110
_context: Arc<TaskContext>,
102111
) -> DFResult<SendableRecordBatchStream> {
103-
let fut = get_batch_stream(self.table.clone());
112+
let fut = get_batch_stream(self.table.clone(), self.projection.clone());
104113
let stream = futures::stream::once(fut).try_flatten();
105114

106115
Ok(Box::pin(RecordBatchStreamAdapter::new(
@@ -116,7 +125,13 @@ impl DisplayAs for IcebergTableScan {
116125
_t: datafusion::physical_plan::DisplayFormatType,
117126
f: &mut std::fmt::Formatter,
118127
) -> std::fmt::Result {
119-
write!(f, "IcebergTableScan")
128+
write!(
129+
f,
130+
"IcebergTableScan projection:[{}]",
131+
self.projection
132+
.clone()
133+
.map_or(String::new(), |v| v.join(","))
134+
)
120135
}
121136
}
122137

@@ -127,8 +142,13 @@ impl DisplayAs for IcebergTableScan {
127142
/// and then converts it into a stream of Arrow [`RecordBatch`]es.
128143
async fn get_batch_stream(
129144
table: Table,
145+
column_names: Option<Vec<String>>,
130146
) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
131-
let table_scan = table.scan().build().map_err(to_datafusion_error)?;
147+
let scan_builder = match column_names {
148+
Some(column_names) => table.scan().select(column_names),
149+
None => table.scan().select_all(),
150+
};
151+
let table_scan = scan_builder.build().map_err(to_datafusion_error)?;
132152

133153
let stream = table_scan
134154
.to_arrow()
@@ -138,3 +158,14 @@ async fn get_batch_stream(
138158

139159
Ok(Box::pin(stream))
140160
}
161+
162+
fn get_column_names(
163+
schema: ArrowSchemaRef,
164+
projection: Option<&Vec<usize>>,
165+
) -> Option<Vec<String>> {
166+
projection.map(|v| {
167+
v.iter()
168+
.map(|p| schema.field(*p).name().clone())
169+
.collect::<Vec<String>>()
170+
})
171+
}

crates/integrations/datafusion/src/table.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,14 @@ impl TableProvider for IcebergTableProvider {
7575
async fn scan(
7676
&self,
7777
_state: &dyn Session,
78-
_projection: Option<&Vec<usize>>,
78+
projection: Option<&Vec<usize>>,
7979
_filters: &[Expr],
8080
_limit: Option<usize>,
8181
) -> DFResult<Arc<dyn ExecutionPlan>> {
8282
Ok(Arc::new(IcebergTableScan::new(
8383
self.table.clone(),
8484
self.schema.clone(),
85+
projection,
8586
)))
8687
}
8788
}

crates/integrations/datafusion/tests/integration_datafusion_test.rs

+96-13
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
2020
use std::collections::HashMap;
2121
use std::sync::Arc;
22+
use std::vec;
2223

24+
use datafusion::arrow::array::{Array, StringArray};
2325
use datafusion::arrow::datatypes::DataType;
2426
use datafusion::execution::context::SessionContext;
2527
use iceberg::io::FileIOBuilder;
26-
use iceberg::spec::{NestedField, PrimitiveType, Schema, Type};
28+
use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Type};
2729
use iceberg::{Catalog, NamespaceIdent, Result, TableCreation};
2830
use iceberg_catalog_memory::MemoryCatalog;
2931
use iceberg_datafusion::IcebergCatalogProvider;
@@ -39,6 +41,13 @@ fn get_iceberg_catalog() -> MemoryCatalog {
3941
MemoryCatalog::new(file_io, Some(temp_path()))
4042
}
4143

44+
fn get_struct_type() -> StructType {
45+
StructType::new(vec![
46+
NestedField::required(4, "s_foo1", Type::Primitive(PrimitiveType::Int)).into(),
47+
NestedField::required(5, "s_foo2", Type::Primitive(PrimitiveType::String)).into(),
48+
])
49+
}
50+
4251
async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent) -> Result<()> {
4352
let properties = HashMap::new();
4453

@@ -47,14 +56,21 @@ async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent)
4756
Ok(())
4857
}
4958

50-
fn set_table_creation(location: impl ToString, name: impl ToString) -> Result<TableCreation> {
51-
let schema = Schema::builder()
52-
.with_schema_id(0)
53-
.with_fields(vec![
54-
NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(),
55-
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::String)).into(),
56-
])
57-
.build()?;
59+
fn get_table_creation(
60+
location: impl ToString,
61+
name: impl ToString,
62+
schema: Option<Schema>,
63+
) -> Result<TableCreation> {
64+
let schema = match schema {
65+
None => Schema::builder()
66+
.with_schema_id(0)
67+
.with_fields(vec![
68+
NestedField::required(1, "foo1", Type::Primitive(PrimitiveType::Int)).into(),
69+
NestedField::required(2, "foo2", Type::Primitive(PrimitiveType::String)).into(),
70+
])
71+
.build()?,
72+
Some(schema) => schema,
73+
};
5874

5975
let creation = TableCreation::builder()
6076
.location(location.to_string())
@@ -72,7 +88,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
7288
let namespace = NamespaceIdent::new("test_provider_get_table_schema".to_string());
7389
set_test_namespace(&iceberg_catalog, &namespace).await?;
7490

75-
let creation = set_table_creation(temp_path(), "my_table")?;
91+
let creation = get_table_creation(temp_path(), "my_table", None)?;
7692
iceberg_catalog.create_table(&namespace, creation).await?;
7793

7894
let client = Arc::new(iceberg_catalog);
@@ -87,7 +103,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
87103
let table = schema.table("my_table").await.unwrap().unwrap();
88104
let table_schema = table.schema();
89105

90-
let expected = [("foo", &DataType::Int32), ("bar", &DataType::Utf8)];
106+
let expected = [("foo1", &DataType::Int32), ("foo2", &DataType::Utf8)];
91107

92108
for (field, exp) in table_schema.fields().iter().zip(expected.iter()) {
93109
assert_eq!(field.name(), exp.0);
@@ -104,7 +120,7 @@ async fn test_provider_list_table_names() -> Result<()> {
104120
let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string());
105121
set_test_namespace(&iceberg_catalog, &namespace).await?;
106122

107-
let creation = set_table_creation(temp_path(), "my_table")?;
123+
let creation = get_table_creation(temp_path(), "my_table", None)?;
108124
iceberg_catalog.create_table(&namespace, creation).await?;
109125

110126
let client = Arc::new(iceberg_catalog);
@@ -130,7 +146,6 @@ async fn test_provider_list_schema_names() -> Result<()> {
130146
let namespace = NamespaceIdent::new("test_provider_list_schema_names".to_string());
131147
set_test_namespace(&iceberg_catalog, &namespace).await?;
132148

133-
set_table_creation("test_provider_list_schema_names", "my_table")?;
134149
let client = Arc::new(iceberg_catalog);
135150
let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?);
136151

@@ -147,3 +162,71 @@ async fn test_provider_list_schema_names() -> Result<()> {
147162
.all(|item| result.contains(&item.to_string())));
148163
Ok(())
149164
}
165+
166+
#[tokio::test]
167+
async fn test_table_projection() -> Result<()> {
168+
let iceberg_catalog = get_iceberg_catalog();
169+
let namespace = NamespaceIdent::new("ns".to_string());
170+
set_test_namespace(&iceberg_catalog, &namespace).await?;
171+
172+
let schema = Schema::builder()
173+
.with_schema_id(0)
174+
.with_fields(vec![
175+
NestedField::required(1, "foo1", Type::Primitive(PrimitiveType::Int)).into(),
176+
NestedField::required(2, "foo2", Type::Primitive(PrimitiveType::String)).into(),
177+
NestedField::optional(3, "foo3", Type::Struct(get_struct_type())).into(),
178+
])
179+
.build()?;
180+
let creation = get_table_creation(temp_path(), "t1", Some(schema))?;
181+
iceberg_catalog.create_table(&namespace, creation).await?;
182+
183+
let client = Arc::new(iceberg_catalog);
184+
let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?);
185+
186+
let ctx = SessionContext::new();
187+
ctx.register_catalog("catalog", catalog);
188+
let table_df = ctx.table("catalog.ns.t1").await.unwrap();
189+
190+
let records = table_df
191+
.clone()
192+
.explain(false, false)
193+
.unwrap()
194+
.collect()
195+
.await
196+
.unwrap();
197+
assert_eq!(1, records.len());
198+
let record = &records[0];
199+
// the first column is plan_type, the second column plan string.
200+
let s = record
201+
.column(1)
202+
.as_any()
203+
.downcast_ref::<StringArray>()
204+
.unwrap();
205+
assert_eq!(2, s.len());
206+
// the first row is logical_plan, the second row is physical_plan
207+
assert_eq!(
208+
"IcebergTableScan projection:[foo1,foo2,foo3]",
209+
s.value(1).trim()
210+
);
211+
212+
// datafusion doesn't support query foo3.s_foo1, use foo3 instead
213+
let records = table_df
214+
.select_columns(&["foo1", "foo3"])
215+
.unwrap()
216+
.explain(false, false)
217+
.unwrap()
218+
.collect()
219+
.await
220+
.unwrap();
221+
assert_eq!(1, records.len());
222+
let record = &records[0];
223+
let s = record
224+
.column(1)
225+
.as_any()
226+
.downcast_ref::<StringArray>()
227+
.unwrap();
228+
assert_eq!(2, s.len());
229+
assert_eq!("IcebergTableScan projection:[foo1,foo3]", s.value(1).trim());
230+
231+
Ok(())
232+
}

0 commit comments

Comments
 (0)