-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
So for context, I'm implementing LLaMA 3 using Halide + Python bindings. I have a numpy reference implementation that I use to compare with for correctness. This is hugely convenient as I can directly convert np.ndarray into halide Buffers and realize to compare.
To schedule, I compile to conceptual statement and iterate by recompiling after each schedule change. However, once I used the official llama model dimensions for my Buffers (i.e. matrices of 3072x2048) my compile times exploded. Interestingly, the slow compile times were only for the conceptual statement. How slow? I'm not sure. The compiler hasn't finished before I terminate the process, but easily over 10 minutes. The regular JIT compiler was slower as well but tolerable in absolute time.
The Culprit
I've tracked the compile-time problem to this loop in src/StmtToHTML.cpp on line 563
while (getline(asm_stream, line)) {
// Try all markers
std::vector<uint64_t> matched_nodes;
for (auto const &[node, marker] : markers) {
if (std::regex_search(line, marker)) {
// Save line number
lnos[node] = lno;
// Save this node's id
matched_nodes.push_back(node);
}
}
// We map to the first match, stop
// checking matched nodes
for (auto const &node : matched_nodes) {
markers.erase(node);
}
lno++;
}The issues is the combination of several factors:
- The embedded buffers are huge (GB of total data)
- The buffers are embedded into the output textual assembly as
.asciisections (asm_streamin the source) - the inner
regex_searchloop parses through GBs of data and appears to hang. - There is a filter for these large .ascii sections (show below) so they don't appear in the browser however, the filter is applied too late in the pipeline to impact the processing.
for (std::string line; std::getline(ss, line);) {
if (line.length() > 500) {
// Very long lines in the assembly are typically the _gpu_kernel_sources
// as a raw ASCII block in the assembly. Let's chop that off to make
// browsers faster when dealing with this.
line = line.substr(0, 100) + "\" # omitted the remainder of the ASCII buffer";
}
stream << html_code_printer.escape_html(line) << "\n";
}Finally, there are two other downstream performance bottlenecks. The first is in file src/LLVM_Output.cpp on line 364
// Work on a copy of the module to avoid modifying the original.
std::unique_ptr<llvm::Module> module = clone_module(module_in);This appears to make a deep copy of the buffers causing considerable slowdown when the buffers are huge. Secondly, LLVM itself is impacted by the size of the buffers although the exact issue is unclear given that the built in timer for the old pass manager doesn't show any slowdown. It's unclear why this compilation is so slow compared to a regular JIT compilation. Something about the way the passes are configured for conceptual statement causes a considerable slowdown.
Reproduction
- head of main branch of GitHub as of the time of this writing.
- MacOS M1
- Python 13
Discussion of Solutions
One could argue that users should never run to this problem and use ImageParam instead of Buffer, however, this is still semantically valid Halide that is virtually uncompilable. One could simply bump the line length filter to earlier in the pipeline. That does not address the followup performance issues of the IR deep copy + LLVM compile times. However, because these contribute far less to the slowdown, it might make sense to simply avoid the worst part (i.e. getline + regex_search + giant lines) and warn users that they may experience slow compile times.
Lastly, there is the question of how to steer users away from this performance cliff altogether.