Skip to content

Commit ecbb4c3

Browse files
authored
Expose Transforms to Python Binding (apache#556)
* bucket transform rust binding * format * poetry x maturin * ignore poetry.lock in license check * update bindings_python_ci to use makefile * newline * python-poetry/poetry#9135 * use hatch instead of poetry * refactor * revert licenserc change * adopt review feedback * comments * unused dependency * adopt review comment * newline * I like this approach a lot better * more tests
1 parent 905ebd2 commit ecbb4c3

File tree

7 files changed

+190
-2
lines changed

7 files changed

+190
-2
lines changed

.github/workflows/bindings_python_ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ jobs:
8080
set -e
8181
pip install hatch==1.12.0
8282
hatch run dev:pip install dist/pyiceberg_core-*.whl --force-reinstall
83-
hatch run dev:test
83+
hatch run dev:test

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@ dist/*
2424
**/venv
2525
*.so
2626
*.pyc
27+
*.whl
28+
*.tar.gz

bindings/python/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ crate-type = ["cdylib"]
3232

3333
[dependencies]
3434
iceberg = { path = "../../crates/iceberg" }
35-
pyo3 = { version = "0.22", features = ["extension-module"] }
35+
pyo3 = { version = "0.21.1", features = ["extension-module"] }
36+
arrow = { version = "52.2.0", features = ["pyarrow"] }

bindings/python/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ ignore = ["F403", "F405"]
4343
dependencies = [
4444
"maturin>=1.0,<2.0",
4545
"pytest>=8.3.2",
46+
"pyarrow>=17.0.0",
4647
]
4748

4849
[tool.hatch.envs.dev.scripts]

bindings/python/src/lib.rs

+6
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,21 @@
1717

1818
use iceberg::io::FileIOBuilder;
1919
use pyo3::prelude::*;
20+
use pyo3::wrap_pyfunction;
21+
22+
mod transform;
2023

2124
#[pyfunction]
2225
fn hello_world() -> PyResult<String> {
2326
let _ = FileIOBuilder::new_fs_io().build().unwrap();
2427
Ok("Hello, world!".to_string())
2528
}
2629

30+
2731
#[pymodule]
2832
fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
2933
m.add_function(wrap_pyfunction!(hello_world, m)?)?;
34+
35+
m.add_class::<transform::ArrowArrayTransform>()?;
3036
Ok(())
3137
}

bindings/python/src/transform.rs

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use iceberg::spec::Transform;
19+
use iceberg::transform::create_transform_function;
20+
21+
use arrow::{
22+
array::{make_array, Array, ArrayData},
23+
};
24+
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
25+
use pyo3::{exceptions::PyValueError, prelude::*};
26+
27+
fn to_py_err(err: iceberg::Error) -> PyErr {
28+
PyValueError::new_err(err.to_string())
29+
}
30+
31+
#[pyclass]
32+
pub struct ArrowArrayTransform {
33+
}
34+
35+
fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult<PyObject> {
36+
// import
37+
let array = ArrayData::from_pyarrow_bound(array.bind(py))?;
38+
let array = make_array(array);
39+
let transform_function = create_transform_function(&transform).map_err(to_py_err)?;
40+
let array = transform_function.transform(array).map_err(to_py_err)?;
41+
// export
42+
let array = array.into_data();
43+
array.to_pyarrow(py)
44+
}
45+
46+
#[pymethods]
47+
impl ArrowArrayTransform {
48+
#[staticmethod]
49+
pub fn identity(array: PyObject, py: Python) -> PyResult<PyObject> {
50+
apply(array, Transform::Identity, py)
51+
}
52+
53+
#[staticmethod]
54+
pub fn void(array: PyObject, py: Python) -> PyResult<PyObject> {
55+
apply(array, Transform::Void, py)
56+
}
57+
58+
#[staticmethod]
59+
pub fn year(array: PyObject, py: Python) -> PyResult<PyObject> {
60+
apply(array, Transform::Year, py)
61+
}
62+
63+
#[staticmethod]
64+
pub fn month(array: PyObject, py: Python) -> PyResult<PyObject> {
65+
apply(array, Transform::Month, py)
66+
}
67+
68+
#[staticmethod]
69+
pub fn day(array: PyObject, py: Python) -> PyResult<PyObject> {
70+
apply(array, Transform::Day, py)
71+
}
72+
73+
#[staticmethod]
74+
pub fn hour(array: PyObject, py: Python) -> PyResult<PyObject> {
75+
apply(array, Transform::Hour, py)
76+
}
77+
78+
#[staticmethod]
79+
pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult<PyObject> {
80+
apply(array, Transform::Bucket(num_buckets), py)
81+
}
82+
83+
#[staticmethod]
84+
pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult<PyObject> {
85+
apply(array, Transform::Truncate(width), py)
86+
}
87+
}
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from datetime import date, datetime
19+
20+
import pyarrow as pa
21+
import pytest
22+
from pyiceberg_core import ArrowArrayTransform
23+
24+
25+
def test_identity_transform():
26+
arr = pa.array([1, 2])
27+
result = ArrowArrayTransform.identity(arr)
28+
assert result == arr
29+
30+
31+
def test_bucket_transform():
32+
arr = pa.array([1, 2])
33+
result = ArrowArrayTransform.bucket(arr, 10)
34+
expected = pa.array([6, 2], type=pa.int32())
35+
assert result == expected
36+
37+
38+
def test_bucket_transform_fails_for_list_type_input():
39+
arr = pa.array([[1, 2], [3, 4]])
40+
with pytest.raises(
41+
ValueError,
42+
match=r"FeatureUnsupported => Unsupported data type for bucket transform",
43+
):
44+
ArrowArrayTransform.bucket(arr, 10)
45+
46+
47+
def test_bucket_chunked_array():
48+
chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
49+
result_chunks = []
50+
for arr in chunked.iterchunks():
51+
result_chunks.append(ArrowArrayTransform.bucket(arr, 10))
52+
53+
expected = pa.chunked_array(
54+
[pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]
55+
)
56+
assert pa.chunked_array(result_chunks).equals(expected)
57+
58+
59+
def test_year_transform():
60+
arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)])
61+
result = ArrowArrayTransform.year(arr)
62+
expected = pa.array([0, 30], type=pa.int32())
63+
assert result == expected
64+
65+
66+
def test_month_transform():
67+
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
68+
result = ArrowArrayTransform.month(arr)
69+
expected = pa.array([0, 30 * 12 + 3], type=pa.int32())
70+
assert result == expected
71+
72+
73+
def test_day_transform():
74+
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
75+
result = ArrowArrayTransform.day(arr)
76+
expected = pa.array([0, 11048], type=pa.int32())
77+
assert result == expected
78+
79+
80+
def test_hour_transform():
81+
arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)])
82+
result = ArrowArrayTransform.hour(arr)
83+
expected = pa.array([19, 264420], type=pa.int32())
84+
assert result == expected
85+
86+
87+
def test_truncate_transform():
88+
arr = pa.array(["this is a long string", "hi my name is sung"])
89+
result = ArrowArrayTransform.truncate(arr, 5)
90+
expected = pa.array(["this ", "hi my"])
91+
assert result == expected

0 commit comments

Comments
 (0)