Commit daf92c2
implementation of fbgemm op - permute_multi_embedding (#2120)
Summary:
X-link: pytorch/FBGEMM#2738
# context
* current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward
* it has several limitations:
a) it has to be a full permute, duplicates are not supported;
b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient;
c) the function output a single KT which requires a split operation
* there is some attempt to support duplicated outputs, but the backward doesn't work
* this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support
# notes
* this diff focuses on the implemenation and test of the operator
* performance analysis and benchmark are in the next diff
# operator example usage
* used in python
```
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_lengths, out_lengths
)
```
* permutes
```
permutes = tensor(
[
[0, 0, 0, 0, 3, 4], # f1
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1
[1, 3, 11, 3, 7, 0], # f5
]
)
```
# details
1. from the above example usage, we can clearly see that the operatior takes in the following:
a) values: List[torch.Tensor], which represents the input KTs
b) permutes: torch.Tensor, which contains the permute information, will be explained later
c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead
d) in_lengths: torch.Tensor, lengths of input tensors, which is on device
e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
2. the operator returns a list of tensors, which represents the permuted KTs
3. `permute` is the most critical argument in this operator:
a) 2-D tensor
b) each row represents a key (feature) permute move
c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump]
d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors
Differential Revision: D570556161 parent e367cfc commit daf92c2
File tree
2 files changed
+257
-2
lines changed- torchrec/sparse
- tests
2 files changed
+257
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
39 | 45 | | |
40 | 46 | | |
41 | 47 | | |
| |||
240 | 246 | | |
241 | 247 | | |
242 | 248 | | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
243 | 325 | | |
244 | 326 | | |
245 | 327 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
19 | 20 | | |
20 | 21 | | |
21 | 22 | | |
| |||
1374 | 1375 | | |
1375 | 1376 | | |
1376 | 1377 | | |
| 1378 | + | |
| 1379 | + | |
| 1380 | + | |
| 1381 | + | |
| 1382 | + | |
| 1383 | + | |
| 1384 | + | |
| 1385 | + | |
| 1386 | + | |
| 1387 | + | |
| 1388 | + | |
| 1389 | + | |
| 1390 | + | |
| 1391 | + | |
| 1392 | + | |
| 1393 | + | |
| 1394 | + | |
| 1395 | + | |
| 1396 | + | |
| 1397 | + | |
| 1398 | + | |
| 1399 | + | |
| 1400 | + | |
| 1401 | + | |
| 1402 | + | |
| 1403 | + | |
| 1404 | + | |
| 1405 | + | |
| 1406 | + | |
| 1407 | + | |
| 1408 | + | |
| 1409 | + | |
| 1410 | + | |
| 1411 | + | |
| 1412 | + | |
| 1413 | + | |
| 1414 | + | |
| 1415 | + | |
| 1416 | + | |
| 1417 | + | |
| 1418 | + | |
| 1419 | + | |
| 1420 | + | |
| 1421 | + | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
| 1426 | + | |
| 1427 | + | |
| 1428 | + | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
| 1434 | + | |
| 1435 | + | |
| 1436 | + | |
| 1437 | + | |
| 1438 | + | |
| 1439 | + | |
| 1440 | + | |
| 1441 | + | |
| 1442 | + | |
| 1443 | + | |
| 1444 | + | |
| 1445 | + | |
| 1446 | + | |
| 1447 | + | |
| 1448 | + | |
| 1449 | + | |
| 1450 | + | |
| 1451 | + | |
| 1452 | + | |
| 1453 | + | |
| 1454 | + | |
| 1455 | + | |
| 1456 | + | |
| 1457 | + | |
| 1458 | + | |
| 1459 | + | |
| 1460 | + | |
| 1461 | + | |
| 1462 | + | |
| 1463 | + | |
| 1464 | + | |
| 1465 | + | |
| 1466 | + | |
| 1467 | + | |
| 1468 | + | |
| 1469 | + | |
| 1470 | + | |
| 1471 | + | |
| 1472 | + | |
| 1473 | + | |
| 1474 | + | |
| 1475 | + | |
| 1476 | + | |
| 1477 | + | |
| 1478 | + | |
| 1479 | + | |
| 1480 | + | |
| 1481 | + | |
| 1482 | + | |
| 1483 | + | |
| 1484 | + | |
| 1485 | + | |
| 1486 | + | |
| 1487 | + | |
| 1488 | + | |
| 1489 | + | |
| 1490 | + | |
| 1491 | + | |
| 1492 | + | |
| 1493 | + | |
| 1494 | + | |
| 1495 | + | |
| 1496 | + | |
| 1497 | + | |
| 1498 | + | |
| 1499 | + | |
| 1500 | + | |
| 1501 | + | |
| 1502 | + | |
| 1503 | + | |
| 1504 | + | |
| 1505 | + | |
| 1506 | + | |
| 1507 | + | |
| 1508 | + | |
| 1509 | + | |
| 1510 | + | |
| 1511 | + | |
| 1512 | + | |
| 1513 | + | |
| 1514 | + | |
| 1515 | + | |
| 1516 | + | |
| 1517 | + | |
| 1518 | + | |
| 1519 | + | |
| 1520 | + | |
| 1521 | + | |
| 1522 | + | |
| 1523 | + | |
| 1524 | + | |
| 1525 | + | |
| 1526 | + | |
| 1527 | + | |
| 1528 | + | |
| 1529 | + | |
| 1530 | + | |
| 1531 | + | |
| 1532 | + | |
| 1533 | + | |
| 1534 | + | |
| 1535 | + | |
| 1536 | + | |
| 1537 | + | |
| 1538 | + | |
| 1539 | + | |
| 1540 | + | |
| 1541 | + | |
| 1542 | + | |
| 1543 | + | |
| 1544 | + | |
| 1545 | + | |
| 1546 | + | |
| 1547 | + | |
| 1548 | + | |
| 1549 | + | |
| 1550 | + | |
| 1551 | + | |
1377 | 1552 | | |
1378 | 1553 | | |
1379 | 1554 | | |
| |||
1650 | 1825 | | |
1651 | 1826 | | |
1652 | 1827 | | |
1653 | | - | |
1654 | | - | |
1655 | 1828 | | |
1656 | 1829 | | |
1657 | 1830 | | |
| |||
0 commit comments