Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ConstantInterval #8179

Merged
merged 22 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d24357f
Make ConstantInterval more of a first-class thing
abadams Apr 3, 2024
1d6b970
Restore bound_correlated_differences calls
abadams Apr 3, 2024
0bb89a0
Elaborate on TODO
abadams Apr 3, 2024
f33d64d
Merge remote-tracking branch 'origin/main' into abadams/refactor_cons…
abadams Apr 3, 2024
64d59ce
Handle some TODOs
abadams Apr 3, 2024
443e486
Fix constant interval mod, clean up constant interval saturating cast
abadams Apr 3, 2024
a73c79b
Improve comment
abadams Apr 3, 2024
4d7e1fd
Avoid unsigned overflow
abadams Apr 4, 2024
af6012c
Fix the most obvious bug in lossless_cast, to make the fuzzer pass more
abadams Apr 4, 2024
85439fe
Skip over pipelines that fail the lossless_cast check
abadams Apr 4, 2024
56874af
Drop iteration count on lossless_cast test
abadams Apr 4, 2024
ed07bed
Add test to CMakeLists.txt
abadams Apr 4, 2024
273f025
Avoid UB in constant_interval test (signed integer overflow of the sc…
abadams Apr 4, 2024
8f132ba
Merge remote-tracking branch 'origin/main' into abadams/refactor_cons…
abadams Apr 4, 2024
a74ab74
Restore accidentally-deleted line from CMakeLists.txt
abadams Apr 4, 2024
2467064
Print on success
abadams Apr 5, 2024
e006eab
Handle Lets in constant_integer_bounds
abadams Apr 5, 2024
a601990
Delete duplicate operator<<
abadams Apr 5, 2024
f2d3927
Just always cast the bounds back to the range of the op type
abadams Apr 19, 2024
46f6f74
Address review comments
abadams Apr 20, 2024
5e448b6
Redo operator<< for ConstantIntervals
abadams Apr 22, 2024
5fcee27
Improve comment; disable buggy code for now
abadams Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ SOURCE_FILES = \
CodeGen_WebGPU_Dev.cpp \
CodeGen_X86.cpp \
CompilerLogger.cpp \
ConstantBounds.cpp \
ConstantInterval.cpp \
CPlusPlusMangle.cpp \
CSE.cpp \
Debug.cpp \
Expand Down Expand Up @@ -671,6 +673,8 @@ HEADER_FILES = \
CompilerLogger.h \
ConciseCasts.h \
CPlusPlusMangle.h \
ConstantBounds.h \
ConstantInterval.h \
CSE.h \
Debug.h \
DebugArguments.h \
Expand Down
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ set(HEADER_FILES
CompilerLogger.h
ConciseCasts.h
CPlusPlusMangle.h
ConstantBounds.h
ConstantInterval.h
CSE.h
Debug.h
DebugArguments.h
Expand Down Expand Up @@ -219,6 +221,8 @@ set(SOURCE_FILES
CodeGen_X86.cpp
CompilerLogger.cpp
CPlusPlusMangle.cpp
ConstantBounds.cpp
ConstantInterval.cpp
CSE.cpp
Debug.cpp
DebugArguments.cpp
Expand Down
170 changes: 170 additions & 0 deletions src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include "ConstantBounds.h"
#include "IR.h"
#include "IROperator.h"
#include "IRPrinter.h"

namespace Halide {
namespace Internal {

namespace {
ConstantInterval bounds_helper(const Expr &e,
Scope<ConstantInterval> &scope,
std::map<Expr, ConstantInterval, ExprCompare> *cache) {
internal_assert(e.defined());

auto recurse = [&](const Expr &e) {
return bounds_helper(e, scope, cache);
};

auto get_infinite_bounds = [&]() {
// Compute the bounds of each IR node from the bounds of its args. Math
// on ConstantInterval is in terms of infinite integers.
if (const UIntImm *op = e.as<UIntImm>()) {
if (Int(64).can_represent(op->value)) {
return ConstantInterval::single_point((int64_t)(op->value));
}
} else if (const IntImm *op = e.as<IntImm>()) {
return ConstantInterval::single_point(op->value);
} else if (const Variable *op = e.as<Variable>()) {
if (const auto *in = scope.find(op->name)) {
return *in;
}
} else if (const Add *op = e.as<Add>()) {
return recurse(op->a) + recurse(op->b);
} else if (const Sub *op = e.as<Sub>()) {
return recurse(op->a) - recurse(op->b);
} else if (const Mul *op = e.as<Mul>()) {
return recurse(op->a) * recurse(op->b);
} else if (const Div *op = e.as<Div>()) {
return recurse(op->a) / recurse(op->b);
} else if (const Mod *op = e.as<Mod>()) {
return recurse(op->a) % recurse(op->b);
} else if (const Min *op = e.as<Min>()) {
return min(recurse(op->a), recurse(op->b));
} else if (const Max *op = e.as<Max>()) {
return max(recurse(op->a), recurse(op->b));
} else if (const Cast *op = e.as<Cast>()) {
return recurse(op->value);
} else if (const Broadcast *op = e.as<Broadcast>()) {
return recurse(op->value);
} else if (const VectorReduce *op = e.as<VectorReduce>()) {
int f = op->value.type().lanes() / op->type.lanes();
ConstantInterval factor(f, f);
ConstantInterval arg_bounds = recurse(op->value);
switch (op->op) {
case VectorReduce::Add:
return arg_bounds * factor;
case VectorReduce::SaturatingAdd:
return saturating_cast(op->type, arg_bounds * factor);
case VectorReduce::Min:
case VectorReduce::Max:
case VectorReduce::And:
case VectorReduce::Or:
return arg_bounds;
default:;
}
} else if (const Shuffle *op = e.as<Shuffle>()) {
ConstantInterval arg_bounds = recurse(op->vectors[0]);
for (size_t i = 1; i < op->vectors.size(); i++) {
arg_bounds.include(recurse(op->vectors[i]));
}
return arg_bounds;
} else if (const Let *op = e.as<Let>()) {
ScopedBinding bind(scope, op->name, recurse(op->value));
return recurse(op->body);
} else if (const Call *op = e.as<Call>()) {
ConstantInterval result;
if (op->is_intrinsic(Call::abs)) {
return abs(recurse(op->args[0]));
} else if (op->is_intrinsic(Call::absd)) {
return abs(recurse(op->args[0]) - recurse(op->args[1]));
} else if (op->is_intrinsic(Call::count_leading_zeros) ||
op->is_intrinsic(Call::count_trailing_zeros)) {
// Conservatively just say it's the potential number of zeros in the type.
return ConstantInterval(0, op->args[0].type().bits());
} else if (op->is_intrinsic(Call::halving_add)) {
return (recurse(op->args[0]) + recurse(op->args[1])) / 2;
} else if (op->is_intrinsic(Call::halving_sub)) {
return (recurse(op->args[0]) - recurse(op->args[1])) / 2;
} else if (op->is_intrinsic(Call::rounding_halving_add)) {
return (recurse(op->args[0]) + recurse(op->args[1]) + 1) / 2;
} else if (op->is_intrinsic(Call::saturating_add)) {
return saturating_cast(op->type,
(recurse(op->args[0]) +
recurse(op->args[1])));
} else if (op->is_intrinsic(Call::saturating_sub)) {
return saturating_cast(op->type,
(recurse(op->args[0]) -
recurse(op->args[1])));
} else if (op->is_intrinsic({Call::widening_add, Call::widen_right_add})) {
return recurse(op->args[0]) + recurse(op->args[1]);
} else if (op->is_intrinsic({Call::widening_sub, Call::widen_right_sub})) {
return recurse(op->args[0]) - recurse(op->args[1]);
} else if (op->is_intrinsic({Call::widening_mul, Call::widen_right_mul})) {
return recurse(op->args[0]) * recurse(op->args[1]);
} else if (op->is_intrinsic({Call::shift_right, Call::widening_shift_right})) {
return recurse(op->args[0]) >> recurse(op->args[1]);
} else if (op->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
return recurse(op->args[0]) << recurse(op->args[1]);
} else if (op->is_intrinsic({Call::rounding_shift_right, Call::rounding_shift_left})) {
ConstantInterval ca = recurse(op->args[0]);
ConstantInterval cb = recurse(op->args[1]);
if (op->is_intrinsic(Call::rounding_shift_left)) {
cb = -cb;
}
ConstantInterval rounding_term = 1 << (cb - 1);
// Note if cb is <= 0, rounding_term is zero.
return (ca + rounding_term) >> cb;
} else if (op->is_intrinsic(Call::mul_shift_right)) {
ConstantInterval ca = recurse(op->args[0]);
ConstantInterval cb = recurse(op->args[1]);
ConstantInterval cq = recurse(op->args[2]);
return (ca * cb) >> cq;
} else if (op->is_intrinsic(Call::rounding_mul_shift_right)) {
ConstantInterval ca = recurse(op->args[0]);
ConstantInterval cb = recurse(op->args[1]);
ConstantInterval cq = recurse(op->args[2]);
ConstantInterval rounding_term = 1 << (cq - 1);
return (ca * cb + rounding_term) >> cq;
}
// If you add a new intrinsic here, also add it to the expression
// generator in test/correctness/lossless_cast.cpp
}

return ConstantInterval::bounds_of_type(e.type());
};

auto get_typed_bounds = [&]() {
return cast(e.type(), get_infinite_bounds());
};

ConstantInterval ret;
if (cache) {
auto [it, cache_miss] = cache->try_emplace(e);
if (cache_miss) {
it->second = get_typed_bounds();
}
ret = it->second;
} else {
ret = get_typed_bounds();
}

internal_assert((!ret.min_defined || e.type().can_represent(ret.min)) &&
(!ret.max_defined || e.type().can_represent(ret.max)))
<< "constant_bounds returned defined bounds that are not representable in "
<< "the type of the Expr passed in.\n Expr: " << e << "\n Bounds: " << ret;

return ret;
}
} // namespace

ConstantInterval constant_integer_bounds(const Expr &e,
const Scope<ConstantInterval> &scope,
std::map<Expr, ConstantInterval, ExprCompare> *cache) {
Scope<ConstantInterval> sub_scope;
sub_scope.set_containing_scope(&scope);
return bounds_helper(e, sub_scope, cache);
}

} // namespace Internal
} // namespace Halide
35 changes: 35 additions & 0 deletions src/ConstantBounds.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef HALIDE_CONSTANT_BOUNDS_H
#define HALIDE_CONSTANT_BOUNDS_H

#include "ConstantInterval.h"
#include "Expr.h"
#include "Scope.h"

/** \file
* Methods for computing compile-time constant int64_t upper and lower bounds of
* an expression. Cheaper than symbolic bounds inference, and useful for things
* like instruction selection.
*/

namespace Halide {
namespace Internal {

/** Deduce constant integer bounds on an expression. This can be useful to
* decide if, for example, the expression can be cast to another type, be
* negated, be incremented, etc without risking overflow.
*
* Also optionally accepts a scope containing the integer bounds of any
* variables that may be referenced, and a cache of constant integer bounds on
* known Exprs, which this function will update. The cache is helpful to
* short-circuit large numbers of redundant queries, but it should not be used
* in contexts where the same Expr object may take on different values within a
* single Expr (i.e. before uniquify_variable_names).
*/
ConstantInterval constant_integer_bounds(const Expr &e,
const Scope<ConstantInterval> &scope = Scope<ConstantInterval>::empty_scope(),
std::map<Expr, ConstantInterval, ExprCompare> *cache = nullptr);

} // namespace Internal
} // namespace Halide

#endif
Loading
Loading