Skip to content

Commit eae9464

Browse files
authored
refactor(python): Expose transform as a submodule for pyiceberg_core (apache#628)
1 parent 8a3de4e commit eae9464

File tree

6 files changed

+102
-96
lines changed

6 files changed

+102
-96
lines changed

bindings/python/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ crate-type = ["cdylib"]
3232

3333
[dependencies]
3434
iceberg = { path = "../../crates/iceberg" }
35-
pyo3 = { version = "0.21.1", features = ["extension-module"] }
36-
arrow = { version = "52.2.0", features = ["pyarrow"] }
35+
pyo3 = { version = "0.21", features = ["extension-module"] }
36+
arrow = { version = "52", features = ["pyarrow"] }

bindings/python/src/error.rs

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use pyo3::exceptions::PyValueError;
19+
use pyo3::PyErr;
20+
21+
/// Convert an iceberg error to a python error
22+
pub fn to_py_err(err: iceberg::Error) -> PyErr {
23+
PyValueError::new_err(err.to_string())
24+
}

bindings/python/src/lib.rs

+3-13
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use iceberg::io::FileIOBuilder;
1918
use pyo3::prelude::*;
20-
use pyo3::wrap_pyfunction;
2119

20+
mod error;
2221
mod transform;
2322

24-
#[pyfunction]
25-
fn hello_world() -> PyResult<String> {
26-
let _ = FileIOBuilder::new_fs_io().build().unwrap();
27-
Ok("Hello, world!".to_string())
28-
}
29-
30-
3123
#[pymodule]
32-
fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
33-
m.add_function(wrap_pyfunction!(hello_world, m)?)?;
34-
35-
m.add_class::<transform::ArrowArrayTransform>()?;
24+
fn pyiceberg_core_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
25+
transform::register_module(py, m)?;
3626
Ok(())
3727
}

bindings/python/src/transform.rs

+55-49
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,55 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::array::{make_array, Array, ArrayData};
19+
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
1820
use iceberg::spec::Transform;
1921
use iceberg::transform::create_transform_function;
22+
use pyo3::prelude::*;
2023

21-
use arrow::{
22-
array::{make_array, Array, ArrayData},
23-
};
24-
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
25-
use pyo3::{exceptions::PyValueError, prelude::*};
24+
use crate::error::to_py_err;
25+
26+
#[pyfunction]
27+
pub fn identity(py: Python, array: PyObject) -> PyResult<PyObject> {
28+
apply(py, array, Transform::Identity)
29+
}
30+
31+
#[pyfunction]
32+
pub fn void(py: Python, array: PyObject) -> PyResult<PyObject> {
33+
apply(py, array, Transform::Void)
34+
}
35+
36+
#[pyfunction]
37+
pub fn year(py: Python, array: PyObject) -> PyResult<PyObject> {
38+
apply(py, array, Transform::Year)
39+
}
40+
41+
#[pyfunction]
42+
pub fn month(py: Python, array: PyObject) -> PyResult<PyObject> {
43+
apply(py, array, Transform::Month)
44+
}
2645

