@@ -1743,3 +1743,51 @@ func.func @test_compress_neg_axis(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.
17431743 %0 = torch.operator " onnx.Compress" (%arg0 , %cst ) {torch.onnx.axis = -2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ],f32 >, !torch.vtensor <[3 ],si64 >) -> !torch.vtensor <[2 ,2 ,4 ],f32 >
17441744 return %0 : !torch.vtensor <[2 ,2 ,4 ],f32 >
17451745}
1746+
1747+ // -----
1748+
1749+ // CHECK-LABEL: func.func @test_einsum_batch_diagonal
1750+ func.func @test_einsum_batch_diagonal (%arg0: !torch.vtensor <[3 ,5 ,5 ],f64 >) -> !torch.vtensor <[3 ,5 ],f64 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1751+ // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,5,5],f64>) -> !torch.list<vtensor>
1752+ // CHECK: %[[EQUATION:.*]] = torch.constant.str "...ii ->...i"
1753+ // CHECK: %[[PATH:.*]] = torch.constant.none
1754+ // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3,5],f64>
1755+ %0 = torch.operator " onnx.Einsum" (%arg0 ) {torch.onnx.equation = " ...ii ->...i" } : (!torch.vtensor <[3 ,5 ,5 ],f64 >) -> !torch.vtensor <[3 ,5 ],f64 >
1756+ return %0 : !torch.vtensor <[3 ,5 ],f64 >
1757+ }
1758+
1759+ // -----
1760+
1761+ // CHECK-LABEL: func.func @test_einsum_batch_matmul
1762+ func.func @test_einsum_batch_matmul (%arg0: !torch.vtensor <[5 ,2 ,3 ],f64 >, %arg1: !torch.vtensor <[5 ,3 ,4 ],f64 >) -> !torch.vtensor <[5 ,2 ,4 ],f64 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1763+ // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5,2,3],f64>, !torch.vtensor<[5,3,4],f64>) -> !torch.list<vtensor>
1764+ // CHECK: %[[EQUATION:.*]] = torch.constant.str "bij, bjk -> bik"
1765+ // CHECK: %[[PATH:.*]] = torch.constant.none
1766+ // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[5,2,4],f64>
1767+ %0 = torch.operator " onnx.Einsum" (%arg0 , %arg1 ) {torch.onnx.equation = " bij, bjk -> bik" } : (!torch.vtensor <[5 ,2 ,3 ],f64 >, !torch.vtensor <[5 ,3 ,4 ],f64 >) -> !torch.vtensor <[5 ,2 ,4 ],f64 >
1768+ return %0 : !torch.vtensor <[5 ,2 ,4 ],f64 >
1769+ }
1770+
1771+ // -----
1772+
1773+ // CHECK-LABEL: func.func @test_einsum_sum
1774+ func.func @test_einsum_sum (%arg0: !torch.vtensor <[3 ,4 ],f64 >) -> !torch.vtensor <[3 ],f64 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1775+ // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
1776+ // CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->i"
1777+ // CHECK: %[[PATH:.*]] = torch.constant.none
1778+ // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[3],f64>
1779+ %0 = torch.operator " onnx.Einsum" (%arg0 ) {torch.onnx.equation = " ij->i" } : (!torch.vtensor <[3 ,4 ],f64 >) -> !torch.vtensor <[3 ],f64 >
1780+ return %0 : !torch.vtensor <[3 ],f64 >
1781+ }
1782+
1783+ // -----
1784+
1785+ // CHECK-LABEL: func.func @test_einsum_transpose
1786+ func.func @test_einsum_transpose (%arg0: !torch.vtensor <[3 ,4 ],f64 >) -> !torch.vtensor <[4 ,3 ],f64 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1787+ // CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[3,4],f64>) -> !torch.list<vtensor>
1788+ // CHECK: %[[EQUATION:.*]] = torch.constant.str "ij->ji"
1789+ // CHECK: %[[PATH:.*]] = torch.constant.none
1790+ // CHECK: torch.aten.einsum %[[EQUATION]], %[[TENSORS]], %[[PATH]] : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[4,3],f64>
1791+ %0 = torch.operator " onnx.Einsum" (%arg0 ) {torch.onnx.equation = " ij->ji" } : (!torch.vtensor <[3 ,4 ],f64 >) -> !torch.vtensor <[4 ,3 ],f64 >
1792+ return %0 : !torch.vtensor <[4 ,3 ],f64 >
1793+ }
0 commit comments