1717from  contextlib  import  ExitStack 
1818from  io  import  BytesIO 
1919from  pathlib  import  Path 
20- from  typing  import  TYPE_CHECKING , Any , Final , cast 
20+ from  typing  import  TYPE_CHECKING , Any , Callable ,  Final , cast 
2121
2222import  numpy  as  np 
2323import  pandas  as  pd 
@@ -138,96 +138,110 @@ def open_example_mfdataset(names, *args, **kwargs) -> Dataset:
138138    )
139139
140140
141- def  create_masked_and_scaled_data () ->  Dataset :
142-     x  =  np .array ([np .nan , np .nan , 10 , 10.1 , 10.2 ], dtype = np . float32 )
141+ def  create_masked_and_scaled_data (dtype :  type [ np . number ]  =   np . float32 ) ->  Dataset :
142+     x  =  np .array ([np .nan , np .nan , 10 , 10.1 , 10.2 ], dtype = dtype )
143143    encoding  =  {
144144        "_FillValue" : - 1 ,
145-         "add_offset" : 10 ,
146-         "scale_factor" : np . float32 (0.1 ),
145+         "add_offset" : dtype ( 10 ) ,
146+         "scale_factor" : dtype (0.1 ),
147147        "dtype" : "i2" ,
148148    }
149149    return  Dataset ({"x" : ("t" , x , {}, encoding )})
150150
151151
152- def  create_encoded_masked_and_scaled_data () ->  Dataset :
153-     attributes  =  {"_FillValue" : - 1 , "add_offset" : 10 , "scale_factor" : np .float32 (0.1 )}
152+ def  create_encoded_masked_and_scaled_data (
153+     dtype : type [np .number ] =  np .float32 ,
154+ ) ->  Dataset :
155+     attributes  =  {"_FillValue" : - 1 , "add_offset" : dtype (10 ), "scale_factor" : dtype (0.1 )}
154156    return  Dataset (
155157        {"x" : ("t" , np .array ([- 1 , - 1 , 0 , 1 , 2 ], dtype = np .int16 ), attributes )}
156158    )
157159
158160
159- def  create_unsigned_masked_scaled_data () ->  Dataset :
161+ def  create_unsigned_masked_scaled_data (
162+     dtype : type [np .number ] =  np .float32 ,
163+ ) ->  Dataset :
160164    encoding  =  {
161165        "_FillValue" : 255 ,
162166        "_Unsigned" : "true" ,
163167        "dtype" : "i1" ,
164-         "add_offset" : 10 ,
165-         "scale_factor" : np . float32 (0.1 ),
168+         "add_offset" : dtype ( 10 ) ,
169+         "scale_factor" : dtype (0.1 ),
166170    }
167-     x  =  np .array ([10.0 , 10.1 , 22.7 , 22.8 , np .nan ], dtype = np . float32 )
171+     x  =  np .array ([10.0 , 10.1 , 22.7 , 22.8 , np .nan ], dtype = dtype )
168172    return  Dataset ({"x" : ("t" , x , {}, encoding )})
169173
170174
171- def  create_encoded_unsigned_masked_scaled_data () ->  Dataset :
175+ def  create_encoded_unsigned_masked_scaled_data (
176+     dtype : type [np .number ] =  np .float32 ,
177+ ) ->  Dataset :
172178    # These are values as written to the file: the _FillValue will 
173179    # be represented in the signed form. 
174180    attributes  =  {
175181        "_FillValue" : - 1 ,
176182        "_Unsigned" : "true" ,
177-         "add_offset" : 10 ,
178-         "scale_factor" : np . float32 (0.1 ),
183+         "add_offset" : dtype ( 10 ) ,
184+         "scale_factor" : dtype (0.1 ),
179185    }
180186    # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned 
181187    sb  =  np .asarray ([0 , 1 , 127 , - 128 , - 1 ], dtype = "i1" )
182188    return  Dataset ({"x" : ("t" , sb , attributes )})
183189
184190
185- def  create_bad_unsigned_masked_scaled_data () ->  Dataset :
191+ def  create_bad_unsigned_masked_scaled_data (
192+     dtype : type [np .number ] =  np .float32 ,
193+ ) ->  Dataset :
186194    encoding  =  {
187195        "_FillValue" : 255 ,
188196        "_Unsigned" : True ,
189197        "dtype" : "i1" ,
190-         "add_offset" : 10 ,
191-         "scale_factor" : np . float32 (0.1 ),
198+         "add_offset" : dtype ( 0 ) ,
199+         "scale_factor" : dtype (0.1 ),
192200    }
193-     x  =  np .array ([10.0 , 10.1 , 22.7 , 22.8 , np .nan ], dtype = np . float32 )
201+     x  =  np .array ([10.0 , 10.1 , 22.7 , 22.8 , np .nan ], dtype = dtype )
194202    return  Dataset ({"x" : ("t" , x , {}, encoding )})
195203
196204
197- def  create_bad_encoded_unsigned_masked_scaled_data () ->  Dataset :
205+ def  create_bad_encoded_unsigned_masked_scaled_data (
206+     dtype : type [np .number ] =  np .float32 ,
207+ ) ->  Dataset :
198208    # These are values as written to the file: the _FillValue will 
199209    # be represented in the signed form. 
200210    attributes  =  {
201211        "_FillValue" : - 1 ,
202212        "_Unsigned" : True ,
203-         "add_offset" : 10 ,
204-         "scale_factor" : np . float32 (0.1 ),
213+         "add_offset" : dtype ( 10 ) ,
214+         "scale_factor" : dtype (0.1 ),
205215    }
206216    # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned 
207217    sb  =  np .asarray ([0 , 1 , 127 , - 128 , - 1 ], dtype = "i1" )
208218    return  Dataset ({"x" : ("t" , sb , attributes )})
209219
210220
211- def  create_signed_masked_scaled_data () ->  Dataset :
221+ def  create_signed_masked_scaled_data (
222+     dtype : type [np .number ] =  np .float32 ,
223+ ) ->  Dataset :
212224    encoding  =  {
213225        "_FillValue" : - 127 ,
214226        "_Unsigned" : "false" ,
215227        "dtype" : "i1" ,
216-         "add_offset" : 10 ,
217-         "scale_factor" : np . float32 (0.1 ),
228+         "add_offset" : dtype ( 10 ) ,
229+         "scale_factor" : dtype (0.1 ),
218230    }
219-     x  =  np .array ([- 1.0 , 10.1 , 22.7 , np .nan ], dtype = np . float32 )
231+     x  =  np .array ([- 1.0 , 10.1 , 22.7 , np .nan ], dtype = dtype )
220232    return  Dataset ({"x" : ("t" , x , {}, encoding )})
221233
222234
223- def  create_encoded_signed_masked_scaled_data () ->  Dataset :
235+ def  create_encoded_signed_masked_scaled_data (
236+     dtype : type [np .number ] =  np .float32 ,
237+ ) ->  Dataset :
224238    # These are values as written to the file: the _FillValue will 
225239    # be represented in the signed form. 
226240    attributes  =  {
227241        "_FillValue" : - 127 ,
228242        "_Unsigned" : "false" ,
229-         "add_offset" : 10 ,
230-         "scale_factor" : np . float32 (0.1 ),
243+         "add_offset" : dtype ( 10 ) ,
244+         "scale_factor" : dtype (0.1 ),
231245    }
232246    # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned 
233247    sb  =  np .asarray ([- 110 , 1 , 127 , - 127 ], dtype = "i1" )
@@ -859,6 +873,8 @@ def test_roundtrip_string_with_fill_value_nchar(self) -> None:
859873            with  self .roundtrip (original ) as  actual :
860874                assert_identical (expected , actual )
861875
876+     # Todo: (kmuehlbauer) make this work np.float64 
877+     @pytest .mark .parametrize ("dtype" , [np .float32 ]) 
862878    @pytest .mark .parametrize ( 
863879        "decoded_fn, encoded_fn" , 
864880        [ 
@@ -878,9 +894,20 @@ def test_roundtrip_string_with_fill_value_nchar(self) -> None:
878894            (create_masked_and_scaled_data , create_encoded_masked_and_scaled_data ), 
879895        ], 
880896    ) 
881-     def  test_roundtrip_mask_and_scale (self , decoded_fn , encoded_fn ) ->  None :
882-         decoded  =  decoded_fn ()
883-         encoded  =  encoded_fn ()
897+     def  test_roundtrip_mask_and_scale (
898+         self ,
899+         decoded_fn : Callable [[type [np .number ]], Dataset ],
900+         encoded_fn : Callable [[type [np .number ]], Dataset ],
901+         dtype : type [np .number ],
902+     ) ->  None :
903+         if  dtype  ==  np .float32  and  isinstance (
904+             self , (TestZarrDirectoryStore , TestZarrDictStore )
905+         ):
906+             pytest .skip (
907+                 "zarr attributes (eg. `scale_factor` are unconditionally promoted to `float64`" 
908+             )
909+         decoded  =  decoded_fn (dtype )
910+         encoded  =  encoded_fn (dtype )
884911
885912        with  self .roundtrip (decoded ) as  actual :
886913            for  k  in  decoded .variables :
@@ -901,7 +928,7 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn) -> None:
901928
902929        # make sure roundtrip encoding didn't change the 
903930        # original dataset. 
904-         assert_allclose (encoded , encoded_fn (), decode_bytes = False )
931+         assert_allclose (encoded , encoded_fn (dtype ), decode_bytes = False )
905932
906933        with  self .roundtrip (encoded ) as  actual :
907934            for  k  in  decoded .variables :
0 commit comments