Skip to content

Commit

Permalink
Add IndexTransformGridPartition::GetCellTransform method
Browse files Browse the repository at this point in the history
The existing `PartitionIndexTransformOverGrid` function iterates over
both the `grid_cell_indices` and corresponding `cell_transform`.

The `GetGridCellRanges` function, used for storage statistics
computation, provides only ranges of `grid_cell_indices`, but not the
corresponding cell transforms, to avoid wastefully computing it even
when not needed.

The newly added `GetCellTransform` method will be used in a susequent
commit that adds zarr v3 support to obtain the cell transform
afterwards in cases where it is needed.

PiperOrigin-RevId: 564262593
Change-Id: I3e774d0be9653ed55e43ecd4e866699bf1014c5f
  • Loading branch information
jbms authored and copybara-github committed Sep 11, 2023
1 parent 85e8a50 commit 40d2628
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 142 deletions.
12 changes: 7 additions & 5 deletions tensorstore/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -535,18 +535,18 @@ tensorstore_cc_library(
"//tensorstore:rank",
"//tensorstore:strided_layout",
"//tensorstore/index_space:index_transform",
"//tensorstore/index_space:output_index_method",
"//tensorstore/util:byte_strided_pointer",
"//tensorstore/util:division",
"//tensorstore/util:dimension_set",
"//tensorstore/util:iterate",
"//tensorstore/util:iterate_over_index_range",
"//tensorstore/util:result",
"//tensorstore/util:span",
"//tensorstore/util:status",
"//tensorstore/util:str_cat",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
],
Expand All @@ -564,11 +564,13 @@ tensorstore_cc_test(
"//tensorstore:index",
"//tensorstore:index_interval",
"//tensorstore/index_space:index_transform",
"//tensorstore/util:dimension_set",
"//tensorstore/util:result",
"//tensorstore/util:span",
"//tensorstore/util:status",
"//tensorstore/util:status_testutil",
"//tensorstore/util:str_cat",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
],
)
Expand All @@ -579,8 +581,8 @@ tensorstore_cc_test(
srcs = ["grid_partition_test.cc"],
deps = [
":grid_partition",
":grid_partition_impl",
":irregular_grid",
":memory",
":regular_grid",
"//tensorstore:array",
"//tensorstore:box",
Expand All @@ -590,7 +592,7 @@ tensorstore_cc_test(
"//tensorstore/util:result",
"//tensorstore/util:span",
"//tensorstore/util:status",
"//tensorstore/util:status_testutil",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
128 changes: 4 additions & 124 deletions tensorstore/internal/grid_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,106 +61,6 @@ struct ConnectedSetIterateParameters {
func;
};

/// Allocates the `cell_transform` and initializes the portions that are the
/// same for all grid cells.
///
/// \param info The preprocessed partitioning data.
/// \param full_input_rank The rank of the "full" input space to be partitioned.
/// \returns A non-null pointer to a partially-initialized transform from the
/// synthetic "cell" index space, of rank `cell_input_rank`, to the "full"
/// index space, of rank `full_input_rank`.
internal_index_space::TransformRep::Ptr<> InitializeCellTransform(
const IndexTransformGridPartition& info, TransformRep* full_transform) {
const DimensionIndex full_input_rank = full_transform->input_rank;
DimensionIndex num_index_array_dims = 0;
for (const IndexArraySet& index_array_set : info.index_array_sets()) {
num_index_array_dims += index_array_set.input_dimensions.count();
}
const DimensionIndex cell_input_rank =
full_input_rank - num_index_array_dims + info.index_array_sets().size();

internal_index_space::TransformRep::Ptr<> cell_transform =
TransformRep::Allocate(cell_input_rank, full_input_rank);
cell_transform->input_rank = cell_input_rank;
cell_transform->output_rank = full_input_rank;
cell_transform->implicit_lower_bounds = false;
cell_transform->implicit_upper_bounds = false;

const span<Index> input_origin =
cell_transform->input_origin().first(cell_input_rank);
const span<OutputIndexMap> output_maps =
cell_transform->output_index_maps().first(full_input_rank);

// Initialize the `cell_transform` output index maps for all input
// dimensions of the original input space that do affect grid cell indices
// (i.e. contained in a connected set).
{
// Next synthetic input dimension index, corresponding to a connected set.
// The synthetic input dimensions for index array connected sets come before
// those for strided connected sets, to match the order of the recursive
// iteration.
DimensionIndex cell_input_dim = 0;
for (const IndexArraySet& index_array_set : info.index_array_sets()) {
// The `input_origin` is always 0 for the synthetic input dimension
// corresponding to an index array connected set (in fact the origin is
// arbitrary and any origin could be used). While iterating, the
// `input_shape[cell_input_dim]` will be set appropriately for each
// partition.
input_origin[cell_input_dim] = 0;
for (const DimensionIndex full_input_dim :
index_array_set.input_dimensions.index_view()) {
auto& map = output_maps[full_input_dim];
// Use an `offset` of `0` and stride of `1`, since the precomputed index
// arrays correspond directly to the input domains.
map.offset() = 0;
map.stride() = 1;
auto& index_array_data = map.SetArrayIndexing(cell_input_rank);
std::fill_n(index_array_data.byte_strides, cell_input_rank, 0);
// Initialize the index array `byte_strides`, which are the same for
// all partitions.
index_array_data.byte_strides[cell_input_dim] =
index_array_set.partitioned_input_indices.byte_strides()[0];
}
++cell_input_dim;
}

// The output index maps corresponding to the original input dimensions in
// strided connected sets do not depend on the partition.
for (const auto& strided_set : info.strided_sets()) {
auto& map = output_maps[strided_set.input_dimension];
map.SetSingleInputDimension(cell_input_dim);
// Use an `offset` of `0`. The actual starting index into the original
// input dimension will be set as `input_origin[cell_input_dim]`.
map.offset() = 0;
// Use a `stride` of `1`, to not skip any part of the original input
// domain.
map.stride() = 1;
++cell_input_dim;
}
}

// Set up the `cell_transform` output index maps corresponding to all input
// dimensions of the original input space that do not affect grid cell
// indices (i.e. not contained in a connected set). These output index maps
// will not be modified.
for (DimensionIndex cell_input_dim = info.index_array_sets().size() +
info.strided_sets().size(),
full_input_dim = 0;
full_input_dim < full_input_rank; ++full_input_dim) {
auto& map = output_maps[full_input_dim];
if (map.method() != OutputIndexMethod::constant) continue;
map.SetSingleInputDimension(cell_input_dim);
map.offset() = 0;
map.stride() = 1;
cell_transform->input_dimension(cell_input_dim) =
full_transform->input_dimension(full_input_dim);
++cell_input_dim;
}

// Invariants checked in InvokeCallback
return cell_transform;
}

/// Sets the fixed grid cell indices for all grid dimensions that do not
/// depend on any input dimensions (i.e. not contained in a connected set).
void InitializeConstantGridCellIndices(
Expand Down Expand Up @@ -249,9 +149,8 @@ class ConnectedSetIterateHelper {
explicit ConnectedSetIterateHelper(ConnectedSetIterateParameters params)
: params_(std::move(params)),
grid_cell_indices_(params_.grid_output_dimensions.size()),
cell_transform_(InitializeCellTransform(
params_.info,
internal_index_space::TransformAccess::rep(params_.transform))) {
cell_transform_(internal_grid_partition::InitializeCellTransform(
params_.info, params_.transform)) {
InitializeConstantGridCellIndices(
params_.transform, params_.grid_output_dimensions,
params_.output_to_grid_cell, grid_cell_indices_);
Expand Down Expand Up @@ -302,27 +201,8 @@ class ConnectedSetIterateHelper {
.grid_cell_indices[grid_cell_indices_offset + grid_i++];
}

// Update the output index maps for the original input dimensions in this
// connected set to reference the precomputed index array of input indices
// corresponding to this partition.
const SharedArray<const Index, 2> partition_input_indices =
index_array_set.partition_input_indices(partition_i);
cell_transform_->input_shape()[set_i] =
partition_input_indices.shape()[0];
ByteStridedPointer<const Index> partition_input_indices_ptr =
partition_input_indices.byte_strided_pointer();
const Index vector_dimension_byte_stride =
partition_input_indices.byte_strides()[1];
const span<OutputIndexMap> output_maps =
cell_transform_->output_index_maps();
for (DimensionIndex full_input_dim :
index_array_set.input_dimensions.index_view()) {
internal_index_space::IndexArrayData& index_array_data =
output_maps[full_input_dim].index_array_data();
index_array_data.element_pointer = std::shared_ptr<const Index>(
partition_input_indices.pointer(), partition_input_indices_ptr);
partition_input_indices_ptr += vector_dimension_byte_stride;
}
UpdateCellTransformForIndexArraySetPartition(
index_array_set, set_i, partition_i, cell_transform_.get());
TENSORSTORE_RETURN_IF_ERROR(IterateOverIndexArraySets(set_i + 1));
}
return absl::OkStatus();
Expand Down
Loading

0 comments on commit 40d2628

Please sign in to comment.