|
14 | 14 | """
|
15 | 15 |
|
16 | 16 |
|
17 |
| -def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover |
18 |
| - """ |
19 |
| - Like torch.linalg.solve, tries to return X |
20 |
| - such that AX=B, with A square. |
21 |
| - """ |
22 |
| - if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"): |
23 |
| - # PyTorch version >= 1.8.0 |
24 |
| - return torch.linalg.solve(A, B) |
25 |
| - |
26 |
| - # pyre-fixme[16]: `Tuple` has no attribute `solution`. |
27 |
| - return torch.solve(B, A).solution |
28 |
| - |
29 |
| - |
30 |
| -def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover |
31 |
| - """ |
32 |
| - Like torch.linalg.lstsq, tries to return X |
33 |
| - such that AX=B. |
34 |
| - """ |
35 |
| - if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"): |
36 |
| - # PyTorch version >= 1.9 |
37 |
| - return torch.linalg.lstsq(A, B).solution |
38 |
| - |
39 |
| - solution = torch.lstsq(B, A).solution |
40 |
| - if A.shape[1] < A.shape[0]: |
41 |
| - return solution[: A.shape[1]] |
42 |
| - return solution |
43 |
| - |
44 |
| - |
45 |
| -def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover |
46 |
| - """ |
47 |
| - Like torch.linalg.qr. |
48 |
| - """ |
49 |
| - if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"): |
50 |
| - # PyTorch version >= 1.9 |
51 |
| - return torch.linalg.qr(A) |
52 |
| - return torch.qr(A) |
53 |
| - |
54 |
| - |
55 |
| -def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover |
56 |
| - """ |
57 |
| - Like torch.linalg.eigh, assuming the argument is a symmetric real matrix. |
58 |
| - """ |
59 |
| - if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): |
60 |
| - return torch.linalg.eigh(A) |
61 |
| - return torch.symeig(A, eigenvectors=True) |
62 |
| - |
63 |
| - |
64 | 17 | def meshgrid_ij(
|
65 | 18 | *A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
66 | 19 | ) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
|
0 commit comments