Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions water/include/water/Analysis/InUseForSpeculation.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ class InUseAnalysis

/// Visitation function along branching and region control flow.
void visitBranchOperand(OpOperand &opOperand) override;

/// Visit the non-forwarded arguments of a region, such as the
/// induction variables of a loop.
void
visitNonControlFlowArguments(RegionSuccessor & /*successor*/,
ArrayRef<BlockArgument> /*arguments*/) override {
// nothing
}
};

} // namespace mlir::water
31 changes: 29 additions & 2 deletions water/lib/Dialect/Wave/Transforms/InferTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ class InferTypeBackwardAnalysis
propagateIfChanged(lattice,
lattice->join(InferTypeLatticeStorage(tensorType)));
}

// Visit the non-forwarded arguments of a region, such as the
// induction variables of a loop.
void
visitNonControlFlowArguments(RegionSuccessor & /*successor*/,
ArrayRef<BlockArgument> /*arguments*/) override {
// This is called for induction variables of an IterateOp, which is handled
// by the forward analysis.
}
};
} // namespace

Expand Down Expand Up @@ -979,6 +988,15 @@ class ElementsPerThreadBackwardAnalysis
}
return llvm::success();
}

// Visit the non-forwarded arguments of a region, such as the
// induction variables of a loop.
void
visitNonControlFlowArguments(RegionSuccessor & /*successor*/,
ArrayRef<BlockArgument> /*arguments*/) override {
// This is called for induction variables of an IterateOp, which is handled
// by the forward analysis.
}
};

// Elements-per-thread propagation pass implementation.
Expand Down Expand Up @@ -1237,7 +1255,7 @@ class IndexExprsForwardAnalysis
LDBG() << " result #" << i << ": " << *result;
}
});
auto scope = llvm::make_scope_exit([&] {
llvm::scope_exit scope([&] {
LLVM_DEBUG({
LDBG() << " updated result lattices:";
for (auto [i, result] : llvm::enumerate(results)) {
Expand Down Expand Up @@ -1572,7 +1590,7 @@ class IndexExprsBackwardAnalysis
LDBG() << " result #" << i << ": " << *result;
}
});
auto scope = llvm::make_scope_exit([&] {
llvm::scope_exit scope([&] {
LLVM_DEBUG({
LDBG() << " updated operand lattices:";
for (auto [i, operand] : llvm::enumerate(operands)) {
Expand Down Expand Up @@ -1628,6 +1646,15 @@ class IndexExprsBackwardAnalysis
return llvm::success();
}

// Visit the non-forwarded arguments of a region, such as the
// induction variables of a loop.
void
visitNonControlFlowArguments(RegionSuccessor & /*successor*/,
ArrayRef<BlockArgument> /*arguments*/) override {
// This is called for induction variables of an IterateOp, which is handled
// by the forward analysis.
}

private:
bool initialized = false;
wave::OverrideInitializationFn overrideInitialization;
Expand Down
2 changes: 1 addition & 1 deletion water/llvm-sha.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
80314817b55facda688e9583c74abca9a6c6a49f
fdc393f9ee12c4069f9f04374e62f5c36a617298
14 changes: 7 additions & 7 deletions wave_lang/kernel/wave/mlir_converter/mlir_to_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,33 @@ def _convert_affine_expr_to_sympy_expr(
Raises:
ValueError: If the expression is not supported.
"""
if ir.AffineConstantExpr.isinstance(expr):
if isinstance(expr, ir.AffineConstantExpr):
return sympy.Integer(ir.AffineConstantExpr(expr).value)
if ir.AffineSymbolExpr.isinstance(expr):
if isinstance(expr, ir.AffineSymbolExpr):
return symbol_mapping[ir.AffineSymbolExpr(expr).position]
if ir.AffineAddExpr.isinstance(expr):
if isinstance(expr, ir.AffineAddExpr):
add_expr = ir.AffineAddExpr(expr)
return _convert_affine_expr_to_sympy_expr(
add_expr.lhs, symbol_mapping
) + _convert_affine_expr_to_sympy_expr(add_expr.rhs, symbol_mapping)
if ir.AffineMulExpr.isinstance(expr):
if isinstance(expr, ir.AffineMulExpr):
mul_expr = ir.AffineMulExpr(expr)
return _convert_affine_expr_to_sympy_expr(
mul_expr.lhs, symbol_mapping
) * _convert_affine_expr_to_sympy_expr(mul_expr.rhs, symbol_mapping)
if ir.AffineFloorDivExpr.isinstance(expr):
if isinstance(expr, ir.AffineFloorDivExpr):
floor_div_expr = ir.AffineFloorDivExpr(expr)
return sympy.floor(
_convert_affine_expr_to_sympy_expr(floor_div_expr.lhs, symbol_mapping)
/ _convert_affine_expr_to_sympy_expr(floor_div_expr.rhs, symbol_mapping)
)
if ir.AffineCeilDivExpr.isinstance(expr):
if isinstance(expr, ir.AffineCeilDivExpr):
ceil_div_expr = ir.AffineCeilDivExpr(expr)
return sympy.ceiling(
_convert_affine_expr_to_sympy_expr(ceil_div_expr.lhs, symbol_mapping)
/ _convert_affine_expr_to_sympy_expr(ceil_div_expr.rhs, symbol_mapping)
)
if ir.AffineModExpr.isinstance(expr):
if isinstance(expr, ir.AffineModExpr):
mod_expr = ir.AffineModExpr(expr)
return _convert_affine_expr_to_sympy_expr(
mod_expr.lhs, symbol_mapping
Expand Down
Loading