Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid redundant scope lookups #8103

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,12 @@ class Bounds : public IRVisitor {

if (const_bound) {
bounds_of_type(op->type);
if (scope.contains(op->name)) {
const Interval &scope_interval = scope.get(op->name);
if (scope_interval.has_upper_bound() && is_const(scope_interval.max)) {
interval.max = Interval::make_min(interval.max, scope_interval.max);
if (const Interval *scope_interval = scope.find(op->name)) {
if (scope_interval->has_upper_bound() && is_const(scope_interval->max)) {
interval.max = Interval::make_min(interval.max, scope_interval->max);
}
if (scope_interval.has_lower_bound() && is_const(scope_interval.min)) {
interval.min = Interval::make_max(interval.min, scope_interval.min);
if (scope_interval->has_lower_bound() && is_const(scope_interval->min)) {
interval.min = Interval::make_max(interval.min, scope_interval->min);
}
}

Expand All @@ -429,8 +428,8 @@ class Bounds : public IRVisitor {
}
}
} else {
if (scope.contains(op->name)) {
interval = scope.get(op->name);
if (const Interval *in = scope.find(op->name)) {
interval = *in;
} else if (op->type.is_vector()) {
// Uh oh, we need to take the min/max lane of some unknown vector. Treat as unbounded.
bounds_of_type(op->type);
Expand Down Expand Up @@ -2054,11 +2053,10 @@ class FindInnermostVar : public IRVisitor {
int innermost_depth = -1;

void visit(const Variable *op) override {
if (vars_depth.contains(op->name)) {
int depth = vars_depth.get(op->name);
if (depth > innermost_depth) {
if (const int *depth = vars_depth.find(op->name)) {
if (*depth > innermost_depth) {
innermost_var = op->name;
innermost_depth = depth;
innermost_depth = *depth;
}
}
}
Expand Down Expand Up @@ -2545,16 +2543,17 @@ class BoxesTouched : public IRGraphVisitor {
// If this let stmt is a redefinition of a previous one, we should
// remove the old let stmt from the 'children' map since it is
// no longer valid at this point.
if ((f.vi.instance > 0) && let_stmts.contains(op->name)) {
const Expr &val = let_stmts.get(op->name);
CollectVars collect(op->name);
val.accept(&collect);
f.old_let_vars = collect.vars;

VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1);
for (const auto &v : f.old_let_vars) {
internal_assert(vars_renaming.count(v));
children[get_var_instance(v)].erase(old_vi);
if (f.vi.instance > 0) {
if (const Expr *val = let_stmts.find(op->name)) {
CollectVars collect(op->name);
val->accept(&collect);
f.old_let_vars = collect.vars;

VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1);
for (const auto &v : f.old_let_vars) {
internal_assert(vars_renaming.count(v));
children[get_var_instance(v)].erase(old_vi);
}
}
}
let_stmts.push(op->name, op->value);
Expand Down Expand Up @@ -2756,17 +2755,17 @@ class BoxesTouched : public IRGraphVisitor {
expr_uses_var(box[i].min, l.min_name))) ||
(box[i].has_upper_bound() && (expr_uses_var(box[i].max, l.max_name) ||
expr_uses_var(box[i].max, l.min_name)))) {
internal_assert(let_stmts.contains(l.var));
const Expr &val = let_stmts.get(l.var);
v_bound = bounds_of_expr_in_scope(val, scope, func_bounds);
const Expr *val = let_stmts.find(l.var);
internal_assert(val);
v_bound = bounds_of_expr_in_scope(*val, scope, func_bounds);
bool fixed = v_bound.min.same_as(v_bound.max);
v_bound.min = simplify(v_bound.min);
v_bound.max = fixed ? v_bound.min : simplify(v_bound.max);

internal_assert(scope.contains(l.var));
const Interval &old_bound = scope.get(l.var);
v_bound.max = simplify(min(v_bound.max, old_bound.max));
v_bound.min = simplify(max(v_bound.min, old_bound.min));
const Interval *old_bound = scope.find(l.var);
internal_assert(old_bound);
v_bound.max = simplify(min(v_bound.max, old_bound->max));
v_bound.min = simplify(max(v_bound.min, old_bound->min));
}

if (box[i].has_lower_bound()) {
Expand Down Expand Up @@ -3017,14 +3016,14 @@ class BoxesTouched : public IRGraphVisitor {
}

Expr min_val, max_val;
if (scope.contains(op->name + ".loop_min")) {
min_val = scope.get(op->name + ".loop_min").min;
if (const Interval *in = scope.find(op->name + ".loop_min")) {
min_val = in->min;
} else {
min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min;
}

if (scope.contains(op->name + ".loop_max")) {
max_val = scope.get(op->name + ".loop_max").max;
if (const Interval *in = scope.find(op->name + ".loop_max")) {
max_val = in->max;
} else {
max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max;
max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max;
Expand Down
4 changes: 2 additions & 2 deletions src/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ class RemoveLets : public IRGraphMutator {
Scope<Expr> scope;

Expr visit(const Variable *op) override {
if (scope.contains(op->name)) {
return scope.get(op->name);
if (const Expr *e = scope.find(op->name)) {
return *e;
} else {
return op;
}
Expand Down
6 changes: 4 additions & 2 deletions src/ClampUnsafeAccesses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ struct ClampUnsafeAccesses : IRMutator {
}

Expr visit(const Variable *var) override {
if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) {
let_var_inside_indexing.ref(var->name) = true;
if (is_inside_indexing) {
if (bool *b = let_var_inside_indexing.shallow_find(var->name)) {
*b = true;
}
}
return var;
}
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ class SubstituteInStridedLoads : public IRMutator {
Expr visit(const Shuffle *op) override {
int stride = op->slice_stride();
const Variable *var = op->vectors[0].as<Variable>();
const Expr *vec = nullptr;
if (var &&
poisoned_vars.count(var->name) == 0 &&
op->vectors.size() == 1 &&
2 <= stride && stride <= 4 &&
op->slice_begin() < stride &&
loads.contains(var->name)) {
return Shuffle::make_slice({loads.get(var->name)}, op->slice_begin(), op->slice_stride(), op->type.lanes());
(vec = loads.find(var->name))) {
return Shuffle::make_slice({*vec}, op->slice_begin(), op->slice_stride(), op->type.lanes());
} else {
return IRMutator::visit(op);
}
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1936,8 +1936,9 @@ void CodeGen_C::visit(const Load *op) {
user_assert(is_const_one(op->predicate)) << "Predicated scalar load is not supported by C backend.\n";

string id_index = print_expr(op->index);
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type.element_of() == t.element_of());
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc &&
alloc->type.element_of() == t.element_of());
if (type_cast_needed) {
const char *const_flag = output_kind == CPlusPlusImplementation ? " const" : "";
rhs << "((" << print_type(t.element_of()) << const_flag << " *)" << name << ")";
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
string id_index = print_expr(op->index);

// Get the rhs just for the cache.
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == op->type);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc &&
alloc->type == op->type);

ostringstream rhs;
if (type_cast_needed) {
Expand Down
11 changes: 5 additions & 6 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class SloppyUnpredicateLoadsAndStores : public IRMutator {
}
}
} else if (const Variable *op = e.as<Variable>()) {
if (monotonic_vectors.contains(op->name)) {
return monotonic_vectors.get(op->name);
if (const auto *p = monotonic_vectors.find(op->name)) {
return *p;
}
} else if (const Let *op = e.as<Let>()) {
auto v = get_extreme_lanes(op->value);
Expand Down Expand Up @@ -2245,10 +2245,9 @@ void CodeGen_Hexagon::visit(const Allocate *alloc) {
codegen(alloc->body);

// If there was no early free, free it now.
if (allocations.contains(alloc->name)) {
Allocation alloc_obj = allocations.get(alloc->name);
internal_assert(alloc_obj.destructor);
trigger_destructor(alloc_obj.destructor_function, alloc_obj.destructor);
if (const Allocation *alloc_obj = allocations.find(alloc->name)) {
internal_assert(alloc_obj->destructor);
trigger_destructor(alloc_obj->destructor_function, alloc_obj->destructor);

allocations.pop(alloc->name);
sym_pop(alloc->name);
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,8 @@ void CodeGen_LLVM::sym_pop(const string &name) {

llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const {
// look in the symbol table
if (!symbol_table.contains(name)) {
llvm::Value *const *v = symbol_table.find(name);
if (!v) {
if (must_succeed) {
std::ostringstream err;
err << "Symbol not found: " << name << "\n";
Expand All @@ -1283,7 +1284,7 @@ llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const
return nullptr;
}
}
return symbol_table.get(name);
return *v;
}

bool CodeGen_LLVM::sym_exists(const string &name) const {
Expand Down
9 changes: 5 additions & 4 deletions src/CodeGen_Metal_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,9 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Load *op) {
string id_index = print_expr(op->index);

// Get the rhs just for the cache.
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == op->type);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc &&
alloc->type == op->type);
ostringstream rhs;
if (type_cast_needed) {
rhs << "((" << get_memory_space(op->name) << " "
Expand Down Expand Up @@ -467,8 +468,8 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Store *op) {
<< id_value << "[" << i << "];\n";
}
} else {
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == t);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc && alloc->type == t);

string id_index = print_expr(op->index);
stream << get_indent();
Expand Down
8 changes: 4 additions & 4 deletions src/CodeGen_OpenCL_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_array_access(const string &na
const Type &type,
const string &id_index) {
ostringstream rhs;
bool type_cast_needed = !(allocations.contains(name) &&
allocations.get(name).type == type);
const auto *alloc = allocations.find(name);
bool type_cast_needed = !(alloc && alloc->type == type);

if (type_cast_needed) {
rhs << "((" << get_memory_space(name) << " "
Expand Down Expand Up @@ -583,8 +583,8 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Store *op) {
// For atomicAdd, we check if op->value - store[index] is independent of store.
// The atomicAdd operations in OpenCL only supports integers so we also check that.
bool is_atomic_add = t.is_int_or_uint() && !expr_uses_var(delta, op->name);
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == t);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc && alloc->type == t);
auto print_store_var = [&]() {
if (type_cast_needed) {
stream << "(("
Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Posix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ void CodeGen_Posix::free_allocation(const std::string &name) {
}

string CodeGen_Posix::get_allocation_name(const std::string &n) {
if (allocations.contains(n)) {
return allocations.get(n).name;
if (const auto *alloc = allocations.find(n)) {
return alloc->name;
} else {
return n;
}
Expand Down
28 changes: 15 additions & 13 deletions src/CodeGen_Vulkan_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,10 +1539,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Load *op) {
user_assert(is_const_one(op->predicate)) << "Predicated loads not supported by SPIR-V codegen\n";

// Construct the pointer to read from
internal_assert(symbol_table.contains(op->name));
SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name);
SpvId variable_id = id_and_storage_class.first;
SpvStorageClass storage_class = id_and_storage_class.second;
const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name);
internal_assert(id_and_storage_class);
SpvId variable_id = id_and_storage_class->first;
SpvStorageClass storage_class = id_and_storage_class->second;
internal_assert(variable_id != SpvInvalidId);
internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax));

Expand Down Expand Up @@ -1576,10 +1576,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Store *op) {
op->value.accept(this);
SpvId value_id = builder.current_id();

internal_assert(symbol_table.contains(op->name));
SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name);
SpvId variable_id = id_and_storage_class.first;
SpvStorageClass storage_class = id_and_storage_class.second;
const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name);
internal_assert(id_and_storage_class);
SpvId variable_id = id_and_storage_class->first;
SpvStorageClass storage_class = id_and_storage_class->second;
internal_assert(variable_id != SpvInvalidId);
internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax));

Expand Down Expand Up @@ -1665,9 +1665,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) {
const std::string intrinsic_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + intrinsic.first;

// Intrinsics are inserted when adding the kernel
internal_assert(symbol_table.contains(intrinsic_var_name));
SpvId intrinsic_id = symbol_table.get(intrinsic_var_name).first;
SpvStorageClass storage_class = symbol_table.get(intrinsic_var_name).second;
const auto *intrin = symbol_table.find(intrinsic_var_name);
internal_assert(intrin);
SpvId intrinsic_id = intrin->first;
SpvStorageClass storage_class = intrin->second;

// extract and cast to the extent type (which is what's expected by Halide's for loops)
Type unsigned_type = UInt(32);
Expand Down Expand Up @@ -1908,8 +1909,9 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Allocate *op) {

void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Free *op) {
debug(3) << "Vulkan: Popping allocation called " << op->name << " off the symbol table\n";
internal_assert(symbol_table.contains(op->name));
SpvId variable_id = symbol_table.get(op->name).first;
const auto *id = symbol_table.find(op->name);
internal_assert(id);
SpvId variable_id = id->first;
storage_access_map.erase(variable_id);
symbol_table.pop(op->name);
}
Expand Down
8 changes: 4 additions & 4 deletions src/CodeGen_WebGPU_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Load *op) {

// Get the allocation type, which may be different from the result type.
Type alloc_type = result_type;
if (allocations.contains(op->name)) {
alloc_type = allocations.get(op->name).type;
if (const auto *alloc = allocations.find(op->name)) {
alloc_type = alloc->type;
} else if (workgroup_allocations.count(op->name)) {
alloc_type = workgroup_allocations.at(op->name)->type;
}
Expand Down Expand Up @@ -826,8 +826,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Store *op) {

// Get the allocation type, which may be different from the value type.
Type alloc_type = value_type;
if (allocations.contains(op->name)) {
alloc_type = allocations.get(op->name).type;
if (const auto *alloc = allocations.find(op->name)) {
alloc_type = alloc->type;
} else if (workgroup_allocations.count(op->name)) {
alloc_type = workgroup_allocations.at(op->name)->type;
}
Expand Down
Loading
Loading