Skip to content

Commit 1533c43

Browse files
a-agmonAlon Agmon
and
Alon Agmon
authored
feat (datafusion integration): convert datafusion expr filters to Iceberg Predicate (apache#588)
* adding main function and tests * adding tests, removing integration test for now * fixing typos and lints * fixing typing issue * - added support in schmema to convert Date32 to correct arrow type - refactored scan to use new predicate converter as visitor and seperated it to a new mod - added support for simple predicates with column cast expressions - added testing, mostly around date functions * fixing format and lic * reducing number of tests (17 -> 7) * fix formats * fix naming * refactoring to use TreeNodeVisitor * fixing fmt * small refactor * adding swapped op and fixing CR comments --------- Co-authored-by: Alon Agmon <[email protected]>
1 parent e967deb commit 1533c43

File tree

5 files changed

+395
-7
lines changed

5 files changed

+395
-7
lines changed

crates/iceberg/src/arrow/schema.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ use arrow_array::types::{
2424
validate_decimal_precision_and_scale, Decimal128Type, TimestampMicrosecondType,
2525
};
2626
use arrow_array::{
27-
BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array,
28-
PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
27+
BooleanArray, Date32Array, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array,
28+
Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
2929
};
3030
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
3131
use bitvec::macros::internal::funty::Fundamental;
@@ -646,6 +646,9 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send
646646
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
647647
Ok(Box::new(StringArray::new_scalar(value.as_str())))
648648
}
649+
(PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
650+
Ok(Box::new(Date32Array::new_scalar(*value)))
651+
}
649652
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
650653
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
651654
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::collections::VecDeque;
19+
20+
use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
21+
use datafusion::common::Column;
22+
use datafusion::error::DataFusionError;
23+
use datafusion::logical_expr::{Expr, Operator};
24+
use datafusion::scalar::ScalarValue;
25+
use iceberg::expr::{Predicate, Reference};
26+
use iceberg::spec::Datum;
27+
28+
pub struct ExprToPredicateVisitor {
29+
stack: VecDeque<Option<Predicate>>,
30+
}
31+
impl ExprToPredicateVisitor {
32+
/// Create a new predicate conversion visitor.
33+
pub fn new() -> Self {
34+
Self {
35+
stack: VecDeque::new(),
36+
}
37+
}
38+
/// Get the predicate from the stack.
39+
pub fn get_predicate(&self) -> Option<Predicate> {
40+
self.stack
41+
.iter()
42+
.filter_map(|opt| opt.clone())
43+
.reduce(Predicate::and)
44+
}
45+
46+
/// Convert a column expression to an iceberg predicate.
47+
fn convert_column_expr(
48+
&self,
49+
col: &Column,
50+
op: &Operator,
51+
lit: &ScalarValue,
52+
) -> Option<Predicate> {
53+
let reference = Reference::new(col.name.clone());
54+
let datum = scalar_value_to_datum(lit)?;
55+
Some(binary_op_to_predicate(reference, op, datum))
56+
}
57+
58+
/// Convert a compound expression to an iceberg predicate.
59+
///
60+
/// The strategy is to support the following cases:
61+
/// - if its an AND expression then the result will be the valid predicates, whether there are 2 or just 1
62+
/// - if its an OR expression then a predicate will be returned only if there are 2 valid predicates on both sides
63+
fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator) -> Option<Predicate> {
64+
let valid_preds_count = valid_preds.len();
65+
match (op, valid_preds_count) {
66+
(Operator::And, 1) => valid_preds.first().cloned(),
67+
(Operator::And, 2) => Some(Predicate::and(
68+
valid_preds[0].clone(),
69+
valid_preds[1].clone(),
70+
)),
71+
(Operator::Or, 2) => Some(Predicate::or(
72+
valid_preds[0].clone(),
73+
valid_preds[1].clone(),
74+
)),
75+
_ => None,
76+
}
77+
}
78+
}
79+
80+
// Implement TreeNodeVisitor for ExprToPredicateVisitor
81+
impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor {
82+
type Node = Expr;
83+
84+
fn f_down(&mut self, _node: &'n Expr) -> Result<TreeNodeRecursion, DataFusionError> {
85+
Ok(TreeNodeRecursion::Continue)
86+
}
87+
88+
fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion, DataFusionError> {
89+
if let Expr::BinaryExpr(binary) = expr {
90+
match (&*binary.left, &binary.op, &*binary.right) {
91+
// process simple binary expressions, e.g. col > 1
92+
(Expr::Column(col), op, Expr::Literal(lit)) => {
93+
let col_pred = self.convert_column_expr(col, op, lit);
94+
self.stack.push_back(col_pred);
95+
}
96+
// // process reversed binary expressions, e.g. 1 < col
97+
(Expr::Literal(lit), op, Expr::Column(col)) => {
98+
let col_pred = op
99+
.swap()
100+
.and_then(|negated_op| self.convert_column_expr(col, &negated_op, lit));
101+
self.stack.push_back(col_pred);
102+
}
103+
// process compound expressions (involving logical operators. e.g., AND or OR and children)
104+
(_left, op, _right) if op.is_logic_operator() => {
105+
let right_pred = self.stack.pop_back().flatten();
106+
let left_pred = self.stack.pop_back().flatten();
107+
let children: Vec<_> = [left_pred, right_pred].into_iter().flatten().collect();
108+
let compound_pred = self.convert_compound_expr(&children, op);
109+
self.stack.push_back(compound_pred);
110+
}
111+
_ => return Ok(TreeNodeRecursion::Continue),
112+
}
113+
}
114+
Ok(TreeNodeRecursion::Continue)
115+
}
116+
}
117+
118+
const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000;
119+
/// Convert a scalar value to an iceberg datum.
120+
fn scalar_value_to_datum(value: &ScalarValue) -> Option<Datum> {
121+
match value {
122+
ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)),
123+
ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)),
124+
ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)),
125+
ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)),
126+
ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)),
127+
ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)),
128+
ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())),
129+
ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())),
130+
ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)),
131+
ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) as i32)),
132+
_ => None,
133+
}
134+
}
135+
136+
/// convert the data fusion Exp to an iceberg [`Predicate`]
137+
fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate {
138+
match op {
139+
Operator::Eq => reference.equal_to(datum),
140+
Operator::NotEq => reference.not_equal_to(datum),
141+
Operator::Lt => reference.less_than(datum),
142+
Operator::LtEq => reference.less_than_or_equal_to(datum),
143+
Operator::Gt => reference.greater_than(datum),
144+
Operator::GtEq => reference.greater_than_or_equal_to(datum),
145+
_ => Predicate::AlwaysTrue,
146+
}
147+
}
148+
149+
#[cfg(test)]
150+
mod tests {
151+
use std::collections::VecDeque;
152+
153+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
154+
use datafusion::common::tree_node::TreeNode;
155+
use datafusion::common::DFSchema;
156+
use datafusion::prelude::SessionContext;
157+
use iceberg::expr::{Predicate, Reference};
158+
use iceberg::spec::Datum;
159+
160+
use super::ExprToPredicateVisitor;
161+
162+
fn create_test_schema() -> DFSchema {
163+
let arrow_schema = Schema::new(vec![
164+
Field::new("foo", DataType::Int32, false),
165+
Field::new("bar", DataType::Utf8, false),
166+
]);
167+
DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap()
168+
}
169+
170+
#[test]
171+
fn test_predicate_conversion_with_single_condition() {
172+
let sql = "foo > 1";
173+
let df_schema = create_test_schema();
174+
let expr = SessionContext::new()
175+
.parse_sql_expr(sql, &df_schema)
176+
.unwrap();
177+
let mut visitor = ExprToPredicateVisitor::new();
178+
expr.visit(&mut visitor).unwrap();
179+
let predicate = visitor.get_predicate().unwrap();
180+
assert_eq!(
181+
predicate,
182+
Reference::new("foo").greater_than(Datum::long(1))
183+
);
184+
}
185+
#[test]
186+
fn test_predicate_conversion_with_single_unsupported_condition() {
187+
let sql = "foo is null";
188+
let df_schema = create_test_schema();
189+
let expr = SessionContext::new()
190+
.parse_sql_expr(sql, &df_schema)
191+
.unwrap();
192+
let mut visitor = ExprToPredicateVisitor::new();
193+
expr.visit(&mut visitor).unwrap();
194+
let predicate = visitor.get_predicate();
195+
assert_eq!(predicate, None);
196+
}
197+
198+
#[test]
199+
fn test_predicate_conversion_with_single_condition_rev() {
200+
let sql = "1 < foo";
201+
let df_schema = create_test_schema();
202+
let expr = SessionContext::new()
203+
.parse_sql_expr(sql, &df_schema)
204+
.unwrap();
205+
let mut visitor = ExprToPredicateVisitor::new();
206+
expr.visit(&mut visitor).unwrap();
207+
let predicate = visitor.get_predicate().unwrap();
208+
assert_eq!(
209+
predicate,
210+
Reference::new("foo").greater_than(Datum::long(1))
211+
);
212+
}
213+
#[test]
214+
fn test_predicate_conversion_with_and_condition() {
215+
let sql = "foo > 1 and bar = 'test'";
216+
let df_schema = create_test_schema();
217+
let expr = SessionContext::new()
218+
.parse_sql_expr(sql, &df_schema)
219+
.unwrap();
220+
let mut visitor = ExprToPredicateVisitor::new();
221+
expr.visit(&mut visitor).unwrap();
222+
let predicate = visitor.get_predicate().unwrap();
223+
let expected_predicate = Predicate::and(
224+
Reference::new("foo").greater_than(Datum::long(1)),
225+
Reference::new("bar").equal_to(Datum::string("test")),
226+
);
227+
assert_eq!(predicate, expected_predicate);
228+
}
229+
230+
#[test]
231+
fn test_predicate_conversion_with_and_condition_unsupported() {
232+
let sql = "foo > 1 and bar is not null";
233+
let df_schema = create_test_schema();
234+
let expr = SessionContext::new()
235+
.parse_sql_expr(sql, &df_schema)
236+
.unwrap();
237+
let mut visitor = ExprToPredicateVisitor::new();
238+
expr.visit(&mut visitor).unwrap();
239+
let predicate = visitor.get_predicate().unwrap();
240+
let expected_predicate = Reference::new("foo").greater_than(Datum::long(1));
241+
assert_eq!(predicate, expected_predicate);
242+
}
243+
#[test]
244+
fn test_predicate_conversion_with_and_condition_both_unsupported() {
245+
let sql = "foo in (1, 2, 3) and bar is not null";
246+
let df_schema = create_test_schema();
247+
let expr = SessionContext::new()
248+
.parse_sql_expr(sql, &df_schema)
249+
.unwrap();
250+
let mut visitor = ExprToPredicateVisitor::new();
251+
expr.visit(&mut visitor).unwrap();
252+
let predicate = visitor.get_predicate();
253+
let expected_predicate = None;
254+
assert_eq!(predicate, expected_predicate);
255+
}
256+
257+
#[test]
258+
fn test_predicate_conversion_with_or_condition_unsupported() {
259+
let sql = "foo > 1 or bar is not null";
260+
let df_schema = create_test_schema();
261+
let expr = SessionContext::new()
262+
.parse_sql_expr(sql, &df_schema)
263+
.unwrap();
264+
let mut visitor = ExprToPredicateVisitor::new();
265+
expr.visit(&mut visitor).unwrap();
266+
let predicate = visitor.get_predicate();
267+
let expected_predicate = None;
268+
assert_eq!(predicate, expected_predicate);
269+
}
270+
271+
#[test]
272+
fn test_predicate_conversion_with_complex_binary_expr() {
273+
let sql = "(foo > 1 and bar = 'test') or foo < 0 ";
274+
let df_schema = create_test_schema();
275+
let expr = SessionContext::new()
276+
.parse_sql_expr(sql, &df_schema)
277+
.unwrap();
278+
let mut visitor = ExprToPredicateVisitor::new();
279+
expr.visit(&mut visitor).unwrap();
280+
let predicate = visitor.get_predicate().unwrap();
281+
let inner_predicate = Predicate::and(
282+
Reference::new("foo").greater_than(Datum::long(1)),
283+
Reference::new("bar").equal_to(Datum::string("test")),
284+
);
285+
let expected_predicate = Predicate::or(
286+
inner_predicate,
287+
Reference::new("foo").less_than(Datum::long(0)),
288+
);
289+
assert_eq!(predicate, expected_predicate);
290+
}
291+
292+
#[test]
293+
fn test_predicate_conversion_with_complex_binary_expr_unsupported() {
294+
let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 ";
295+
let df_schema = create_test_schema();
296+
let expr = SessionContext::new()
297+
.parse_sql_expr(sql, &df_schema)
298+
.unwrap();
299+
let mut visitor = ExprToPredicateVisitor::new();
300+
expr.visit(&mut visitor).unwrap();
301+
let predicate = visitor.get_predicate().unwrap();
302+
let expected_predicate = Reference::new("foo").less_than(Datum::long(0));
303+
assert_eq!(predicate, expected_predicate);
304+
}
305+
306+
#[test]
307+
// test the get result method
308+
fn test_get_result_multiple() {
309+
let predicates = vec![
310+
Some(Reference::new("foo").greater_than(Datum::long(1))),
311+
None,
312+
Some(Reference::new("bar").equal_to(Datum::string("test"))),
313+
];
314+
let stack = VecDeque::from(predicates);
315+
let visitor = ExprToPredicateVisitor { stack };
316+
assert_eq!(
317+
visitor.get_predicate(),
318+
Some(Predicate::and(
319+
Reference::new("foo").greater_than(Datum::long(1)),
320+
Reference::new("bar").equal_to(Datum::string("test")),
321+
))
322+
);
323+
}
324+
325+
#[test]
326+
fn test_get_result_single() {
327+
let predicates = vec![Some(Reference::new("foo").greater_than(Datum::long(1)))];
328+
let stack = VecDeque::from(predicates);
329+
let visitor = ExprToPredicateVisitor { stack };
330+
assert_eq!(
331+
visitor.get_predicate(),
332+
Some(Reference::new("foo").greater_than(Datum::long(1)))
333+
);
334+
}
335+
}

crates/integrations/datafusion/src/physical_plan/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub(crate) mod expr_to_predicate;
1819
pub(crate) mod scan;

0 commit comments

Comments
 (0)