Skip to content

Commit 6835a1b

Browse files
ericastorcopybara-github
authored andcommitted
[XLS] When comparing TreeBitLocations, compare node ID rather than pointer
Do the same when caching BDD trees for each node in the `node_locations_` cache in the BDD query engine. This is generally both a correctness & a safety improvement. In this specific case, it was found as a fix for nondeterminism! Due to false matches when memory locations were reused for new Nodes, the BDD query engine's variable caches could nondeterministically match with an (no-longer-used) variable created for a deleted Node. This was safe, but resulted in effective changes to the variable order, which could cause the BDD to saturate in different places nondeterministically. PiperOrigin-RevId: 761556944
1 parent b14928c commit 6835a1b

File tree

4 files changed

+61
-15
lines changed

4 files changed

+61
-15
lines changed

xls/ir/node.h

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "absl/log/check.h"
3030
#include "absl/status/status.h"
3131
#include "absl/status/statusor.h"
32-
#include "absl/strings/str_cat.h"
3332
#include "absl/types/span.h"
3433
#include "xls/common/casts.h"
3534
#include "xls/common/status/status_macros.h"
@@ -50,6 +49,35 @@ absl::Span<ChangeListener* const> GetChangeListeners(
5049
// Forward declaration to avoid circular dependency.
5150
class DfsVisitor;
5251

52+
// A (non-owning) reference to a node; contains both the node's ID and the
53+
// pointer to the node, enabling efficient access & safe comparison.
54+
class NodeRef {
55+
public:
56+
explicit NodeRef(Node* node);
57+
58+
int64_t id() const { return id_; }
59+
Node* node() const { return node_; }
60+
61+
// On dereference, behaves exactly like the underlying node.
62+
Node& operator*() const { return *node_; }
63+
Node* operator->() const { return node_; }
64+
65+
friend bool operator==(const NodeRef& a, const NodeRef& b) {
66+
DCHECK(a.id() != b.id() || a.node() == b.node())
67+
<< "False node match due to reused ID";
68+
return a.id() == b.id();
69+
}
70+
71+
template <typename H>
72+
friend H AbslHashValue(H h, const NodeRef& node) {
73+
return h.combine(std::move(h), node.id());
74+
}
75+
76+
private:
77+
int64_t id_;
78+
Node* node_;
79+
};
80+
5381
// Abstract type for a node (representing an expression) in the high level IR.
5482
//
5583
// Node is subtyped and can be checked-converted via the As* methods below.
@@ -336,6 +364,9 @@ class Node {
336364
absl::InlinedVector<Node*, 2> users_;
337365
};
338366

367+
inline NodeRef::NodeRef(Node* node)
368+
: id_(node == nullptr ? -1 : node->id()), node_(node) {}
369+
339370
inline std::ostream& operator<<(std::ostream& os, const Node& node) {
340371
os << node.ToString();
341372
return os;
@@ -344,10 +375,20 @@ inline std::ostream& operator<<(std::ostream& os, const Node* node) {
344375
os << (node == nullptr ? std::string("<nullptr Node*>") : node->ToString());
345376
return os;
346377
}
378+
inline std::ostream& operator<<(std::ostream& os, const NodeRef& node) {
379+
os << node->ToString();
380+
return os;
381+
}
347382

348-
inline void NodeAppend(std::string* out, const Node* n) {
349-
absl::StrAppend(out, n->ToString());
383+
inline bool operator==(const NodeRef& a, Node* b) {
384+
if (b == nullptr) {
385+
return a.node() == nullptr;
386+
}
387+
DCHECK(a.id() != b->id() || a.node() == b)
388+
<< "False node match due to reused ID";
389+
return a.id() == b->id();
350390
}
391+
inline bool operator==(Node* a, const NodeRef& b) { return b == a; }
351392

352393
} // namespace xls
353394

xls/passes/bdd_query_engine.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,17 @@ BddTree BddQueryEngine::ComputeInfo(
482482
Node* node, absl::Span<const BddTree* const> operand_infos) const {
483483
if (!ShouldEvaluate(node)) {
484484
VLOG(3) << " node filtered out by generic ShouldEvaluate heuristic.";
485-
return leaf_type_tree::Clone(GetVariablesFor(node));
485+
return leaf_type_tree::Clone(GetVariablesFor(NodeRef(node)));
486486
}
487487
if (node_filter_.has_value() && !(*node_filter_)(node)) {
488488
VLOG(3) << " node filtered out by configured filter.";
489-
return leaf_type_tree::Clone(GetVariablesFor(node));
489+
return leaf_type_tree::Clone(GetVariablesFor(NodeRef(node)));
490490
}
491491

492492
VLOG(3) << " computing BDD value...";
493-
BddNodeEvaluator node_evaluator(
494-
*evaluator_, [this](Node* node) { return GetVariablesFor(node); });
493+
BddNodeEvaluator node_evaluator(*evaluator_, [this](Node* node) {
494+
return GetVariablesFor(NodeRef(node));
495+
});
495496
absl::flat_hash_set<Node*> injected_operands;
496497
injected_operands.reserve(node->operand_count());
497498
for (auto [operand, operand_info] :
@@ -1255,7 +1256,7 @@ bool BddQueryEngine::IsFullyKnown(
12551256
}
12561257

12571258
BddNodeIndex BddQueryEngine::GetVariableFor(TreeBitLocation location) const {
1258-
if (auto it = node_variables_.find(location.node());
1259+
if (auto it = node_variables_.find(location.node_ref());
12591260
it != node_variables_.end()) {
12601261
if (it->second->type() == location.node()->GetType()) {
12611262
return std::get<BddNodeIndex>(
@@ -1271,7 +1272,7 @@ BddNodeIndex BddQueryEngine::GetVariableFor(TreeBitLocation location) const {
12711272
CHECK(bit_variables_.emplace(location, result).second);
12721273
return result;
12731274
}
1274-
BddTreeView BddQueryEngine::GetVariablesFor(Node* node) const {
1275+
BddTreeView BddQueryEngine::GetVariablesFor(NodeRef node) const {
12751276
if (auto it = node_variables_.find(node);
12761277
it != node_variables_.end() && it->second->type() == node->GetType()) {
12771278
// If a node has changed type (which can happen!), we need a new set of

xls/passes/bdd_query_engine.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,11 @@ class BddQueryEngine
226226
mutable absl::flat_hash_map<TreeBitLocation, BddNodeIndex> bit_variables_;
227227
BddNodeIndex GetVariableFor(TreeBitLocation location) const;
228228

229-
// A map from nodes to BDD variables used to represent fully-unknown values;
230-
// used to avoid creating new variables for the same node.
231-
mutable absl::flat_hash_map<Node*, std::unique_ptr<BddTree>> node_variables_;
232-
BddTreeView GetVariablesFor(Node* node) const;
229+
// A map from node IDs to BDD variables used to represent fully-unknown
230+
// values; used to avoid creating new variables for the same node.
231+
mutable absl::flat_hash_map<NodeRef, std::unique_ptr<BddTree>>
232+
node_variables_;
233+
BddTreeView GetVariablesFor(NodeRef node) const;
233234
};
234235

235236
} // namespace xls

xls/passes/query_engine.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ class TreeBitLocation {
5050
bit_index_(bit_index),
5151
tree_index_(tree_index.begin(), tree_index.end()) {}
5252

53-
Node* node() const { return node_; }
53+
// TODO: https://github.com/google/xls/issues/2235 - Replace node() with this.
54+
const NodeRef& node_ref() const { return node_; }
55+
56+
Node* node() const { return node_.node(); }
5457

5558
int64_t bit_index() const { return bit_index_; }
5659

@@ -77,7 +80,7 @@ class TreeBitLocation {
7780
}
7881

7982
private:
80-
Node* node_;
83+
NodeRef node_;
8184
int64_t bit_index_;
8285
std::vector<int64_t> tree_index_;
8386
};

0 commit comments

Comments
 (0)