Commit e42682b
authored
Reduce Peak WAN inference VRAM usage (#9898)
* flux: Do the xq and xk ropes one at a time
This was doing independendent interleaved tensor math on the q and k
tensors, leading to the holding of more than the minimum intermediates
in VRAM. On a bad day, it would VRAM OOM on xk intermediates.
Do everything q and then everything k, so torch can garbage collect
all of qs intermediates before k allocates its intermediates.
This reduces peak VRAM usage for some WAN2.2 inferences (at least).
* wan: Optimize qkv intermediates on attention
As commented. The former logic computed independent pieces of QKV in
parallel which help more inference intermediates in VRAM spiking
VRAM usage. Fully roping Q and garbage collecting the intermediates
before touching K reduces the peak inference VRAM usage.1 parent a39ac59 commit e42682b
2 files changed
+17
-14
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
38 | 42 | | |
39 | 43 | | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
| 44 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
| 11 | + | |
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
63 | | - | |
64 | | - | |
| 63 | + | |
65 | 64 | | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
66 | 68 | | |
67 | | - | |
68 | | - | |
| 69 | + | |
69 | 70 | | |
70 | | - | |
71 | | - | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
72 | 76 | | |
73 | 77 | | |
74 | 78 | | |
75 | 79 | | |
76 | | - | |
| 80 | + | |
77 | 81 | | |
78 | 82 | | |
79 | 83 | | |
| |||
0 commit comments