Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 101 additions & 45 deletions reproject/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _reproject_dispatcher(
shape_out,
wcs_out,
block_size=None,
non_reprojected_dims=None,
array_out=None,
return_footprint=True,
output_footprint=None,
Expand Down Expand Up @@ -93,6 +94,11 @@ def _reproject_dispatcher(
the block size automatically determined. If ``block_size`` is not
specified or set to `None`, the reprojection will not be carried out in
blocks.
non_reprojected_dims : tuple
Dimensions that should not be reprojected but instead for which a
1-to-1 mapping between input and output pixel space should be assumed.
By default, this is any leading extra dimensions if the input WCS has
fewer dimensions than the input data.
array_out : `~numpy.ndarray`, optional
An array in which to store the reprojected data. This can be any numpy
array including a memory map, which may be helpful when dealing with
Expand Down Expand Up @@ -143,6 +149,19 @@ def _reproject_dispatcher(
if reproject_func_kwargs is None:
reproject_func_kwargs = {}

# For now, we are quite restrictive in what non_reprojected_dims can
# be, but it is designed so that if we wanted we could support more use
# cases in future. For now, it has to be a tuple where each element is
# sequential from zero, e.g. (0,) or (0, 1) or (0, 1, 2)

if non_reprojected_dims is None:
n_dim_reproject = min(wcs_in.low_level_wcs.pixel_n_dim, wcs_out.low_level_wcs.pixel_n_dim)
else:
if non_reprojected_dims == tuple(range(len(non_reprojected_dims))):
n_dim_reproject = len(shape_out) - len(non_reprojected_dims)
else:
raise ValueError("non_reprojected_dims should be a tuple with values increasing sequentially from zero")

# We set up a global temporary directory since this will be used e.g. to
# store memory mapped Numpy arrays and zarr arrays.

Expand Down Expand Up @@ -206,9 +225,41 @@ def _reproject_dispatcher(
# shape_out will be the full size of the output array as this is updated
# in parse_output_projection, even if shape_out was originally passed in as
# the shape of a single image.
broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out)

logger.info(f"Broadcasting is {'' if broadcasting else 'not '}being used")
broadcasting = n_dim_reproject < len(shape_out)

logger.info(f"Broadcasting is {'' if broadcasting else 'not '}being used, reprojecting last {n_dim_reproject} axes")

# Output shape should match input shape for any ignored dimensions
# TODO: check for shape_out not matching shape_in along broadcasted dimensions

shape_in = array_in.shape

if shape_out[:-n_dim_reproject] != shape_in[:-n_dim_reproject]:
raise ValueError("Input shape should match output shape for non-reprojected dimensions")

if len(block_size) > len(shape_out):
raise ValueError(
f"block_size {block_size} cannot have more elements "
f"than the dimensionality of the output ({len(shape_out)})"
)


if len(block_size) != n_dim_reproject and len(block_size) != len(shape_out):
raise ValueError(
f"block_size {block_size} should have either {n_dim_reproject} or {len(shape_out)} elements"
)

if len(block_size) == n_dim_reproject:
block_size = (-1,) * (len(shape_out) - n_dim_reproject) + tuple(block_size)

block_size = [(block_size[i] if block_size[i] != -1 else shape_out[i]) for i in range(len(block_size))]

block_size = tuple(block_size)
shape_out = tuple(shape_out)

# TODO: replace block size of -1 by actual value for logic below to work
# TODO: re-implement block_size auto

# Check block size and determine whether block size indicates we should
# parallelize over broadcasted dimension. The logic is as follows: if
Expand All @@ -220,32 +271,15 @@ def _reproject_dispatcher(
# don't make any assumptions for now and assume a single chunk in the
# missing dimensions.
broadcasted_parallelization = False
if broadcasting and block_size is not None and block_size != "auto":
if len(block_size) == len(shape_out):
if (
block_size[-wcs_in.low_level_wcs.pixel_n_dim :]
== shape_out[-wcs_in.low_level_wcs.pixel_n_dim :]
):
broadcasted_parallelization = True
block_size = (
block_size[: -wcs_in.low_level_wcs.pixel_n_dim]
+ (-1,) * wcs_in.low_level_wcs.pixel_n_dim
)
else:
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
if block_size[i] != -1 and block_size[i] != shape_out[i]:
raise ValueError(
"block shape should either match output data shape along broadcasted dimension or non-broadcasted dimensions"
)
elif len(block_size) < len(shape_out):
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
else:
if broadcasting and block_size is not None:
if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]:
# TODO: maybe error if block_size was given in full and is wrong
broadcasted_parallelization = True
block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[-n_dim_reproject:]
elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]:
raise ValueError(
f"block_size {len(block_size)} cannot have more elements "
f"than the dimensionality of the output ({len(shape_out)})"
)

# TODO: check for shape_out not matching shape_in along broadcasted dimensions
"block shape should either match output data shape along reprojected dimensions or non-reprojected dimensions"
)

logger.info(
f"{'P' if broadcasted_parallelization else 'Not p'}arallelizing along "
Expand All @@ -255,8 +289,6 @@ def _reproject_dispatcher(
if output_footprint is None and return_footprint:
output_footprint = np.zeros(shape_out, dtype=float)

shape_in = array_in.shape

def reproject_single_block(a, array_or_path, block_info=None):

if (
Expand All @@ -270,6 +302,8 @@ def reproject_single_block(a, array_or_path, block_info=None):
if isinstance(array_or_path, str) and array_or_path == "from-dict":
array_or_path = dask_arrays["array"]

shape_out = block_info[None]["chunk-shape"][1:]

# The WCS class from astropy is not thread-safe, see e.g.
# https://github.com/astropy/astropy/issues/16244
# https://github.com/astropy/astropy/issues/16245
Expand All @@ -281,16 +315,38 @@ def reproject_single_block(a, array_or_path, block_info=None):
wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in
wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out

slices = [
slice(*x) for x in block_info[None]["array-location"][-wcs_out_cp.pixel_n_dim :]
]
slices_in = []
slices_out = []
for idx in range(len(shape_out)):
interval = block_info[None]["array-location"][idx + 1]
if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject:
if interval[1] - interval[0] != 1:
raise RuntimeError(f"Expected a chunk of width 1 along dimension {idx} (got {interval[1] - interval[0]})")
slices_in.append(interval[0])
slices_out.append(interval[0])
else:
slices_in.append(slice(None))
slices_out.append(slice(*block_info[None]["array-location"][idx + 1]))

slices_in = slices_in[-wcs_in.pixel_n_dim]
slices_out = slices_out[-wcs_out.pixel_n_dim]

if broadcasted_parallelization:
if isinstance(wcs_in_cp, BaseHighLevelWCS):
low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in)
else:
low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in)

if isinstance(wcs_out, BaseHighLevelWCS):
low_level_wcs = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices)
wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in)
else:
low_level_wcs = SlicedLowLevelWCS(wcs_out_cp, slices=slices)
wcs_in_sub = wcs_in_cp

