-
Notifications
You must be signed in to change notification settings - Fork 29
pin memory bug fix #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pin memory bug fix #141
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||
| from typing import Dict | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import platform | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda"): | ||||||||||||||||||||||||||||||||||||||||
| module = module.to("cpu") | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -26,13 +26,14 @@ def _forward_pre_hook(module: nn.Module, input_): | |||||||||||||||||||||||||||||||||||||||
| for name, buffer in module.named_buffers(recurse=recurse): | ||||||||||||||||||||||||||||||||||||||||
| buffer.data = buffer.data.to(device=device) | ||||||||||||||||||||||||||||||||||||||||
| return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input_) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| for name, param in module.named_parameters(recurse=recurse): | ||||||||||||||||||||||||||||||||||||||||
| param.data = param.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| for name, param in module.named_parameters(recurse=recurse): | ||||||||||||||||||||||||||||||||||||||||
| if platform.system() == 'Linux': | ||||||||||||||||||||||||||||||||||||||||
| param.data = param.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| offload_param_dict[name] = param.data | ||||||||||||||||||||||||||||||||||||||||
| param.data = param.data.to(device=device) | ||||||||||||||||||||||||||||||||||||||||
| for name, buffer in module.named_buffers(recurse=recurse): | ||||||||||||||||||||||||||||||||||||||||
| buffer.data = buffer.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| if platform.system() == 'Linux': | ||||||||||||||||||||||||||||||||||||||||
| buffer.data = buffer.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| offload_param_dict[name] = buffer.data | ||||||||||||||||||||||||||||||||||||||||
| buffer.data = buffer.data.to(device=device) | ||||||||||||||||||||||||||||||||||||||||
| setattr(module, "_offload_param_dict", offload_param_dict) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -58,10 +59,12 @@ def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]: | |||||||||||||||||||||||||||||||||||||||
| module = module.to("cpu") | ||||||||||||||||||||||||||||||||||||||||
| offload_param_dict = {} | ||||||||||||||||||||||||||||||||||||||||
| for name, param in module.named_parameters(recurse=True): | ||||||||||||||||||||||||||||||||||||||||
| param.data = param.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| if platform.system() == 'Linux': | ||||||||||||||||||||||||||||||||||||||||
| param.data = param.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| offload_param_dict[name] = param.data | ||||||||||||||||||||||||||||||||||||||||
| for name, buffer in module.named_buffers(recurse=True): | ||||||||||||||||||||||||||||||||||||||||
| buffer.data = buffer.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| if platform.system() == 'Linux': | ||||||||||||||||||||||||||||||||||||||||
| buffer.data = buffer.data.pin_memory() | ||||||||||||||||||||||||||||||||||||||||
| offload_param_dict[name] = buffer.data | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
61
to
68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, Also, there are some minor style issues: there are trailing whitespaces and two spaces around
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| return offload_param_dict | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
platform.system()inside loops is inefficient as the result will not change during execution. It's better to call it once before the loops and store the result in a variable. This will improve performance by avoiding redundant system calls.Additionally, there are some minor style issues: there are trailing whitespaces and two spaces around
==instead of one (platform.system() == 'Linux').