@@ -2284,4 +2284,111 @@ mod tests {
22842284
22852285 Ok ( ( ) )
22862286 }
2287+
2288+ #[ tokio:: test]
2289+ async fn test_schema_preserved_with_replace_where ( ) -> TestResult {
2290+ // Test that schema is preserved when using overwrite with predicate (replaceWhere)
2291+ use arrow_array:: { BooleanArray , Int32Array , Int64Array , RecordBatch , StringArray } ;
2292+ use arrow_schema:: { DataType , Field , Schema as ArrowSchema } ;
2293+ use std:: sync:: Arc ;
2294+
2295+ // Create initial table with mixed nullability
2296+ let initial_schema = Arc :: new ( ArrowSchema :: new ( vec ! [
2297+ Field :: new( "id" , DataType :: Int64 , false ) , // non-nullable
2298+ Field :: new( "name" , DataType :: Utf8 , true ) , // nullable
2299+ Field :: new( "active" , DataType :: Boolean , false ) , // non-nullable
2300+ Field :: new( "count" , DataType :: Int32 , false ) , // non-nullable
2301+ ] ) ) ;
2302+
2303+ let initial_batch = RecordBatch :: try_new (
2304+ initial_schema. clone ( ) ,
2305+ vec ! [
2306+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 , 4 , 5 ] ) ) ,
2307+ Arc :: new( StringArray :: from( vec![
2308+ Some ( "Alice" ) ,
2309+ Some ( "Bob" ) ,
2310+ None ,
2311+ Some ( "David" ) ,
2312+ Some ( "Eve" ) ,
2313+ ] ) ) ,
2314+ Arc :: new( BooleanArray :: from( vec![ true , false , true , false , true ] ) ) ,
2315+ Arc :: new( Int32Array :: from( vec![ 10 , 20 , 30 , 40 , 50 ] ) ) ,
2316+ ] ,
2317+ ) ?;
2318+
2319+ let table = DeltaOps :: new_in_memory ( )
2320+ . write ( vec ! [ initial_batch] )
2321+ . with_save_mode ( SaveMode :: Overwrite )
2322+ . await ?;
2323+
2324+ // Capture initial schema
2325+ let initial_fields: Vec < _ > = table
2326+ . snapshot ( )
2327+ . unwrap ( )
2328+ . schema ( )
2329+ . fields ( )
2330+ . cloned ( )
2331+ . collect ( ) ;
2332+
2333+ // Create new data with all nullable fields (typical from Pandas)
2334+ let new_schema = Arc :: new ( ArrowSchema :: new ( vec ! [
2335+ Field :: new( "id" , DataType :: Int64 , true ) , // nullable in new data
2336+ Field :: new( "name" , DataType :: Utf8 , true ) , // nullable
2337+ Field :: new( "active" , DataType :: Boolean , true ) , // nullable
2338+ Field :: new( "count" , DataType :: Int32 , true ) , // nullable
2339+ ] ) ) ;
2340+
2341+ let replacement_batch = RecordBatch :: try_new (
2342+ new_schema. clone ( ) ,
2343+ vec ! [
2344+ Arc :: new( Int64Array :: from( vec![ Some ( 2 ) , Some ( 4 ) ] ) ) , // Replace ids 2 and 4
2345+ Arc :: new( StringArray :: from( vec![ Some ( "Bob2" ) , Some ( "David2" ) ] ) ) ,
2346+ Arc :: new( BooleanArray :: from( vec![ Some ( true ) , Some ( true ) ] ) ) ,
2347+ Arc :: new( Int32Array :: from( vec![ Some ( 200 ) , Some ( 400 ) ] ) ) ,
2348+ ] ,
2349+ ) ?;
2350+
2351+ // Use replaceWhere to selectively overwrite
2352+ let table = DeltaOps ( table)
2353+ . write ( vec ! [ replacement_batch] )
2354+ . with_save_mode ( SaveMode :: Overwrite )
2355+ . with_replace_where ( "id = 2 OR id = 4" )
2356+ . await ?;
2357+
2358+ // Verify schema is preserved
2359+ let final_fields: Vec < _ > = table. snapshot ( ) . unwrap ( ) . schema ( ) . fields ( ) . collect ( ) ;
2360+
2361+ for ( i, field) in final_fields. iter ( ) . enumerate ( ) {
2362+ assert_eq ! (
2363+ field. is_nullable( ) ,
2364+ initial_fields[ i] . is_nullable( ) ,
2365+ "Field '{}' nullability should be preserved with replaceWhere" ,
2366+ field. name( )
2367+ ) ;
2368+ }
2369+
2370+ // Now test that constraints are still enforced with replaceWhere
2371+ let invalid_batch = RecordBatch :: try_new (
2372+ new_schema,
2373+ vec ! [
2374+ Arc :: new( Int64Array :: from( vec![ None , Some ( 3 ) ] ) ) , // NULL in non-nullable id!
2375+ Arc :: new( StringArray :: from( vec![ Some ( "Invalid" ) , Some ( "Valid" ) ] ) ) ,
2376+ Arc :: new( BooleanArray :: from( vec![ Some ( false ) , Some ( false ) ] ) ) ,
2377+ Arc :: new( Int32Array :: from( vec![ Some ( 999 ) , Some ( 333 ) ] ) ) ,
2378+ ] ,
2379+ ) ?;
2380+
2381+ let result = DeltaOps ( table)
2382+ . write ( vec ! [ invalid_batch] )
2383+ . with_save_mode ( SaveMode :: Overwrite )
2384+ . with_replace_where ( "id = 1 OR id = 3" )
2385+ . await ;
2386+
2387+ assert ! (
2388+ result. is_err( ) ,
2389+ "replaceWhere should still enforce non-nullable constraints"
2390+ ) ;
2391+
2392+ Ok ( ( ) )
2393+ }
22872394}
0 commit comments