|
| 1 | +import torch |
1 | 2 | import torch.nn as nn |
2 | 3 |
|
3 | | -from diffsynth_engine.models.basic.transformer_helper import RMSNorm |
4 | | -from diffsynth_engine.models.basic.relative_position_emb import RelativePositionEmbedding |
5 | | - |
6 | | - |
7 | | -SUPPORTED_OFFLOAD_MODULES = ( |
8 | | - nn.Embedding, |
9 | | - nn.Linear, |
10 | | - nn.LayerNorm, |
11 | | - nn.Conv2d, |
12 | | - nn.GroupNorm, |
13 | | - RMSNorm, |
14 | | - RelativePositionEmbedding, |
15 | | -) |
16 | | - |
17 | 4 |
|
18 | 5 | def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda:0"): |
19 | | - if isinstance(module, SUPPORTED_OFFLOAD_MODULES): |
20 | | - add_cpu_offload_hook(module, device) |
| 6 | + if len(list(module.children())) == 0: |
| 7 | + if len(list(module.parameters())) > 0: # leaf module with parameters |
| 8 | + add_cpu_offload_hook(module, device) |
21 | 9 | return |
| 10 | + if len(list(module.parameters(recurse=False))) > 0: # module with direct parameters |
| 11 | + add_cpu_offload_hook(module, device, recurse=False) |
22 | 12 | for submodule in module.children(): |
23 | 13 | enable_sequential_cpu_offload(submodule, device) |
24 | 14 |
|
25 | 15 |
|
26 | | -def add_cpu_offload_hook(module: nn.Module, device: str = "cuda:0"): |
| 16 | +# TODO: supports module buffer |
| 17 | +def add_cpu_offload_hook(module: nn.Module, device: str = "cuda:0", recurse: bool = True): |
27 | 18 | def _forward_pre_hook(module: nn.Module, input): |
28 | 19 | offload_params = {} |
29 | | - for name, param in module.named_parameters(): |
| 20 | + for name, param in module.named_parameters(recurse=recurse): |
30 | 21 | offload_params[name] = param.data |
31 | 22 | param.data = param.data.to(device=device) |
32 | 23 | setattr(module, "_offload_params", offload_params) |
| 24 | + return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input) |
33 | 25 |
|
34 | 26 | def _forward_hook(module: nn.Module, input, output): |
35 | 27 | offload_params = getattr(module, "_offload_params", {}) |
36 | | - for name, param in module.named_parameters(): |
| 28 | + for name, param in module.named_parameters(recurse=recurse): |
37 | 29 | if name in offload_params: |
38 | 30 | param.data = offload_params[name] |
39 | 31 |
|
40 | | - if getattr(module, "_sequential_cpu_offload_enabled", False): |
| 32 | + if getattr(module, "_cpu_offload_enabled", False): |
41 | 33 | return |
42 | 34 | module.register_forward_pre_hook(_forward_pre_hook) |
43 | 35 | module.register_forward_hook(_forward_hook) |
44 | | - setattr(module, "_sequential_cpu_offload_enabled", True) |
| 36 | + setattr(module, "_cpu_offload_enabled", True) |
0 commit comments