Skip to content

Commit

Permalink
Add IsBroadcastScalar function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551386329
Change-Id: I110e14188aeace424a4a5bdd62fef609dc28a33d
  • Loading branch information
jbms authored and copybara-github committed Jul 27, 2023
1 parent 386e858 commit f11d7a5
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
15 changes: 15 additions & 0 deletions tensorstore/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,9 @@ UnbroadcastArrayPreserveRank(
}

/// Checks if `array` has a contiguous layout with the specified order.
///
/// \relates Array
/// \id array
template <typename ElementTag, DimensionIndex Rank, ArrayOriginKind OriginKind,
ContainerKind LayoutCKind>
bool IsContiguousLayout(
Expand All @@ -2062,6 +2065,18 @@ bool IsContiguousLayout(
array.dtype().size());
}

/// Checks if `array` has at most a single distinct element.
///
/// \relates Array
/// \membergroup Broadcasting
/// \id array
template <typename ElementTag, DimensionIndex Rank, ArrayOriginKind OriginKind,
ContainerKind LayoutCKind>
bool IsBroadcastScalar(
const Array<ElementTag, Rank, OriginKind, LayoutCKind>& array) {
return tensorstore::IsBroadcastScalar(array.layout());
}

namespace internal_array {

/// Encodes an array to `sink`.
Expand Down
8 changes: 8 additions & 0 deletions tensorstore/strided_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ bool IsContiguousLayout(DimensionIndex rank, const Index* shape,
}
return true;
}

bool IsBroadcastScalar(DimensionIndex rank, const Index* shape,
const Index* byte_strides) {
for (DimensionIndex i = 0; i < rank; ++i) {
if (shape[i] > 1 && byte_strides[i] != 0) return false;
}
return true;
}
} // namespace internal_strided_layout

} // namespace tensorstore
20 changes: 19 additions & 1 deletion tensorstore/strided_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -932,10 +932,13 @@ namespace internal_strided_layout {
bool IsContiguousLayout(DimensionIndex rank, const Index* shape,
const Index* byte_strides, ContiguousLayoutOrder order,
Index element_size);
}
} // namespace internal_strided_layout

/// Checks if `layout` is a contiguous layout with the specified order and
/// element size.
///
/// \relates StridedLayout
/// \id strided_layout
template <DimensionIndex Rank, ArrayOriginKind OriginKind, ContainerKind CKind>
bool IsContiguousLayout(const StridedLayout<Rank, OriginKind, CKind>& layout,
ContiguousLayoutOrder order, Index element_size) {
Expand All @@ -944,6 +947,21 @@ bool IsContiguousLayout(const StridedLayout<Rank, OriginKind, CKind>& layout,
element_size);
}

namespace internal_strided_layout {
bool IsBroadcastScalar(DimensionIndex rank, const Index* shape,
const Index* byte_strides);
} // namespace internal_strided_layout

/// Checks if `layout` contains at most a single distinct element.
///
/// \relates StridedLayout
/// \id strided_layout
template <DimensionIndex Rank, ArrayOriginKind OriginKind, ContainerKind CKind>
bool IsBroadcastScalar(const StridedLayout<Rank, OriginKind, CKind>& layout) {
return internal_strided_layout::IsBroadcastScalar(
layout.rank(), layout.shape().data(), layout.byte_strides().data());
}

} // namespace tensorstore

#endif // TENSORSTORE_STRIDED_LAYOUT_H_
9 changes: 9 additions & 0 deletions tensorstore/strided_layout_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1144,4 +1144,13 @@ TEST(StridedLayoutTest, IsContiguousLayout) {
ContiguousLayoutOrder::c, 2));
}

TEST(StridedLayoutTest, IsBroadcastScalar) {
EXPECT_TRUE(IsBroadcastScalar(StridedLayout<>({1}, {5})));
EXPECT_FALSE(IsBroadcastScalar(StridedLayout<>({2}, {5})));
EXPECT_TRUE(IsBroadcastScalar(StridedLayout<>({2}, {0})));
EXPECT_TRUE(IsBroadcastScalar(StridedLayout<>({1, 1, 1}, {5, 10, 15})));
EXPECT_FALSE(IsBroadcastScalar(StridedLayout<>({1, 2}, {0, 5})));
EXPECT_TRUE(IsBroadcastScalar(StridedLayout<>({1, 2}, {5, 0})));
}

} // namespace

0 comments on commit f11d7a5

Please sign in to comment.