Skip to content

Commit

Permalink
Add "identical" array comparison operation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547298221
Change-Id: I699fa3da5afebfedd6102aa3c990331578072afe
  • Loading branch information
jbms authored and copybara-github committed Jul 11, 2023
1 parent 3db32dd commit 4296def
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 16 deletions.
11 changes: 11 additions & 0 deletions tensorstore/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@ bool AreArraysSameValueEqual(const OffsetArrayView<const void>& a,
.success;
}

bool AreArraysIdenticallyEqual(const OffsetArrayView<const void>& a,
const OffsetArrayView<const void>& b) {
if (a.dtype() != b.dtype()) return false;
if (a.domain() != b.domain()) return false;
return internal::IterateOverArrays({&a.dtype()->compare_identical, nullptr},
/*status=*/nullptr,
/*constraints=*/skip_repeated_elements, a,
b)
.success;
}

namespace internal_array {

bool EncodeArray(serialization::EncodeSink& sink,
Expand Down
12 changes: 10 additions & 2 deletions tensorstore/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -1898,15 +1898,23 @@ std::string ToString(
/// ``operator==`` in that negative zero is not equal to positive zero, and
/// NaN is equal to NaN.
///
/// Note that this differs from bit equality, because there are multiple bit
/// representations of NaN, and this functions treats all of them as equal.
/// Note that this differs from bit equality (`AreArraysIdentical`), because
/// there are multiple bit representations of NaN, and this functions treats all
/// of them as equal.
///
/// Checks that the data types, domains, and content are equal.
///
/// \relates Array
bool AreArraysSameValueEqual(const OffsetArrayView<const void>& a,
const OffsetArrayView<const void>& b);

/// Compares two arrays for "identical" equality.
///
/// This differs from normal equality in that floating point values are compared
/// bitwise.
bool AreArraysIdenticallyEqual(const OffsetArrayView<const void>& a,
const OffsetArrayView<const void>& b);

/// Validates that `source_shape` can be broadcast to `target_shape`.
///
/// A `source_shape` can be broadcast to a `target_shape` if, starting from the
Expand Down
44 changes: 44 additions & 0 deletions tensorstore/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ using ::tensorstore::MakeArray;
using ::tensorstore::MakeArrayView;
using ::tensorstore::MakeCopy;
using ::tensorstore::MakeOffsetArray;
using ::tensorstore::MakeScalarArray;
using ::tensorstore::MakeScalarArrayView;
using ::tensorstore::MatchesStatus;
using ::tensorstore::offset_origin;
Expand Down Expand Up @@ -995,6 +996,15 @@ TEST(ArrayTest, Compare) {
EXPECT_TRUE(ArrayView<void>(MakeScalarArrayView(1.0)) !=
MakeScalarArrayView(1));
EXPECT_TRUE(MakeArrayView({1}) != MakeArrayView({1, 2}));

EXPECT_FALSE(
MakeScalarArray<float>(std::numeric_limits<float>::quiet_NaN()) ==
MakeScalarArray<float>(std::numeric_limits<float>::quiet_NaN()));

EXPECT_FALSE(MakeScalarArray<std::complex<float>>(
std::numeric_limits<float>::quiet_NaN()) ==
MakeScalarArray<std::complex<float>>(
std::numeric_limits<float>::quiet_NaN()));
}

TEST(ArrayTest, SameValue) {
Expand All @@ -1009,6 +1019,40 @@ TEST(ArrayTest, SameValue) {
EXPECT_FALSE(AreArraysSameValueEqual(
MakeArrayView<float>({{NAN, 2, +0.0}, {4, 5, 6}}),
MakeArrayView<float>({{NAN, 2, -0.0}, {4, 5, 6}})));

EXPECT_TRUE(AreArraysSameValueEqual(
MakeScalarArray<float>(std::numeric_limits<float>::quiet_NaN()),
MakeScalarArray<float>(std::numeric_limits<float>::signaling_NaN())));

EXPECT_TRUE(AreArraysSameValueEqual(
MakeScalarArray<std::complex<float>>(
std::numeric_limits<float>::quiet_NaN()),
MakeScalarArray<std::complex<float>>(
std::numeric_limits<float>::signaling_NaN())));
}

TEST(ArrayTest, Identical) {
EXPECT_TRUE(
AreArraysIdenticallyEqual(MakeArrayView<float>({{1, 2, 3}, {4, 5, 6}}),
MakeArrayView<float>({{1, 2, 3}, {4, 5, 6}})));

EXPECT_TRUE(AreArraysIdenticallyEqual(
MakeArrayView<float>({{NAN, 2, 3}, {4, 5, 6}}),
MakeArrayView<float>({{NAN, 2, 3}, {4, 5, 6}})));

EXPECT_FALSE(AreArraysIdenticallyEqual(
MakeArrayView<float>({{NAN, 2, +0.0}, {4, 5, 6}}),
MakeArrayView<float>({{NAN, 2, -0.0}, {4, 5, 6}})));

EXPECT_FALSE(AreArraysIdenticallyEqual(
MakeScalarArray<float>(std::numeric_limits<float>::quiet_NaN()),
MakeScalarArray<float>(std::numeric_limits<float>::signaling_NaN())));

EXPECT_FALSE(AreArraysIdenticallyEqual(
MakeScalarArray<std::complex<float>>(
std::numeric_limits<float>::quiet_NaN()),
MakeScalarArray<std::complex<float>>(
std::numeric_limits<float>::signaling_NaN())));
}

TEST(CopyArrayTest, ZeroOrigin) {
Expand Down
66 changes: 58 additions & 8 deletions tensorstore/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,19 +506,31 @@ struct DataTypeOperations {
using AppendToStringFunction = void (*)(std::string* result, const void* ptr);
AppendToStringFunction append_to_string;

/// Compares two strided arrays for equality.
/// Compares two arrays for equality.
///
/// This uses regular equality, which for floating point types considers
/// positive and negative zero equal, and NaN unequal to itself.
using CompareEqualFunction = ElementwiseFunction<2, absl::Status*>;
CompareEqualFunction compare_equal;

/// Compares two strided arrays for equality, taking into account negative
/// Compares two arrays for equality, taking into account negative
/// zero and NaN for floating point types (negative zero is not equal to
/// positive zero, and NaN is equal to NaN).
///
/// Note that this not the same as bit equality, because there are multiple
/// possible bit representations of NaN, and this function considers all of
/// them to be equal.
/// For integer types this is equivalent to `compare_equal`.
///
/// Note that this not the same as `compare_identical`, because there are
/// multiple possible bit representations of NaN, and this function considers
/// all of them to be equal.
CompareEqualFunction compare_same_value;

/// Checks if two arrays are identical.
///
/// For integer and floating point types, this performs a bitwise comparison.
///
/// For integer types this is equivalent to `compare_equal`.
CompareEqualFunction compare_identical;

struct CanonicalConversionOperations {
// Function for converting to/from canonical data type.
using ConvertFunction = ElementwiseFunction<2, absl::Status*>;
Expand Down Expand Up @@ -621,6 +633,11 @@ class DataType {
return operations_->compare_same_value;
}

constexpr const Ops::CompareEqualFunction& compare_identical_function()
const {
return operations_->compare_identical;
}

constexpr const Ops::CopyAssignFunction& copy_assign_function() const {
return operations_->copy_assign;
}
Expand Down Expand Up @@ -708,8 +725,9 @@ bool CompareEqual(const T& a, const T& b) {
/// For floating point types, this differs from normal `operator==` in that
/// negative zero is not equal to positive zero, and NaN is equal to NaN.
///
/// Note that this differs from bit equality, because there are multiple bit
/// representations of NaN, and this functions treats all of them as equal.
/// Note that this differs from bit equality (`CompareIdentical`), because there
/// are multiple bit representations of NaN, and this functions treats all of
/// them as equal.
template <typename T>
bool CompareSameValue(const T& a, const T& b) {
if constexpr (internal::IsEqualityComparable<T>) {
Expand All @@ -718,6 +736,17 @@ bool CompareSameValue(const T& a, const T& b) {
return false;
}

/// Checks if two values are identical (indistinguishable).
///
/// For floating point types, this does a bitwise comparison.
template <typename T>
bool CompareIdentical(const T& a, const T& b) {
if constexpr (internal::IsEqualityComparable<T>) {
return a == b;
}
return false;
}

#define TENSORSTORE_INTERNAL_DO_DEFINE_COMPARE_SAME_VALUE_FLOAT(T, ...) \
template <> \
inline bool CompareSameValue<T>(const T& a, const T& b) { \
Expand All @@ -726,6 +755,11 @@ bool CompareSameValue(const T& a, const T& b) {
using Int = internal::uint_t<sizeof(T) * 8>; \
return internal::bit_cast<Int>(a) == internal::bit_cast<Int>(b); \
} \
template <> \
inline bool CompareIdentical<T>(const T& a, const T& b) { \
using Int = internal::uint_t<sizeof(T) * 8>; \
return internal::bit_cast<Int>(a) == internal::bit_cast<Int>(b); \
} \
/**/
TENSORSTORE_FOR_EACH_FLOAT_DATA_TYPE(
TENSORSTORE_INTERNAL_DO_DEFINE_COMPARE_SAME_VALUE_FLOAT)
Expand All @@ -737,6 +771,11 @@ TENSORSTORE_FOR_EACH_FLOAT_DATA_TYPE(
return CompareSameValue(a.real(), b.real()) && \
CompareSameValue(a.imag(), b.imag()); \
} \
template <> \
inline bool CompareIdentical<T>(const T& a, const T& b) { \
return CompareIdentical(a.real(), b.real()) && \
CompareIdentical(a.imag(), b.imag()); \
} \
/**/
TENSORSTORE_FOR_EACH_COMPLEX_DATA_TYPE(
TENSORSTORE_INTERNAL_DO_DEFINE_COMPARE_SAME_VALUE_COMPLEX)
Expand Down Expand Up @@ -801,6 +840,12 @@ struct DataTypeElementwiseOperationsImpl {
}
};

struct CompareIdenticalImpl {
bool operator()(const T* source, const T* dest, absl::Status*) const {
return internal_data_type::CompareIdentical<T>(*source, *dest);
}
};

using Initialize =
internal::SimpleElementwiseFunction<InitializeImpl(T), absl::Status*>;

Expand All @@ -817,6 +862,9 @@ struct DataTypeElementwiseOperationsImpl {
absl::Status*>;
using CompareSameValue = internal::SimpleElementwiseFunction<
CompareSameValueImpl(const T, const T), absl::Status*>;

using CompareIdentical = internal::SimpleElementwiseFunction<
CompareIdenticalImpl(const T, const T), absl::Status*>;
};

template <typename T>
Expand All @@ -839,8 +887,10 @@ constexpr internal::DataTypeOperations DataTypeOperationsImpl = {
/*.append_to_string=*/&DataTypeSimpleOperationsImpl<T>::AppendToString,
/*.compare_equal=*/
typename DataTypeElementwiseOperationsImpl<T>::CompareEqual(),
/*.compare_equal=*/
/*.compare_same_value=*/
typename DataTypeElementwiseOperationsImpl<T>::CompareSameValue(),
/*.compare_identical=*/
typename DataTypeElementwiseOperationsImpl<T>::CompareIdentical(),
/*.canonical_conversion=*/nullptr,
};

Expand Down
18 changes: 12 additions & 6 deletions tensorstore/internal/async_write_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ AsyncWriteArray::MaskedArray::GetArrayForWriteback(
bool read_state_already_integrated) {
assert(origin.size() == spec.rank());
WritebackData writeback;

const auto must_store = [&](ArrayView<const void> array) {
if (spec.store_if_equal_to_fill_value) return true;
if (spec.compare_to_fill_value_using_identical_equality) {
return !AreArraysIdenticallyEqual(array, spec.fill_value);
} else {
return !AreArraysSameValueEqual(array, spec.fill_value);
}
};

if (!data) {
// No data has been allocated for the write array. This is only possible in
// two cases:
Expand All @@ -110,9 +120,7 @@ AsyncWriteArray::MaskedArray::GetArrayForWriteback(
// Case 2: array is unmodified.
assert(IsUnmodified());
if (read_array.data()) {
writeback.must_store =
spec.store_if_equal_to_fill_value ||
!AreArraysSameValueEqual(read_array, spec.fill_value);
writeback.must_store = must_store(read_array);
if (!writeback.must_store) {
writeback.array = spec.fill_value;
} else {
Expand Down Expand Up @@ -144,9 +152,7 @@ AsyncWriteArray::MaskedArray::GetArrayForWriteback(
}
writeback.array = SharedArrayView<void>(
SharedElementPointer<void>(data, spec.dtype()), spec.write_layout());
writeback.must_store =
spec.store_if_equal_to_fill_value ||
!AreArraysSameValueEqual(writeback.array, spec.fill_value);
writeback.must_store = must_store(writeback.array);
if (!writeback.must_store) {
data = nullptr;
writeback.array = spec.fill_value;
Expand Down
5 changes: 5 additions & 0 deletions tensorstore/internal/async_write_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ struct AsyncWriteArray {
/// call to `WriteFillValue`, then it won't be stored.
bool store_if_equal_to_fill_value = false;

/// If `true`, compare to fill value using `AreArraysIdenticallyEqual`. If
/// `false`, compare to fill value using `AreArraysSameValueEqual`. Only
/// has an effect if `store_if_equal_to_fill_value == false`.
bool compare_to_fill_value_using_identical_equality = false;

/// Returns the shape of the array.
span<const Index> shape() const { return fill_value.shape(); }

Expand Down
41 changes: 41 additions & 0 deletions tensorstore/internal/async_write_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,47 @@ TEST(MaskedArrayTest, StoreIfEqualToFillValue) {
}
}

// Tests that `compare_to_fill_value_using_identical_equality==true` is
// correctly handled.
TEST(MaskedArrayTest, CompareFillValueIdenticallyEqual) {
auto fill_value =
MakeScalarArray<float>(std::numeric_limits<float>::quiet_NaN());
tensorstore::Box<> component_bounds;
Spec spec(fill_value, component_bounds);
spec.compare_to_fill_value_using_identical_equality = true;
MaskedArray write_state(0);
// Fully overwrite the portion within `component_bounds`.
TestWrite(&write_state, spec, {},
tensorstore::MakeScalarArray<float>(
std::numeric_limits<float>::signaling_NaN()),
/*expected_modified=*/true);
{
auto writeback_data = write_state.GetArrayForWriteback(
spec, /*origin=*/{}, /*read_array=*/{},
/*read_state_already_integrated=*/false);
EXPECT_TRUE(AreArraysIdenticallyEqual(
tensorstore::MakeScalarArray<float>(
std::numeric_limits<float>::signaling_NaN()),
writeback_data.array));
EXPECT_TRUE(writeback_data.must_store);
}

TestWrite(&write_state, spec, {},
tensorstore::MakeScalarArray<float>(
std::numeric_limits<float>::quiet_NaN()),
/*expected_modified=*/true);
{
auto writeback_data = write_state.GetArrayForWriteback(
spec, /*origin=*/{}, /*read_array=*/{},
/*read_state_already_integrated=*/false);
EXPECT_TRUE(
AreArraysIdenticallyEqual(tensorstore::MakeScalarArray<float>(
std::numeric_limits<float>::quiet_NaN()),
writeback_data.array));
EXPECT_FALSE(writeback_data.must_store);
}
}

TEST(AsyncWriteArrayTest, Basic) {
AsyncWriteArray async_write_array(2);
auto fill_value = MakeArray<int32_t>({{1, 2, 3}, {4, 5, 6}});
Expand Down

0 comments on commit 4296def

Please sign in to comment.