Commit c680107
authored
Avoid nn.Linear decomposition by replacing view -> mm -> view with einsum (#26)
* [WIP] Replace view -> mm -> view with matmul
This tries to support CP-style sharding, by overcoming a limitation of DTensor. Doesn't yet work as _mm_strategy is failing
* Fix matmul propagation rule
Somethings are starting to work, but we are not yet there
* Move function to graph_utils.py
* Pull improvements from #29
* Fix equation for einsum
* Cleanup code now that PyTorch has fixed _gen_einsum_strategies
Requires pytorch/pytorch#157593
* Generalize to more than 3d
* Generalize backward pass as well and make everything call into einsum
* Add note about future work
* Add einsum flops and generalize creation of sharded tensors
Before this, if we had a list of tensors we wouldn't shard the tensors inside the list
* Disable erroneous sdpa rule from backward
* Account for compute cost in collectives as well
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
* Account for compute cost in collectives as well
This removes a long-standing hack to tell the solver that S(1) -> R is more expensive than S(0) -> R because of an additional data movement
* Support getitem as well
* Improve comments and suppose 80% efficiency
* Suppose 70% efficiency for comms
* Add comment and set it to false by default
* Revert changes from another PR
* Add spaces back1 parent 8737c5a commit c680107
File tree
4 files changed
+133
-1
lines changed- autoparallel
4 files changed
+133
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
32 | 33 | | |
33 | 34 | | |
34 | 35 | | |
| |||
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
| 41 | + | |
| 42 | + | |
40 | 43 | | |
41 | 44 | | |
42 | 45 | | |
| |||
230 | 233 | | |
231 | 234 | | |
232 | 235 | | |
| 236 | + | |
| 237 | + | |
233 | 238 | | |
234 | 239 | | |
235 | 240 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
12 | 39 | | |
13 | 40 | | |
14 | 41 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
153 | 153 | | |
154 | 154 | | |
155 | 155 | | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
761 | 761 | | |
762 | 762 | | |
763 | 763 | | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
| 773 | + | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
| 777 | + | |
| 778 | + | |
| 779 | + | |
| 780 | + | |
| 781 | + | |
| 782 | + | |
| 783 | + | |
| 784 | + | |
| 785 | + | |
0 commit comments