Commit 289ac65
authored
* Permit negative indices for tensordot when `axes` is a tuple
Per array API standard
Also fixes vecdot attempting to call `py_dot` with incorrect batching dims
* Fixes incorrect shape and stride in contiguous dot product indexers
* Fixes typo in GemmBatchNoAtomicFunctorThreadK
Closes #1570
* Aligns `vecdot` with array API spec changes
Only negative values for `axis` are permitted to avoid ambiguity
Now separately checks that the `axis` parameter is valid for each array before broadcasting occurs
* `test_usm_ndarray_linalg` changed to reflect `vecdot` and `tensordot` changes
* Updates `tensordot` and `vecdot` docstrings to reflect changes
* Adds tests for bugs changes in `vecdot`, `tensordot`
1 parent 1465451 commit 289ac65
File tree
4 files changed
+83
-50
lines changed- dpctl
- tensor
- libtensor/include/kernels/linalg_functions
- tests
4 files changed
+83
-50
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
93 | | - | |
94 | | - | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
95 | 96 | | |
96 | 97 | | |
97 | 98 | | |
| |||
154 | 155 | | |
155 | 156 | | |
156 | 157 | | |
157 | | - | |
158 | | - | |
159 | 158 | | |
160 | | - | |
161 | | - | |
162 | 159 | | |
163 | 160 | | |
164 | 161 | | |
| |||
314 | 311 | | |
315 | 312 | | |
316 | 313 | | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
323 | 319 | | |
324 | 320 | | |
325 | 321 | | |
| |||
355 | 351 | | |
356 | 352 | | |
357 | 353 | | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
358 | 363 | | |
359 | 364 | | |
360 | | - | |
361 | 365 | | |
362 | 366 | | |
363 | | - | |
364 | | - | |
365 | | - | |
366 | | - | |
367 | | - | |
368 | | - | |
369 | 367 | | |
370 | 368 | | |
371 | 369 | | |
| |||
375 | 373 | | |
376 | 374 | | |
377 | 375 | | |
| 376 | + | |
| 377 | + | |
378 | 378 | | |
379 | | - | |
| 379 | + | |
380 | 380 | | |
381 | 381 | | |
382 | 382 | | |
| |||
414 | 414 | | |
415 | 415 | | |
416 | 416 | | |
417 | | - | |
418 | | - | |
419 | | - | |
| 417 | + | |
| 418 | + | |
420 | 419 | | |
421 | 420 | | |
422 | 421 | | |
| |||
427 | 426 | | |
428 | 427 | | |
429 | 428 | | |
430 | | - | |
| 429 | + | |
431 | 430 | | |
432 | 431 | | |
433 | 432 | | |
| |||
459 | 458 | | |
460 | 459 | | |
461 | 460 | | |
462 | | - | |
463 | | - | |
| 461 | + | |
| 462 | + | |
464 | 463 | | |
465 | 464 | | |
466 | 465 | | |
| |||
471 | 470 | | |
472 | 471 | | |
473 | 472 | | |
474 | | - | |
| 473 | + | |
475 | 474 | | |
476 | 475 | | |
477 | 476 | | |
| |||
501 | 500 | | |
502 | 501 | | |
503 | 502 | | |
504 | | - | |
505 | | - | |
| 503 | + | |
| 504 | + | |
506 | 505 | | |
507 | 506 | | |
508 | 507 | | |
| |||
513 | 512 | | |
514 | 513 | | |
515 | 514 | | |
516 | | - | |
| 515 | + | |
517 | 516 | | |
518 | 517 | | |
519 | 518 | | |
| |||
548 | 547 | | |
549 | 548 | | |
550 | 549 | | |
551 | | - | |
552 | | - | |
| 550 | + | |
| 551 | + | |
553 | 552 | | |
554 | 553 | | |
555 | 554 | | |
| |||
560 | 559 | | |
561 | 560 | | |
562 | 561 | | |
563 | | - | |
| 562 | + | |
564 | 563 | | |
565 | 564 | | |
566 | 565 | | |
| |||
Lines changed: 10 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
552 | 552 | | |
553 | 553 | | |
554 | 554 | | |
555 | | - | |
556 | | - | |
| 555 | + | |
| 556 | + | |
557 | 557 | | |
558 | 558 | | |
559 | 559 | | |
| |||
588 | 588 | | |
589 | 589 | | |
590 | 590 | | |
591 | | - | |
592 | | - | |
| 591 | + | |
| 592 | + | |
593 | 593 | | |
594 | 594 | | |
595 | 595 | | |
| |||
1174 | 1174 | | |
1175 | 1175 | | |
1176 | 1176 | | |
1177 | | - | |
1178 | | - | |
| 1177 | + | |
| 1178 | + | |
1179 | 1179 | | |
1180 | 1180 | | |
1181 | 1181 | | |
| |||
1212 | 1212 | | |
1213 | 1213 | | |
1214 | 1214 | | |
1215 | | - | |
1216 | | - | |
| 1215 | + | |
| 1216 | + | |
1217 | 1217 | | |
1218 | 1218 | | |
1219 | 1219 | | |
| |||
1280 | 1280 | | |
1281 | 1281 | | |
1282 | 1282 | | |
1283 | | - | |
1284 | | - | |
| 1283 | + | |
| 1284 | + | |
1285 | 1285 | | |
1286 | 1286 | | |
1287 | 1287 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2100 | 2100 | | |
2101 | 2101 | | |
2102 | 2102 | | |
2103 | | - | |
| 2103 | + | |
2104 | 2104 | | |
2105 | 2105 | | |
2106 | 2106 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
782 | 782 | | |
783 | 783 | | |
784 | 784 | | |
785 | | - | |
786 | | - | |
787 | 785 | | |
788 | | - | |
789 | | - | |
| 786 | + | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
790 | 795 | | |
791 | 796 | | |
792 | 797 | | |
| |||
834 | 839 | | |
835 | 840 | | |
836 | 841 | | |
837 | | - | |
| 842 | + | |
838 | 843 | | |
839 | 844 | | |
840 | 845 | | |
| |||
864 | 869 | | |
865 | 870 | | |
866 | 871 | | |
867 | | - | |
| 872 | + | |
868 | 873 | | |
869 | 874 | | |
870 | 875 | | |
| |||
903 | 908 | | |
904 | 909 | | |
905 | 910 | | |
| 911 | + | |
| 912 | + | |
| 913 | + | |
906 | 914 | | |
907 | 915 | | |
908 | 916 | | |
| |||
946 | 954 | | |
947 | 955 | | |
948 | 956 | | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
0 commit comments