-
Notifications
You must be signed in to change notification settings - Fork 18
[Algs] Add new inplace sparse communication algorithm #1510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @tdavidcl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new in-place sparse communication algorithm into the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new in-place sparse communication algorithm, which is a significant piece of functionality. The implementation includes building a communication table and managing MPI requests. The changes are well-structured and include comprehensive tests. I've found a few issues, including two critical bugs in buffer size validation that could lead to memory corruption, a couple of medium-severity logic and API design issues in the request handling and communication table construction. My detailed comments and suggestions for fixes are provided below.
| if (comm_table.send_total_size < bytebuffer_send.get_size()) { | ||
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | ||
| "The send total size is greater than the send buffer size\n" | ||
| " send_total_size = {}, send_buffer_size = {}", | ||
| comm_table.send_total_size, | ||
| bytebuffer_send.get_size())); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer size check is incorrect. It should throw an exception if comm_table.send_total_size is greater than bytebuffer_send.get_size(). The current check comm_table.send_total_size < bytebuffer_send.get_size() is for a valid condition where the buffer is larger than needed. The error message correctly describes the condition that should be checked (is greater than), but the if condition is the opposite. This is a critical bug that can lead to buffer overflows.
| if (comm_table.send_total_size < bytebuffer_send.get_size()) { | |
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | |
| "The send total size is greater than the send buffer size\n" | |
| " send_total_size = {}, send_buffer_size = {}", | |
| comm_table.send_total_size, | |
| bytebuffer_send.get_size())); | |
| } | |
| if (comm_table.send_total_size > bytebuffer_send.get_size()) { | |
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | |
| "The send total size is greater than the send buffer size\n" | |
| " send_total_size = {}, send_buffer_size = {}", | |
| comm_table.send_total_size, | |
| bytebuffer_send.get_size())); | |
| } |
| if (comm_table.recv_total_size < bytebuffer_recv.get_size()) { | ||
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | ||
| "The recv total size is greater than the recv buffer size\n" | ||
| " recv_total_size = {}, recv_buffer_size = {}", | ||
| comm_table.recv_total_size, | ||
| bytebuffer_recv.get_size())); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the send buffer check, this size check for the receive buffer is incorrect. It should throw if comm_table.recv_total_size is greater than bytebuffer_recv.get_size(). The current condition is the opposite of what's needed to prevent a buffer overflow when receiving data.
| if (comm_table.recv_total_size < bytebuffer_recv.get_size()) { | |
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | |
| "The recv total size is greater than the recv buffer size\n" | |
| " recv_total_size = {}, recv_buffer_size = {}", | |
| comm_table.recv_total_size, | |
| bytebuffer_recv.get_size())); | |
| } | |
| if (comm_table.recv_total_size > bytebuffer_recv.get_size()) { | |
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | |
| "The recv total size is greater than the recv buffer size\n" | |
| " recv_total_size = {}, recv_buffer_size = {}", | |
| comm_table.recv_total_size, | |
| bytebuffer_recv.get_size())); | |
| } |
| shamcomm::mpi::Waitall( | ||
| shambase::narrow_or_throw<i32>(rqs.size()), rqs.data(), st_lst.data()); | ||
| ready_count = rqs.size(); | ||
| is_ready.assign(rqs.size(), true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After shamcomm::mpi::Waitall completes, all MPI requests are finished. The internal state of RequestList (ready_count and is_ready) should be updated to reflect this. By removing these lines, other member functions like remain_count_no_test() or all_ready() will return incorrect information if called after wait_all(). This makes the class state inconsistent and can lead to subtle bugs if this class is reused elsewhere. Please restore this state update.
| void spin_lock_partial_wait(size_t max_in_flight, f64 timeout, f64 print_freq) { | ||
|
|
||
| if (rqs.size() < max_in_flight) { | ||
| return; | ||
| } | ||
|
|
||
| f64 last_print_time = 0; | ||
| size_t in_flight = remain_count(); | ||
|
|
||
| if (in_flight < max_in_flight) { | ||
| return; | ||
| } | ||
|
|
||
| shambase::Timer twait; | ||
| twait.start(); | ||
| do { | ||
| twait.end(); | ||
| if (twait.elasped_sec() > timeout) { | ||
| report_timeout(); | ||
| } | ||
|
|
||
| if (twait.elasped_sec() - last_print_time > print_freq) { | ||
| logger::warn_ln( | ||
| "SparseComm", | ||
| "too many messages in flight :", | ||
| in_flight, | ||
| "/", | ||
| max_in_flight); | ||
| last_print_time = twait.elasped_sec(); | ||
| } | ||
| in_flight = remain_count(); | ||
| } while (in_flight >= max_in_flight); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in this spin lock can be simplified and made more correct. The current implementation uses a stale value of in_flight for logging because it's updated at the end of the loop. This can be misleading during debugging. Additionally, there are some redundant checks before the loop. A while loop would be cleaner, more efficient, and ensure the logged information is always up-to-date.
void spin_lock_partial_wait(size_t max_in_flight, f64 timeout, f64 print_freq) {
if (rqs.size() < max_in_flight) {
return;
}
shambase::Timer twait;
twait.start();
f64 last_print_time = 0;
size_t in_flight;
while ((in_flight = remain_count()) >= max_in_flight) {
twait.end();
if (twait.elasped_sec() > timeout) {
report_timeout();
}
if (twait.elasped_sec() - last_print_time > print_freq) {
logger::warn_ln(
"SparseComm",
"too many messages in flight :",
in_flight,
"/",
max_in_flight);
last_print_time = twait.elasped_sec();
}
}
}| // the sender shoudl have set the offset for all messages, otherwise throw | ||
| auto expected_offset = shambase::get_check_ref( | ||
| messages_send.at(send_idx).message_bytebuf_offset_send); | ||
|
|
||
| // check that the send offset match for good measure | ||
| if (message_info.message_bytebuf_offset_send != expected_offset) { | ||
| throw shambase::make_except_with_loc<std::invalid_argument>(shambase::format( | ||
| "The sender has not set the offset for all messages, otherwise throw\n" | ||
| " expected_offset = {}, actual_offset = {}", | ||
| expected_offset, | ||
| message_info.message_bytebuf_offset_send.value())); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check verifies that the message_bytebuf_offset_send provided in the input messages_send matches the offset calculated within this function. This makes the API fragile, as it requires the caller to pre-calculate and provide correct send offsets. A more robust design would be for this function to be solely responsible for calculating the offsets, ignoring any that might be provided in the input. This would simplify the caller's responsibility to only providing message sizes, senders, and receivers. Consider removing this check and making the function calculate the offsets authoritatively.
Additionally, there's a typo "shoudl" in the comment on line 143, and the error message on line 150 is unclear.
|
Thanks @tdavidcl for opening this PR! You can do multiple things directly here: Once the workflow completes a message will appear displaying informations related to the run. Also the PR gets automatically reviewed by gemini, you can: |
Workflow reportworkflow report corresponding to commit 1174f0a Pre-commit check reportPre-commit check: ✅ Test pipeline can run. Clang-tidy diff reportNo relevant changes found. You should now go back to your normal life and enjoy a hopefully sunny day while waiting for the review. Doxygen diff with
|
extracted from #1484