Skip to content

Commit

Permalink
[XLA] [AlgebraicSimplifier] Reuse visitor across computations.
Browse files Browse the repository at this point in the history
This avoids recalculating instruction count in a hlo module per
computation, which is a non-trivial overhead if the model is big.

PiperOrigin-RevId: 262589426
  • Loading branch information
yunxing authored and tensorflower-gardener committed Aug 9, 2019
1 parent 7f70878 commit c02d99f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
35 changes: 22 additions & 13 deletions tensorflow/compiler/xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ bool IsUnstridedSlice(const HloInstruction* hlo) {
// more general case a worklist based approach would be needed.
class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
public:
explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier)
: options_(options), simplifier_(simplifier) {}

Status HandleAdd(HloInstruction* add) override;

Status HandleAnd(HloInstruction* logical_and) override;
Expand Down Expand Up @@ -230,7 +234,7 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {

Status HandleReshape(HloInstruction* reshape) override;

Status HandleReduce(HloInstruction* reduce) override;
Status HandleReduce(HloInstruction* hlo) override;

Status HandleReduceWindow(HloInstruction* reduce_window) override;

Expand All @@ -252,16 +256,11 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
Status HandleMap(HloInstruction* map) override;

// Runs the visitor on a computation.
static bool Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier);
bool Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier);

private:
explicit AlgebraicSimplifierVisitor(HloComputation* computation,
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier)
: computation_(computation), options_(options), simplifier_(simplifier) {}

// Removes degenerate dimension from dot.
StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);

Expand Down Expand Up @@ -391,6 +390,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
// Tries to convert slice(reshape(X)) into reshape(slice(X))
StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);

// Useful when we want to use the same visitor over multiple computations.
void ResetState(HloComputation* computation);

// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
Expand All @@ -409,12 +411,18 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {

} // namespace

void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
changed_ = false;
ResetVisitStates();
computation_ = computation;
}

bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
const AlgebraicSimplifierOptions& options,
AlgebraicSimplifier* simplifier) {
AlgebraicSimplifierVisitor visitor(computation, options, simplifier);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_ || visitor.changed();
ResetState(computation);
TF_CHECK_OK(computation->Accept(this));
return changed_ || changed();
}

bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
Expand Down Expand Up @@ -4045,8 +4053,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
AlgebraicSimplifierVisitor visitor(options_, this);
for (auto* comp : module->MakeNonfusionComputations()) {
if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) {
if (visitor.Run(comp, options_, this)) {
changed = true;
}
}
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/compiler/xla/service/dfs_hlo_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ class DfsHloVisitorBase {

// Useful when we want to visit the same computation more than once with the
// same visitor.
void ResetVisitStates() { visit_state_.clear(); }
void ResetVisitStates() {
// Clear the map, but don't resize the capacity across uses -- Calculating
// and reserving space could be expensive, and we always use the same
// module->instruction_count() as the capacity.
visit_state_.erase(visit_state_.begin(), visit_state_.end());
}

void SetVisitState(int id, VisitState state) { visit_state_[id] = state; }

Expand Down

0 comments on commit c02d99f

Please sign in to comment.