@@ -60,6 +60,7 @@ def _reproject_dispatcher(
6060 shape_out ,
6161 wcs_out ,
6262 block_size = None ,
63+ non_reprojected_dims = None ,
6364 array_out = None ,
6465 return_footprint = True ,
6566 output_footprint = None ,
@@ -92,6 +93,11 @@ def _reproject_dispatcher(
9293 the block size automatically determined. If ``block_size`` is not
9394 specified or set to `None`, the reprojection will not be carried out in
9495 blocks.
96+ non_reprojected_dims : tuple
97+ Dimensions that should not be reprojected but instead for which a
98+ 1-to-1 mapping between input and output pixel space should be assumed.
99+ By default, this is any leading extra dimensions if the input WCS has
100+ fewer dimensions than the input data.
95101 array_out : `~numpy.ndarray`, optional
96102 An array in which to store the reprojected data. This can be any numpy
97103 array including a memory map, which may be helpful when dealing with
@@ -198,9 +204,32 @@ def _reproject_dispatcher(
198204 # shape_out will be the full size of the output array as this is updated
199205 # in parse_output_projection, even if shape_out was originally passed in as
200206 # the shape of a single image.
201- broadcasting = wcs_in .low_level_wcs .pixel_n_dim < len (shape_out )
207+ if non_reprojected_dims is None :
208+ non_reprojected_dims = list (range (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim ))
209+ else :
210+ non_reprojected_dims = list (non_reprojected_dims )
211+
212+ broadcasting = len (non_reprojected_dims ) > 0
213+
214+ reprojected_dims = [x for x in range (len (shape_out )) if x not in non_reprojected_dims ]
202215
203216 logger .info (f"Broadcasting is { '' if broadcasting else 'not ' } being used" )
217+ logger .info (f"Dimensions being reprojected: { reprojected_dims } " )
218+ logger .info (f"Dimensions not being reprojected: { non_reprojected_dims } " )
219+
220+ if len (block_size ) < len (shape_out ):
221+ block_size = [- 1 ] * (len (shape_out ) - len (block_size )) + list (block_size )
222+ elif len (block_size ) > len (shape_out ):
223+ raise ValueError (
224+ f"block_size { len (block_size )} cannot have more elements "
225+ f"than the dimensionality of the output ({ len (shape_out )} )"
226+ )
227+
228+ block_size = np .array (block_size )
229+ shape_out = np .array (shape_out )
230+
231+ # TODO: replace block size of -1 by actual value for logic below to work
232+ # TODO: re-implement block_size auto
204233
205234 # Check block size and determine whether block size indicates we should
206235 # parallelize over broadcasted dimension. The logic is as follows: if
@@ -212,33 +241,23 @@ def _reproject_dispatcher(
212241 # don't make any assumptions for now and assume a single chunk in the
213242 # missing dimensions.
214243 broadcasted_parallelization = False
215- if broadcasting and block_size is not None and block_size != "auto" :
216- if len (block_size ) == len (shape_out ):
217- if (
218- block_size [- wcs_in .low_level_wcs .pixel_n_dim :]
219- == shape_out [- wcs_in .low_level_wcs .pixel_n_dim :]
220- ):
221- broadcasted_parallelization = True
222- block_size = (
223- block_size [: - wcs_in .low_level_wcs .pixel_n_dim ]
224- + (- 1 ,) * wcs_in .low_level_wcs .pixel_n_dim
225- )
226- else :
227- for i in range (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim ):
228- if block_size [i ] != - 1 and block_size [i ] != shape_out [i ]:
229- raise ValueError (
230- "block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
231- )
232- elif len (block_size ) < len (shape_out ):
233- block_size = [- 1 ] * (len (shape_out ) - len (block_size )) + list (block_size )
234- else :
235- raise ValueError (
236- f"block_size { len (block_size )} cannot have more elements "
237- f"than the dimensionality of the output ({ len (shape_out )} )"
244+ if broadcasting and block_size is not None :
245+ if np .all (block_size [reprojected_dims ] == shape_out [reprojected_dims ]):
246+ broadcasted_parallelization = True
247+ block_size = np .array (
248+ tuple (block_size [non_reprojected_dims ].tolist ())
249+ + (- 1 ,) * len (reprojected_dims )
238250 )
251+ elif np .all (block_size [non_reprojected_dims ] != shape_out [non_reprojected_dims ]):
252+ raise ValueError (
253+ "block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
254+ )
239255
240256 # TODO: check for shape_out not matching shape_in along broadcasted dimensions
241257
258+ block_size = tuple (block_size .tolist ())
259+ shape_out = tuple (shape_out .tolist ())
260+
242261 logger .info (
243262 f"{ 'P' if broadcasted_parallelization else 'Not p' } arallelizing along "
244263 f"broadcasted dimension ({ block_size = } , { shape_out = } )"
@@ -270,17 +289,38 @@ def reproject_single_block(a, array_or_path, block_info=None):
270289 wcs_in_cp = wcs_in .deepcopy () if isinstance (wcs_in , WCS ) else wcs_in
271290 wcs_out_cp = wcs_out .deepcopy () if isinstance (wcs_out , WCS ) else wcs_out
272291
273- slices = [
274- slice (* x ) for x in block_info [None ]["array-location" ][- wcs_out_cp .pixel_n_dim :]
275- ]
292+ print (block_info [None ]["array-location" ])
276293
277- if isinstance (wcs_out , BaseHighLevelWCS ):
294+ slices = []
295+ for i in reprojected_dims :
296+ slices .append (slice (* block_info [None ]["array-location" ][i ]))
297+
298+ print (slices )
299+
300+ if isinstance (wcs_out_cp , BaseHighLevelWCS ):
278301 low_level_wcs = SlicedLowLevelWCS (wcs_out_cp .low_level_wcs , slices = slices )
279302 else :
280303 low_level_wcs = SlicedLowLevelWCS (wcs_out_cp , slices = slices )
281304
305+ print (low_level_wcs .pixel_n_dim , low_level_wcs .world_n_dim )
306+
282307 wcs_out_sub = HighLevelWCSWrapper (low_level_wcs )
283308
309+ slices = []
310+ for i in range (wcs_in_cp .pixel_n_dim ):
311+ if i in non_reprojected_dims :
312+ # slices.append(slice(*block_info[None]["array-location"][i]))
313+ slices .append (block_info [None ]["array-location" ][i ][0 ])
314+ else :
315+ slices .append (slice (None ))
316+
317+ if isinstance (wcs_in_cp , BaseHighLevelWCS ):
318+ low_level_wcs_in = SlicedLowLevelWCS (wcs_in_cp .low_level_wcs , slices = slices )
319+ else :
320+ low_level_wcs_in = SlicedLowLevelWCS (wcs_in_cp , slices = slices )
321+
322+ wcs_in_sub = HighLevelWCSWrapper (low_level_wcs_in )
323+
284324 if isinstance (array_or_path , tuple ):
285325 array_in = np .memmap (array_or_path [0 ], ** array_or_path [1 ], mode = "r" )
286326 elif isinstance (array_or_path , str ):
@@ -295,7 +335,7 @@ def reproject_single_block(a, array_or_path, block_info=None):
295335
296336 array , footprint = reproject_func (
297337 array_in ,
298- wcs_in_cp ,
338+ wcs_in_sub ,
299339 wcs_out_sub ,
300340 shape_out = shape_out ,
301341 array_out = np .zeros (shape_out ),
@@ -308,10 +348,11 @@ def reproject_single_block(a, array_or_path, block_info=None):
308348
309349 array_out_dask = da .empty (shape_out , chunks = block_size )
310350 if isinstance (array_in , da .core .Array ):
351+ # FIXME: Should take into account -1s here
311352 if array_in .chunksize != block_size :
312353 logger .info (
313354 f"Rechunking input dask array as chunks ({ array_in .chunksize } ) "
314- "do not match block size ({block_size})"
355+ f "do not match block size ({ block_size } )"
315356 )
316357 array_in = array_in .rechunk (block_size )
317358 else :
0 commit comments