Skip to content

Commit a40e237

Browse files
ericastorcopybara-github
authored andcommitted
[XLS] When comparing TreeBitLocations, compare node ID rather than pointer
Do the same when caching BDD trees for each node in the `node_locations_` cache in the BDD query engine. This is generally both a correctness & a safety improvement. In this specific case, it was found as a fix for nondeterminism! Due to false matches when memory locations were reused for new Nodes, the BDD query engine's variable caches could nondeterministically match with an (no-longer-used) variable created for a deleted Node. This was safe, but resulted in effective changes to the variable order, which could cause the BDD to saturate in different places nondeterministically. PiperOrigin-RevId: 761556944
1 parent 79b5a09 commit a40e237

File tree

4 files changed

+61
-15
lines changed

4 files changed

+61
-15
lines changed

xls/ir/node.h

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "absl/log/check.h"
3030
#include "absl/status/status.h"
3131
#include "absl/status/statusor.h"
32-
#include "absl/strings/str_cat.h"
3332
#include "absl/types/span.h"
3433
#include "xls/common/casts.h"
3534
#include "xls/common/status/status_macros.h"
@@ -50,6 +49,35 @@ absl::Span<ChangeListener* const> GetChangeListeners(
5049
// Forward declaration to avoid circular dependency.
5150
class DfsVisitor;
5251

52+
// A (non-owning) reference to a node; contains both the node's ID and the
53+
// pointer to the node, enabling efficient access & safe comparison.
54+
class NodeRef {
55+
public:
56+
explicit NodeRef(Node* node);
57+
58+
int64_t id() const { return id_; }
59+
Node* node() const { return node_; }
60+
61+
// On dereference, behaves exactly like the underlying node.
62+
Node& operator*() const { return *node_; }
63+
Node* operator->() const { return node_; }
64+
65+
friend bool operator==(const NodeRef& a, const NodeRef& b) {
66+
DCHECK(a.id() != b.id() || a.node() == b.node())
67+
<< "False node match due to reused ID";
68+
return a.id() == b.id();
69+
}
70+
71+
template <typename H>
72+
friend H AbslHashValue(H h, const NodeRef& node) {
73+
return h.combine(std::move(h), node.id());
74+
}
75+
76+
private:
77+
int64_t id_;
78+
Node* node_;
79+
};
80+
5381
// Abstract type for a node (representing an expression) in the high level IR.
5482
//
5583
// Node is subtyped and can be checked-converted via the As* methods below.
@@ -336,6 +364,9 @@ class Node {
336364
absl::InlinedVector<Node*, 2> users_;
337365
};
338366

367+
inline NodeRef::NodeRef(Node* node)
368+
: id_(node == nullptr ? -1 : node->id()), node_(node) {}
369+
339370
inline std::ostream& operator<<(std::ostream& os, const Node& node) {
340371
os << node.ToString();
341372
return os;
@@ -344,10 +375,20 @@ inline std::ostream& operator<<(std::ostream& os, const Node* node) {
344375
os << (node == nullptr ? std::string("<nullptr Node*>") : node->ToString());
345376
return os;
346377
}
378+
inline std::ostream& operator<<(std::ostream& os, const NodeRef& node) {
379+
os << node->ToString();
380+
return os;
381+
}
347382

348-
inline void NodeAppend(std::string* out, const Node* n) {
349-
absl::StrAppend(out, n->ToString());
383+
inline bool operator==(const NodeRef& a, Node* b) {
384+
if (b == nullptr) {
385+
return a.node() == nullptr;
386+
}
387+
DCHECK(a.id() != b->id() || a.node() == b)
388+
<< "False node match due to reused ID";
389+
return a.id() == b->id();
350390
}
391+
inline bool operator==(Node* a, const NodeRef& b) { return b == a; }
351392

352393
} // namespace xls
353394

xls/passes/bdd_query_engine.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -497,16 +497,17 @@ BddTree BddQueryEngine::ComputeInfo(
497497
Node* node, absl::Span<const BddTree* const> operand_infos) const {
498498
if (!ShouldEvaluate(node)) {
499499
VLOG(3) << " node filtered out by generic ShouldEvaluate heuristic.";
500-
return leaf_type_tree::Clone(GetVariablesFor(node));
500+
return leaf_type_tree::Clone(GetVariablesFor(NodeRef(node)));
501501
}
502502
if (node_filter_.has_value() && !(*node_filter_)(node)) {
503503
VLOG(3) << " node filtered out by configured filter.";
504-
return leaf_type_tree::Clone(GetVariablesFor(node));
504+
return leaf_type_tree::Clone(GetVariablesFor(NodeRef(node)));
505505
}
506506

507507
VLOG(3) << " computing BDD value...";
508-
BddNodeEvaluator node_evaluator(
509-
*evaluator_, [this](Node* node) { return GetVariablesFor(node); });
508+
BddNodeEvaluator node_evaluator(*evaluator_, [this](Node* node) {
509+
return GetVariablesFor(NodeRef(node));
510+
});
510511
absl::flat_hash_set<Node*> injected_operands;
511512
injected_operands.reserve(node->operand_count());
512513
for (auto [operand, operand_info] :
@@ -1270,7 +1271,7 @@ bool BddQueryEngine::IsFullyKnown(
12701271
}
12711272

12721273
BddNodeIndex BddQueryEngine::GetVariableFor(TreeBitLocation location) const {
1273-
if (auto it = node_variables_.find(location.node());
1274+
if (auto it = node_variables_.find(location.node_ref());
12741275
it != node_variables_.end()) {
12751276
if (it->second->type() == location.node()->GetType()) {
12761277
return std::get<BddNodeIndex>(
@@ -1286,7 +1287,7 @@ BddNodeIndex BddQueryEngine::GetVariableFor(TreeBitLocation location) const {
12861287
CHECK(bit_variables_.emplace(location, result).second);
12871288
return result;
12881289
}
1289-
BddTreeView BddQueryEngine::GetVariablesFor(Node* node) const {
1290+
BddTreeView BddQueryEngine::GetVariablesFor(NodeRef node) const {
12901291
if (auto it = node_variables_.find(node);
12911292
it != node_variables_.end() && it->second->type() == node->GetType()) {
12921293
// If a node has changed type (which can happen!), we need a new set of

xls/passes/bdd_query_engine.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,11 @@ class BddQueryEngine
224224
mutable absl::flat_hash_map<TreeBitLocation, BddNodeIndex> bit_variables_;
225225
BddNodeIndex GetVariableFor(TreeBitLocation location) const;
226226

227-
// A map from nodes to BDD variables used to represent fully-unknown values;
228-
// used to avoid creating new variables for the same node.
229-
mutable absl::flat_hash_map<Node*, std::unique_ptr<BddTree>> node_variables_;
230-
BddTreeView GetVariablesFor(Node* node) const;
227+
// A map from node IDs to BDD variables used to represent fully-unknown
228+
// values; used to avoid creating new variables for the same node.
229+
mutable absl::flat_hash_map<NodeRef, std::unique_ptr<BddTree>>
230+
node_variables_;
231+
BddTreeView GetVariablesFor(NodeRef node) const;
231232
};
232233

233234
} // namespace xls

xls/passes/query_engine.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ class TreeBitLocation {
5050
bit_index_(bit_index),
5151
tree_index_(tree_index.begin(), tree_index.end()) {}
5252

53-
Node* node() const { return node_; }
53+
// TODO: https://github.com/google/xls/issues/2235 - Replace node() with this.
54+
const NodeRef& node_ref() const { return node_; }
55+
56+
Node* node() const { return node_.node(); }
5457

5558
int64_t bit_index() const { return bit_index_; }
5659

@@ -77,7 +80,7 @@ class TreeBitLocation {
7780
}
7881

7982
private:
80-
Node* node_;
83+
NodeRef node_;
8184
int64_t bit_index_;
8285
std::vector<int64_t> tree_index_;
8386
};

0 commit comments

Comments
 (0)