Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(treespec): memorize ongoing repr / hash calls to resolve infinite recursion under self-referential case #82

Merged
merged 8 commits into from
Sep 24, 2023
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.290
rev: v0.0.291
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -43,7 +43,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/asottile/pyupgrade
rev: v3.12.0
rev: v3.13.0
hooks:
- id: pyupgrade
args: [--py37-plus]
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Memorize ongoing `repr` / `hash` calls to resolve infinite recursion under self-referential case by [@XuehaiPan](https://github.com/XuehaiPan) and [@JieRen98](https://github.com/JieRen98) in [#82](https://github.com/metaopt/optree/pull/82).

### Removed

Expand Down
38 changes: 34 additions & 4 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#pragma once

#include <absl/container/flat_hash_set.h>
#include <absl/container/inlined_vector.h>
#include <absl/hash/hash.h>
#include <pybind11/pybind11.h>
Expand All @@ -25,6 +26,7 @@ limitations under the License.
#include <optional>
#include <stdexcept>
#include <string>
#include <thread> // NOLINT[build/c++11]
#include <tuple>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -156,6 +158,7 @@ class PyTreeSpec {
inline bool operator>(const PyTreeSpec &other) const { return IsSuffix(other, true); }
inline bool operator>=(const PyTreeSpec &other) const { return IsSuffix(other, false); }

// Return the hash value of the PyTreeSpec.
template <typename H>
friend H AbslHashValue(H h, const Node &n) {
ssize_t data_hash = 0;
Expand Down Expand Up @@ -201,17 +204,34 @@ class PyTreeSpec {
INTERNAL_ERROR();
}

h = H::combine(
return H::combine(
std::move(h), n.kind, n.arity, n.custom, n.num_leaves, n.num_nodes, data_hash);
return h;
}

template <typename H>
friend H AbslHashValueImpl(H h, const PyTreeSpec &t) {
return H::combine(std::move(h), t.m_traversal, t.m_none_is_leaf, t.m_namespace);
}

template <typename H>
friend H AbslHashValue(H h, const PyTreeSpec &t) {
JieRen98 marked this conversation as resolved.
Show resolved Hide resolved
h = H::combine(std::move(h), t.m_traversal, t.m_none_is_leaf, t.m_namespace);
return h;
std::pair<const PyTreeSpec *, std::thread::id> indent{&t, std::this_thread::get_id()};
if (sm_hash_running.contains(indent)) {
return h;
}

sm_hash_running.insert(indent);
try {
H hash = AbslHashValueImpl(std::move(h), t);
sm_hash_running.erase(indent);
return hash;
} catch (...) {
sm_hash_running.erase(indent);
std::rethrow_exception(std::current_exception());
}
}

// Return a string representation of the PyTreeSpec.
[[nodiscard]] std::string ToString() const;

// Transform the PyTreeSpec into a picklable object.
Expand Down Expand Up @@ -268,6 +288,14 @@ class PyTreeSpec {
// The registry namespace used to resolve the custom pytree node types.
std::string m_namespace;

// A set of (treespec, thread_id) pairs that are currently being represented as strings.
inline static absl::flat_hash_set<std::pair<const PyTreeSpec *, std::thread::id>>
sm_repr_running{};

// A set of (treespec, thread_id) pairs that are currently being hashed.
inline static absl::flat_hash_set<std::pair<const PyTreeSpec *, std::thread::id>>
sm_hash_running{};

// Helper that manufactures an instance of a node given its children.
static py::object MakeNode(const Node &node, const absl::Span<py::object> &children);

Expand Down Expand Up @@ -307,6 +335,8 @@ class PyTreeSpec {
const ssize_t &pos,
const ssize_t &depth) const;

[[nodiscard]] std::string ToStringImpl() const;

static std::unique_ptr<PyTreeSpec> FromPicklableImpl(const py::object &picklable);
};

Expand Down
4 changes: 2 additions & 2 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,11 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references]
// The keys remain in the insertion order.
PyErr_Clear();
} else [[unlikely]] {
throw;
std::rethrow_exception(std::current_exception());
}
}
} else [[unlikely]] {
throw;
std::rethrow_exception(std::current_exception());
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ target_link_libraries(
_C
PUBLIC
absl::flat_hash_map
absl::flat_hash_set
absl::str_format
)
19 changes: 18 additions & 1 deletion src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ template PyTreeKind PyTreeSpec::GetKind<NONE_IS_LEAF>(const py::handle&,
const std::string&);

// NOLINTNEXTLINE[readability-function-cognitive-complexity]
std::string PyTreeSpec::ToString() const {
std::string PyTreeSpec::ToStringImpl() const {
auto agenda = std::vector<std::string>{};
for (const Node& node : m_traversal) {
EXPECT_GE(py::ssize_t_cast(agenda.size()), node.arity, "Too few elements for container.");
Expand Down Expand Up @@ -697,6 +697,23 @@ std::string PyTreeSpec::ToString() const {
")");
}

std::string PyTreeSpec::ToString() const {
std::pair<const PyTreeSpec*, std::thread::id> indent{this, std::this_thread::get_id()};
if (sm_repr_running.contains(indent)) {
return "...";
}

sm_repr_running.insert(indent);
try {
std::string representation = ToStringImpl();
sm_repr_running.erase(indent);
return representation;
} catch (...) {
sm_repr_running.erase(indent);
std::rethrow_exception(std::current_exception());
}
}

py::object PyTreeSpec::ToPicklable() const {
py::tuple node_states{GetNumNodes()};
ssize_t i = 0;
Expand Down
58 changes: 58 additions & 0 deletions tests/test_treespec.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,64 @@ def test_treespec_string_representation(data):
assert str(treespec) == correct_string


def test_treespec_self_referential():
class Holder:
def __init__(self, value):
self.value = value

def __eq__(self, other):
return isinstance(other, Holder) and self.value == other.value

def __hash__(self):
return hash(self.value)

def __repr__(self):
return f'Holder({self.value!r})'

key = Holder('a')

hashes = set()
treespec = optree.tree_structure({key: 0})
assert str(treespec) == "PyTreeSpec({Holder('a'): *})"
assert hash(treespec) == hash(treespec)
hashes.add(hash(treespec))

key.value = 'b'
assert str(treespec) == "PyTreeSpec({Holder('b'): *})"
assert hash(treespec) == hash(treespec)
assert hash(treespec) not in hashes
hashes.add(hash(treespec))

key.value = treespec
assert str(treespec) == 'PyTreeSpec({Holder(...): *})'
assert hash(treespec) == hash(treespec)
assert hash(treespec) not in hashes
hashes.add(hash(treespec))

key.value = ('a', treespec, treespec)
assert str(treespec) == "PyTreeSpec({Holder(('a', ..., ...)): *})"
assert hash(treespec) == hash(treespec)
assert hash(treespec) not in hashes
hashes.add(hash(treespec))

other = optree.tree_structure({Holder(treespec): 1})
assert str(other) == "PyTreeSpec({Holder(PyTreeSpec({Holder(('a', ..., ...)): *})): *})"
assert hash(other) == hash(other)
assert hash(other) not in hashes
hashes.add(hash(other))

key.value = other
assert str(treespec) == 'PyTreeSpec({Holder(PyTreeSpec({Holder(...): *})): *})'
assert str(other) == 'PyTreeSpec({Holder(PyTreeSpec({Holder(...): *})): *})'
assert hash(treespec) == hash(treespec)
assert hash(treespec) not in hashes
hashes.add(hash(treespec))
assert hash(other) == hash(other)
assert hash(treespec) == hash(other)
with pytest.raises(RecursionError):
assert treespec != other


def test_with_namespace():
tree = NAMESPACED_TREE

Expand Down
Loading