27-
fn to_py_err(err: iceberg::Error) -> PyErr {
28-
PyValueError::new_err(err.to_string())
46+
#[pyfunction]
47+
pub fn day(py: Python, array: PyObject) -> PyResult<PyObject> {
48+
apply(py, array, Transform::Day)
2949
}
3050

31-
#[pyclass]
32-
pub struct ArrowArrayTransform {
51+
#[pyfunction]
52+
pub fn hour(py: Python, array: PyObject) -> PyResult<PyObject> {
53+
apply(py, array, Transform::Hour)
3354
}
3455

35-
fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult<PyObject> {
56+
#[pyfunction]
57+
pub fn bucket(py: Python, array: PyObject, num_buckets: u32) -> PyResult<PyObject> {
58+
apply(py, array, Transform::Bucket(num_buckets))
59+
}
60+
61+
#[pyfunction]
62+
pub fn truncate(py: Python, array: PyObject, width: u32) -> PyResult<PyObject> {
63+
apply(py, array, Transform::Truncate(width))
64+
}
65+
66+
fn apply(py: Python, array: PyObject, transform: Transform) -> PyResult<PyObject> {
3667
// import
3768
let array = ArrayData::from_pyarrow_bound(array.bind(py))?;
3869
let array = make_array(array);
@@ -43,45 +74,20 @@ fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult<PyObject
4374
array.to_pyarrow(py)
4475
}
4576

46-
#[pymethods]
47-
impl ArrowArrayTransform {
48-
#[staticmethod]
49-
pub fn identity(array: PyObject, py: Python) -> PyResult<PyObject> {
50-
apply(array, Transform::Identity, py)
51-
}
52-
53-
#[staticmethod]
54-
pub fn void(array: PyObject, py: Python) -> PyResult<PyObject> {
55-
apply(array, Transform::Void, py)
56-
}
57-
58-
#[staticmethod]
59-
pub fn year(array: PyObject, py: Python) -> PyResult<PyObject> {
60-
apply(array, Transform::Year, py)
61-
}
62-
63-
#[staticmethod]
64-
pub fn month(array: PyObject, py: Python) -> PyResult<PyObject> {
65-
apply(array, Transform::Month, py)
66-
}
67-
68-
#[staticmethod]
69-
pub fn day(array: PyObject, py: Python) -> PyResult<PyObject> {
70-
apply(array, Transform::Day, py)
71-
}
72-
73-
#[staticmethod]
74-
pub fn hour(array: PyObject, py: Python) -> PyResult<PyObject> {
75-
apply(array, Transform::Hour, py)
76-
}
77+
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
78+
let this = PyModule::new_bound(py, "transform")?;
7779

78-
#[staticmethod]
79-
pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult<PyObject> {
80-
apply(array, Transform::Bucket(num_buckets), py)
81-
}
80+
this.add_function(wrap_pyfunction!(identity, &this)?)?;
81+
this.add_function(wrap_pyfunction!(void, &this)?)?;
82+
this.add_function(wrap_pyfunction!(year, &this)?)?;
83+
this.add_function(wrap_pyfunction!(month, &this)?)?;
84+
this.add_function(wrap_pyfunction!(day, &this)?)?;
85+
this.add_function(wrap_pyfunction!(hour, &this)?)?;
86+
this.add_function(wrap_pyfunction!(bucket, &this)?)?;
87+
this.add_function(wrap_pyfunction!(truncate, &this)?)?;
8288

83-
#[staticmethod]
84-
pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult<PyObject> {
85-
apply(array, Transform::Truncate(width), py)
86-
}
89+
m.add_submodule(&this)?;
90+
py.import_bound("sys")?
91+
.getattr("modules")?
92+
.set_item("pyiceberg_core.transform", this)
8793
}

bindings/python/tests/test_basic.py

-22
This file was deleted.

bindings/python/tests/test_transform.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@
1919

2020
import pyarrow as pa
2121
import pytest
22-
from pyiceberg_core import ArrowArrayTransform
22+
from pyiceberg_core import transform
2323

2424

2525
def test_identity_transform():
2626
arr = pa.array([1, 2])
27-
result = ArrowArrayTransform.identity(arr)
27+
result = transform.identity(arr)
2828
assert result == arr
2929

3030

3131
def test_bucket_transform():
3232
arr = pa.array([1, 2])
33-
result = ArrowArrayTransform.bucket(arr, 10)
33+
result = transform.bucket(arr, 10)
3434
expected = pa.array([6, 2], type=pa.int32())
3535
assert result == expected
3636

@@ -41,14 +41,14 @@ def test_bucket_transform_fails_for_list_type_input():
4141
ValueError,
4242
match=r"FeatureUnsupported => Unsupported data type for bucket transform",
4343
):
44-
ArrowArrayTransform.bucket(arr, 10)
44+
transform.bucket(arr, 10)
4545

4646

4747
def test_bucket_chunked_array():
4848
chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
4949
result_chunks = []
5050
for arr in chunked.iterchunks():
51-
result_chunks.append(ArrowArrayTransform.bucket(arr, 10))
51+
result_chunks.append(transform.bucket(arr, 10))
5252

5353
expected = pa.chunked_array(
5454
[pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]
@@ -58,34 +58,42 @@ def test_bucket_chunked_array():
5858

5959
def test_year_transform():
6060
arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)])
61-
result = ArrowArrayTransform.year(arr)
61+
result = transform.year(arr)
6262
expected = pa.array([0, 30], type=pa.int32())
6363
assert result == expected
6464

6565

6666
def test_month_transform():
6767
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
68-
result = ArrowArrayTransform.month(arr)
68+
result = transform.month(arr)
6969
expected = pa.array([0, 30 * 12 + 3], type=pa.int32())
7070
assert result == expected
7171

7272

7373
def test_day_transform():
7474
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
75-
result = ArrowArrayTransform.day(arr)
75+
result = transform.day(arr)
7676
expected = pa.array([0, 11048], type=pa.int32())
7777
assert result == expected
7878

7979

8080
def test_hour_transform():
8181
arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)])
82-
result = ArrowArrayTransform.hour(arr)
82+
result = transform.hour(arr)
8383
expected = pa.array([19, 264420], type=pa.int32())
8484
assert result == expected
8585

8686

8787
def test_truncate_transform():
8888
arr = pa.array(["this is a long string", "hi my name is sung"])
89-
result = ArrowArrayTransform.truncate(arr, 5)
89+
result = transform.truncate(arr, 5)
9090
expected = pa.array(["this ", "hi my"])
9191
assert result == expected
92+
93+
94+
def test_identity_transform_with_direct_import():
95+
from pyiceberg_core.transform import identity
96+
97+
arr = pa.array([1, 2])
98+
result = identity(arr)
99+
assert result == arr

0 commit comments

Comments
 (0)