19
19
20
20
use std:: collections:: HashMap ;
21
21
use std:: sync:: Arc ;
22
+ use std:: vec;
22
23
24
+ use datafusion:: arrow:: array:: { Array , StringArray } ;
23
25
use datafusion:: arrow:: datatypes:: DataType ;
24
26
use datafusion:: execution:: context:: SessionContext ;
25
27
use iceberg:: io:: FileIOBuilder ;
26
- use iceberg:: spec:: { NestedField , PrimitiveType , Schema , Type } ;
28
+ use iceberg:: spec:: { NestedField , PrimitiveType , Schema , StructType , Type } ;
27
29
use iceberg:: { Catalog , NamespaceIdent , Result , TableCreation } ;
28
30
use iceberg_catalog_memory:: MemoryCatalog ;
29
31
use iceberg_datafusion:: IcebergCatalogProvider ;
@@ -39,6 +41,13 @@ fn get_iceberg_catalog() -> MemoryCatalog {
39
41
MemoryCatalog :: new ( file_io, Some ( temp_path ( ) ) )
40
42
}
41
43
44
+ fn get_struct_type ( ) -> StructType {
45
+ StructType :: new ( vec ! [
46
+ NestedField :: required( 4 , "s_foo1" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
47
+ NestedField :: required( 5 , "s_foo2" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
48
+ ] )
49
+ }
50
+
42
51
async fn set_test_namespace ( catalog : & MemoryCatalog , namespace : & NamespaceIdent ) -> Result < ( ) > {
43
52
let properties = HashMap :: new ( ) ;
44
53
@@ -47,14 +56,21 @@ async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent)
47
56
Ok ( ( ) )
48
57
}
49
58
50
- fn set_table_creation ( location : impl ToString , name : impl ToString ) -> Result < TableCreation > {
51
- let schema = Schema :: builder ( )
52
- . with_schema_id ( 0 )
53
- . with_fields ( vec ! [
54
- NestedField :: required( 1 , "foo" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
55
- NestedField :: required( 2 , "bar" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
56
- ] )
57
- . build ( ) ?;
59
+ fn get_table_creation (
60
+ location : impl ToString ,
61
+ name : impl ToString ,
62
+ schema : Option < Schema > ,
63
+ ) -> Result < TableCreation > {
64
+ let schema = match schema {
65
+ None => Schema :: builder ( )
66
+ . with_schema_id ( 0 )
67
+ . with_fields ( vec ! [
68
+ NestedField :: required( 1 , "foo1" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
69
+ NestedField :: required( 2 , "foo2" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
70
+ ] )
71
+ . build ( ) ?,
72
+ Some ( schema) => schema,
73
+ } ;
58
74
59
75
let creation = TableCreation :: builder ( )
60
76
. location ( location. to_string ( ) )
@@ -72,7 +88,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
72
88
let namespace = NamespaceIdent :: new ( "test_provider_get_table_schema" . to_string ( ) ) ;
73
89
set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
74
90
75
- let creation = set_table_creation ( temp_path ( ) , "my_table" ) ?;
91
+ let creation = get_table_creation ( temp_path ( ) , "my_table" , None ) ?;
76
92
iceberg_catalog. create_table ( & namespace, creation) . await ?;
77
93
78
94
let client = Arc :: new ( iceberg_catalog) ;
@@ -87,7 +103,7 @@ async fn test_provider_get_table_schema() -> Result<()> {
87
103
let table = schema. table ( "my_table" ) . await . unwrap ( ) . unwrap ( ) ;
88
104
let table_schema = table. schema ( ) ;
89
105
90
- let expected = [ ( "foo " , & DataType :: Int32 ) , ( "bar " , & DataType :: Utf8 ) ] ;
106
+ let expected = [ ( "foo1 " , & DataType :: Int32 ) , ( "foo2 " , & DataType :: Utf8 ) ] ;
91
107
92
108
for ( field, exp) in table_schema. fields ( ) . iter ( ) . zip ( expected. iter ( ) ) {
93
109
assert_eq ! ( field. name( ) , exp. 0 ) ;
@@ -104,7 +120,7 @@ async fn test_provider_list_table_names() -> Result<()> {
104
120
let namespace = NamespaceIdent :: new ( "test_provider_list_table_names" . to_string ( ) ) ;
105
121
set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
106
122
107
- let creation = set_table_creation ( temp_path ( ) , "my_table" ) ?;
123
+ let creation = get_table_creation ( temp_path ( ) , "my_table" , None ) ?;
108
124
iceberg_catalog. create_table ( & namespace, creation) . await ?;
109
125
110
126
let client = Arc :: new ( iceberg_catalog) ;
@@ -130,7 +146,6 @@ async fn test_provider_list_schema_names() -> Result<()> {
130
146
let namespace = NamespaceIdent :: new ( "test_provider_list_schema_names" . to_string ( ) ) ;
131
147
set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
132
148
133
- set_table_creation ( "test_provider_list_schema_names" , "my_table" ) ?;
134
149
let client = Arc :: new ( iceberg_catalog) ;
135
150
let catalog = Arc :: new ( IcebergCatalogProvider :: try_new ( client) . await ?) ;
136
151
@@ -147,3 +162,71 @@ async fn test_provider_list_schema_names() -> Result<()> {
147
162
. all( |item| result. contains( & item. to_string( ) ) ) ) ;
148
163
Ok ( ( ) )
149
164
}
165
+
166
+ #[ tokio:: test]
167
+ async fn test_table_projection ( ) -> Result < ( ) > {
168
+ let iceberg_catalog = get_iceberg_catalog ( ) ;
169
+ let namespace = NamespaceIdent :: new ( "ns" . to_string ( ) ) ;
170
+ set_test_namespace ( & iceberg_catalog, & namespace) . await ?;
171
+
172
+ let schema = Schema :: builder ( )
173
+ . with_schema_id ( 0 )
174
+ . with_fields ( vec ! [
175
+ NestedField :: required( 1 , "foo1" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
176
+ NestedField :: required( 2 , "foo2" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
177
+ NestedField :: optional( 3 , "foo3" , Type :: Struct ( get_struct_type( ) ) ) . into( ) ,
178
+ ] )
179
+ . build ( ) ?;
180
+ let creation = get_table_creation ( temp_path ( ) , "t1" , Some ( schema) ) ?;
181
+ iceberg_catalog. create_table ( & namespace, creation) . await ?;
182
+
183
+ let client = Arc :: new ( iceberg_catalog) ;
184
+ let catalog = Arc :: new ( IcebergCatalogProvider :: try_new ( client) . await ?) ;
185
+
186
+ let ctx = SessionContext :: new ( ) ;
187
+ ctx. register_catalog ( "catalog" , catalog) ;
188
+ let table_df = ctx. table ( "catalog.ns.t1" ) . await . unwrap ( ) ;
189
+
190
+ let records = table_df
191
+ . clone ( )
192
+ . explain ( false , false )
193
+ . unwrap ( )
194
+ . collect ( )
195
+ . await
196
+ . unwrap ( ) ;
197
+ assert_eq ! ( 1 , records. len( ) ) ;
198
+ let record = & records[ 0 ] ;
199
+ // the first column is plan_type, the second column plan string.
200
+ let s = record
201
+ . column ( 1 )
202
+ . as_any ( )
203
+ . downcast_ref :: < StringArray > ( )
204
+ . unwrap ( ) ;
205
+ assert_eq ! ( 2 , s. len( ) ) ;
206
+ // the first row is logical_plan, the second row is physical_plan
207
+ assert_eq ! (
208
+ "IcebergTableScan projection:[foo1,foo2,foo3]" ,
209
+ s. value( 1 ) . trim( )
210
+ ) ;
211
+
212
+ // datafusion doesn't support query foo3.s_foo1, use foo3 instead
213
+ let records = table_df
214
+ . select_columns ( & [ "foo1" , "foo3" ] )
215
+ . unwrap ( )
216
+ . explain ( false , false )
217
+ . unwrap ( )
218
+ . collect ( )
219
+ . await
220
+ . unwrap ( ) ;
221
+ assert_eq ! ( 1 , records. len( ) ) ;
222
+ let record = & records[ 0 ] ;
223
+ let s = record
224
+ . column ( 1 )
225
+ . as_any ( )
226
+ . downcast_ref :: < StringArray > ( )
227
+ . unwrap ( ) ;
228
+ assert_eq ! ( 2 , s. len( ) ) ;
229
+ assert_eq ! ( "IcebergTableScan projection:[foo1,foo3]" , s. value( 1 ) . trim( ) ) ;
230
+
231
+ Ok ( ( ) )
232
+ }
0 commit comments