Commit de0db16
refactor using encapsulate and decapsulate (#2174)
Summary:
Pull Request resolved: #2174
# context
1. we plan to consolidate the `non_strict_forward` into the TorchRec IR serializer, and this diff is to implement the new set of APIs for testing and migration.
**a.** current TorchRec IR APIs are `serialize_embedding_modules` and `deserialize_embedding_modules`.
**b.** the new set of APIs are named `encapsulate_ir_modules` and `decapsulate_ir_modules`.
**c.** two sets of APIs basically take in the same arguments.
3. the major differences are that:
**a.** in the new `encapsulate_ir_modules`, the TorchRec (embedding) modules' forward functions are replaced by a `ir_meta_forward` via the corresponding serializer. While in the old API flow the `non_strict_forward` is embedded within the TorchRec module and hidden behind a `is_non_strict_export` flag.
**b.** in the new `decapsulate_ir_modules`, it takes in the `unflatten_module` which is unflattened by `torch.export.unflatten`. While in the old API, it takes in the `ExportedProgram` and run the `torch.export.unflatten` function inside the TorchRec IR's `deserialize_embedding_modules` function, which introduced extra coupling.
3. This diff does **NOT** affect the original APIs.
4. Embedding modules including EBC, PEA, VLE, fpEBC have been tested with dynamic shape support.
# details
* schema
{F1733431041} {F1733431375}
* static ir_custom_op definition, **dynamic shape is supported**
```
torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
def ir_custom_op_impl(
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
) -> torch.Tensor: # when multiple output tensor is needed, we can simply do a torch.split
device = None
for t in tensors:
if t is not None:
device = t.device
break
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim})")
return torch.empty(batch_size, dim, device=device)
```
* The previous `custom_op` is unified into a static operator named `ir_custom_op`, which does not require registering on the fly. So there is no need to call a helper function `register_custom_ops_for_nodes` when unflatten an exported IR artifact.
* It decouples the `non_strict_export_forward` function in each TorchRec module, and avoids the complexity in the original module's forward logic as below, and the `non_strict_export` flag in `torch.export.export` is no longer a necessary flag.
```
def forward(
self,
id_list_features: KeyedJaggedTensor,
id_score_list_features: Optional[KeyedJaggedTensor] = None,
) -> Dict[str, torch.Tensor]:
if is_non_strict_exporting() and not torch.jit.is_scripting():
return self._non_strict_exporting_forward( # <---- this create a shortcut to the original forward function
id_list_features, id_score_list_features
)
... # the actual forward function
```
Reviewed By: dstaay-fb, PaulZhang12
Differential Revision: D59019375
fbshipit-source-id: ef0e539a7e0d11a206fb62485071e9e9cf2887bc1 parent a62de86 commit de0db16
File tree
4 files changed
+192
-42
lines changed- torchrec/ir
- tests
4 files changed
+192
-42
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
12 | 11 | | |
13 | 12 | | |
14 | 13 | | |
| |||
23 | 22 | | |
24 | 23 | | |
25 | 24 | | |
| 25 | + | |
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
| 35 | + | |
| 36 | + | |
35 | 37 | | |
36 | 38 | | |
37 | 39 | | |
| |||
82 | 84 | | |
83 | 85 | | |
84 | 86 | | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
85 | 107 | | |
86 | 108 | | |
87 | 109 | | |
| |||
163 | 185 | | |
164 | 186 | | |
165 | 187 | | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
166 | 242 | | |
167 | 243 | | |
168 | 244 | | |
169 | 245 | | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
170 | 252 | | |
171 | 253 | | |
172 | 254 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
| 21 | + | |
| 22 | + | |
22 | 23 | | |
23 | | - | |
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
33 | | - | |
34 | 33 | | |
35 | 34 | | |
36 | 35 | | |
| |||
183 | 182 | | |
184 | 183 | | |
185 | 184 | | |
186 | | - | |
| 185 | + | |
187 | 186 | | |
188 | 187 | | |
189 | 188 | | |
| |||
199 | 198 | | |
200 | 199 | | |
201 | 200 | | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
220 | | - | |
221 | | - | |
222 | | - | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
227 | | - | |
228 | 201 | | |
229 | | - | |
| 202 | + | |
| 203 | + | |
230 | 204 | | |
231 | 205 | | |
232 | 206 | | |
| |||
265 | 239 | | |
266 | 240 | | |
267 | 241 | | |
268 | | - | |
269 | | - | |
270 | 242 | | |
| 243 | + | |
271 | 244 | | |
272 | 245 | | |
273 | 246 | | |
| |||
292 | 265 | | |
293 | 266 | | |
294 | 267 | | |
295 | | - | |
| 268 | + | |
296 | 269 | | |
297 | 270 | | |
298 | 271 | | |
| |||
311 | 284 | | |
312 | 285 | | |
313 | 286 | | |
314 | | - | |
| 287 | + | |
| 288 | + | |
315 | 289 | | |
316 | 290 | | |
317 | 291 | | |
| |||
330 | 304 | | |
331 | 305 | | |
332 | 306 | | |
333 | | - | |
| 307 | + | |
334 | 308 | | |
335 | 309 | | |
336 | 310 | | |
| |||
345 | 319 | | |
346 | 320 | | |
347 | 321 | | |
348 | | - | |
349 | | - | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
350 | 325 | | |
351 | 326 | | |
352 | 327 | | |
| |||
408 | 383 | | |
409 | 384 | | |
410 | 385 | | |
411 | | - | |
| 386 | + | |
412 | 387 | | |
413 | 388 | | |
414 | 389 | | |
| |||
424 | 399 | | |
425 | 400 | | |
426 | 401 | | |
427 | | - | |
428 | | - | |
| 402 | + | |
| 403 | + | |
429 | 404 | | |
430 | 405 | | |
431 | 406 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| |||
24 | 25 | | |
25 | 26 | | |
26 | 27 | | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
27 | 105 | | |
28 | 106 | | |
29 | 107 | | |
| |||
0 commit comments