wcs_out_sub = HighLevelWCSWrapper(low_level_wcs)
if isinstance(wcs_out_cp, BaseHighLevelWCS):
low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out)
else:
low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out)

wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out)

if isinstance(array_or_path, tuple):
array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r")
Expand All @@ -302,11 +358,9 @@ def reproject_single_block(a, array_or_path, block_info=None):
if array_or_path is None:
raise RuntimeError("array_or_path is not set")

shape_out = block_info[None]["chunk-shape"][1:]

array, footprint = reproject_func(
array_in,
wcs_in_cp,
wcs_in_sub,
wcs_out_sub,
shape_out=shape_out,
array_out=np.zeros(shape_out),
Expand All @@ -319,12 +373,14 @@ def reproject_single_block(a, array_or_path, block_info=None):

array_out_dask = da.empty(shape_out, chunks=block_size)
if isinstance(array_in, da.core.Array):
if array_in.chunksize != block_size:
logger.info(
f"Rechunking input dask array as chunks ({array_in.chunksize}) "
"do not match block size ({block_size})"
)
array_in = array_in.rechunk(block_size)
pass
# FIXME: Should take into account -1s here
# if array_in.chunksize != block_size:
# logger.info(
# f"Rechunking input dask array as chunks ({array_in.chunksize}) "
# f"do not match block size ({block_size})"
# )
# array_in = array_in.rechunk(block_size)
else:

class ArrayWrapper:
Expand Down
2 changes: 1 addition & 1 deletion reproject/interpolation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def _validate_wcs(wcs_in, wcs_out, shape_in, shape_out):
if wcs_in.low_level_wcs.pixel_n_dim != wcs_out.low_level_wcs.pixel_n_dim:
raise ValueError("Number of dimensions in input and output WCS should match")
raise ValueError(f"Number of dimensions in input and output WCS should match (got {wcs_in.low_level_wcs.pixel_n_dim} and {wcs_out.low_level_wcs.pixel_n_dim})")
elif len(shape_out) < wcs_out.low_level_wcs.pixel_n_dim:
raise ValueError("Too few dimensions in shape_out")
elif len(shape_in) < wcs_in.low_level_wcs.pixel_n_dim:
Expand Down
2 changes: 2 additions & 0 deletions reproject/interpolation/high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def reproject_interp(
output_footprint=None,
return_footprint=True,
block_size=None,
non_reprojected_dims=None,
parallel=False,
return_type=None,
dask_method=None,
Expand Down Expand Up @@ -152,6 +153,7 @@ def reproject_interp(
array_out=output_array,
parallel=parallel,
block_size=block_size,
non_reprojected_dims=non_reprojected_dims,
return_footprint=return_footprint,
output_footprint=output_footprint,
reproject_func_kwargs=dict(
Expand Down
15 changes: 10 additions & 5 deletions reproject/mosaicking/coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,15 @@ def reproject_and_coadd(
# convex in the output projection), and transforming every edge pixel,
# which provides a lot of redundant information.

edges = sample_array_edges(
array_in.shape[-wcs_in.low_level_wcs.pixel_n_dim :], n_samples=11
)[::-1]
edges_out = pixel_to_pixel(wcs_in, wcs_out, *edges)[::-1]
# TODO: ignore non-repreojected dims here and slice WCS

try:
edges = sample_array_edges(
array_in.shape[-wcs_in.low_level_wcs.pixel_n_dim :], n_samples=11
)[::-1]
edges_out = pixel_to_pixel(wcs_in, wcs_out, *edges)[::-1]
except:
edges_out = np.array([np.nan])

# Determine the cutout parameters

Expand All @@ -257,7 +262,7 @@ def reproject_and_coadd(
ndim_out = len(shape_out)

# Determine how many extra broadcasted dimensions are present
n_broadcasted = len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim
n_broadcasted = len(shape_out) - wcs_out.low_level_wcs.pixel_n_dim
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hack currently!


skip_data = False
if np.any(np.isnan(edges_out)):
Expand Down
Loading