diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index a5e97b65ba89b..56b34e9b2a49f 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -17,8 +17,8 @@ concurrency: # Provisioned Jobs: -# ubuntu - x86_64 - llvm in-tree - pytorch binary - build+test # most used dev flow and fastest signal -# ubuntu - x86_64 - llvm out-of-tree - pytorch source - build+test # most elaborate build +# ubuntu/docker - x86_64 - llvm in-tree - pytorch binary - build+test # most used dev flow and fastest signal +# ubuntu/docker - x86_64 - llvm out-of-tree - pytorch source - build+test # most elaborate build # macos - arm64 - llvm in-tree - pytorch binary - build only # cross compile, can't test arm64 jobs: build-test: @@ -57,60 +57,11 @@ jobs: with: cache-suffix: ${{ matrix.os-arch }}-${{ matrix.llvm-build }}-${{ matrix.torch-binary }} - - name: Configure os-arch='ubuntu-x86_64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' - # Fastest build, most used dev flow - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - cmake -GNinja -Bbuild \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_LINKER=lld \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ - -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ - -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/externals/llvm-external-projects/torch-mlir-dialects" \ - -DLLVM_TARGETS_TO_BUILD=host \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ - -DPython3_EXECUTABLE="$(which python)" \ - $GITHUB_WORKSPACE/externals/llvm-project/llvm - - - name: Configure os-arch='ubuntu-x86_64' llvm-build='out-of-tree' torch-binary='${{ matrix.torch-binary }}' - # Most elaborate build, but cached - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'out-of-tree' }} + - name: Build and Test os-arch='ubuntu-x86_64' llvm-build='${{ matrix.llvm-build }}' torch-binary='${{ matrix.torch-binary }}' + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} run: | - cmake -GNinja -Bllvm-build \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_LINKER=lld \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_TARGETS_TO_BUILD=host \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DPython3_EXECUTABLE="$(which python)" \ - $GITHUB_WORKSPACE/externals/llvm-project/llvm - cmake --build llvm-build - - cmake -GNinja -Bbuild \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_LINKER=lld \ - -DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm/" \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir/" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=OFF \ - -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ - -DPython3_EXECUTABLE="$(which python)" \ - $GITHUB_WORKSPACE - + cd $GITHUB_WORKSPACE + TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' # cross compile, can't test arm64 if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} @@ -139,63 +90,7 @@ jobs: -DMACOSX_DEPLOYMENT_TARGET=12.0 \ -DPython3_EXECUTABLE="$(which python)" \ $GITHUB_WORKSPACE/externals/llvm-project/llvm - - - name: Build torch-mlir - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - cmake --build build - - name: Build torch-mlir (cross-compile) if: ${{ matrix.os-arch == 'macos-arm64' }} run: | cmake --build build_arm64 - - - name: Run torch-mlir unit tests - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - cmake --build build --target check-torch-mlir-all - - - name: Ensure generated files are up to date - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - ./build_tools/update_torch_ods.sh - ./build_tools/update_shape_lib.sh - if ! git diff --quiet; then - echo "#######################################################" - echo "Generated files are not up to date (see diff below)" - echo ">>> Please run ./build_tools/update_torch_ods.sh and ./build_tools/update_shape_lib.sh <<<" - echo "#######################################################" - git diff --color=always - exit 1 - fi - - - name: Run refbackend e2e integration tests - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.main --config=refbackend -v - - - name: Run eager_mode e2e integration tests - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.main --config=eager_mode -v - - - name: Run mhlo e2e integration tests - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.main --config=mhlo -v - - - name: Run tosa e2e integration tests - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - python -m e2e_testing.main --config=tosa -v - - - name: Run lazy_tensor_core e2e integration tests - if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} - run: | - export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" - echo "LTC tests disabled temporarily. https://github.com/llvm/torch-mlir/pull/1292" - # python -m e2e_testing.main --config=lazy_tensor_core -v diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f6d4d932d19d..0ffd338baafc9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,7 +54,9 @@ torch_mlir_add_llvm_external_project( TORCH_MLIR_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects) -if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) +option(TORCH_MLIR_OOT_BUILD "Specifies an out of tree build" OFF) + +if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR TORCH_MLIR_OOT_BUILD) message(STATUS "Torch-MLIR out-of-tree build.") # Out-of-tree build diff --git a/README.md b/README.md index cd6f43b5c4c23..48a5efd75de61 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Multiple Vendors use MLIR as the middle layer, mapping from platform frameworks We have few paths to lower down to the Torch MLIR Dialect. -![Torch Lowering Architectures](Torch-MLIR.png) +![Torch Lowering Architectures](docs/Torch-MLIR.png) - TorchScript This is the most tested path down to Torch MLIR Dialect, and the PyTorch ecosystem is converging on using TorchScript IR as a lingua franca. diff --git a/Torch-MLIR.png b/Torch-MLIR.png deleted file mode 100644 index 4a800e4e61675..0000000000000 Binary files a/Torch-MLIR.png and /dev/null differ diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 55ec80f6d8a7d..74501cdf123c2 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -264,6 +264,9 @@ def get_opnames(ops): # Additional ops to support that are not supported by Torch-MLIR explicitly supported |= set(config.get("additional_ops", [])) + # List of ops that will take in symints for its size + symint = set(config.get("symint", [])) + self.ops = sorted(ops) with self.source_yaml.open("w") as f: @@ -272,6 +275,7 @@ def get_opnames(ops): "cpp_namespace": "torch::lazy", "full_codegen": self.ops, "supported": sorted(supported), + "symint": sorted(symint), "non_native": non_native, } yaml.dump(source_yaml, f, default_flow_style=False) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index dde1f0f014c20..c55a5fffb2274 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -16,6 +16,7 @@ blacklist: - copy_ # Disabled for consistency with TS backend +- lift_fresh_copy - new_empty - rsub - slice.Tensor # Disabled in favour of slice_copy.Tensor @@ -60,6 +61,7 @@ supported: # but their implementations call view operators (which we need to functionalize away). - block_diag - new_empty_strided +- narrow_copy - pixel_shuffle - pixel_unshuffle - select_backward @@ -69,6 +71,16 @@ supported: - linalg_pinv.atol_rtol_tensor - logsumexp.out +# List of ops that will take in symints for the size instead of ints +symint: +- empty.memory_format +- new_empty_strided +- expand +- expand_copy +- narrow_copy +- view +- view_copy + additional_ops: # Additional ops to support that are not supported by Torch-MLIR explicitly diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index ac824f41cb69c..e84de6c1ec1bb 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -48,7 +48,7 @@ TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp39-cp39 cp310-cp310}" # Location to store Release wheels TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}" # What "packages to build" -TM_PACKAGES="${TM_PACKAGES:-torch-mlir out-of-tree in-tree}" +TM_PACKAGES="${TM_PACKAGES:-torch-mlir}" # Use pre-built Pytorch TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" # Skip running tests if you want quick iteration @@ -211,11 +211,9 @@ function _check_file_not_changed_by() { # TODO: Is there a better cleanup strategy that doesn't require duplicating # this inside and outside the `if`? rm "$file_new" - rm "$file_backup" return 1 fi rm "$file_new" - rm "$file_backup" } function test_in_tree() { @@ -238,7 +236,9 @@ function test_in_tree() { python -m e2e_testing.main --config=eager_mode -v echo ":::: Run TOSA e2e integration tests" - python -m e2e_testing.main --config=tosa -v + # crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed issues: + # - AvgPool2dFloatModule_basic,AvgPool2dCeilModeTrueModule_basic: https://github.com/llvm/torch-mlir/issues/1361 + python -m e2e_testing.main --config=tosa -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed AvgPool2dFloatModule_basic AvgPool2dCeilModeTrueModule_basic # Temporarily disabled in top of main (https://github.com/llvm/torch-mlir/pull/1292) #echo ":::: Run Lazy Tensor Core e2e integration tests" diff --git a/docs/Torch-MLIR.excalidraw b/docs/Torch-MLIR.excalidraw index 4b7d8d29b4130..4431f266d5c19 100644 --- a/docs/Torch-MLIR.excalidraw +++ b/docs/Torch-MLIR.excalidraw @@ -1,12 +1,12 @@ { "type": "excalidraw", "version": 2, - "source": "https://app.excalidraw.com", + "source": "https://excalidraw.com", "elements": [ { "type": "rectangle", - "version": 926, - "versionNonce": 1878316780, + "version": 1327, + "versionNonce": 305753707, "isDeleted": false, "id": "VoA-P762E-kYEfEXMXMan", "fillStyle": "hachure", @@ -15,24 +15,24 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 354.493408203125, - "y": 142, + "x": 435.65045166015625, + "y": 154.2143096923828, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 562.9459228515625, - "height": 205.00000000000003, + "width": 459.9085693359377, + "height": 187.3185882568359, "seed": 66180426, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663213426832, "link": null, "locked": false }, { "type": "text", - "version": 386, - "versionNonce": 1469668820, + "version": 422, + "versionNonce": 1111702469, "isDeleted": false, "id": "HqFbe3ioHlzPjh47of6Xm", "fillStyle": "hachure", @@ -41,8 +41,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 548, - "y": 183.5, + "x": 559.8150024414062, + "y": 182.78280639648438, "strokeColor": "#1864ab", "backgroundColor": "transparent", "width": 201, @@ -51,7 +51,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663213388662, "link": null, "locked": false, "fontSize": 51.319148936170194, @@ -65,8 +65,8 @@ }, { "type": "rectangle", - "version": 175, - "versionNonce": 261298540, + "version": 274, + "versionNonce": 1440029491, "isDeleted": false, "id": "gIQVefxMbT2pRGToulSyw", "fillStyle": "hachure", @@ -75,7 +75,7 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 400, + "x": 484, "y": 280, "strokeColor": "#000000", "backgroundColor": "transparent", @@ -85,45 +85,14 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "rectangle", - "version": 466, - "versionNonce": 970359636, - "isDeleted": false, - "id": "LRs912__zHToeBmjAzSQ7", - "fillStyle": "hachure", - "strokeWidth": 1, - "strokeStyle": "solid", - "roughness": 1, - "opacity": 100, - "angle": 0, - "x": 573.9208984375, - "y": 281.69195556640625, - "strokeColor": "#000000", - "backgroundColor": "transparent", - "width": 157.39685058593753, - "height": 47.50274658203128, - "seed": 344731990, - "groupIds": [], - "strokeSharpness": "sharp", - "boundElements": [ - { - "type": "arrow", - "id": "m5nLHFx0hX6Cd6zMuikcu" - } - ], - "updated": 1660940133882, - "link": null, - "locked": false - }, - { - "type": "rectangle", - "version": 472, - "versionNonce": 1464142828, + "version": 639, + "versionNonce": 1569799379, "isDeleted": false, "id": "cubDvRltmWCH__B9Y9m-8", "fillStyle": "hachure", @@ -132,8 +101,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 742.048583984375, - "y": 281.6216735839844, + "x": 683.7152099609375, + "y": 281.2934875488281, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 163.87255859375006, @@ -147,14 +116,14 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "rectangle", - "version": 385, - "versionNonce": 1951594708, + "version": 482, + "versionNonce": 1957975933, "isDeleted": false, "id": "3oURCWeTRMOEqHOJb9pDi", "fillStyle": "hachure", @@ -163,8 +132,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 753, - "y": 410, + "x": 699.5054931640625, + "y": 409.33502197265625, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 127.99999999999997, @@ -178,14 +147,14 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, + "updated": 1663212567140, "link": null, "locked": false }, { "type": "rectangle", - "version": 618, - "versionNonce": 1101930348, + "version": 633, + "versionNonce": 144190067, "isDeleted": false, "id": "HXNEs54Djw-u5oqv0I0RN", "fillStyle": "hachure", @@ -215,16 +184,32 @@ { "id": "CPvTKrc3_ABgC6tI8JY9-", "type": "arrow" + }, + { + "id": "pNcSwccuMNO6_J-0ec8fZ", + "type": "arrow" + }, + { + "id": "3jV4ltqqNRgUJ_hiQTprf", + "type": "arrow" + }, + { + "id": "DTYYOEYxneSWWLSsMj-QA", + "type": "arrow" + }, + { + "id": "x63UEL7zv_DhnLWWouZUy", + "type": "arrow" } ], - "updated": 1660940744726, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "rectangle", - "version": 717, - "versionNonce": 1654963796, + "version": 725, + "versionNonce": 204220989, "isDeleted": false, "id": "xPee8sq_dDf5TxhIIE9xf", "fillStyle": "hachure", @@ -248,14 +233,14 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "rectangle", - "version": 885, - "versionNonce": 365691116, + "version": 893, + "versionNonce": 1001415699, "isDeleted": false, "id": "qSfNJH2ZSv_X2ar-lnJiq", "fillStyle": "hachure", @@ -279,49 +264,14 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, - "link": null, - "locked": false - }, - { - "type": "rectangle", - "version": 505, - "versionNonce": 2107757524, - "isDeleted": false, - "id": "zDzlPtMgDR9JllW5HcHQk", - "fillStyle": "hachure", - "strokeWidth": 1, - "strokeStyle": "solid", - "roughness": 1, - "opacity": 100, - "angle": 0, - "x": 588, - "y": 412, - "strokeColor": "#000000", - "backgroundColor": "transparent", - "width": 127.99999999999999, - "height": 57, - "seed": 791673302, - "groupIds": [], - "strokeSharpness": "sharp", - "boundElements": [ - { - "type": "arrow", - "id": "mhl9dSP2l8IK7eFYvftAg" - }, - { - "type": "arrow", - "id": "m5nLHFx0hX6Cd6zMuikcu" - } - ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "rectangle", - "version": 697, - "versionNonce": 83166060, + "version": 811, + "versionNonce": 2127587763, "isDeleted": false, "id": "DCHk8Ww01wbH6p1ggjcHw", "fillStyle": "hachure", @@ -330,7 +280,7 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 401, + "x": 499, "y": 412, "strokeColor": "#000000", "backgroundColor": "transparent", @@ -353,14 +303,14 @@ "id": "jlcRseYBmZOfpeR6fsXlH" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "text", - "version": 71, - "versionNonce": 1041727340, + "version": 272, + "versionNonce": 1678409469, "isDeleted": false, "id": "4mua5Z9wyXZyyXjAS3jhg", "fillStyle": "hachure", @@ -369,8 +319,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 431, - "y": 291, + "x": 507.856689453125, + "y": 292.0832214355469, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 113, @@ -384,7 +334,7 @@ "id": "jlcRseYBmZOfpeR6fsXlH" } ], - "updated": 1660940707533, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 20, @@ -398,47 +348,8 @@ }, { "type": "text", - "version": 115, - "versionNonce": 1557771756, - "isDeleted": false, - "id": "yW-2pzejqWogK6u9YysVa", - "fillStyle": "hachure", - "strokeWidth": 1, - "strokeStyle": "solid", - "roughness": 1, - "opacity": 100, - "angle": 0, - "x": 579.5, - "y": 293, - "strokeColor": "#000000", - "backgroundColor": "transparent", - "width": 149, - "height": 25, - "seed": 1259559178, - "groupIds": [], - "strokeSharpness": "sharp", - "boundElements": [ - { - "type": "arrow", - "id": "m5nLHFx0hX6Cd6zMuikcu" - } - ], - "updated": 1660940133882, - "link": null, - "locked": false, - "fontSize": 20, - "fontFamily": 1, - "text": "torch_dispatch", - "baseline": 18, - "textAlign": "center", - "verticalAlign": "middle", - "containerId": null, - "originalText": "torch_dispatch" - }, - { - "type": "text", - "version": 213, - "versionNonce": 1060499156, + "version": 345, + "versionNonce": 2075891549, "isDeleted": false, "id": "KGkYT_1D9auJhECUHsXn6", "fillStyle": "hachure", @@ -447,8 +358,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 744.37158203125, - "y": 292, + "x": 687.0521240234375, + "y": 291.7152404785156, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 158, @@ -462,7 +373,7 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 20, @@ -476,8 +387,8 @@ }, { "type": "text", - "version": 232, - "versionNonce": 1787041900, + "version": 368, + "versionNonce": 1069984901, "isDeleted": false, "id": "lzgAzH8DMNlzk1SEvenUB", "fillStyle": "hachure", @@ -486,8 +397,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 760.5, - "y": 409.5, + "x": 711.9647216796875, + "y": 411.91094970703125, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 101.53703703703698, @@ -505,7 +416,7 @@ "id": "pNcSwccuMNO6_J-0ec8fZ" } ], - "updated": 1660940133882, + "updated": 1663213438512, "link": null, "locked": false, "fontSize": 17.97115699770567, @@ -519,51 +430,8 @@ }, { "type": "text", - "version": 485, - "versionNonce": 28898388, - "isDeleted": false, - "id": "lC49obx_HZvLUDhBFwN3d", - "fillStyle": "hachure", - "strokeWidth": 1, - "strokeStyle": "solid", - "roughness": 1, - "opacity": 100, - "angle": 0, - "x": 600, - "y": 417.5, - "strokeColor": "#000000", - "backgroundColor": "transparent", - "width": 105, - "height": 46, - "seed": 390444682, - "groupIds": [], - "strokeSharpness": "sharp", - "boundElements": [ - { - "type": "arrow", - "id": "mhl9dSP2l8IK7eFYvftAg" - }, - { - "type": "arrow", - "id": "TE1j6kxZKej3YsuMPcxgT" - } - ], - "updated": 1660940133882, - "link": null, - "locked": false, - "fontSize": 18.14814814814815, - "fontFamily": 1, - "text": "Build per-op\nJIT graph", - "baseline": 39, - "textAlign": "center", - "verticalAlign": "middle", - "containerId": null, - "originalText": "Build per-op\nJIT graph" - }, - { - "type": "text", - "version": 692, - "versionNonce": 1237199596, + "version": 817, + "versionNonce": 1775256211, "isDeleted": false, "id": "rxna1NdNTTOVVZAztGxaH", "fillStyle": "hachure", @@ -572,7 +440,7 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 409.33333333333337, + "x": 507.33333333333337, "y": 421.5, "strokeColor": "#000000", "backgroundColor": "transparent", @@ -589,9 +457,13 @@ { "type": "arrow", "id": "3jV4ltqqNRgUJ_hiQTprf" + }, + { + "id": "jlcRseYBmZOfpeR6fsXlH", + "type": "arrow" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 15.123456790123457, @@ -605,8 +477,8 @@ }, { "type": "arrow", - "version": 332, - "versionNonce": 1145876972, + "version": 1008, + "versionNonce": 205071389, "isDeleted": false, "id": "jlcRseYBmZOfpeR6fsXlH", "fillStyle": "hachure", @@ -615,77 +487,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 472.1378278340716, - "y": 329.1111111111111, + "x": 566.0874402159667, + "y": 328.2395595974392, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 3.1903837298033295, - "height": 69.77777777777783, + "width": 1.436104160979994, + "height": 80.10074530707482, "seed": 715847702, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940707534, + "updated": 1663212561032, "link": null, "locked": false, "startBinding": { "elementId": "4mua5Z9wyXZyyXjAS3jhg", - "gap": 13.111111111111107, - "focus": 0.2511447692849595 + "focus": -0.027005630112096182, + "gap": 11.156338161892336 }, "endBinding": { - "elementId": "DCHk8Ww01wbH6p1ggjcHw", - "gap": 13.111111111111107, - "focus": -0.06264298943240007 - }, - "lastCommittedPoint": null, - "startArrowhead": null, - "endArrowhead": "arrow", - "points": [ - [ - 0, - 0 - ], - [ - -3.1903837298033295, - 69.77777777777783 - ] - ] - }, - { - "type": "arrow", - "version": 247, - "versionNonce": 219278700, - "isDeleted": false, - "id": "m5nLHFx0hX6Cd6zMuikcu", - "fillStyle": "hachure", - "strokeWidth": 1, - "strokeStyle": "solid", - "roughness": 1, - "opacity": 100, - "angle": 0, - "x": 650.9914408647296, - "y": 330.66666666666663, - "strokeColor": "#000000", - "backgroundColor": "transparent", - "width": 5.669498788367719, - "height": 73.73333333333335, - "seed": 938405002, - "groupIds": [], - "strokeSharpness": "round", - "boundElements": [], - "updated": 1660940133882, - "link": null, - "locked": false, - "startBinding": { - "elementId": "yW-2pzejqWogK6u9YysVa", - "gap": 12.666666666666666, - "focus": 0.014297452752670503 - }, - "endBinding": { - "elementId": "zDzlPtMgDR9JllW5HcHQk", - "gap": 7.599999999999999, - "focus": -0.14282601590587216 + "elementId": "rxna1NdNTTOVVZAztGxaH", + "focus": -0.05849244953968733, + "gap": 13.159695095486086 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -696,15 +519,15 @@ 0 ], [ - -5.669498788367719, - 73.73333333333335 + 1.436104160979994, + 80.10074530707482 ] ] }, { "type": "arrow", - "version": 591, - "versionNonce": 2002710356, + "version": 1107, + "versionNonce": 296276965, "isDeleted": false, "id": "mhl9dSP2l8IK7eFYvftAg", "fillStyle": "hachure", @@ -713,28 +536,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 820.4646398319892, - "y": 330.44444444444446, + "x": 764.5464095179234, + "y": 321.874284532335, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 12.349181782953792, - "height": 74.66666666666663, + "width": 3.265527640003029, + "height": 85.64777628580737, "seed": 288672586, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663213438513, "link": null, "locked": false, "startBinding": { "elementId": "KGkYT_1D9auJhECUHsXn6", - "gap": 13.444444444444441, - "focus": -0.016956973122620834 + "focus": 0.027407616735912774, + "gap": 5.1590440538194 }, "endBinding": { - "elementId": "3oURCWeTRMOEqHOJb9pDi", - "gap": 4.8888888888888875, - "focus": -0.2077568362566729 + "elementId": "lzgAzH8DMNlzk1SEvenUB", + "focus": 0.11937751807613077, + "gap": 4.388888888888857 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -745,15 +568,15 @@ 0 ], [ - -12.349181782953792, - 74.66666666666663 + 3.265527640003029, + 85.64777628580737 ] ] }, { "type": "line", - "version": 399, - "versionNonce": 1707696108, + "version": 408, + "versionNonce": 1085864403, "isDeleted": false, "id": "Ua_J40SfhXyeAd-Z97hiN", "fillStyle": "hachure", @@ -772,7 +595,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "startBinding": null, @@ -793,8 +616,8 @@ }, { "type": "line", - "version": 621, - "versionNonce": 1039000788, + "version": 629, + "versionNonce": 2127200477, "isDeleted": false, "id": "RQYWMlM3DABMQJxtGYJXR", "fillStyle": "hachure", @@ -813,7 +636,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "startBinding": null, @@ -834,8 +657,8 @@ }, { "type": "text", - "version": 276, - "versionNonce": 646149228, + "version": 284, + "versionNonce": 163596147, "isDeleted": false, "id": "EohfXBF_ChzOXb26jwh5C", "fillStyle": "hachure", @@ -854,7 +677,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660946733297, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 16, @@ -868,8 +691,8 @@ }, { "type": "text", - "version": 356, - "versionNonce": 1177231956, + "version": 364, + "versionNonce": 652238141, "isDeleted": false, "id": "HyVVUIQKFQuD09qT1bRwZ", "fillStyle": "hachure", @@ -888,7 +711,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660946723727, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 16, @@ -902,8 +725,8 @@ }, { "type": "text", - "version": 74, - "versionNonce": 642471148, + "version": 95, + "versionNonce": 2077024531, "isDeleted": false, "id": "0khJ3P1VsWyHIbWAkrpHE", "fillStyle": "hachure", @@ -912,12 +735,12 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 605.5, - "y": 518.5, + "x": 583, + "y": 526, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 93, - "height": 40, + "width": 138, + "height": 25, "seed": 346821974, "groupIds": [], "strokeSharpness": "sharp", @@ -943,22 +766,22 @@ "type": "arrow" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, - "fontSize": 16, + "fontSize": 20, "fontFamily": 1, - "text": "Torch-MLIR\nDialect", - "baseline": 34, + "text": "Torch Dialect", + "baseline": 18, "textAlign": "center", "verticalAlign": "middle", "containerId": null, - "originalText": "Torch-MLIR\nDialect" + "originalText": "Torch Dialect" }, { "type": "text", - "version": 402, - "versionNonce": 913424340, + "version": 410, + "versionNonce": 482061725, "isDeleted": false, "id": "aonYH0YvXUyexHmAKJd9H", "fillStyle": "hachure", @@ -994,7 +817,7 @@ "id": "DTYYOEYxneSWWLSsMj-QA" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 14.909090909090922, @@ -1008,8 +831,8 @@ }, { "type": "text", - "version": 604, - "versionNonce": 1919539052, + "version": 612, + "versionNonce": 1616201395, "isDeleted": false, "id": "BKdt39so0UxlQhr92ms3M", "fillStyle": "hachure", @@ -1045,7 +868,7 @@ "id": "DTYYOEYxneSWWLSsMj-QA" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 14.909090909090922, @@ -1059,8 +882,8 @@ }, { "type": "arrow", - "version": 227, - "versionNonce": 592618836, + "version": 590, + "versionNonce": 1871573501, "isDeleted": false, "id": "3jV4ltqqNRgUJ_hiQTprf", "fillStyle": "hachure", @@ -1069,77 +892,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 475.15209996893157, - "y": 472, + "x": 574.6729760318743, + "y": 472.734375, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 177.46230293328904, - "height": 39, + "width": 58.86827900677804, + "height": 33.39569091796875, "seed": 91937098, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "startBinding": { "elementId": "rxna1NdNTTOVVZAztGxaH", - "focus": 0.8787723093564095, - "gap": 10.5 + "focus": 0.5238679258733953, + "gap": 11.234375 }, "endBinding": { - "elementId": "0khJ3P1VsWyHIbWAkrpHE", - "focus": 0.9144895152973532, - "gap": 7.5 - }, - "lastCommittedPoint": null, - "startArrowhead": null, - "endArrowhead": "arrow", - "points": [ - [ - 0, - 0 - ], - [ - 177.46230293328904, - 39 - ] - ] - }, - { - "type": "arrow", - "version": 157, - "versionNonce": 1821641196, - "isDeleted": false, - "id": "TE1j6kxZKej3YsuMPcxgT", - "fillStyle": "hachure", - "strokeWidth": 1, - "strokeStyle": "solid", - "roughness": 1, - "opacity": 100, - "angle": 0, - "x": 649.6665160954531, - "y": 475, - "strokeColor": "#000000", - "backgroundColor": "transparent", - "width": 0.35695639494633724, - "height": 35, - "seed": 2060134986, - "groupIds": [], - "strokeSharpness": "round", - "boundElements": [], - "updated": 1660940133882, - "link": null, - "locked": false, - "startBinding": { - "elementId": "lC49obx_HZvLUDhBFwN3d", - "focus": 0.047058823529411764, - "gap": 11.5 - }, - "endBinding": { - "elementId": "0khJ3P1VsWyHIbWAkrpHE", - "focus": -0.06382978723404255, - "gap": 8.5 + "elementId": "HXNEs54Djw-u5oqv0I0RN", + "focus": 0.2045438295507551, + "gap": 5.57952880859375 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -1150,15 +924,15 @@ 0 ], [ - -0.35695639494633724, - 35 + 58.86827900677804, + 33.39569091796875 ] ] }, { "type": "arrow", - "version": 177, - "versionNonce": 1576721108, + "version": 620, + "versionNonce": 1737428805, "isDeleted": false, "id": "pNcSwccuMNO6_J-0ec8fZ", "fillStyle": "hachure", @@ -1167,28 +941,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 806.5768512080433, - "y": 469, + "x": 763.4915969820063, + "y": 469.45166015625, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 155.95625943644677, - "height": 41, + "width": 78.581455390869, + "height": 35.96331787109375, "seed": 1813501834, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663213438513, "link": null, "locked": false, "startBinding": { "elementId": "lzgAzH8DMNlzk1SEvenUB", - "focus": -0.8987651623136664, - "gap": 10.977876106194685 + "focus": -0.7075818920536863, + "gap": 9.018586555413435 }, "endBinding": { - "elementId": "0khJ3P1VsWyHIbWAkrpHE", - "focus": -0.8956714761376248, - "gap": 8.5 + "elementId": "HXNEs54Djw-u5oqv0I0RN", + "focus": -0.25918721184196697, + "gap": 6.29461669921875 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -1199,15 +973,15 @@ 0 ], [ - -155.95625943644677, - 41 + -78.581455390869, + 35.96331787109375 ] ] }, { "type": "arrow", - "version": 673, - "versionNonce": 936635500, + "version": 776, + "versionNonce": 1118008819, "isDeleted": false, "id": "DTYYOEYxneSWWLSsMj-QA", "fillStyle": "hachure", @@ -1216,28 +990,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 626.2705408977356, - "y": 567, + "x": 630.7870274560897, + "y": 570.8671264648438, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 27.138599512812107, - "height": 36.913818359375, + "width": 31.954930839599342, + "height": 41.81890869140625, "seed": 1140554966, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "startBinding": { - "elementId": "0khJ3P1VsWyHIbWAkrpHE", - "gap": 8.5, - "focus": 0.07962382445141065 + "elementId": "HXNEs54Djw-u5oqv0I0RN", + "focus": 0.01832168089045514, + "gap": 4.15753173828125 }, "endBinding": { "elementId": "Rjt45nyi1UlloVmswsnId", - "gap": 11.75262451171875, - "focus": -0.39929698566253613 + "focus": -0.3155931326251116, + "gap": 2.98040771484375 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -1248,15 +1022,15 @@ 0 ], [ - -27.138599512812107, - 36.913818359375 + -31.954930839599342, + 41.81890869140625 ] ] }, { "type": "text", - "version": 313, - "versionNonce": 1358979156, + "version": 321, + "versionNonce": 1139627709, "isDeleted": false, "id": "HWSOxxZBB7Y41X2hNmfVY", "fillStyle": "hachure", @@ -1275,7 +1049,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 16, @@ -1289,8 +1063,8 @@ }, { "type": "rectangle", - "version": 961, - "versionNonce": 99843820, + "version": 969, + "versionNonce": 1028013971, "isDeleted": false, "id": "i5ZeSwnhCon-_MtxYNUZP", "fillStyle": "hachure", @@ -1314,14 +1088,14 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "text", - "version": 678, - "versionNonce": 607729108, + "version": 686, + "versionNonce": 1357739805, "isDeleted": false, "id": "6RXNGjHK6CGiBPsckoqoq", "fillStyle": "hachure", @@ -1357,7 +1131,7 @@ "id": "DTYYOEYxneSWWLSsMj-QA" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 14.909090909090922, @@ -1371,8 +1145,8 @@ }, { "type": "rectangle", - "version": 941, - "versionNonce": 1065140588, + "version": 949, + "versionNonce": 404673843, "isDeleted": false, "id": "7Eot8G67eEcBglL1uSBIw", "fillStyle": "hachure", @@ -1396,14 +1170,14 @@ "id": "mhl9dSP2l8IK7eFYvftAg" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false }, { "type": "text", - "version": 672, - "versionNonce": 347822932, + "version": 680, + "versionNonce": 499813245, "isDeleted": false, "id": "31Juzr7aHXLZC8omDLfZp", "fillStyle": "hachure", @@ -1439,7 +1213,7 @@ "id": "DTYYOEYxneSWWLSsMj-QA" } ], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "fontSize": 14.909090909090922, @@ -1453,8 +1227,8 @@ }, { "type": "text", - "version": 93, - "versionNonce": 733590508, + "version": 135, + "versionNonce": 982472485, "isDeleted": false, "id": "92OV5fS7X4ZKvuJQDlQtC", "fillStyle": "hachure", @@ -1463,8 +1237,8 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 532.7581787109375, - "y": 75.37094116210938, + "x": 547.06591796875, + "y": 86.42996215820312, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 247, @@ -1473,7 +1247,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663213430429, "link": null, "locked": false, "fontSize": 20, @@ -1487,8 +1261,8 @@ }, { "type": "freedraw", - "version": 122, - "versionNonce": 1289539796, + "version": 130, + "versionNonce": 180753373, "isDeleted": false, "id": "q5GjA5dVEf3xpkGXxgY_x", "fillStyle": "hachure", @@ -1507,7 +1281,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561032, "link": null, "locked": false, "points": [ @@ -1614,8 +1388,8 @@ }, { "type": "freedraw", - "version": 132, - "versionNonce": 414471788, + "version": 140, + "versionNonce": 2067489907, "isDeleted": false, "id": "xuNFYEjvBADrQPPJPBPdW", "fillStyle": "hachure", @@ -1634,7 +1408,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "points": [ @@ -1781,8 +1555,8 @@ }, { "type": "freedraw", - "version": 143, - "versionNonce": 1010013780, + "version": 151, + "versionNonce": 1945210941, "isDeleted": false, "id": "XK6Yvtn8CG0Xfr9Bjbrkp", "fillStyle": "hachure", @@ -1801,7 +1575,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "points": [ @@ -1992,8 +1766,8 @@ }, { "type": "rectangle", - "version": 866, - "versionNonce": 1865029356, + "version": 874, + "versionNonce": 702307859, "isDeleted": false, "id": "tuK_yULMKLM4aneU8P4e0", "fillStyle": "hachure", @@ -2029,14 +1803,14 @@ "type": "arrow" } ], - "updated": 1660940747834, + "updated": 1663212561033, "link": null, "locked": false }, { "type": "rectangle", - "version": 973, - "versionNonce": 1944283756, + "version": 981, + "versionNonce": 925794461, "isDeleted": false, "id": "Rjt45nyi1UlloVmswsnId", "fillStyle": "hachure", @@ -2072,14 +1846,14 @@ "type": "arrow" } ], - "updated": 1660940753156, + "updated": 1663212561033, "link": null, "locked": false }, { "type": "text", - "version": 136, - "versionNonce": 1053469548, + "version": 144, + "versionNonce": 2037195699, "isDeleted": false, "id": "955ApbANtg1nM0T4L_lPF", "fillStyle": "solid", @@ -2098,7 +1872,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "fontSize": 20, @@ -2112,8 +1886,8 @@ }, { "type": "text", - "version": 80, - "versionNonce": 1763250516, + "version": 88, + "versionNonce": 2049860861, "isDeleted": false, "id": "4aYMFG5z_f5D-5B73ME6f", "fillStyle": "solid", @@ -2132,7 +1906,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "fontSize": 20, @@ -2146,8 +1920,8 @@ }, { "type": "rectangle", - "version": 1119, - "versionNonce": 498753004, + "version": 1127, + "versionNonce": 1377056083, "isDeleted": false, "id": "JMk3kAkopUnnVgQxEftpW", "fillStyle": "hachure", @@ -2179,14 +1953,14 @@ "type": "arrow" } ], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false }, { "type": "text", - "version": 208, - "versionNonce": 1763197932, + "version": 216, + "versionNonce": 812825949, "isDeleted": false, "id": "ZDNsEoxAd0IRaNlRtJrxj", "fillStyle": "solid", @@ -2210,7 +1984,7 @@ "type": "arrow" } ], - "updated": 1660940755597, + "updated": 1663212561033, "link": null, "locked": false, "fontSize": 20, @@ -2224,8 +1998,8 @@ }, { "type": "rectangle", - "version": 1282, - "versionNonce": 108840276, + "version": 1290, + "versionNonce": 1395135219, "isDeleted": false, "id": "GC28VKCyldd4DkqpJ6x5L", "fillStyle": "hachure", @@ -2265,14 +2039,14 @@ "type": "arrow" } ], - "updated": 1660940758987, + "updated": 1663212561033, "link": null, "locked": false }, { "type": "text", - "version": 249, - "versionNonce": 287034452, + "version": 257, + "versionNonce": 1632438717, "isDeleted": false, "id": "2WJXGAKYpqP6z56YGe8wd", "fillStyle": "solid", @@ -2291,7 +2065,7 @@ "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "fontSize": 20, @@ -2305,8 +2079,8 @@ }, { "type": "arrow", - "version": 953, - "versionNonce": 2023019244, + "version": 1043, + "versionNonce": 1831812243, "isDeleted": false, "id": "mEM1iJl3apSkidvjPFb07", "fillStyle": "hachure", @@ -2315,28 +2089,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 586.0586122160754, - "y": 567.7095947265625, + "x": 583.8660462981067, + "y": 567.8035888671875, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 169.99376240103805, - "height": 44.79620361328125, + "width": 143.3845705065068, + "height": 39.77276611328125, "seed": 1848065302, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { "elementId": "HXNEs54Djw-u5oqv0I0RN", - "gap": 1, - "focus": -0.16397353424924313 + "focus": -0.13464165997329347, + "gap": 1.093994140625 }, "endBinding": { "elementId": "tuK_yULMKLM4aneU8P4e0", - "gap": 1, - "focus": -0.7421691796813538 + "focus": -0.6969599460204285, + "gap": 5.929443359375 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -2347,15 +2121,15 @@ 0 ], [ - -169.99376240103805, - 44.79620361328125 + -143.3845705065068, + 39.77276611328125 ] ] }, { "type": "arrow", - "version": 782, - "versionNonce": 638001620, + "version": 915, + "versionNonce": 769575453, "isDeleted": false, "id": "x63UEL7zv_DhnLWWouZUy", "fillStyle": "hachure", @@ -2364,28 +2138,28 @@ "roughness": 1, "opacity": 100, "angle": 0, - "x": 697.741418897031, - "y": 566.218505859375, + "x": 696.9856629899332, + "y": 568.0851440429688, "strokeColor": "#000000", "backgroundColor": "transparent", - "width": 16.700023534062893, - "height": 36.4176025390625, + "width": 18.79687984870543, + "height": 44.84625244140625, "seed": 1875803402, "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { - "elementId": "0khJ3P1VsWyHIbWAkrpHE", - "focus": -0.5933125414432678, - "gap": 7.718505859375 + "elementId": "HXNEs54Djw-u5oqv0I0RN", + "focus": -0.2090425463838958, + "gap": 1.37554931640625 }, "endBinding": { "elementId": "JMk3kAkopUnnVgQxEftpW", - "focus": 0.26991784070863634, - "gap": 11.78057861328125 + "focus": 0.19601132632090307, + "gap": 1.48529052734375 }, "lastCommittedPoint": null, "startArrowhead": null, @@ -2396,15 +2170,15 @@ 0 ], [ - 16.700023534062893, - 36.4176025390625 + 18.79687984870543, + 44.84625244140625 ] ] }, { "type": "arrow", - "version": 958, - "versionNonce": 2060686700, + "version": 966, + "versionNonce": 1004548659, "isDeleted": false, "id": "CPvTKrc3_ABgC6tI8JY9-", "fillStyle": "hachure", @@ -2423,7 +2197,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940133882, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { @@ -2452,8 +2226,8 @@ }, { "type": "arrow", - "version": 1135, - "versionNonce": 1210220628, + "version": 1143, + "versionNonce": 387546749, "isDeleted": false, "id": "fg1evGb2SXtz9bEkOlLIc", "fillStyle": "hachure", @@ -2472,7 +2246,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940763344, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { @@ -2497,8 +2271,8 @@ }, { "type": "arrow", - "version": 1178, - "versionNonce": 99515604, + "version": 1186, + "versionNonce": 539063251, "isDeleted": false, "id": "pIXLNNqmJXXOzYe5WlYqa", "fillStyle": "hachure", @@ -2517,7 +2291,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940753156, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { @@ -2542,8 +2316,8 @@ }, { "type": "arrow", - "version": 1176, - "versionNonce": 735824724, + "version": 1184, + "versionNonce": 62826205, "isDeleted": false, "id": "1COZPP792gFA4J2p8SUL6", "fillStyle": "hachure", @@ -2562,7 +2336,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940755596, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { @@ -2587,8 +2361,8 @@ }, { "type": "arrow", - "version": 1185, - "versionNonce": 569227116, + "version": 1193, + "versionNonce": 436830579, "isDeleted": false, "id": "H9ED9PW7ahwdjILa_abi_", "fillStyle": "hachure", @@ -2607,7 +2381,7 @@ "groupIds": [], "strokeSharpness": "round", "boundElements": [], - "updated": 1660940758987, + "updated": 1663212561033, "link": null, "locked": false, "startBinding": { diff --git a/docs/Torch-MLIR.png b/docs/Torch-MLIR.png new file mode 100644 index 0000000000000..ce85b239a653f Binary files /dev/null and b/docs/Torch-MLIR.png differ diff --git a/docs/architecture.md b/docs/architecture.md index ebfb9029d4f07..3b19cf37d5f6d 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -445,4 +445,4 @@ characteristics. ### Presentations and Talks * 2021-10-07: MLIR ODM: Introduction to Torch-MLIR. ([recording](https://www.youtube.com/watch?v=QbNkex-gizs) and [slides](https://docs.google.com/presentation/d/1ZhzfE4EK6XV7AdQTYicrsE_OYjkER_yiB0vBeszRfzY/edit#slide=id.gf56404f79c_1_55)) -* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U&t=2374s) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) +* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) diff --git a/docs/development.md b/docs/development.md index 46bdb9c5609b1..f990fed0a0fe7 100644 --- a/docs/development.md +++ b/docs/development.md @@ -313,7 +313,10 @@ Torch-MLIR has two types of tests: 2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. For example, these might involve using `torch-mlir-opt` to run a pass and - check the output with `FileCheck`. + check the output with `FileCheck`. These tests usually live in the `test/` + directory with a parallel file naming scheme to the `lib/*` structure. + More details about this kind of test is available in the upstream + [LLVM Testing Guide](https://llvm.org/docs/TestingGuide.html#regression-test-structure). ## Running execution (end-to-end) tests: diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 623a7739d0a93..7c5a18b8fd81b 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -68,7 +68,12 @@ def _get_argparse(): parser.add_argument('-s', '--sequential', default=False, action='store_true', - help='run e2e tests sequentially rather than in parallel') + help='''Run tests sequentially rather than in parallel. +This can be useful for debugging, since it runs the tests in the same process, +which make it easier to attach a debugger or get a stack trace.''') + parser.add_argument('--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed', + metavar="TEST", type=str, nargs='+', + help='A set of tests to not attempt to run, since they crash and cannot be XFAILed.') return parser def main(): @@ -102,9 +107,17 @@ def main(): config = LazyTensorCoreTestConfig() xfail_set = LTC_XFAIL_SET + do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []) + available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt] + if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None: + for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed: + if arg not in all_test_unique_names: + print(f'ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument "{arg}" is not a valid test name') + sys.exit(1) + # Find the selected tests, and emit a diagnostic if none are found. tests = [ - test for test in GLOBAL_TEST_REGISTRY + test for test in available_tests if re.match(args.filter, test.unique_name) ] if len(tests) == 0: @@ -112,7 +125,7 @@ def main(): f'ERROR: the provided filter {args.filter!r} does not match any tests' ) print('The available tests are:') - for test in GLOBAL_TEST_REGISTRY: + for test in available_tests: print(test.unique_name) sys.exit(1) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index b05632ec726b3..c833c9f409fc2 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -18,13 +18,103 @@ # RefBackend fails "TableBatchEmbeddingModule_basic", "QuantizedMLP_basic", - "Matmul_vecmat" + "Matmul_vecmat", + "BatchMlpLayerModule_basic" } MHLO_PASS_SET = { + "BroadcastToIdentityCaseStaticModule_basic", + "ArangeDtypeFloatModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeIntModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeNegativeStartIntModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartIntModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartNegativeStepIntModule_basic", + "ArangeStartStepFloatModule_basic", + "ArangeStartStepIntModule_basic", + "ArangeZeroElementOutputModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampMaxModule_basic", + "BmmModule_basic", + "BroadcastToModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseNegModule_basic", + "ElementwiseSqrtModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseUnsqueezeNegDimsModule_basic", + "ElementwiseToDtypeF32ToI64Module_basic", + "ElementwiseAddModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseGeluModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseMulScalarModule_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseMulScalarModule_int", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeIntScalarModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseRelu6Module_basic", + "ElementwiseReluModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ExpandAsIntModule_basic", + "ExpandModule_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2DStatic_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleFloat2D_basic", + "FullModuleFloat3D_basic", + "FullModuleInt2D_basic", + "FullModuleInt3D_basic", + "MatmulBroadcastBatchDim_basic", + "MatmulSingleDynamicBatchDim_basic", + "Matmul_3d", + "Matmul_4d", + "MeanDtypeModule_basic", + "MmTanhModule_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListKeepDimFloatModule_basic", + "ReduceSumDimIntListKeepDimIntModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "SelectIntModule_basic", + "SliceSingleIdxModule_basic", + "SqueezeDimModule_dynamic", + "SqueezeDimModule_negDim", + "ReduceFrobeniusNormModule_basic", "FlattenStaticModule_basic", "FlattenRank0Module_basic", "TensorsConcatNegativeDimModule_basic", + "LiftFreshCopyModule_basic", + "Mlp2LayerModuleNoBias_basic", "NumelModule_basic", "ReduceSumDimIntListEmptyDimModule_basic", "SqueezeModule_allUnitDim", @@ -81,6 +171,11 @@ "MaxPool2dStaticModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", "ZerosModuleDefaultDtype_basic", "ZerosModuleInt2D_basic", "ZerosModuleInt3D_basic", @@ -91,6 +186,10 @@ "OnesModuleInt_basic", "OnesModuleFloat_basic", "OnesModuleFalsePinMemory_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", "NewZerosModuleDefaultDtype_basic", "NewZerosModuleInt2D_basic", "NewZerosModuleInt3D_basic", @@ -158,6 +257,18 @@ "TensorOpaqueLiteralModule_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", + "ToDtypeBoolLayoutNoneModule_basic", + "ToDtypeLayoutNoneModule_basic", + "ToDtypeLayoutStridedModule_basic", + "TypeAsSameModule_basic", + "TypeConversionF32ToF64Module_basic", + "TypeConversionF64ToF32Module_basic", + "TypeConversionI1ToF32Module_basic", + "TypeConversionI1ToF64Module_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI1ToI64Module_basic", + "TypeConversionI32ToI64Module_basic", + "TypeConversionI64ToI32Module_basic", "OnesModuleCPUDevice_basic", "Permute0RankModule_basic", "UnsafeViewCollapseModule_basic", @@ -167,6 +278,8 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", "ElementwiseUnaryModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseSigmoidModule_basic", @@ -243,6 +356,8 @@ "ElementwiseMulScalarModule_float", "ElementwiseCeilModule_basic", "ElementwiseReciprocalModule_basic", + "ElementwiseNotIntegerModule_basic", + "ElementwiseNotInt32Module_basic", "TypePromotionAlphaWiderModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "BatchNorm1DModule_basic", @@ -313,6 +428,7 @@ "Convolution2DStaticModule_basic", "ElementwiseNegModule_basic", "TestMultipleTensorReturn_basic", + "TypeAsSameModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "BaddbmmDynamicModule_basic", "BaddbmmStaticModule_basic", @@ -332,6 +448,9 @@ "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "_LogSoftmaxModuleStable_basic", + "LiftFreshCopyModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "BroadcastToIdentityCaseStaticModule_basic", } LTC_XFAIL_SET = { @@ -434,6 +553,9 @@ "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousDynamic_basic", "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "Matmul_dot", "Matmul_matvec", "MulIntModule_basic", diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td index 9a641bcba6976..a4093cb651657 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td @@ -43,6 +43,7 @@ def TMTensor_Dialect : Dialect { to. }]; let hasCanonicalizer = 1; + let emitAccessorPrefix = kEmitAccessorPrefix_Raw; } //===----------------------------------------------------------------------===// diff --git a/externals/llvm-project b/externals/llvm-project index 00d648bdb5a8b..458598ccc50c5 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 00d648bdb5a8b71785269b4851b651c883de2cd9 +Subproject commit 458598ccc50c5118107f05d60f3d043772a91f26 diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 305a2f2522966..cd9da150e729f 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 305a2f25229660ea789bf70ed8e7336227f6228a +Subproject commit cd9da150e729fd046109e7962e5f63f5fe067a3b diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index e2f2b5061e120..faaeaa7a36e31 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -158,6 +158,51 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [ }]; } +def Torch_AtenRelu6Op : Torch_Op<"aten.relu6", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::relu6 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRelu6Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRelu6Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenRelu6_Op : Torch_Op<"aten.relu6_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::relu6_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRelu6_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRelu6_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [ AllowsTypeRefinement, HasValueSemantics, @@ -2851,6 +2896,30 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [ }]; } +def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$exponent + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPowTensorTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPowTensorTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ AllowsTypeRefinement, HasValueSemantics, @@ -4459,6 +4528,31 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ }]; } +def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFrobeniusNormDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenFrobeniusNormDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, @@ -5163,6 +5257,29 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ }]; } +def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::lift_fresh_copy : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLiftFreshCopyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLiftFreshCopyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ AllowsTypeRefinement, ReadOnly @@ -5589,6 +5706,30 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ }]; } +def Torch_AtenIndexTensorHackedTwinOp : Torch_Op<"aten.index.Tensor_hacked_twin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTensorType:$indices + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexTensorHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenIndexTensorHackedTwinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ AllowsTypeRefinement, HasValueSemantics, @@ -6133,6 +6274,7 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenViewOp : Torch_Op<"aten.view", [ @@ -8791,6 +8933,7 @@ def Torch_AtenDivOp : Torch_Op<"aten.div", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenAddOp : Torch_Op<"aten.add", [ @@ -8817,6 +8960,55 @@ def Torch_AtenAddOp : Torch_Op<"aten.add", [ }]; } +def Torch_AtenSubOp : Torch_Op<"aten.sub", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sub : (Scalar, Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a, + AnyTorchScalarType:$b + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSubOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenSubOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenCeilScalarOp : Torch_Op<"aten.ceil.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ceil.Scalar : (Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeilScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeilScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenSqrtIntOp : Torch_Op<"aten.sqrt.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchBase.td b/include/torch-mlir/Dialect/Torch/IR/TorchBase.td index 6adb733aab4f9..a5e8767e6fa76 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchBase.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchBase.td @@ -37,6 +37,7 @@ def Torch_Dialect : Dialect { let hasRegionArgAttrVerify = 1; let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 0; + let emitAccessorPrefix = kEmitAccessorPrefix_Raw; let extraClassDeclaration = [{ /// Parse a type registered to this dialect. @@ -56,5 +57,6 @@ def ReadOnly : TorchOpTrait<"ReadOnly">; def IsTrailingUnderscoreInplaceVariant : TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">; def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">; +def AllowedInModuleInitializer : TorchOpTrait<"AllowedInModuleInitializer">; #endif // TORCH_BASE diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index fae78b45ae373..df73855a328a9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -251,7 +251,8 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [ IsolatedFromAbove, - SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp"> + SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">, + AllowedInModuleInitializer, ]> { let summary = "Module initializer for all `torch.global_slot` ops"; let description = [{ @@ -277,7 +278,9 @@ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [ Terminator, - HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> { + HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">, + AllowedInModuleInitializer, + ]> { let summary = "Terminator for torch.global_slot.module_initializer region"; let description = [{ Atomically updates the value of all the global slots named in `slotSymNames` @@ -375,8 +378,9 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ NoSideEffect, TypesMatchWith<"contained types correspond to operand types", "elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))", - "isValidSubtype"> - ]> { + "isValidSubtype">, + AllowedInModuleInitializer, + ]> { let summary = "TorchScript prim::TupleConstruct op"; let description = [{ Note: This op does not allow trivial type refinement, because the @@ -398,7 +402,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ NoSideEffect, AllowsTypeRefinement, - ]> { + AllowedInModuleInitializer, + ]> { let summary = "TorchScript prim::ListConstruct op"; let arguments = (ins @@ -418,7 +423,8 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ AllowsTypeRefinement, SameVariadicOperandSize, - ]> { + AllowedInModuleInitializer, + ]> { let summary = "TorchScript prim::DictConstruct op"; let arguments = (ins @@ -650,9 +656,12 @@ def Torch_PrimExitOp : Torch_Op<"prim.Exit", []> { // Ops corresponding to prim::Constant //===----------------------------------------------------------------------===// -def Torch_ConstantNoneOp : Torch_Op<"constant.none", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantNoneOp : Torch_Op<"constant.none", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Get the singleton None value."; let description = [{ Not to be confused with the `mlir::NoneType`. Be careful to use @@ -664,9 +673,12 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none", let hasFolder = 1; } -def Torch_ConstantStrOp : Torch_Op<"constant.str", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantStrOp : Torch_Op<"constant.str", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant str value."; let description = [{ Note: Strings in Python (and TorchScript) are immutable. @@ -697,9 +709,12 @@ def Torch_ConstantDeviceOp : Torch_Op<"constant.device", let assemblyFormat = "$value attr-dict"; } -def Torch_ConstantIntOp : Torch_Op<"constant.int", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantIntOp : Torch_Op<"constant.int", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant `int` value."; let description = [{ Note: TorchScript represents integers as 64-bit signed values, unlike @@ -716,9 +731,12 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int", let hasFolder = 1; } -def Torch_ConstantFloatOp : Torch_Op<"constant.float", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantFloatOp : Torch_Op<"constant.float", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant `float` value."; let description = [{ Note: TorchScript represents `float` as 64-bit floating point values. @@ -735,9 +753,34 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float", let hasFolder = 1; } -def Torch_ConstantBoolOp : Torch_Op<"constant.bool", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantNumberOp : Torch_Op<"constant.number", + [ConstantLike, NoSideEffect]> { + let summary = "Materialize a constant `number` value."; + let description = [{ + This op is used as a workaround to the fact that the constant + materialization in MLIR must materialize a constant with a single op. + To materialize ops with a static `!torch.number` type, we must use this op, + even though we statically know if it is an integer or a float. + + Note: This op unconditionally canonicalizes to + `torch.constant.{float,int}` + `torch.derefine` + }]; + let arguments = (ins + AnyAttrOf<[F64Attr, I64Attr]>:$value + ); + let results = (outs + Torch_NumberType:$result + ); + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant `bool` value."; let description = [{ }]; @@ -808,7 +851,8 @@ def Torch_OperatorOp : Torch_Op<"operator", [ } def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ - AllowsTypeRefinement + AllowsTypeRefinement, + AllowedInModuleInitializer, ]> { let summary = "Create a `!torch.LinearParams`"; let arguments = (ins @@ -823,7 +867,8 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ } def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [ - AllowsTypeRefinement + AllowsTypeRefinement, + AllowedInModuleInitializer, ]> { let summary = "Create a per-tensor-affine quantized tensor"; let description = [{ @@ -854,6 +899,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [ def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [ DeclareOpInterfaceMethods, AllowsTypeRefinement, + AllowedInModuleInitializer, ]> { let summary = "Create a value of !torch.tensor type from a literal"; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h index 23b4c2ffe2dba..20f1bc1098854 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h @@ -56,6 +56,14 @@ template class AllowsTypeRefinement : public ::mlir::OpTrait::TraitBase {}; +// If a Torch op has this trait, it means that the op is allowed to be used +// in the module initializer. Only a small set of ops are permitted in the +// module initializer. These ops are essentially those which can be produced +// by the IValue importer. +template +class AllowedInModuleInitializer + : public ::mlir::OpTrait::TraitBase {}; + } // namespace OpTrait } // namespace Torch } // namespace torch diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 61872ae9fc1fc..aacd7566e4531 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -31,6 +31,10 @@ Type getTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); +Type getTypeForTorchType( + MLIRContext *context, Type type, + mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); + Type getTorchTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index f8b749768d7c9..fd350da1d61ec 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -24,21 +24,30 @@ namespace TorchConversion { /// Creates a pipeline that lowers from the torch backend contract to the /// linalg-on-tensors backend contract. -void createTorchBackendToLinalgOnTensorsBackendPipeline( - OpPassManager &pm, - const torch::Torch::TorchLoweringPipelineOptions &options); +void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline( - OpPassManager &pm, - const torch::Torch::TorchLoweringPipelineOptions &options); +void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); // Do not register the torch-to-mhlo pipeline if mhlo target is disabled #ifdef TORCH_MLIR_ENABLE_MHLO +struct MhloBackendPipelineOptions + : public PassPipelineOptions { + Option enableStaticShape{ + *this, "enable-static-shape", + llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option enableI32Index{ + *this, "enable-i32-index", + llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"), + llvm::cl::init(false)}; +}; + void createTorchBackendToMhloBackendPipeline( - OpPassManager &pm, - const torch::Torch::TorchLoweringPipelineOptions &options); + OpPassManager &pm, const MhloBackendPipelineOptions &options); std::unique_ptr> createVerifyMhloBackendContractPass(); #endif diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 98f1acb75e054..ff8a9a16ddcbf 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -11,6 +11,7 @@ #ifdef TORCH_MLIR_ENABLE_MHLO #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Transforms/passes.h" #endif // TORCH_MLIR_ENABLE_MHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" @@ -34,5 +35,8 @@ void mlir::torch::registerConversionPasses() { ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { return mlir::mhlo::createLegalizeHloToLinalgPass(); }); + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::createSymbolicShapeOptimizationPass(); + }); #endif // TORCH_MLIR_ENABLE_MHLO } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9482187e5bed3..9c01db32c1077 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -296,6 +296,21 @@ class ConvertAtenMatmulOp : public OpConversionPattern { op, "unable to perform broadcast operation"); } + if (maxRank == 3) { + Value zeroTensor = createZeroInitTensor( + rewriter, loc, + ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1}, + elementType); + Value matmul = + rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + // Check if the result of the matrix multiplication has more than one // dynamic batch dimensions. ArrayRef batchDimsInt = resultType.getShape().drop_back(2); @@ -454,176 +469,6 @@ class ConvertAtenBmmOp : public OpConversionPattern { }; } // namespace -namespace { -// See comments at in convertMmOp and the heading for this section for general -// considerations. This function needs to be auto-generated. -class ConvertAtenLinearOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); - Location loc = op->getLoc(); - Value input = adaptor.input(); - Value weight = adaptor.weight(); - Value bias = adaptor.bias(); - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - auto inputType = input.getType().cast(); - auto weightType = weight.getType().cast(); - - if (inputType.getRank() != 2 && inputType.getRank() != 3) { - return rewriter.notifyMatchFailure( - op, "expected input to be rank 2 or rank 3"); - } - - if (!bias.getType().isa()) { - auto biasType = bias.getType().cast(); - // Only handle the case of rank 2 `weight` for now. - // TODO: Insert the appropriate reshape to collapse any leading dimensions. - if (weightType.getRank() != 2 || biasType.getRank() != 1) { - return rewriter.notifyMatchFailure( - op, "expected weight to be rank 2 and bias to be rank 1"); - } - // TODO: Handle type promotion. What are ATen's promotion rules? - if (inputType.getElementType() != weightType.getElementType() || - inputType.getElementType() != biasType.getElementType()) { - return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); - } - // TODO: We can handle a static size 1 here at some complexity cost, but the - // dynamic case is not representable in linalg. We don't handle either for - // now. Biases are generally statically shaped for most models (since for - // inference they are constants, and for training they don't change shape - // typically), so this is not too constraining. - auto biasSize = bias.getType().cast().getShape()[0]; - if (biasSize == 1 || biasSize == ShapedType::kDynamicSize) - return rewriter.notifyMatchFailure( - op, "unimplemented: size-1 broadcasting for aten::LinearOp"); - } - - - Value batchDim = nullptr; - int restDim = 0; - if (inputType.getRank() == 3) { - batchDim = getDimOp(rewriter, loc, input, 0); - restDim = 1; - } - - Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0); - Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1); - Value weightDim0 = getDimOp(rewriter, loc, weight, 0); - Value weightDim1 = getDimOp(rewriter, loc, weight, 1); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, inputDim1, weightDim1); - rewriter.create( - loc, contractingDimEqual, - rewriter.getStringAttr( - "mismatching contracting dimension for aten.linear")); - - if (!bias.getType().isa()) { - Value biasDim0 = getDimOp(rewriter, loc, bias, 0); - // Here we take advantage of ruling out the size-1 case above. - // In the static-size-1 case, we will not emit this check at all. - Value biasSizeCorrect = rewriter.create( - loc, arith::CmpIPredicate::eq, weightDim0, biasDim0); - rewriter.create( - loc, biasSizeCorrect, - rewriter.getStringAttr("mismatching bias size for aten.linear")); - } - - Value initTensor; - SmallVector broadcastIndexingMaps; - Value transposedWeightInitTensor; - if (inputType.getRank() > 2) { - initTensor = rewriter.create( - loc, ValueRange{batchDim, inputDim0, weightDim0}, - inputType.getElementType()); - transposedWeightInitTensor = rewriter.create( - loc, ValueRange{batchDim, weightDim1, weightDim0}, - weightType.getElementType()); - broadcastIndexingMaps = { - AffineMap::get( - /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(1 + restDim)}, context), - rewriter.getMultiDimIdentityMap(inputType.getRank())}; - } else { - initTensor = rewriter.create( - loc, ValueRange{inputDim0, weightDim0}, - inputType.getElementType()); - transposedWeightInitTensor = rewriter.create( - loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType()); - broadcastIndexingMaps = { - AffineMap::get( - /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(1)}, context), - rewriter.getMultiDimIdentityMap(inputType.getRank())}; - } - - SmallVector iteratorTypes(inputType.getRank(), "parallel"); - Value broadcasted; - if (!bias.getType().isa()) { - broadcasted = - rewriter - .create( - loc, initTensor.getType(), bias, initTensor, - /*indexingMaps=*/broadcastIndexingMaps, - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - } else { - Type elementType = - input.getType().cast().getElementType(); - Value c0float = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - broadcasted = rewriter.create(loc, c0float, initTensor) - .getResult(0); - } - // We need a matmul with dimension ordering (N, K) * (M, K), so transpose - // the weights to fit into linalg::MatmulOp which is (N, K) * (K, M). - // TODO: This whole aten.linear lowering should eventually be generated from - // a single linalg ODS generator statement. Both the bias and matmul part. - SmallVector transposeIndexingMaps = { - AffineMap::get( - /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(1 + restDim), - rewriter.getAffineDimExpr(0 + restDim)}, - context), - rewriter.getMultiDimIdentityMap(inputType.getRank())}; - Value transposedWeights = - rewriter - .create( - loc, transposedWeightInitTensor.getType(), weight, - transposedWeightInitTensor, - /*indexingMaps=*/transposeIndexingMaps, - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - Value matmul; - if (batchDim) - matmul = rewriter - .create( - loc, broadcasted.getType(), - ValueRange{input, transposedWeights}, broadcasted) - .getResult(0); - else - matmul = rewriter - .create( - loc, broadcasted.getType(), - ValueRange{input, transposedWeights}, broadcasted) - .getResult(0); - - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, matmul); - return success(); - } -}; -} // namespace - namespace { class ConvertAtenConvolutionOp : public OpConversionPattern { public: @@ -996,8 +841,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 728c53bf21dd0..dc4e9704c05e4 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -196,7 +196,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } - if (isa(op)) + if (isa(op) || isa(op)) return b.create(loc, b.getZeroAttr(elementType)); op->emitError("unimplemented lowering in createInitElementForReduceOp"); @@ -244,6 +244,15 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + auto abs = b.create(loc, self); + Attribute twoAttr = b.getFloatAttr(resultElementType, 2.0); + auto ord = b.create(loc, twoAttr); + auto pow = b.create(loc, abs, ord); + return b.create(loc, pow, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -321,6 +330,9 @@ class ConvertReductionOp : public ConversionPattern { if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); + if (auto normOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); + return rewriter.notifyMatchFailure(op, "not a supported reduce op"); } @@ -405,7 +417,8 @@ class ConvertReductionOp : public ConversionPattern { LogicalResult validateReductionElementType(Operation *op, Type elemType, ConversionPatternRewriter &rewriter) const { - if (isa(op) && !elemType.isa()) + if ((isa(op) || isa(op)) && + !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); // No checks for all other reduction operations @@ -455,6 +468,15 @@ class ConvertReductionOp : public ConversionPattern { reduceOp = *secondReduceOp; } + // If it is aten.frobenius_norm.dim op, take the square root of reduceOp as + // the final result + if (auto normOp = dyn_cast(op)) { + auto halfAttr = rewriter.getFloatAttr(elemType, 0.5); + auto exp = rewriter.create(loc, halfAttr); + reduceOp = + createElementwiseExp(loc, elemType, exp, reduceOp, *opInfo, rewriter); + } + rewriter.replaceOpWithNewOp(op, resultType, reduceOp); return success(); } @@ -471,5 +493,6 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d1f4739903611..70b158e09048d 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -23,6 +23,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/APSInt.h" using namespace mlir; using namespace mlir::torch; @@ -528,6 +529,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0], expPromoted); } + if (auto pow = dyn_cast(op)) { + Type dtype = converter->convertType(pow.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + pow.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } + if (auto gtScalar = dyn_cast(op)) { Type dtype = gtScalar.self().getType().cast().getDtype(); @@ -935,6 +949,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, scalar, zero); } + if (auto bitwiseNot = dyn_cast(op)) { + Type elementType = converter->convertType(bitwiseNot.getType()) + .cast() + .getElementType(); + if (elementType.isa()) { + bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); + return nullptr; + } + + Value allOnesVal = b.create( + loc, b.getIntegerAttr( + elementType, + APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth()))); + return b.create(loc, payloadArgs[0], allOnesVal); + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -973,15 +1003,17 @@ class ConvertElementwiseOp : public ConversionPattern { AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, - AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, - AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, - AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, - AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op)) + AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, + AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, + AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, + AtenBitwiseNotOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1257,271 +1289,6 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { }; } // namespace -// For layernorm, the mean and standard-deviation are calculated separately over -// the last certain number dimensions which have to be of the shape specified by -// normalized_shape. -// -// The shapes of different parts are as the following: -// +-------------------+--------------------+ -// | meanAndVarShape | normalizedShape | -// +-------------------+--------------------- -// <------------+ inputShape +--------------> -// There are the following steps: -// Step 1. Check if all the arguments meet the requirements. -// Step 2. Common parts to be used for getting mean and var. -// This includes elements count, affineMap and iteratorTypes. -// Step 3. Get mean. -// Step 4. Get rSTD. -// Step 5. Get layernorm. -namespace { -class ConvertAtenNativeLayerNormOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); - Location loc = op->getLoc(); - Value input = adaptor.input(); - Value weight = adaptor.weight(); - Value bias = adaptor.bias(); - Value eps = adaptor.eps(); - Value normalizedShape = op.normalized_shape(); - - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - // TODO: Handle the None cases for the optional parameters: - // weight, bias. - if (failed(checkNotNone(rewriter, op, weight)) || - failed(checkNotNone(rewriter, op, bias))) - return failure(); - - auto inputType = input.getType().cast(); - auto weightType = weight.getType().cast(); - auto biasType = bias.getType().cast(); - int64_t inputRank = inputType.getRank(); - Type elemTy = inputType.getElementType(); - - // Step 1. Check if all the arguments meet the requirements. - SmallVector normalizedShapeSizesTorchInt; - if (!getListConstructElements(normalizedShape, - normalizedShapeSizesTorchInt)) { - return rewriter.notifyMatchFailure(op, - "Unimplemented normalized_shape not" - "constructed from ListConstruct"); - } - SmallVector normalizedShapeSizesInt = getTypeConvertedValues( - rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt); - int64_t normalizedShapeRank = normalizedShapeSizesInt.size(); - if (weightType.getRank() != normalizedShapeRank || - biasType.getRank() != normalizedShapeRank || - inputRank < normalizedShapeRank || normalizedShapeRank < 1) - return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" - "normalized shape not compatible"); - - // Check all the dimensions match the normalized_shape - int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size(); - for (auto en : enumerate((normalizedShapeSizesInt))) { - auto index = en.index(); - auto inputDim = - getDimOp(rewriter, loc, input, index + meanAndVarShapeRank); - auto weightDim = getDimOp(rewriter, loc, weight, index); - auto biasDim = getDimOp(rewriter, loc, bias, index); - - auto expectedSize = en.value(); - checkDimEqualHelper(rewriter, loc, inputDim, expectedSize); - checkDimEqualHelper(rewriter, loc, weightDim, expectedSize); - checkDimEqualHelper(rewriter, loc, biasDim, expectedSize); - } - - // Get iterator types for input shape. - SmallVector normalizedShapeIteratorTypes( - normalizedShapeRank, getReductionIteratorTypeName()); - SmallVector meanAndVarIterationTypes( - meanAndVarShapeRank, getParallelIteratorTypeName()); - SmallVector inputShapeIteratorTypes = meanAndVarIterationTypes; - inputShapeIteratorTypes.append(normalizedShapeIteratorTypes); - - // Step 2. Common parts to be used for getting mean and var. - - // Get sizes and affineMaps needed for mean and var. - AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank); - SmallVector meanAndVarShapeExprs; - for (int i = 0; i < meanAndVarShapeRank; i++) - meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); - auto meanAndVarShapeAffineMap = AffineMap::get( - /*dimCount=*/inputRank, - /*symbolCount=*/0, meanAndVarShapeExprs, context); - SmallVector meanAndVarShapeSizes = - getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1); - - // Get number of elements to be used for calculating mean and var. - Value elemCnts = normalizedShapeSizesInt[0]; - for (int i = 1; i < normalizedShapeRank; i++) { - elemCnts = rewriter.create(loc, elemCnts, - normalizedShapeSizesInt[i]); - } - Value elemCntsFloat = - rewriter.create(loc, elemTy, elemCnts); - - // Helper to calculate mean and var. - auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) { - SmallVector indexingMaps( - 2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank)); - Value initShapeTensor = rewriter.create( - loc, meanAndVarShapeSizes, elemTy); - return rewriter - .create( - loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/meanAndVarIterationTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value sumOrSqureSum = args[0]; - Value result = - b.create(loc, sumOrSqureSum, elemCntsFloat); - b.create(loc, result); - }) - .getResult(0); - }; - - // Step 3. Get mean. - - // Get sum to be used for calculating mean. - SmallVector sumIndexingMaps = { - inputShapeAffineMap, // input - meanAndVarShapeAffineMap, // output - }; - auto initSumTensor = - createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); - Value sum = rewriter - .create( - loc, initSumTensor.getType(), input, initSumTensor, - /*indexingMaps=*/sumIndexingMaps, - /*iteratorTypes=*/inputShapeIteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0], sum = args[1]; - Value result = - rewriter.create(loc, sum, input); - b.create(loc, result); - }) - .getResult(0); - Value mean = genMeanOrVarCalculation(sum); - - // Step 4. Get rSTD. - - // Calculate squareSum for the layer. - SmallVector squareSumIndexingMaps{ - inputShapeAffineMap, - meanAndVarShapeAffineMap, - meanAndVarShapeAffineMap, - }; - auto initSquareSumTensor = - createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); - Value squareSum = - rewriter - .create( - loc, initSquareSumTensor.getType(), ValueRange{input, mean}, - initSquareSumTensor, - /*indexingMaps=*/squareSumIndexingMaps, - /*iteratorTypes=*/inputShapeIteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0], mean = args[1], squareSum = args[2]; - Value sub = rewriter.create(loc, input, mean); - Value square = rewriter.create(loc, sub, sub); - Value result = - rewriter.create(loc, squareSum, square); - b.create(loc, result); - }) - .getResult(0); - Value var = genMeanOrVarCalculation(squareSum); - Value rSTDTensor = rewriter.create( - loc, meanAndVarShapeSizes, elemTy); - SmallVector rSTDIndexingMap( - 2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank)); - - Value rSTD = rewriter - .create( - loc, rSTDTensor.getType(), var, rSTDTensor, - rSTDIndexingMap, meanAndVarIterationTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value result = - calculateRSTD(b, loc, elemTy, eps, args[0]); - b.create(loc, result); - }) - .getResult(0); - - // Step 5. Get layernorm. - - // Get affineMap for normalized shape. - SmallVector normalizedShapeExprs; - for (int i = meanAndVarShapeRank; i < inputRank; i++) - normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); - auto normalizedShapeAffineMap = AffineMap::get( - /*dimCount=*/inputRank, - /*symbolCount=*/0, normalizedShapeExprs, context); - auto inputSizes = getTensorSizes(rewriter, loc, input); - Value initLayerNormTensor = - rewriter.create(loc, inputSizes, elemTy); - SmallVector indexingMaps(1, inputShapeAffineMap); - indexingMaps.resize(3, meanAndVarShapeAffineMap); - indexingMaps.resize(5, normalizedShapeAffineMap); - indexingMaps.push_back(inputShapeAffineMap); - SmallVector layerNormIterationTypes( - inputRank, getParallelIteratorTypeName()); - Value layerNorm = - rewriter - .create( - loc, initLayerNormTensor.getType(), - ValueRange{input, mean, rSTD, weight, bias}, - initLayerNormTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/layerNormIterationTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0], mean = args[1], rSTD = args[2], - weight = args[3], bias = args[4]; - Value result = - createLinalgPayloadCalculationForNormOpsWithRSTD( - b, loc, elemTy, input, mean, rSTD, eps, weight, bias); - b.create(loc, result); - }) - .getResult(0); - SmallVector expandShape(inputRank, 1); - for (int i = 0; i < meanAndVarShapeRank; i++) { - // `mean` and `rstd` are not yet casted, so they will be having dynamic - // shape. Hence to match them, for each dimension corresponding to `mean` - // or `rstd` assign -1. - expandShape[i] = -1; - } - auto expandShapeType = RankedTensorType::get(expandShape, elemTy); - SmallVector reassociation(meanAndVarShapeRank); - for (auto i : llvm::seq(0, meanAndVarShapeRank)) { - reassociation[i].push_back(i); - if (i == meanAndVarShapeRank - 1) { - for (auto j : llvm::seq(0, normalizedShapeRank)) - reassociation[i].push_back(i + j + 1); - } - } - Value meanResult = rewriter.create( - loc, expandShapeType, mean, reassociation); - Value rSTDResult = rewriter.create( - loc, expandShapeType, rSTD, reassociation); - Type layerNormResultType = getTypeConverter()->convertType(op.getType(0)); - Type meanResultType = getTypeConverter()->convertType(op.getType(1)); - Type rSTDResultType = getTypeConverter()->convertType(op.getType(2)); - Value layerNorm_ = - rewriter.create(loc, layerNormResultType, layerNorm); - Value mean_ = - rewriter.create(loc, meanResultType, meanResult); - Value var_ = - rewriter.create(loc, rSTDResultType, rSTDResult); - rewriter.replaceOp(op, {layerNorm_, mean_, var_}); - return success(); - } -}; -} // namespace - namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { @@ -1714,13 +1481,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, - AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, - AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); + AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, + AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, + AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, + AtenRemainderScalarOp, AtenBitwiseNotOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -1728,8 +1496,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 5d4d95c26b824..a61202c0ec072 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -41,6 +41,60 @@ bool skipMultiplyAlpha(Value alphaValue) { return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); } +static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementType); + if (elementType.isa()) { + auto constAttr = SplatElementsAttr::get( + constType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*negative=*/false)); + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + if (elementType.isa()) { + auto integerType = elementType.cast(); + DenseElementsAttr constAttr; + if (integerType.isUnsigned()) { + constAttr = SplatElementsAttr::get( + constType, APInt::getMaxValue(integerType.getWidth())); + } else { + constAttr = SplatElementsAttr::get( + constType, APInt::getSignedMaxValue(integerType.getWidth())); + } + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + return failure(); +} + +static FailureOr getMinValueOfDtype(Operation *op, Type elementType, + PatternRewriter &rewriter) { + auto constType = RankedTensorType::get({}, elementType); + if (elementType.isa()) { + auto constAttr = SplatElementsAttr::get( + constType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*negative=*/true)); + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + if (elementType.isa()) { + auto integerType = elementType.cast(); + DenseElementsAttr constAttr; + if (integerType.isUnsigned()) { + constAttr = SplatElementsAttr::get( + constType, APInt::getMinValue(integerType.getWidth())); + } else { + constAttr = SplatElementsAttr::get( + constType, APInt::getSignedMinValue(integerType.getWidth())); + } + return rewriter.create(op->getLoc(), constType, constAttr) + .getResult(); + } + return failure(); +} + // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { @@ -130,6 +184,41 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { } // namespace +// The binary broadcast patterns +namespace { +template +class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + Value rhs = adaptor.other(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("only Tensor types supported"); + + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) + return op.emitError("input data types mismatched"); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + lhs, rhs, + /*broadcast_attr*/ nullptr); + return success(); + } +}; +} // namespace + // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { @@ -379,6 +468,36 @@ class ConvertAtenTransposeIntOp }; } // namespace +// AtenToDtypeOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenToDtypeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + auto outType = + getTypeConverter()->convertType(op.getType()).cast(); + rewriter.replaceOpWithNewOp(op, outType, self); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSizeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("only tensor types are currently supported"); + auto dim = rewriter.create( + op.getLoc(), rewriter.getIndexType(), adaptor.dim()); + auto dimSize = rewriter.create( + op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), dimSize); + return success(); +} + // AtenBroadcastToOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -409,10 +528,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value dValue = shape[i]; Value newD; int64_t dInt; - if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) { - return op->emitError("element of desired shape must be a scalar"); - } - if (i >= leadingRank && dInt == -1) { + if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && + dInt == -1) { newD = rewriter.create(op->getLoc(), self, i - leadingRank); } else { @@ -433,6 +550,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } + if (bcastShapeVec.size() == 0) { + rewriter.replaceOpWithNewOp(op, outType, self); + } else { Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = @@ -440,7 +560,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); - return success(); + } + return success(); } // AtenPermuteOp @@ -757,10 +878,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } else { Type outputTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, outputTy, input, weight, bias, runningMean, runningVar, - rewriter.getFloatAttr(inputTy.getElementType(), eps), - rewriter.getI64IntegerAttr(1)); + SmallVector castShape{inputTy.getShape().begin(), + inputTy.getShape().end()}; + castShape[1] = weightTy.getShape()[0]; + auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); + // Feature counts must match among operands of mhlo::BatchNormInferenceOp. + Value inputCasted = + rewriter.create(op.getLoc(), castTy, input); + Value output = rewriter.create( + op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, + runningMean, runningVar, + // 'epsilon' must satisfy constraint: 32-bit float attribute. + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp(op, outputTy, output); return success(); } } @@ -942,33 +1072,122 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // AtenNumelOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( - AtenNumelOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { + AtenNumelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().dyn_cast(); size_t rank = selfTy.getRank(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); - Value numel = - rewriter.create(loc, rewriter.getIntegerAttr(intType, 1)); - for (size_t d = 0 ; d < rank; ++ d) { - Value dimSize = rewriter.create( + Value numel = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + for (size_t d = 0; d < rank; ++d) { + Value dimSize = rewriter.create( loc, intType, rewriter.create(loc, self, d)); - numel = rewriter.create(loc, numel, dimSize); + numel = rewriter.create(loc, numel, dimSize); } auto outTy = getTypeConverter()->convertType(op.getType()); if (outTy != numel.getType()) { - rewriter.replaceOpWithNewOp( - op, outTy, numel); + rewriter.replaceOpWithNewOp(op, outTy, numel); } else { rewriter.replaceOp(op, numel); } return success(); } +// AtenClampOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenClampOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputType = input.getType().cast(); + auto inputElemType = inputType.getElementType(); + Value minValue = adaptor.min(); + Value maxValue = adaptor.max(); + if (failed(checkNotNone(rewriter, op, minValue)) && + failed(checkNotNone(rewriter, op, maxValue))) { + return rewriter.notifyMatchFailure( + op, "this op should be folded as its `min` and `max` both are none"); + } else if (failed(checkNotNone(rewriter, op, minValue))) { + maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); + auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); + if (failed(minInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate min value of dtype"); + } + minValue = *minInfo; + } else if (failed(checkNotNone(rewriter, op, maxValue))) { + minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); + auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); + if (failed(maxInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to generate max value of dtype"); + } + maxValue = *maxInfo; + } else { + minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); + maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); + } + rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); + return success(); +} + +// AtenArangeStartStepOp +// aten.arange.start_step = range(ceil((end-start)/step)) * step + start. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenArangeStartStepOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + + // The pinMemory should be either `none` or `false`. + bool pinMemory; + if (!op.pin_memory().getType().isa() && + (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: pin_memory must be either None or false"); + } + + // Get element type of resultType as dtype + auto outType = this->getTypeConverter() + ->convertType(op.getType()) + .cast(); + auto dtype = outType.getElementType(); + if (!dtype.isa() && !dtype.isa()) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only int or float dtype supported"); + } + + Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.start(), dtype); + Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.end(), dtype); + Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.step(), dtype); + + // Get length of the 1-d output tensor + Value subOut = rewriter.create(loc, end, start); + Value divOut = rewriter.create(loc, subOut, step); + + Value resultLength = rewriter.create( + loc, RankedTensorType::get({1}, dtype), divOut); + if (dtype.isa()) { + resultLength = rewriter.create(loc, resultLength); + resultLength = rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); + } + + Value window = + rewriter.create(loc, outType, resultLength, 0); + DenseIntElementsAttr broadcastDimensions; + Value mulOut = rewriter.create(loc, window, step, + broadcastDimensions); + rewriter.replaceOpWithNewOp(op, mulOut, start, + broadcastDimensions); + return success(); +} + void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToMhloOptions &options) { @@ -1047,9 +1266,22 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenNumelOp); + INSERT_ATENOP_PATTERN(AtenSizeIntOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); #undef INSERT_ATENOP_PATTERN + +#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context) + INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); + INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); + INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); +#undef INSERT_BINARY_BROADCAST_PATTERN } diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 39d956fddb176..8195107249425 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MhloDialect MhloToLinalg MLIRMhloPassIncGen + LMHLOTransformsPassIncGen TorchMLIRConversionPassIncGen LINK_COMPONENTS @@ -26,6 +27,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MLIRPass MhloDialect MhloToLinalg + MLIRBufferTransforms StablehloBase TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index 3b428c98cda19..a8c3c1544574e 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -71,6 +71,63 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, return result.getResult(); } +RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, + Value &lhs, Value &rhs, + int64_t lhsResultDim, int64_t rhsResultDim, + int64_t lhsContractingDim, + int64_t rhsContractingDim) { + auto lhsTy = lhs.getType().dyn_cast(); + auto rhsTy = rhs.getType().dyn_cast(); + + auto oldLhsShape = lhsTy.getShape(); + auto oldRhsShape = rhsTy.getShape(); + SmallVector lhsShape; + SmallVector rhsShape; + lhsShape.append(oldLhsShape.begin(), oldLhsShape.end()); + rhsShape.append(oldRhsShape.begin(), oldRhsShape.end()); + auto lhsContractingDimSize = lhsShape[lhsContractingDim]; + auto rhsContractingDimSize = rhsShape[rhsContractingDim]; + if (lhsContractingDimSize != rhsContractingDimSize) { + if (lhsContractingDimSize == ShapedType::kDynamicSize && + rhsContractingDimSize >= 0) { + lhsShape[lhsContractingDim] = rhsContractingDimSize; + auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType()); + lhs = rewriter.create(op->getLoc(), newRankTy, lhs); + } else if (rhsContractingDimSize == ShapedType::kDynamicSize && + lhsContractingDimSize >= 0) { + rhsShape[rhsContractingDim] = lhsContractingDimSize; + auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType()); + rhs = rewriter.create(op->getLoc(), newRankTy, rhs); + } + } + SmallVector outShape; + // set batch dims, will skip invalid dimensions + for (size_t k = 0; k < lhsShape.size(); ++k) { + if (k == lhsResultDim || k == lhsContractingDim) + continue; + outShape.push_back(lhsShape[k]); + } + for (size_t k = 0, b = 0; k < rhsShape.size(); ++k) { + if (b >= outShape.size()) + break; + if (k == rhsResultDim || k == rhsContractingDim) + continue; + if (outShape[b] == ShapedType::kDynamicSize && rhsShape[k] >= 0) { + outShape[b] = rhsShape[k]; + } + b++; + } + + // set result dimensions + if (lhsResultDim < lhsShape.size() && lhsResultDim >= 0) { + outShape.push_back(lhsShape[lhsResultDim]); + } + if (rhsResultDim < rhsShape.size() && rhsResultDim >= 0) { + outShape.push_back(rhsShape[rhsResultDim]); + } + return RankedTensorType::get(outShape, lhsTy.getElementType()); +} + void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, Value &inpRhs, int64_t leadingRank, size_t dimSizeIndexBits) { @@ -183,10 +240,15 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { options.dimSizeIndexBits); } auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; auto lhsContractingDim = nBatchDims + 1; auto rhsContractingDim = nBatchDims; - if (lhsRank == 1) + if (lhsRank == 1) { + lhsResultDim = nBatchDims + 1; lhsContractingDim = nBatchDims; + } mhlo::DotDimensionNumbersAttr dotDimensionNumbers = mhlo::DotDimensionNumbersAttr::get( @@ -195,15 +257,13 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - auto resultTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - + auto outTy = + castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, + lhsContractingDim, rhsContractingDim); output = rewriter - .create(op->getLoc(), resultTy, lhs, rhs, + .create(op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr) .getResult(); - return success(); } @@ -221,7 +281,7 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) return op.emitError("failed to perform matmul operation"); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, ConvertAtenOp::getTypeConverter() ->convertType(op.getType()) @@ -355,9 +415,15 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); auto nBatchDims = resultRank - 2; auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; auto lhsContractingDim = nBatchDims + 1; auto rhsContractingDim = nBatchDims; + auto outTy = + castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, + lhsContractingDim, rhsContractingDim); mhlo::DotDimensionNumbersAttr dotDimensionNumbers = mhlo::DotDimensionNumbersAttr::get( rewriter.getContext(), @@ -365,24 +431,21 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - - auto resultTy = - ConvertAtenOp::getTypeConverter()->convertType(op.getType()); - Value matmulOutput = rewriter.create( - op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr); + op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); Value matmulPlusBias = matmulOutput; if (!biasTy.template isa()) { // Bias addition broadcasts to the matmul output shape. - matmulPlusBias = - rewriter - .create(op->getLoc(), resultTy, - matmulOutput, bias, nullptr) - .getResult(); + matmulPlusBias = rewriter + .create( + op->getLoc(), outTy, matmulOutput, bias, nullptr) + .getResult(); } - rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); + auto resultTy = + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); return success(); } }; @@ -609,8 +672,9 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { return mhloConvOp.getResult(); } - LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + LogicalResult + matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Value input = adaptor.input(); Value weight = adaptor.weight(); diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index eb5422d58683a..a185a27d45907 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -30,7 +30,7 @@ using namespace mlir::torch::torch_to_mhlo; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -571,6 +571,113 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenFrobeniusNormDimOp +// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims) +// + mhlo.sqrt +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenFrobeniusNormDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + const TorchToMhloOptions &options = getOptions(); + + Value input = adaptor.self(); + auto inputType = input.getType().dyn_cast(); + if (!inputType) { + return op.emitError( + "only ranked tensor input supported in AtenFrobeniusNormDimOp"); + } + auto inputRank = inputType.getRank(); + auto inputElemType = inputType.getElementType(); + if (!inputElemType.isa()) { + return op.emitError( + "only float dtype allowed in input tensor of AtenFrobeniusNormDimOp"); + } + + SmallVector dims; + if (!matchPattern(op.dim(), m_TorchConstantIntList(dims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + for (auto &dim : dims) { + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, + "invalid dimension detected in `dim`"); + } + } + + // Sort the dims in ascending order, making the conversion + // stable with unordered dims. + std::sort(dims.begin(), dims.end()); + + bool keepDim = false; + if (!matchPattern(op.keepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure( + op, "non-const bool `keepdim` is not supported"); + } + + auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); + if (!initValue) { + return failure(); + } + + auto squareSumReduceOp = rewriter.create( + op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + + Region ®ion = squareSumReduceOp.body(); + Block &block = region.emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputElemType); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + auto constantOrd2 = rewriter.create( + op->getLoc(), blockArgumentTy, + DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef{2.0})); + auto abs = rewriter.create(op->getLoc(), *secondArgument); + auto squareResult = rewriter.create( + op->getLoc(), abs, constantOrd2); + auto addResult = rewriter.create(op->getLoc(), squareResult, + *firstArgument); + rewriter.create(op->getLoc(), addResult.getResult()); + } + + auto output = rewriter.create(op->getLoc(), + squareSumReduceOp.getResult(0)); + + if (keepDim) { + auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + for (int64_t i : dims) { + outShapeVec[i] = one; + } + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), output, + outShapeTensor); + return success(); + } + rewriter.replaceOp(op, output.getResult()); + return success(); +} +} // namespace + void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToMhloOptions &options) { @@ -583,5 +690,6 @@ void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN } diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 60b03d576acde..d92a295547d08 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" @@ -72,6 +74,29 @@ static Value createTMTensorScatterOp( return scatterOp->getResult(0); } +static Value createTMTensorScanOp( + OpBuilder &b, Location loc, Value input, Value output, Value accumulator, + int64_t dim, bool inclusive, + function_ref bodyBuild) { + auto inputType = input.getType().cast(); + auto accType = accumulator.getType().cast(); + Type elementType = inputType.getElementType(); + auto scanOp = b.create( + loc, TypeRange{inputType, accType}, input, + ValueRange{output, accumulator}, b.getI64IntegerAttr(dim), + b.getBoolAttr(inclusive)); + + Region &scanOpRegion = scanOp.region(); + auto &scanOpBlock = scanOpRegion.emplaceBlock(); + scanOpBlock.addArguments({elementType, elementType}, {loc, loc}); + OpBuilder regionBuilder(scanOpRegion); + auto blockArgs = scanOpBlock.getArguments(); + Value inputElement = blockArgs[0]; + Value accElement = blockArgs[1]; + bodyBuild(regionBuilder, loc, inputElement, accElement); + return scanOp->getResult(0); +} + namespace { // aten::bincount op counts the frequency of each value in a 1-d input tensor of // non-negative ints. @@ -523,6 +548,60 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp }; } // namespace +namespace { +class ConvertAtenCumsumOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value input = adaptor.self(); + auto resultType = input.getType().cast(); + Type elementType = resultType.getElementType(); + int64_t inputRank = resultType.getRank(); + Location loc = op->getLoc(); + Value dtype = op.dtype(); + if (!dtype.getType().isa()) + return rewriter.notifyMatchFailure( + op, "unsupported: dtype argument not supported"); + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "invalid dim"); + + SmallVector sizes = getTensorSizes(rewriter, loc, input); + Value output = createZeroInitTensor(rewriter, loc, sizes, elementType); + output = rewriter.create(loc, resultType, output); + + SmallVector accSizes(sizes); + accSizes.erase(accSizes.begin() + dim); + SmallVector accStatic(resultType.getShape()); + accStatic.erase(accStatic.begin() + dim); + Value acc = createZeroInitTensor(rewriter, loc, accSizes, elementType); + Type accType = RankedTensorType::get(accStatic, elementType); + acc = rewriter.create(loc, accType, acc); + + Value result = createTMTensorScanOp( + rewriter, loc, input, output, acc, dim, /*inclusive=*/true, + [](OpBuilder &b, Location loc, Value input, Value acc) { + Value sum = (input.getType().isa() + ? b.create(loc, input, acc) + : b.create(loc, input, acc)) + ->getResult(0); + b.create(loc, sum); + }); + + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -560,6 +639,8 @@ class ConvertTorchToTMTensor target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 558ac82a1ab97..906f462bdfffd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2866,8 +2866,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( adaptor.self(), dimAttr); if (argMax.getType() != indicesType) { - argMax = rewriter.create(op->getLoc(), indicesType, argMax, - prunedShapeAttr); + argMax = rewriter.create( + op->getLoc(), indicesType, argMax, + rewriter.getI64ArrayAttr(reducedShape)); } if (!keepDim) { @@ -2939,6 +2940,40 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBroadcastToOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + auto selfElemTy = selfType.getElementType(); + if (!selfElemTy.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); + } + + SmallVector outShape; + if (!matchPattern(op.size(), m_TorchConstantIntList(outShape))) + return rewriter.notifyMatchFailure(op, + "size must consist of Scalar constants"); + + SmallVector inputShape(selfType.getShape()); + if (!llvm::equal(inputShape, outShape)) + return rewriter.notifyMatchFailure(op, + "Only identity cases are supported."); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self(), + rewriter.getI64ArrayAttr(outShape)); + + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -3358,6 +3393,31 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { } }; +// Legalizes the torch.clone op. +template +class ConvertAtenCloneOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int64_t memoryFormat; + if (!op.memory_format().getType().template isa() && + (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) || + memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { + return op.emitError( + "unimplemented: only default memory format is supported"); + } + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + rewriter.replaceOpWithNewOp(op, outType, adaptor.self()); + + return success(); + } +}; + } // namespace // ----------------------------------------------------------------------------- @@ -3561,8 +3621,15 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenTransposeIntOp); INSERT_ATENOP_PATTERN(AtenMaxDimOp); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); #undef INSERT_ATENOP_PATTERN +#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); +#undef INSERT_CLONE_ATENOP_PATTERN + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 835be031644be..a29c2e16a3ae9 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -149,6 +149,14 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, if (auto floatType = type.dyn_cast()) return builder.create(loc, value.cast()); + if (auto numberType = type.dyn_cast()) { + if (auto floatValue = value.dyn_cast()) { + return builder.create(loc, floatValue); + } else if (auto intValue = value.dyn_cast()) { + return builder.create(loc, intValue); + } + } + if (type.isa()) { return builder.create(loc, value.cast()); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index f34ece3468da1..0de4784a29f1d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -673,6 +673,20 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenTypeAsOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTypeAsOp::fold(ArrayRef operands) { + Type inType = self().getType(); + Type newType = other().getType(); + + if (inType == newType) + return self(); + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// @@ -1577,6 +1591,34 @@ void Torch::ConstantFloatOp::getAsmResultNames( setNameFn(getResult(), StringRef(buf.data(), buf.size())); } +//===----------------------------------------------------------------------===// +// ConstantNumberOp +//===----------------------------------------------------------------------===// + +OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef operands) { + return valueAttr(); +} + +void Torch::ConstantNumberOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) { + Location loc = op->getLoc(); + + Value constValue; + Attribute value = op.valueAttr(); + if (auto floatValue = value.dyn_cast()) { + constValue = rewriter.create(loc, floatValue); + } else if (auto intValue = value.dyn_cast()) { + constValue = rewriter.create(loc, intValue); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + constValue); + return success(); + }); +} + //===----------------------------------------------------------------------===// // ConstantBoolOp //===----------------------------------------------------------------------===// @@ -1872,15 +1914,39 @@ OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef operands) { } using BinaryIntOperatorFn = std::function; -template -static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op, - BinaryIntOperatorFn f) { - int64_t lhs, rhs; - if (!matchPattern(op.getOperand(0), m_TorchConstantInt(&lhs)) || - !matchPattern(op.getOperand(1), m_TorchConstantInt(&rhs))) +static OpFoldResult +atenBinaryIntOperatorFoldHelper(ArrayRef operands, + BinaryIntOperatorFn f) { + auto intLhs = operands[0].dyn_cast_or_null(); + auto intRhs = operands[1].dyn_cast_or_null(); + if (!intLhs || !intRhs) { return nullptr; + } + return IntegerAttr::get( + intLhs.getType(), + f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue())); +} - return getI64IntegerAttr(op.getContext(), f(lhs, rhs)); +using BinaryFloatOperatorFn = std::function; +static OpFoldResult +atenBinaryFloatOperatorFoldHelper(ArrayRef operands, + BinaryFloatOperatorFn f) { + double lhs, rhs; + auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool { + if (auto intLhs = attr.dyn_cast_or_null()) { + value = static_cast(intLhs.getValue().getSExtValue()); + } else if (auto floatLhs = attr.dyn_cast_or_null()) { + value = floatLhs.getValue().convertToDouble(); + } else { + return false; + } + return true; + }; + if (!parseDoubleAttribute(operands[0], lhs) || + !parseDoubleAttribute(operands[1], rhs)) { + return nullptr; + } + return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs)); } //===----------------------------------------------------------------------===// @@ -1889,7 +1955,7 @@ static OpFoldResult atenBinaryIntOperatorFoldHelper(OpTy op, OpFoldResult AtenFloordivIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); + operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } //===----------------------------------------------------------------------===// @@ -1898,7 +1964,7 @@ OpFoldResult AtenFloordivIntOp::fold(ArrayRef operands) { OpFoldResult AtenRemainderIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return a % b; }); + operands, [](int64_t a, int64_t b) { return a % b; }); } //===----------------------------------------------------------------------===// @@ -1907,7 +1973,7 @@ OpFoldResult AtenRemainderIntOp::fold(ArrayRef operands) { OpFoldResult AtenAddIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return a + b; }); + operands, [](int64_t a, int64_t b) { return a + b; }); } //===----------------------------------------------------------------------===// @@ -1916,7 +1982,7 @@ OpFoldResult AtenAddIntOp::fold(ArrayRef operands) { OpFoldResult AtenSubIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - *this, [](int64_t a, int64_t b) { return a - b; }); + operands, [](int64_t a, int64_t b) { return a - b; }); } //===----------------------------------------------------------------------===// @@ -1934,6 +2000,54 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenSubOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSubOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) { + return nullptr; + } + + if (operands[0].isa() && operands[1].isa()) { + return atenBinaryIntOperatorFoldHelper( + operands, [](int64_t a, int64_t b) -> int64_t { return a - b; }); + } + return atenBinaryFloatOperatorFoldHelper( + operands, [](double a, double b) -> double { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// AtenDivOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenDivOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) { + return nullptr; + } + // Since AtenDivOp always returns float value, we don't need to deal with the + // case where the operands are both integers separately. + return atenBinaryFloatOperatorFoldHelper( + operands, [](double a, double b) -> double { return a / b; }); +} + +//===----------------------------------------------------------------------===// +// AtenCeilScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCeilScalarOp::fold(ArrayRef operands) { + if (!operands[0]) { + return nullptr; + } + auto floatValue = operands[0].dyn_cast_or_null(); + if (!floatValue) { + return nullptr; + } + return getI64IntegerAttr( + getContext(), + static_cast(std::ceil(floatValue.getValue().convertToDouble()))); +} + //===----------------------------------------------------------------------===// // AtenNegIntOp //===----------------------------------------------------------------------===// @@ -2189,11 +2303,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { // We only permit a small set of ops in the module initializer. // These ops are essentially those which can be produced by the IValue // importer. - if (isa(op)) + if (op->hasTrait()) return WalkResult::advance(); op->emitOpError() << "is not allowed in a module initializer"; return WalkResult::interrupt(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0f9e5d149b0be..73c47ec26bfbd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -126,13 +126,10 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc, // converted the to the element type of the given tensor type. static Value createInitTensor(PatternRewriter &rewriter, Location loc, Type resultType, Value scalar, Value sizeList) { - BaseTensorType tensorType = resultType.cast(); Value noneVal = rewriter.create(loc); - Value emptyTensor = rewriter.create( - loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal, - /*device=*/noneVal, /*pin_memory=*/noneVal, /*memory_format=*/noneVal); - return rewriter.create(loc, resultType, - emptyTensor, scalar); + return rewriter.create( + loc, resultType, sizeList, scalar, /*dtype=*/noneVal, /*layout=*/noneVal, + /*device=*/noneVal, /*memory_format=*/noneVal); } // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` @@ -648,6 +645,20 @@ static Value getRelu6Results(PatternRewriter &rewriter, Location loc, return relu6Out; } +namespace { +class DecomposeAtenRelu6Op : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRelu6Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value relu6 = getRelu6Results(rewriter, loc, op.self()); + rewriter.replaceOp(op, relu6); + return success(); + } +}; +} // namespace + // Hardswish(x) = x * Relu6(x+3)/6 namespace { class DecomposeAtenHardswishOp : public OpRewritePattern { @@ -1504,8 +1515,8 @@ class DecomposeAtenRandLikeOp : public OpRewritePattern { rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value emptyTensor = rewriter.create( - loc, resultType, input, op.dtype(), op.layout(), op.device(), + Value emptyTensor = rewriter.create( + loc, resultType, input, zero, op.dtype(), op.layout(), op.device(), op.pin_memory(), op.memory_format()); rewriter.replaceOpWithNewOp( op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, @@ -1690,6 +1701,88 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenNativeLayerNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.input().getType().cast(); + if (!inputTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputTy.getSizes().size(); + Value normalizedShape = op.normalized_shape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); + auto reduceDimInts = llvm::to_vector<4>(llvm::seq(axis, inputRank)); + auto reducedTy = op.getResult(1).getType(); + auto sizeListType = ListType::get(IntType::get(context)); + + // build reduce dims + SmallVector reduceDimVals; + reduceDimVals.reserve(reduceDimInts.size()); + std::transform(reduceDimInts.begin(), reduceDimInts.end(), + std::back_inserter(reduceDimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.input(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = + rewriter.create(loc, inputTy, inputMean, op.input()); + Value inputZeroMean = rewriter.create( + loc, inputTy, op.input(), inputMeanExpanded, one); + // var(x) = mean((x - mean(x))^2) + Value inputZeroMeanSquare = rewriter.create( + loc, inputTy, inputZeroMean, inputZeroMean); + Value inputVar = rewriter.create( + loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.eps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = rewriter.create( + loc, inputTy, inputRsqrtVar, op.input()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); + Value out = rewriter.create( + loc, op.getResult(0).getType(), inputNormalized); + + Value weight = op.weight(); + Value bias = op.bias(); + if (!weight.getType().isa()) { + out = rewriter.create(loc, out.getType(), out, weight); + } + if (!bias.getType().isa()) { + out = + rewriter.create(loc, out.getType(), out, bias, one); + } + rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { @@ -1752,22 +1845,18 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern { } // namespace namespace { -// Decompose constant tensor allocation like ops. +// Decompose constant tensor full like ops. template class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - // Allocate a memory block. - Value initTensor = rewriter.create( - loc, op.getType(), op.self(), op.dtype(), op.layout(), op.device(), - op.pin_memory(), op.memory_format()); Value constVal = rewriter.create( loc, rewriter.getI64IntegerAttr(fillVal)); - // Initialize the allocated memory block with `fillVal`. - rewriter.replaceOpWithNewOp( - op, initTensor.getType(), initTensor, constVal); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.self(), constVal, op.dtype(), op.layout(), + op.device(), op.pin_memory(), op.memory_format()); return success(); } }; @@ -1960,19 +2049,73 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { } // namespace namespace { -// Decompose `aten.full` op into `aten.empty` and `aten.fill` ops. +// Decompose `aten.full` op into `aten.broadcast_to` class DecomposeAtenFullOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value noneVal = rewriter.create(loc); - Value emptyTensor = rewriter.create( - loc, op.getType(), op.size(), op.dtype(), op.layout(), op.device(), - op.pin_memory(), /*memory_format=*/noneVal); - rewriter.replaceOpWithNewOp( - op, op.getType(), emptyTensor, op.fill_value()); + BaseTensorType outTy = op.getType().template cast(); + SmallVector empty; + auto dtype = + getTypeForTorchType(op.getContext(), op.fill_value().getType()); + Type tensorType = + outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype); + Value fillVal = rewriter.create(loc, tensorType, + op.fill_value()); + fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); + rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, + op.size()); + return success(); + } +}; +} // namespace + +namespace { +// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops. +class DecomposeAtenLinearOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinearOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.input(); + Value weight = op.weight(); + Value bias = op.bias(); + + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasSizes() || inputType.getSizes().size() < 2) + return rewriter.notifyMatchFailure( + op, "expected input to be rank 2 or greater"); + + BaseTensorType weightType = weight.getType().cast(); + // `weight` must be a rank 2 matrix. + if (!weightType.hasSizes() || weightType.getSizes().size() != 2) + return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); + + SmallVector transposeShape = + llvm::to_vector(llvm::reverse(weightType.getSizes())); + Type transposeType = weightType.getWithSizesAndDtype( + llvm::makeArrayRef(transposeShape), weightType.getDtype()); + Value transposeWeight = + rewriter.create(loc, transposeType, weight); + + Value matmul = rewriter.create(loc, op.getType(), input, + transposeWeight); + if (bias.getType().isa()) { + rewriter.replaceOp(op, matmul); + return success(); + } + + BaseTensorType biasType = bias.getType().cast(); + if (!biasType.hasSizes() || biasType.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); + + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + rewriter.replaceOpWithNewOp(op, op.getType(), matmul, + op.bias(), alpha); return success(); } }; @@ -1985,11 +2128,18 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullLikeOp op, PatternRewriter &rewriter) const override { - Value emptyTensor = rewriter.create( - op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(), - op.device(), op.pin_memory(), op.memory_format()); - rewriter.replaceOpWithNewOp( - op, op.getType(), emptyTensor, op.fill_value()); + BaseTensorType outTy = op.getType().template cast(); + SmallVector empty; + auto dtype = + getTypeForTorchType(op.getContext(), op.fill_value().getType()); + Type tensorType = + outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype); + Value fillVal = rewriter.create( + op.getLoc(), tensorType, op.fill_value()); + fillVal = + convertTensorToDtype(rewriter, op.getLoc(), fillVal, outTy.getDtype()); + rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, + op.self()); return success(); } }; @@ -2035,8 +2185,10 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { - Value emptyTensor = rewriter.create( - op.getLoc(), op.getType(), op.self(), op.dtype(), op.layout(), + Value zero = rewriter.create( + op.getLoc(), rewriter.getF64FloatAttr(0.0)); + Value emptyTensor = rewriter.create( + op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(), op.device(), op.pin_memory(), op.memory_format()); rewriter.replaceOpWithNewOp( op, op.getType(), emptyTensor, op.self(), op.non_blocking()); @@ -2616,6 +2768,36 @@ class DecomposeAten_EmbeddingBagOp }; } // namespace +namespace { +// Decompose `aten.lift_fresh_copy` op into `aten.clone` op. +class DecomposeAtenLiftFreshCopyOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLiftFreshCopyOp op, + PatternRewriter &rewriter) const override { + Value constantNone = rewriter.create(op.getLoc()); + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + /*memoryFormat=*/constantNone); + return success(); + } +}; +} // namespace + +namespace { +// Decompose `aten.index.Tensor_hacked_twin` op into `aten.index.Tensor` op. +class DecomposeAtenIndexTensorHackedTwinOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + op.indices()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -2696,6 +2878,9 @@ class DecomposeComplexOpsPass target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); @@ -2736,6 +2921,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); @@ -2752,6 +2939,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); @@ -2796,6 +2985,10 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); for (std::string opName : legalOps) { target.addLegalOp(OperationName(opName, context)); diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 29e435b3d10f0..e947433bc0db7 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -98,12 +98,6 @@ class InlineGlobalSlotsAnalysisState : public AnalysisState { setSafe(); } - bool isUninitialized() const override { - // We are an optimistic analysis, so we are always default initialized to - // the optimistic "assumed safe" state. - return false; - } - void print(raw_ostream &os) const override { os << "InlineGlobalSlotsAnalysisState(" << (isSafe ? "safe" : "unsafe") << ")"; diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index a67f65ec34440..b286546ecbd32 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -167,15 +167,19 @@ namespace { // we cannot claim to know something about a value which is false. // This class could also be called "dataflow facts", "lattice value", etc. struct ValueKnowledge { - ValueKnowledge() = delete; + ValueKnowledge() = default; ValueKnowledge(Type dtype, Type scalarType, OptionalKnowledge optionalKnowledge, torch_upstream::TypeKind kind) - : dtype(dtype), scalarType(scalarType), kind(kind), + : isInitialized(true), dtype(dtype), scalarType(scalarType), kind(kind), optional(optionalKnowledge) {} void print(raw_ostream &os) const { os << "ValueKnowledge("; + if (!isInitialized) { + os << "uninitialized)"; + return; + } if (dtype) os << "dtype=" << dtype; if (scalarType) @@ -249,13 +253,21 @@ struct ValueKnowledge { } bool operator==(const ValueKnowledge &rhs) const { - return std::make_tuple(dtype, optional) == - std::make_tuple(rhs.dtype, rhs.optional); + if (!isInitialized && !rhs.isInitialized) + return true; + return isInitialized && rhs.isInitialized && + std::make_tuple(dtype, optional) == + std::make_tuple(rhs.dtype, rhs.optional); } // Return true if the `refinedType` has more concrete type info than `type`. static bool hasStrictlyMoreRefinedTypeInfo(const ValueKnowledge &refinedType, const ValueKnowledge &type) { + if (!refinedType.isInitialized) + return false; + if (!type.isInitialized) + return true; + if (type.kind == torch_upstream::TypeKind::AnyType && refinedType.kind != torch_upstream::TypeKind::AnyType) return true; @@ -284,6 +296,11 @@ struct ValueKnowledge { // both. static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { + if (!lhs.isInitialized) + return rhs; + if (!rhs.isInitialized) + return lhs; + // Mental model: All conditions are checking how to change from the safe "no // knowledge" default-initialized state to a state with more knowledge // consistent with lhs and rhs. @@ -294,6 +311,11 @@ struct ValueKnowledge { static ValueKnowledge joinTypes(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { + if (!lhs.isInitialized) + return rhs; + if (!rhs.isInitialized) + return lhs; + if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs)) return rhs; if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs)) @@ -308,6 +330,11 @@ struct ValueKnowledge { // If the two pieces of knowledge are contradictory, None is returned. static Optional meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { + if (!lhs.isInitialized) + return lhs; + if (!rhs.isInitialized) + return rhs; + Optional knowledge = meetTypes(lhs, rhs); if (!knowledge.has_value()) @@ -324,6 +351,11 @@ struct ValueKnowledge { static Optional meetTypes(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { + if (!lhs.isInitialized) + return lhs; + if (!rhs.isInitialized) + return rhs; + if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs)) return lhs; if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs)) @@ -333,6 +365,9 @@ struct ValueKnowledge { return None; } + // We start in the uninitialized state by default. + bool isInitialized = false; + // The dtype of a tensor. // This is equal to nullptr for the follow cases: // 1. it is unknown whether the value is a tensor or not, ie the `kind` field @@ -383,6 +418,12 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) final; + void setToEntryState(ValueState *lattice) override { + auto refType = lattice->getPoint().getType(); + auto knowledge = ValueKnowledge::getKnowledgeFromType(refType); + propagateIfChanged(lattice, lattice->join(knowledge)); + } + private: // Get the MLIR type of the tensor dtype given the dtype integer value and the // input dtype. When DType is None the type is inferred from the input dtype. @@ -636,7 +677,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Take dtype from first operand. if (isa( - op)) { + AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, + AtenIndexTensorHackedTwinOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } // Dtype is always float32, except for bfloat16, float64 and nullptr. if (isa(op)) { + AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) { ValueKnowledge knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Type dtype = operands[0]->getValue().dtype; @@ -889,6 +930,10 @@ void TypeAnalysis::visitOperation(Operation *op, if (auto sum = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; + // If the input dtype is bool, the result type should be i64. + if (defaultDtype.isInteger(1)) + defaultDtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); @@ -898,6 +943,10 @@ void TypeAnalysis::visitOperation(Operation *op, } if (auto sumDimIntList = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; + // If the input dtype is bool, the result type should be i64. + if (defaultDtype.isInteger(1)) + defaultDtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); Type dtype = getDtypeOrDefault(sumDimIntList.getContext(), sumDimIntList.dtype(), defaultDtype); visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(), @@ -1106,9 +1155,8 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Otherwise, this is an unknown operation. Just mark all results as - // having reached a pessimistic fixpoint. - markAllPessimisticFixpoint(results); + // Otherwise, this is an unknown operation, so reset the state. + setAllToEntryStates(results); return; } @@ -1426,13 +1474,13 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { }; if (auto tensorType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement || latticeElement->isUninitialized()) + if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); return getRefinedTensorType(tensorType, knowledge); } else if (auto optionalType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement || latticeElement->isUninitialized()) + if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); if (knowledge.optional == OptionalKnowledge::isNone) @@ -1446,7 +1494,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { } } else if (auto scalarType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement || latticeElement->isUninitialized()) + if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); if (knowledge.kind == torch_upstream::TypeKind::IntType) diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index f7a8f69ca3554..af1823f83ac5e 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -17,6932 +17,7812 @@ using namespace mlir; StringRef mlir::torch::Torch::getShapeLibrary() { -// TODO: Find a way to embed this string nicely. -// It is currently too long, and will probably break MSVC builds if anyone -// attempts that. -// We want to preserve the legibility of the shape library as a checked in file, -// since that is sometimes useful for debugging / diffing. -// Probably the ideal outcome is to have the shape library be a .mlir file -// that is checked in, and then we embed it as part of the build process. +#ifndef _MSC_VER #pragma clang diagnostic push #pragma clang diagnostic ignored "-Woverlength-strings" - constexpr StringLiteral shapeLib(R"mlir( -module { - func.func @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { - ^bb0(%arg1: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.append.t %0, %2 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions._copy(%arg0: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { - ^bb0(%arg1: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.append.t %0, %2 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %13 = torch.aten.eq.int %12, %int4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %13 : !torch.bool - } - torch.prim.If %4 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.__range_length %int1, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %6, %true, init() { - ^bb0(%arg2: !torch.int): - %12 = torch.aten.__derive_index %arg2, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %13 = torch.aten.__getitem__.t %arg0, %12 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.ne.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %14 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %7 = torch.prim.ListConstruct : () -> !torch.list - %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %9 = torch.aten.sub.int %8, %int2 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.__range_length %int0, %9, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %10, %true, init() { - ^bb0(%arg2: !torch.int): - %12 = torch.aten.__derive_index %arg2, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %13 = torch.aten.__getitem__.t %arg0, %12 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.append.t %7, %13 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %11 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - torch.prim.Loop %11, %true, init() { - ^bb0(%arg2: !torch.int): - %12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.append.t %7, %12 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %7 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.zero_dim_tensor(%arg0: !torch.any) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.arange_end(%arg0: !torch.union, %arg1: !torch.any, %arg2: !torch.any, %arg3: !torch.any, %arg4: !torch.any) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %0 = torch.operator "aten.ge"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool - torch.prim.If %0 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %1 = torch.operator "aten.ceil.Scalar"(%arg0) : (!torch.union) -> !torch.number - %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int - %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list - return %3 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.arange_start(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.any, %arg3: !torch.any, %arg4: !torch.any, %arg5: !torch.any) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %0 = torch.operator "aten.ge"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool - torch.prim.If %0 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %1 = torch.operator "aten.ge"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.operator "aten.sub"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.number - %3 = torch.operator "aten.ceil.Scalar"(%2) : (!torch.number) -> !torch.number - %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int - %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list - return %5 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.arange_start_step(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.any, %arg4: !torch.any, %arg5: !torch.any, %arg6: !torch.any) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %0 = torch.operator "aten.ne"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool - torch.prim.If %0 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %1 = torch.operator "aten.lt"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool - torch.prim.If %1 -> () { - %6 = torch.operator "aten.ge"(%arg0, %arg1) : (!torch.union, !torch.union) -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %6 = torch.operator "aten.ge"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield - } - %2 = torch.operator "aten.sub"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.number - %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union -> !torch.float - %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int - %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list - return %5 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.squeeze_nodim(%arg0: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { - ^bb0(%arg1: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.ne.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.append.t %0, %4 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %1 : !torch.int - } - %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int - %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %12 : !torch.bool - } - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.int) { - %12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %12 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %11, %true, init() { - ^bb0(%arg2: !torch.int): - %12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %12 -> () { - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %14 -> () { - %15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.append.t %0, %15 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.append.t %0, %13 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.le.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.int) { - torch.prim.If %arg2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %2 = torch.aten.neg.int %1 : !torch.int -> !torch.int - %3 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int - %4 = torch.aten.lt.int %arg0, %2 : !torch.int, !torch.int -> !torch.bool - %5 = torch.prim.If %4 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %9 = torch.aten.gt.int %arg0, %3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %9 : !torch.bool - } - %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.lt.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool - %8 = torch.prim.If %7 -> (!torch.int) { - %9 = torch.aten.add.int %arg0, %1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %9 : !torch.int - } else { - torch.prim.If.yield %arg0 : !torch.int - } - return %8 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int - %2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %1 : !torch.int - } - %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int - %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %13 : !torch.bool - } - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.int) { - %13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %13 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %11 = torch.prim.ListConstruct : () -> !torch.list - %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %12, %true, init() { - ^bb0(%arg2: !torch.int): - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.append.t %11, %13 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.aten.insert.t %11, %10, %int1 : !torch.list, !torch.int, !torch.int - return %11 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list { - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %0 : !torch.int - } - %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int - %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %33 : !torch.bool - } - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.int) { - %33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %33 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %11 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.int) { - %33 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int - torch.prim.If.yield %33 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %13 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.int) { - %33 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int - torch.prim.If.yield %33 : !torch.int - } else { - torch.prim.If.yield %int9223372036854775807 : !torch.int - } - %15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.int) { - torch.prim.If.yield %int0 : !torch.int - } else { - torch.prim.If.yield %12 : !torch.int - } - %18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.int) { - %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %34 : !torch.int - } else { - torch.prim.If.yield %17 : !torch.int - } - %20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool - %21 = torch.prim.If %20 -> (!torch.int) { - %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %34 : !torch.int - } else { - torch.prim.If.yield %14 : !torch.int - } - %22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.int) { - torch.prim.If.yield %int0 : !torch.int - } else { - %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool - %35 = torch.prim.If %34 -> (!torch.int) { - %36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %36 : !torch.int - } else { - torch.prim.If.yield %19 : !torch.int - } - torch.prim.If.yield %35 : !torch.int - } - %24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool - %25 = torch.prim.If %24 -> (!torch.int) { - torch.prim.If.yield %23 : !torch.int - } else { - %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool - %35 = torch.prim.If %34 -> (!torch.int) { - %36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %36 : !torch.int - } else { - torch.prim.If.yield %21 : !torch.int - } - torch.prim.If.yield %35 : !torch.int - } - %26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int - %27 = torch.prim.ListConstruct : () -> !torch.list - %28 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %28, %true, init() { - ^bb0(%arg5: !torch.int): - %33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int - %34 = torch.aten.append.t %27, %33 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int - %30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int - %31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int - %32 = torch.aten._set_item.t %27, %10, %31 : !torch.list, !torch.int, !torch.int -> !torch.list - return %27 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.max_int() -> !torch.int { - %int9223372036854775807 = torch.constant.int 9223372036854775807 - return %int9223372036854775807 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.select(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %0 : !torch.int - } - %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int - %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %17 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool - } - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.int) { - %17 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %17 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %11 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.neg.int %11 : !torch.int -> !torch.int - %13 = torch.aten.lt.int %arg2, %12 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %17 = torch.aten.ge.int %arg2, %11 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool - } - %15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %16 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %0, %true, init() { - ^bb0(%arg3: !torch.int): - %17 = torch.aten.ne.int %arg3, %10 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %17 -> () { - %18 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.append.t %16, %18 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %16 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %0 : !torch.int - } - %3 = torch.aten.neg.int %2 : !torch.int -> !torch.int - %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %18 : !torch.bool - } - %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool - torch.prim.If %7 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.int) { - %18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %18 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %10 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %11 = torch.prim.Loop %10, %true, init(%int1) { - ^bb0(%arg3: !torch.int, %arg4: !torch.int): - %18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%19 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %12 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %13 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %18 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %16 = torch.prim.ListConstruct : () -> !torch.list - %17 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %17, %true, init() { - ^bb0(%arg3: !torch.int): - %18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %18 -> () { - %19 = torch.aten.append.t %16, %11 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - %19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.append.t %16, %19 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %16 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.multiply_integers(%arg0: !torch.list) -> !torch.int { - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.Loop %0, %true, init(%int1) { - ^bb0(%arg1: !torch.int, %arg2: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%3 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - return %1 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.embedding(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.le.int %5, %int0 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %5 : !torch.int - } - %8 = torch.aten.neg.int %7 : !torch.int -> !torch.int - %9 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.lt.int %int0, %8 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %19 = torch.aten.gt.int %int0, %9 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } - %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %13 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %14 = torch.prim.Loop %13, %true, init(%int1) { - ^bb0(%arg5: !torch.int, %arg6: !torch.int): - %19 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.mul.int %arg6, %19 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%20 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %15 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %16 = torch.aten.le.int %15, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %16 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %17 = torch.prim.ListConstruct : () -> !torch.list - %18 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %18, %true, init() { - ^bb0(%arg5: !torch.int): - %19 = torch.aten.eq.int %int0, %arg5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %19 -> () { - %20 = torch.aten.append.t %17, %14 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - %20 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.append.t %17, %20 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %17 : !torch.list - } else { - %5 = torch.prim.ListConstruct : () -> !torch.list - %6 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - torch.prim.Loop %6, %true, init() { - ^bb0(%arg5: !torch.int): - %9 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.append.t %5, %9 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.append.t %5, %7 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield %5 : !torch.list - } - return %4 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.mm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "AssertionError: mat2 must be a matrix" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: self must be a matrix" - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list - return %9 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.dot(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %8 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %5 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %6 = torch.prim.ListConstruct : () -> !torch.list - return %6 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.mv(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - %8 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %9 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %5 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %7 = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list - return %7 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.matmul(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_0 = torch.constant.str "AssertionError: mat2 must be a matrix" - %str_1 = torch.constant.str "AssertionError: self must be a matrix" - %str_2 = torch.constant.str "AssertionError: " - %none = torch.constant.none - %str_3 = torch.constant.str "AssertionError: both arguments to matmul need to be at least 1D" - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %int-2 = torch.constant.int -2 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 - %0 = torch.prim.Uninitialized : !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.bool) { - %6 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %6 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %5 = torch.prim.If %4 -> (!torch.list) { - %6 = torch.prim.ListConstruct : () -> !torch.list - %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.bool) { - %13 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %14 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %14 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %9 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %6 : !torch.list - } else { - %6 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.bool) { - %9 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %9 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %8 = torch.prim.If %7 -> (!torch.list) { - %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { - %17 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %18 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.eq.int %12, %13 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %14 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list - torch.prim.If.yield %16 : !torch.list - } else { - %9 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.bool) { - %12 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %12 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %11 = torch.prim.If %10 -> (!torch.list) { - %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %13 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.le.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %13 : !torch.int - } - %16 = torch.aten.neg.int %15 : !torch.int -> !torch.int - %17 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int - %18 = torch.aten.lt.int %int0, %16 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %34 = torch.aten.gt.int %int0, %17 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %34 : !torch.bool - } - %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool - torch.prim.If %20 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %21 = torch.prim.ListConstruct : () -> !torch.list - %22 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %22, %true, init() { - ^bb0(%arg2: !torch.int): - %34 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.append.t %21, %34 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.aten.insert.t %21, %int0, %int1 : !torch.list, !torch.int, !torch.int - %23 = torch.aten.len.t %21 : !torch.list -> !torch.int - %24 = torch.aten.eq.int %23, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %24 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %25 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %26 = torch.aten.eq.int %25, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %26 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %27 = torch.aten.__getitem__.t %21, %int1 : !torch.list, !torch.int -> !torch.int - %28 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %29 = torch.aten.eq.int %27, %28 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %29 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %30 = torch.aten.__getitem__.t %21, %int0 : !torch.list, !torch.int -> !torch.int - %31 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %32 = torch.prim.ListConstruct %30, %31 : (!torch.int, !torch.int) -> !torch.list - %33 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %int2, %true, init() { - ^bb0(%arg2: !torch.int): - %34 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %34 -> () { - %35 = torch.aten.__getitem__.t %32, %arg2 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.ne.int %35, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %36 -> () { - %37 = torch.aten.__getitem__.t %32, %arg2 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.append.t %33, %37 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %35 = torch.aten.__getitem__.t %32, %arg2 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.append.t %33, %35 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %33 : !torch.list - } else { - %12 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.bool) { - %15 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %15 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %14 = torch.prim.If %13 -> (!torch.list) { - %15 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %16 = torch.aten.eq.int %15, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %16 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %17 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %18 = torch.aten.eq.int %17, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %18 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.eq.int %19, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %21 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %22 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %23 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %24 = torch.prim.ListConstruct %22, %23 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %24 : !torch.list - } else { - %15 = torch.aten.ge.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.bool) { - %18 = torch.aten.ge.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %18 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %17 = torch.prim.If %16 -> (!torch.list) { - %18 = torch.aten.gt.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.int) { - %31 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %31 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %20 = torch.prim.ListConstruct : () -> !torch.list - %21 = torch.aten.sub.int %1, %int2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop %21, %true, init() { - ^bb0(%arg2: !torch.int): - %31 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %32 = torch.aten.append.t %20, %31 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %22 = torch.aten.__getitem__.t %arg1, %int-1 : !torch.list, !torch.int -> !torch.int - %23 = torch.prim.ListConstruct : () -> !torch.list - %24 = torch.aten.sub.int %2, %int2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop %24, %true, init() { - ^bb0(%arg2: !torch.int): - %31 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %32 = torch.aten.append.t %23, %31 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %25 = torch.aten.len.t %20 : !torch.list -> !torch.int - %26 = torch.aten.len.t %23 : !torch.list -> !torch.int - %27 = torch.prim.max.int %25, %26 : !torch.int, !torch.int -> !torch.int - %28 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %27, %true, init() { - ^bb0(%arg2: !torch.int): - %31 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int - %32 = torch.aten.sub.int %31, %arg2 : !torch.int, !torch.int -> !torch.int - %33 = torch.aten.sub.int %25, %int1 : !torch.int, !torch.int -> !torch.int - %34 = torch.aten.sub.int %33, %32 : !torch.int, !torch.int -> !torch.int - %35 = torch.aten.sub.int %26, %int1 : !torch.int, !torch.int -> !torch.int - %36 = torch.aten.sub.int %35, %32 : !torch.int, !torch.int -> !torch.int - %37 = torch.aten.ge.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %38 = torch.prim.If %37 -> (!torch.int) { - %47 = torch.aten.__getitem__.t %20, %34 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %47 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %39 = torch.aten.ge.int %36, %int0 : !torch.int, !torch.int -> !torch.bool - %40 = torch.prim.If %39 -> (!torch.int) { - %47 = torch.aten.__getitem__.t %23, %36 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %47 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %41 = torch.aten.ne.int %38, %40 : !torch.int, !torch.int -> !torch.bool - %42 = torch.prim.If %41 -> (!torch.bool) { - %47 = torch.aten.ne.int %38, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %47 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %43 = torch.prim.If %42 -> (!torch.bool) { - %47 = torch.aten.ne.int %40, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %47 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %43 -> () { - %47 = torch.aten.format(%str, %38, %40, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %48 = torch.aten.add.str %str_2, %47 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %48, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %44 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool - %45 = torch.prim.If %44 -> (!torch.int) { - torch.prim.If.yield %40 : !torch.int - } else { - torch.prim.If.yield %38 : !torch.int - } - %46 = torch.aten.append.t %28, %45 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %29 = torch.aten.gt.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %29 -> () { - %31 = torch.aten.append.t %28, %19 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %30 = torch.aten.gt.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %30 -> () { - %31 = torch.aten.append.t %28, %22 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield %28 : !torch.list - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield %0 : !torch.list - } - torch.prim.If.yield %17 : !torch.list - } - torch.prim.If.yield %14 : !torch.list - } - torch.prim.If.yield %11 : !torch.list - } - torch.prim.If.yield %8 : !torch.list - } - return %5 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %false = torch.constant.bool false - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int - %3 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %2, %true, init() { - ^bb0(%arg2: !torch.int): - %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int - %7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int - %8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int - %9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.int) { - %20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - %20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.bool) { - %20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %16 = torch.prim.If %15 -> (!torch.bool) { - %20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %16 -> () { - %20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %21, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.int) { - torch.prim.If.yield %13 : !torch.int - } else { - torch.prim.If.yield %11 : !torch.int - } - %19 = torch.aten.append.t %3, %18 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %3 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.linear(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list { - %str = torch.constant.str "AssertionError: both arguments to matmul need to be at least 1D" - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %int-2 = torch.constant.int -2 - %false = torch.constant.bool false - %str_0 = torch.constant.str "AssertionError: self must be a matrix" - %str_1 = torch.constant.str "AssertionError: mat2 must be a matrix" - %str_2 = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_3 = torch.constant.str "AssertionError: " - %none = torch.constant.none - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.list) { - %13 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.If.yield %13 : !torch.list - } else { - %13 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.list) { - %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list - torch.prim.If.yield %16 : !torch.list - } else { - %15 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %17 = torch.prim.ListConstruct %15, %16 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %17 : !torch.list - } - torch.prim.If.yield %14 : !torch.list - } - %5 = torch.prim.ListConstruct : () -> !torch.list - %6 = torch.prim.Uninitialized : !torch.list - %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %8 = torch.aten.len.t %4 : !torch.list -> !torch.int - %9 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.bool) { - %13 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %13 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %11 = torch.prim.If %10 -> (!torch.list) { - %13 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %14 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.bool) { - %19 = torch.aten.len.t %4 : !torch.list -> !torch.int - %20 = torch.aten.eq.int %19, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.eq.int %16, %17 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %18 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %5 : !torch.list - } else { - %13 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.bool) { - %16 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %16 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %15 = torch.prim.If %14 -> (!torch.list) { - %16 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %17 = torch.aten.eq.int %16, %int2 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.bool) { - %24 = torch.aten.len.t %4 : !torch.list -> !torch.int - %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %25 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %18 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.eq.int %19, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %21 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %22 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %23 = torch.prim.ListConstruct %22 : (!torch.int) -> !torch.list - torch.prim.If.yield %23 : !torch.list - } else { - %16 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.bool) { - %19 = torch.aten.eq.int %8, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %18 = torch.prim.If %17 -> (!torch.list) { - %19 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %20 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.le.int %20, %int0 : !torch.int, !torch.int -> !torch.bool - %22 = torch.prim.If %21 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %20 : !torch.int - } - %23 = torch.aten.neg.int %22 : !torch.int -> !torch.int - %24 = torch.aten.sub.int %22, %int1 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.lt.int %int0, %23 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %41 = torch.aten.gt.int %int0, %24 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %41 : !torch.bool - } - %27 = torch.aten.__not__ %26 : !torch.bool -> !torch.bool - torch.prim.If %27 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %28 = torch.prim.ListConstruct : () -> !torch.list - %29 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %29, %true, init() { - ^bb0(%arg3: !torch.int): - %41 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %42 = torch.aten.append.t %28, %41 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.aten.insert.t %28, %int0, %int1 : !torch.list, !torch.int, !torch.int - %30 = torch.aten.len.t %28 : !torch.list -> !torch.int - %31 = torch.aten.eq.int %30, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %31 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %32 = torch.aten.len.t %4 : !torch.list -> !torch.int - %33 = torch.aten.eq.int %32, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %33 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %34 = torch.aten.__getitem__.t %28, %int1 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %36 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %37 = torch.aten.__getitem__.t %28, %int0 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.__getitem__.t %4, %int1 : !torch.list, !torch.int -> !torch.int - %39 = torch.prim.ListConstruct %37, %38 : (!torch.int, !torch.int) -> !torch.list - %40 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %int2, %true, init() { - ^bb0(%arg3: !torch.int): - %41 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %41 -> () { - %42 = torch.aten.__getitem__.t %39, %arg3 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.ne.int %42, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %43 -> () { - %44 = torch.aten.__getitem__.t %39, %arg3 : !torch.list, !torch.int -> !torch.int - %45 = torch.aten.append.t %40, %44 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %42 = torch.aten.__getitem__.t %39, %arg3 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.append.t %40, %42 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %40 : !torch.list - } else { - %19 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool - %20 = torch.prim.If %19 -> (!torch.bool) { - %22 = torch.aten.eq.int %8, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %22 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %21 = torch.prim.If %20 -> (!torch.list) { - %22 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %23 = torch.aten.eq.int %22, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %23 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %24 = torch.aten.len.t %4 : !torch.list -> !torch.int - %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %25 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %26 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %27 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int - %28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %28 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %30 = torch.aten.__getitem__.t %4, %int1 : !torch.list, !torch.int -> !torch.int - %31 = torch.prim.ListConstruct %29, %30 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %31 : !torch.list - } else { - %22 = torch.aten.ge.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - %25 = torch.aten.ge.int %8, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %25 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %24 = torch.prim.If %23 -> (!torch.list) { - %25 = torch.aten.gt.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.int) { - %38 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %38 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %27 = torch.prim.ListConstruct : () -> !torch.list - %28 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop %28, %true, init() { - ^bb0(%arg3: !torch.int): - %38 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.append.t %27, %38 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %29 = torch.aten.__getitem__.t %4, %int-1 : !torch.list, !torch.int -> !torch.int - %30 = torch.prim.ListConstruct : () -> !torch.list - %31 = torch.aten.sub.int %8, %int2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop %31, %true, init() { - ^bb0(%arg3: !torch.int): - %38 = torch.aten.__getitem__.t %4, %arg3 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.append.t %30, %38 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %32 = torch.aten.len.t %27 : !torch.list -> !torch.int - %33 = torch.aten.len.t %30 : !torch.list -> !torch.int - %34 = torch.prim.max.int %32, %33 : !torch.int, !torch.int -> !torch.int - %35 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %34, %true, init() { - ^bb0(%arg3: !torch.int): - %38 = torch.aten.sub.int %34, %int1 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.sub.int %38, %arg3 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.sub.int %32, %int1 : !torch.int, !torch.int -> !torch.int - %41 = torch.aten.sub.int %40, %39 : !torch.int, !torch.int -> !torch.int - %42 = torch.aten.sub.int %33, %int1 : !torch.int, !torch.int -> !torch.int - %43 = torch.aten.sub.int %42, %39 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.ge.int %41, %int0 : !torch.int, !torch.int -> !torch.bool - %45 = torch.prim.If %44 -> (!torch.int) { - %54 = torch.aten.__getitem__.t %27, %41 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %54 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %46 = torch.aten.ge.int %43, %int0 : !torch.int, !torch.int -> !torch.bool - %47 = torch.prim.If %46 -> (!torch.int) { - %54 = torch.aten.__getitem__.t %30, %43 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %54 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %48 = torch.aten.ne.int %45, %47 : !torch.int, !torch.int -> !torch.bool - %49 = torch.prim.If %48 -> (!torch.bool) { - %54 = torch.aten.ne.int %45, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %54 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %50 = torch.prim.If %49 -> (!torch.bool) { - %54 = torch.aten.ne.int %47, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %54 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %50 -> () { - %54 = torch.aten.format(%str_2, %45, %47, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %55 = torch.aten.add.str %str_3, %54 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %55, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %51 = torch.aten.eq.int %45, %int1 : !torch.int, !torch.int -> !torch.bool - %52 = torch.prim.If %51 -> (!torch.int) { - torch.prim.If.yield %47 : !torch.int - } else { - torch.prim.If.yield %45 : !torch.int - } - %53 = torch.aten.append.t %35, %52 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %36 = torch.aten.gt.int %7, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %36 -> () { - %38 = torch.aten.append.t %35, %26 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %37 = torch.aten.gt.int %8, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %37 -> () { - %38 = torch.aten.append.t %35, %29 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield %35 : !torch.list - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield %6 : !torch.list - } - torch.prim.If.yield %24 : !torch.list - } - torch.prim.If.yield %21 : !torch.list - } - torch.prim.If.yield %18 : !torch.list - } - torch.prim.If.yield %15 : !torch.list - } - %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - torch.prim.If %12 -> () { - %13 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %14 = torch.aten.len.t %13 : !torch.list -> !torch.int - %15 = torch.aten.len.t %11 : !torch.list -> !torch.int - %16 = torch.prim.max.int %14, %15 : !torch.int, !torch.int -> !torch.int - %17 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %16, %true, init() { - ^bb0(%arg3: !torch.int): - %19 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.sub.int %19, %arg3 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %21, %20 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.sub.int %23, %20 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.ge.int %22, %int0 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.int) { - %35 = torch.aten.__getitem__.t %13, %22 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %35 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %27 = torch.aten.ge.int %24, %int0 : !torch.int, !torch.int -> !torch.bool - %28 = torch.prim.If %27 -> (!torch.int) { - %35 = torch.aten.__getitem__.t %11, %24 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %35 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %29 = torch.aten.ne.int %26, %28 : !torch.int, !torch.int -> !torch.bool - %30 = torch.prim.If %29 -> (!torch.bool) { - %35 = torch.aten.ne.int %26, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %35 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %31 = torch.prim.If %30 -> (!torch.bool) { - %35 = torch.aten.ne.int %28, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %35 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %31 -> () { - %35 = torch.aten.format(%str_2, %26, %28, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %36 = torch.aten.add.str %str_3, %35 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %36, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %32 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool - %33 = torch.prim.If %32 -> (!torch.int) { - torch.prim.If.yield %28 : !torch.int - } else { - torch.prim.If.yield %26 : !torch.int - } - %34 = torch.aten.append.t %17, %33 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %18 = torch.aten.eq.int_list %17, %11 : !torch.list, !torch.list -> !torch.bool - torch.prim.If %18 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield - } else { - torch.prim.If.yield - } - return %11 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.t(%arg0: !torch.list) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.list) { - %5 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.If.yield %5 : !torch.list - } else { - %5 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.list) { - %7 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list - torch.prim.If.yield %8 : !torch.list - } else { - %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %9 : !torch.list - } - torch.prim.If.yield %6 : !torch.list - } - return %4 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.max_pool2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list { - %false = torch.constant.bool false - %str = torch.constant.str "AssertionError: stride should not be zeero" - %int-1 = torch.constant.int -1 - %int-2 = torch.constant.int -2 - %int-3 = torch.constant.int -3 - %int-4 = torch.constant.int -4 - %str_0 = torch.constant.str "AssertionError: " - %str_1 = torch.constant.str "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints" - %str_2 = torch.constant.str "AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints" - %str_3 = torch.constant.str "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" - %none = torch.constant.none - %str_4 = torch.constant.str "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints" - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.int) { - torch.prim.If.yield %3 : !torch.int - } else { - %86 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } - %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } - %10 = torch.prim.If %9 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - torch.prim.If.yield %3 : !torch.int - } else { - %86 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } - %14 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - torch.prim.If.yield %6 : !torch.int - } else { - %86 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int1 : !torch.int, !torch.int -> !torch.bool - %88 = torch.prim.If %87 -> (!torch.int) { - torch.prim.If.yield %13 : !torch.int - } else { - %89 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %89 : !torch.int - } - torch.prim.If.yield %88 : !torch.int - } - %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } - torch.prim.If %19 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.int) { - torch.prim.If.yield %20 : !torch.int - } else { - %86 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } - %24 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } - torch.prim.If %26 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %27 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int - %28 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %29 = torch.aten.eq.int %28, %int1 : !torch.int, !torch.int -> !torch.bool - %30 = torch.prim.If %29 -> (!torch.int) { - torch.prim.If.yield %27 : !torch.int - } else { - %86 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } - %31 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %32 = torch.aten.eq.int %31, %int3 : !torch.int, !torch.int -> !torch.bool - %33 = torch.prim.If %32 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %87 = torch.aten.eq.int %86, %int4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } - torch.prim.If %33 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %34 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %35 = torch.aten.eq.int %34, %int4 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.int) { - %86 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %37 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int - %40 = torch.aten.ne.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %40 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %41 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int - %42 = torch.aten.add.int %41, %20 : !torch.int, !torch.int -> !torch.int - %43 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.mul.int %27, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.sub.int %42, %44 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.sub.int %45, %int1 : !torch.int, !torch.int -> !torch.int - %47 = torch.prim.If %arg5 -> (!torch.int) { - %86 = torch.aten.sub.int %13, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %48 = torch.aten.add.int %46, %47 : !torch.int, !torch.int -> !torch.int - %49 = torch.aten.floordiv.int %48, %13 : !torch.int, !torch.int -> !torch.int - %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int - %51 = torch.prim.If %arg5 -> (!torch.int) { - %86 = torch.aten.mul.int %49, %13 : !torch.int, !torch.int -> !torch.int - %87 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int - %88 = torch.aten.ge.int %86, %87 : !torch.int, !torch.int -> !torch.bool - %89 = torch.prim.If %88 -> (!torch.int) { - torch.prim.If.yield %49 : !torch.int - } else { - torch.prim.If.yield %50 : !torch.int - } - torch.prim.If.yield %89 : !torch.int - } else { - torch.prim.If.yield %50 : !torch.int - } - %52 = torch.aten.ne.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %52 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %53 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int - %54 = torch.aten.add.int %53, %23 : !torch.int, !torch.int -> !torch.int - %55 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int - %56 = torch.aten.mul.int %30, %55 : !torch.int, !torch.int -> !torch.int - %57 = torch.aten.sub.int %54, %56 : !torch.int, !torch.int -> !torch.int - %58 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int - %59 = torch.prim.If %arg5 -> (!torch.int) { - %86 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %86 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %60 = torch.aten.add.int %58, %59 : !torch.int, !torch.int -> !torch.int - %61 = torch.aten.floordiv.int %60, %16 : !torch.int, !torch.int -> !torch.int - %62 = torch.aten.add.int %61, %int1 : !torch.int, !torch.int -> !torch.int - %63 = torch.prim.If %arg5 -> (!torch.int) { - %86 = torch.aten.mul.int %61, %16 : !torch.int, !torch.int -> !torch.int - %87 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int - %88 = torch.aten.ge.int %86, %87 : !torch.int, !torch.int -> !torch.bool - %89 = torch.prim.If %88 -> (!torch.int) { - torch.prim.If.yield %61 : !torch.int - } else { - torch.prim.If.yield %62 : !torch.int - } - torch.prim.If.yield %89 : !torch.int - } else { - torch.prim.If.yield %62 : !torch.int - } - %64 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %65 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %66 = torch.prim.If %65 -> (!torch.bool) { - %86 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %86 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %66 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %67 = torch.aten.gt.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - %68 = torch.prim.If %67 -> (!torch.bool) { - %86 = torch.aten.gt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %86 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %68 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %69 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool - %70 = torch.prim.If %69 -> (!torch.bool) { - %86 = torch.aten.gt.int %30, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %86 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %70 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %71 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %72 = torch.aten.ne.int %71, %int0 : !torch.int, !torch.int -> !torch.bool - %73 = torch.prim.If %72 -> (!torch.bool) { - %86 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int - %87 = torch.aten.ne.int %86, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %74 = torch.aten.eq.int %64, %int3 : !torch.int, !torch.int -> !torch.bool - %75 = torch.prim.If %74 -> (!torch.bool) { - %86 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %87 = torch.aten.ne.int %86, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %76 = torch.prim.If %75 -> (!torch.bool) { - torch.prim.If.yield %73 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %77 = torch.prim.If %76 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %86 = torch.aten.eq.int %64, %int4 : !torch.int, !torch.int -> !torch.bool - %87 = torch.prim.If %86 -> (!torch.bool) { - torch.prim.If.yield %73 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %88 = torch.prim.If %87 -> (!torch.bool) { - %89 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int - %90 = torch.aten.ne.int %89, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %90 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %88 : !torch.bool - } - torch.prim.If %77 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %78 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int - %79 = torch.aten.ge.int %78, %23 : !torch.int, !torch.int -> !torch.bool - %80 = torch.prim.If %79 -> (!torch.bool) { - %86 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int - %87 = torch.aten.ge.int %86, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %80 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %81 = torch.aten.ge.int %63, %int1 : !torch.int, !torch.int -> !torch.bool - %82 = torch.prim.If %81 -> (!torch.bool) { - %86 = torch.aten.ge.int %51, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %86 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %82 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %83 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %84 = torch.aten.eq.int %83, %int3 : !torch.int, !torch.int -> !torch.bool - %85 = torch.prim.If %84 -> (!torch.list) { - %86 = torch.prim.ListConstruct %37, %51, %63 : (!torch.int, !torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %86 : !torch.list - } else { - %86 = torch.prim.ListConstruct %36, %37, %51, %63 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %86 : !torch.list - } - return %85 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.pooling_output_shape(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.bool) -> !torch.int { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: stride should not be zeero" - %int0 = torch.constant.int 0 - %0 = torch.aten.ne.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %0 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %1 = call @__torch__.torch.jit._shape_functions.pooling_output_shape_pad_lr(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int - return %1 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.pooling_output_shape_pad_lr(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int - %1 = torch.aten.add.int %0, %arg3 : !torch.int, !torch.int -> !torch.int - %2 = torch.aten.sub.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int - %3 = torch.aten.mul.int %arg5, %2 : !torch.int, !torch.int -> !torch.int - %4 = torch.aten.sub.int %1, %3 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int - %6 = torch.prim.If %arg6 -> (!torch.int) { - %11 = torch.aten.sub.int %arg4, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %11 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %7 = torch.aten.add.int %5, %6 : !torch.int, !torch.int -> !torch.int - %8 = call @__torch__.torch.jit._shape_functions.div_rtn(%7, %arg4) : (!torch.int, !torch.int) -> !torch.int - %9 = torch.aten.add.int %8, %int1 : !torch.int, !torch.int -> !torch.int - %10 = torch.prim.If %arg6 -> (!torch.int) { - %11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.mul.int %11, %arg4 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.ge.int %12, %13 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.int) { - %16 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %16 : !torch.int - } else { - torch.prim.If.yield %9 : !torch.int - } - torch.prim.If.yield %15 : !torch.int - } else { - torch.prim.If.yield %9 : !torch.int - } - return %10 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.div_rtn(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int { - %0 = torch.aten.floordiv.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int - return %0 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.pool2d_shape_check(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.int) -> !torch.none { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - %19 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.bool) { - %19 = torch.aten.gt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %4 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %5 = torch.aten.gt.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.bool) { - %19 = torch.aten.gt.int %arg8, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.ne.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.bool) { - %19 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.ne.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %10 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { - %19 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.ne.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %12 = torch.prim.If %11 -> (!torch.bool) { - torch.prim.If.yield %9 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %13 = torch.prim.If %12 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %19 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool - %20 = torch.prim.If %19 -> (!torch.bool) { - torch.prim.If.yield %9 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %21 = torch.prim.If %20 -> (!torch.bool) { - %22 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int - %23 = torch.aten.ne.int %22, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %23 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %21 : !torch.bool - } - torch.prim.If %13 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %14 = torch.aten.floordiv.int %arg2, %int2 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.ge.int %14, %arg6 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.bool) { - %19 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.ge.int %19, %arg5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %16 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %17 = torch.aten.ge.int %arg13, %int1 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.bool) { - %19 = torch.aten.ge.int %arg12, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %19 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %18 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - return %none : !torch.none - } - func.func @__torch__.torch.jit._shape_functions.max_pool2d_with_indices(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> { - %false = torch.constant.bool false - %str = torch.constant.str "AssertionError: stride should not be zeero" - %int4 = torch.constant.int 4 - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %str_0 = torch.constant.str "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" - %str_2 = torch.constant.str "AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints" - %str_3 = torch.constant.str "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints" - %str_4 = torch.constant.str "AssertionError: " - %int-4 = torch.constant.int -4 - %int-3 = torch.constant.int -3 - %int-2 = torch.constant.int -2 - %int-1 = torch.constant.int -1 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.int) { - torch.prim.If.yield %3 : !torch.int - } else { - %87 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } - %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } - %10 = torch.prim.If %9 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - torch.prim.If.yield %3 : !torch.int - } else { - %87 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } - %14 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - torch.prim.If.yield %6 : !torch.int - } else { - %87 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int1 : !torch.int, !torch.int -> !torch.bool - %89 = torch.prim.If %88 -> (!torch.int) { - torch.prim.If.yield %13 : !torch.int - } else { - %90 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %90 : !torch.int - } - torch.prim.If.yield %89 : !torch.int - } - %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } - torch.prim.If %19 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.int) { - torch.prim.If.yield %20 : !torch.int - } else { - %87 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } - %24 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } - torch.prim.If %26 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %27 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int - %28 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %29 = torch.aten.eq.int %28, %int1 : !torch.int, !torch.int -> !torch.bool - %30 = torch.prim.If %29 -> (!torch.int) { - torch.prim.If.yield %27 : !torch.int - } else { - %87 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } - %31 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %32 = torch.aten.eq.int %31, %int3 : !torch.int, !torch.int -> !torch.bool - %33 = torch.prim.If %32 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %88 = torch.aten.eq.int %87, %int4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } - torch.prim.If %33 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %34 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %35 = torch.aten.eq.int %34, %int4 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.int) { - %87 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %37 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int - %40 = torch.aten.ne.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %40 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %41 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int - %42 = torch.aten.add.int %41, %20 : !torch.int, !torch.int -> !torch.int - %43 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.mul.int %27, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.sub.int %42, %44 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.sub.int %45, %int1 : !torch.int, !torch.int -> !torch.int - %47 = torch.prim.If %arg5 -> (!torch.int) { - %87 = torch.aten.sub.int %13, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %48 = torch.aten.add.int %46, %47 : !torch.int, !torch.int -> !torch.int - %49 = torch.aten.floordiv.int %48, %13 : !torch.int, !torch.int -> !torch.int - %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int - %51 = torch.prim.If %arg5 -> (!torch.int) { - %87 = torch.aten.mul.int %49, %13 : !torch.int, !torch.int -> !torch.int - %88 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int - %89 = torch.aten.ge.int %87, %88 : !torch.int, !torch.int -> !torch.bool - %90 = torch.prim.If %89 -> (!torch.int) { - torch.prim.If.yield %49 : !torch.int - } else { - torch.prim.If.yield %50 : !torch.int - } - torch.prim.If.yield %90 : !torch.int - } else { - torch.prim.If.yield %50 : !torch.int - } - %52 = torch.aten.ne.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %52 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %53 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int - %54 = torch.aten.add.int %53, %23 : !torch.int, !torch.int -> !torch.int - %55 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int - %56 = torch.aten.mul.int %30, %55 : !torch.int, !torch.int -> !torch.int - %57 = torch.aten.sub.int %54, %56 : !torch.int, !torch.int -> !torch.int - %58 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int - %59 = torch.prim.If %arg5 -> (!torch.int) { - %87 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %87 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %60 = torch.aten.add.int %58, %59 : !torch.int, !torch.int -> !torch.int - %61 = torch.aten.floordiv.int %60, %16 : !torch.int, !torch.int -> !torch.int - %62 = torch.aten.add.int %61, %int1 : !torch.int, !torch.int -> !torch.int - %63 = torch.prim.If %arg5 -> (!torch.int) { - %87 = torch.aten.mul.int %61, %16 : !torch.int, !torch.int -> !torch.int - %88 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int - %89 = torch.aten.ge.int %87, %88 : !torch.int, !torch.int -> !torch.bool - %90 = torch.prim.If %89 -> (!torch.int) { - torch.prim.If.yield %61 : !torch.int - } else { - torch.prim.If.yield %62 : !torch.int - } - torch.prim.If.yield %90 : !torch.int - } else { - torch.prim.If.yield %62 : !torch.int - } - %64 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %65 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %66 = torch.prim.If %65 -> (!torch.bool) { - %87 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %66 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %67 = torch.aten.gt.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - %68 = torch.prim.If %67 -> (!torch.bool) { - %87 = torch.aten.gt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %68 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %69 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool - %70 = torch.prim.If %69 -> (!torch.bool) { - %87 = torch.aten.gt.int %30, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %70 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %71 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %72 = torch.aten.ne.int %71, %int0 : !torch.int, !torch.int -> !torch.bool - %73 = torch.prim.If %72 -> (!torch.bool) { - %87 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int - %88 = torch.aten.ne.int %87, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %74 = torch.aten.eq.int %64, %int3 : !torch.int, !torch.int -> !torch.bool - %75 = torch.prim.If %74 -> (!torch.bool) { - %87 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %88 = torch.aten.ne.int %87, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %76 = torch.prim.If %75 -> (!torch.bool) { - torch.prim.If.yield %73 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %77 = torch.prim.If %76 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %87 = torch.aten.eq.int %64, %int4 : !torch.int, !torch.int -> !torch.bool - %88 = torch.prim.If %87 -> (!torch.bool) { - torch.prim.If.yield %73 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %89 = torch.prim.If %88 -> (!torch.bool) { - %90 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int - %91 = torch.aten.ne.int %90, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %91 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %89 : !torch.bool - } - torch.prim.If %77 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %78 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int - %79 = torch.aten.ge.int %78, %23 : !torch.int, !torch.int -> !torch.bool - %80 = torch.prim.If %79 -> (!torch.bool) { - %87 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int - %88 = torch.aten.ge.int %87, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %88 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %80 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %81 = torch.aten.ge.int %63, %int1 : !torch.int, !torch.int -> !torch.bool - %82 = torch.prim.If %81 -> (!torch.bool) { - %87 = torch.aten.ge.int %51, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %87 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %82 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %83 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %84 = torch.aten.eq.int %83, %int3 : !torch.int, !torch.int -> !torch.bool - %85 = torch.prim.If %84 -> (!torch.list) { - %87 = torch.prim.ListConstruct %37, %51, %63 : (!torch.int, !torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %87 : !torch.list - } else { - %87 = torch.prim.ListConstruct %36, %37, %51, %63 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %87 : !torch.list - } - %86 = torch.prim.TupleConstruct %85, %85 : !torch.list, !torch.list -> !torch.tuple, list> - return %86 : !torch.tuple, list> - } - func.func @__torch__.torch.jit._shape_functions.transpose(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %0 : !torch.int - } - %3 = torch.aten.neg.int %2 : !torch.int -> !torch.int - %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %21 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %21 : !torch.bool - } - %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool - torch.prim.If %7 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.int) { - %21 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %21 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %10 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %0 : !torch.int - } - %12 = torch.aten.neg.int %11 : !torch.int -> !torch.int - %13 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.lt.int %arg2, %12 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %21 = torch.aten.gt.int %arg2, %13 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %21 : !torch.bool - } - %16 = torch.aten.__not__ %15 : !torch.bool -> !torch.bool - torch.prim.If %16 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %17 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.int) { - %21 = torch.aten.add.int %arg2, %11 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %21 : !torch.int - } else { - torch.prim.If.yield %arg2 : !torch.int - } - %19 = torch.aten.eq.int %9, %18 : !torch.int, !torch.int -> !torch.bool - %20 = torch.prim.If %19 -> (!torch.list) { - %21 = torch.prim.ListConstruct : () -> !torch.list - %22 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %22, %true, init() { - ^bb0(%arg3: !torch.int): - %23 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.append.t %21, %23 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %21 : !torch.list - } else { - %21 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %0, %true, init() { - ^bb0(%arg3: !torch.int): - %22 = torch.aten.eq.int %arg3, %9 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %22 -> () { - %23 = torch.aten.__getitem__.t %arg0, %18 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.append.t %21, %23 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - %23 = torch.aten.eq.int %arg3, %18 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %23 -> () { - %24 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int - %25 = torch.aten.append.t %21, %24 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - %24 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %25 = torch.aten.append.t %21, %24 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %21 : !torch.list - } - return %20 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.conv1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int3 = torch.constant.int 3 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %6 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %7 = torch.prim.Loop %6, %true, init(%false) { - ^bb0(%arg7: !torch.int, %arg8: !torch.bool): - %34 = torch.aten.__getitem__.t %arg4, %arg7 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg8 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%36 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %10 = torch.prim.Loop %9, %true, init(%false) { - ^bb0(%arg7: !torch.int, %arg8: !torch.bool): - %34 = torch.aten.__getitem__.t %arg3, %arg7 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg8 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%36 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.eq.int %5, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.ge.int %13, %arg6 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %14 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.remainder.int %15, %arg6 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %17 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %18 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.mul.int %19, %arg6 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.eq.int %18, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %21 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %22 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %34 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %35 = torch.aten.len.t %34 : !torch.list -> !torch.int - %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool - %37 = torch.prim.If %36 -> (!torch.bool) { - %38 = torch.aten.__getitem__.t %34, %int0 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %37 : !torch.bool - } - torch.prim.If %23 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %24 = torch.aten.__range_length %int2, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %24, %true, init() { - ^bb0(%arg7: !torch.int): - %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %37 = torch.aten.__getitem__.t %arg4, %36 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.mul.int %37, %int2 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.add.int %35, %38 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %41 = torch.aten.__getitem__.t %arg5, %40 : !torch.list, !torch.int -> !torch.int - %42 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.sub.int %42, %int1 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.mul.int %41, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.ge.int %39, %45 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %46 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %25 = torch.aten.len.t %arg5 : !torch.list -> !torch.int - %26 = torch.aten.gt.int %25, %int0 : !torch.int, !torch.int -> !torch.bool - %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %28 = torch.prim.ListConstruct : () -> !torch.list - %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %30 = torch.aten.append.t %28, %29 : !torch.list, !torch.int -> !torch.list - %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list - %33 = torch.aten.__range_length %int2, %27, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %33, %true, init() { - ^bb0(%arg7: !torch.int): - %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.prim.If %26 -> (!torch.int) { - %51 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %52 = torch.aten.__getitem__.t %arg5, %51 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %52 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %36 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int - %37 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int - %38 = torch.aten.mul.int %35, %37 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.add.int %38, %int1 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %41 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %42 = torch.aten.__getitem__.t %arg4, %41 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.mul.int %42, %int2 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.add.int %40, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.sub.int %44, %39 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %47 = torch.aten.__getitem__.t %arg3, %46 : !torch.list, !torch.int -> !torch.int - %48 = torch.aten.floordiv.int %45, %47 : !torch.int, !torch.int -> !torch.int - %49 = torch.aten.add.int %48, %int1 : !torch.int, !torch.int -> !torch.int - %50 = torch.aten.append.t %28, %49 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %28 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.conv_output_size(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = call @__torch__.torch.jit._shape_functions.check_shape_forward(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.none - %1 = torch.aten.len.t %arg5 : !torch.list -> !torch.int - %2 = torch.aten.gt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %4 = torch.prim.ListConstruct : () -> !torch.list - %5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.append.t %4, %5 : !torch.list, !torch.int -> !torch.list - %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.append.t %4, %7 : !torch.list, !torch.int -> !torch.list - %9 = torch.aten.__range_length %int2, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %9, %true, init() { - ^bb0(%arg7: !torch.int): - %10 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %11 = torch.prim.If %2 -> (!torch.int) { - %27 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %28 = torch.aten.__getitem__.t %arg5, %27 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %28 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int - %16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.mul.int %int2, %18 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int - %26 = torch.aten.append.t %4, %25 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %4 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.check_shape_forward(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.none { - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = call @__torch__.torch.jit._shape_functions.check_non_negative(%arg4) : (!torch.list) -> !torch.bool - %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = call @__torch__.torch.jit._shape_functions.check_non_negative(%arg3) : (!torch.list) -> !torch.bool - %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool - torch.prim.If %5 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %6 = torch.aten.eq.int %1, %0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.ge.int %7, %arg6 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.remainder.int %9, %arg6 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.eq.int %10, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.mul.int %13, %arg6 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.eq.int %12, %14 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %16 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %19 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %20 = torch.aten.len.t %19 : !torch.list -> !torch.int - %21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool - %22 = torch.prim.If %21 -> (!torch.bool) { - %23 = torch.aten.__getitem__.t %19, %int0 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %25 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %22 : !torch.bool - } - torch.prim.If %17 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %18 = torch.aten.__range_length %int2, %0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %18, %true, init() { - ^bb0(%arg7: !torch.int): - %19 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %arg0, %19 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.sub.int %19, %int2 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.__getitem__.t %arg4, %21 : !torch.list, !torch.int -> !torch.int - %23 = torch.aten.mul.int %int2, %22 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.add.int %20, %23 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.sub.int %19, %int2 : !torch.int, !torch.int -> !torch.int - %26 = torch.aten.__getitem__.t %arg5, %25 : !torch.list, !torch.int -> !torch.int - %27 = torch.aten.__getitem__.t %arg1, %19 : !torch.list, !torch.int -> !torch.int - %28 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int - %29 = torch.aten.mul.int %26, %28 : !torch.int, !torch.int -> !torch.int - %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int - %31 = torch.aten.ge.int %24, %30 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %31 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %none : !torch.none - } - func.func @__torch__.torch.jit._shape_functions.check_non_negative(%arg0: !torch.list) -> !torch.bool { - %true = torch.constant.bool true - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.Loop %0, %true, init(%false) { - ^bb0(%arg1: !torch.int, %arg2: !torch.bool): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.lt.int %2, %int0 : !torch.int, !torch.int -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg2 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%4 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - return %1 : !torch.bool - } - func.func @__torch__.torch.jit._shape_functions.conv2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int4 = torch.constant.int 4 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %6 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %7 = torch.prim.Loop %6, %true, init(%false) { - ^bb0(%arg7: !torch.int, %arg8: !torch.bool): - %34 = torch.aten.__getitem__.t %arg4, %arg7 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg8 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%36 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %10 = torch.prim.Loop %9, %true, init(%false) { - ^bb0(%arg7: !torch.int, %arg8: !torch.bool): - %34 = torch.aten.__getitem__.t %arg3, %arg7 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg8 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%36 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.eq.int %5, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.ge.int %13, %arg6 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %14 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.remainder.int %15, %arg6 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %17 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %18 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.mul.int %19, %arg6 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.eq.int %18, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %21 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %22 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %34 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %35 = torch.aten.len.t %34 : !torch.list -> !torch.int - %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool - %37 = torch.prim.If %36 -> (!torch.bool) { - %38 = torch.aten.__getitem__.t %34, %int0 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %37 : !torch.bool - } - torch.prim.If %23 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %24 = torch.aten.__range_length %int2, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %24, %true, init() { - ^bb0(%arg7: !torch.int): - %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %37 = torch.aten.__getitem__.t %arg4, %36 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.mul.int %37, %int2 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.add.int %35, %38 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %41 = torch.aten.__getitem__.t %arg5, %40 : !torch.list, !torch.int -> !torch.int - %42 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.sub.int %42, %int1 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.mul.int %41, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.ge.int %39, %45 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %46 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %25 = torch.aten.len.t %arg5 : !torch.list -> !torch.int - %26 = torch.aten.gt.int %25, %int0 : !torch.int, !torch.int -> !torch.bool - %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %28 = torch.prim.ListConstruct : () -> !torch.list - %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %30 = torch.aten.append.t %28, %29 : !torch.list, !torch.int -> !torch.list - %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list - %33 = torch.aten.__range_length %int2, %27, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %33, %true, init() { - ^bb0(%arg7: !torch.int): - %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.prim.If %26 -> (!torch.int) { - %51 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %52 = torch.aten.__getitem__.t %arg5, %51 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %52 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %36 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int - %37 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int - %38 = torch.aten.mul.int %35, %37 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.add.int %38, %int1 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %41 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %42 = torch.aten.__getitem__.t %arg4, %41 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.mul.int %42, %int2 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.add.int %40, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.sub.int %44, %39 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %47 = torch.aten.__getitem__.t %arg3, %46 : !torch.list, !torch.int -> !torch.int - %48 = torch.aten.floordiv.int %45, %47 : !torch.int, !torch.int -> !torch.int - %49 = torch.aten.add.int %48, %int1 : !torch.int, !torch.int -> !torch.int - %50 = torch.aten.append.t %28, %49 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %28 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.batch_norm(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { - ^bb0(%arg9: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg9 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.append.t %0, %2 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.conv3d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %false = torch.constant.bool false - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int5 = torch.constant.int 5 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %6 = torch.aten.len.t %arg4 : !torch.list -> !torch.int - %7 = torch.prim.Loop %6, %true, init(%false) { - ^bb0(%arg7: !torch.int, %arg8: !torch.bool): - %34 = torch.aten.__getitem__.t %arg4, %arg7 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg8 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%36 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %10 = torch.prim.Loop %9, %true, init(%false) { - ^bb0(%arg7: !torch.int, %arg8: !torch.bool): - %34 = torch.aten.__getitem__.t %arg3, %arg7 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool - %36 = torch.prim.If %35 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg8 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%36 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.eq.int %5, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.ge.int %13, %arg6 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %14 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.remainder.int %15, %arg6 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %17 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %18 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.mul.int %19, %arg6 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.eq.int %18, %20 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %21 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %22 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %34 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %35 = torch.aten.len.t %34 : !torch.list -> !torch.int - %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool - %37 = torch.prim.If %36 -> (!torch.bool) { - %38 = torch.aten.__getitem__.t %34, %int0 : !torch.list, !torch.int -> !torch.int - %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %37 : !torch.bool - } - torch.prim.If %23 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %24 = torch.aten.__range_length %int2, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %24, %true, init() { - ^bb0(%arg7: !torch.int): - %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %37 = torch.aten.__getitem__.t %arg4, %36 : !torch.list, !torch.int -> !torch.int - %38 = torch.aten.mul.int %37, %int2 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.add.int %35, %38 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %41 = torch.aten.__getitem__.t %arg5, %40 : !torch.list, !torch.int -> !torch.int - %42 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.sub.int %42, %int1 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.mul.int %41, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.ge.int %39, %45 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %46 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %25 = torch.aten.len.t %arg5 : !torch.list -> !torch.int - %26 = torch.aten.gt.int %25, %int0 : !torch.int, !torch.int -> !torch.bool - %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %28 = torch.prim.ListConstruct : () -> !torch.list - %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %30 = torch.aten.append.t %28, %29 : !torch.list, !torch.int -> !torch.list - %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list - %33 = torch.aten.__range_length %int2, %27, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %33, %true, init() { - ^bb0(%arg7: !torch.int): - %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.prim.If %26 -> (!torch.int) { - %51 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %52 = torch.aten.__getitem__.t %arg5, %51 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %52 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %36 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int - %37 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int - %38 = torch.aten.mul.int %35, %37 : !torch.int, !torch.int -> !torch.int - %39 = torch.aten.add.int %38, %int1 : !torch.int, !torch.int -> !torch.int - %40 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %41 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %42 = torch.aten.__getitem__.t %arg4, %41 : !torch.list, !torch.int -> !torch.int - %43 = torch.aten.mul.int %42, %int2 : !torch.int, !torch.int -> !torch.int - %44 = torch.aten.add.int %40, %43 : !torch.int, !torch.int -> !torch.int - %45 = torch.aten.sub.int %44, %39 : !torch.int, !torch.int -> !torch.int - %46 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int - %47 = torch.aten.__getitem__.t %arg3, %46 : !torch.list, !torch.int -> !torch.int - %48 = torch.aten.floordiv.int %45, %47 : !torch.int, !torch.int -> !torch.int - %49 = torch.aten.add.int %48, %int1 : !torch.int, !torch.int -> !torch.int - %50 = torch.aten.append.t %28, %49 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %28 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.conv_backwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>) -> !torch.tuple, list, list> { - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { - ^bb0(%arg4: !torch.int): - %7 = torch.aten.__getitem__.t %arg1, %arg4 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.append.t %0, %7 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %2 = torch.prim.ListConstruct : () -> !torch.list - %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - torch.prim.Loop %3, %true, init() { - ^bb0(%arg4: !torch.int): - %7 = torch.aten.__getitem__.t %arg2, %arg4 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.append.t %2, %7 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list - %6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> - return %6 : !torch.tuple, list, list> - } - func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int - %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.prim.ListConstruct : () -> !torch.list - %4 = torch.prim.If %arg6 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %int0 : !torch.int - } - %5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.append.t %3, %5 : !torch.list, !torch.int -> !torch.list - %7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.append.t %3, %7 : !torch.list, !torch.int -> !torch.list - %9 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %9, %true, init() { - ^bb0(%arg9: !torch.int): - %10 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %11 = torch.prim.If %1 -> (!torch.int) { - %12 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.__getitem__.t %arg5, %12 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %13 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - torch.prim.If %arg6 -> () { - %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %18 = torch.aten.__getitem__.t %arg3, %17 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.__getitem__.t %arg4, %20 : !torch.list, !torch.int -> !torch.int - %22 = torch.aten.mul.int %21, %int2 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.sub.int %19, %22 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.add.int %23, %14 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int - %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int - %16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.mul.int %18, %int2 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int - %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %3 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.int, %arg7: !torch.optional>) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %0 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %15 : !torch.list - } else { - %15 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list - torch.prim.If.yield %15 : !torch.list - } - %2 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.list) { - %15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %15 : !torch.list - } else { - %15 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list - torch.prim.If.yield %15 : !torch.list - } - %4 = torch.aten.__is__ %arg7, %none : !torch.optional>, !torch.none -> !torch.bool - %5 = torch.prim.If %4 -> (!torch.list) { - %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %15 : !torch.list - } else { - %15 = torch.prim.unchecked_cast %arg7 : !torch.optional> -> !torch.list - torch.prim.If.yield %15 : !torch.list - } - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %9 = torch.prim.ListConstruct : () -> !torch.list - %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list - %12 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list - %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %14, %true, init() { - ^bb0(%arg8: !torch.int): - %15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %16 = torch.prim.If %7 -> (!torch.int) { - %32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int - %33 = torch.aten.__getitem__.t %5, %32 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %33 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.__getitem__.t %1, %22 : !torch.list, !torch.int -> !torch.int - %24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int - %26 = torch.aten.__getitem__.t %3, %25 : !torch.list, !torch.int -> !torch.int - %27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int - %28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int - %29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int - %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int - %31 = torch.aten.append.t %9, %30 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %9 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %0 : !torch.int - } - %3 = torch.aten.neg.int %2 : !torch.int -> !torch.int - %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %24 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %24 : !torch.bool - } - %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool - torch.prim.If %7 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.int) { - %24 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %24 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %11 = torch.aten.le.int %10, %int0 : !torch.int, !torch.int -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %10 : !torch.int - } - %13 = torch.aten.neg.int %12 : !torch.int -> !torch.int - %14 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.lt.int %arg2, %13 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %24 = torch.aten.gt.int %arg2, %14 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %24 : !torch.bool - } - %17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool - torch.prim.If %17 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %18 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.int) { - %24 = torch.aten.add.int %arg2, %12 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %24 : !torch.int - } else { - torch.prim.If.yield %arg2 : !torch.int - } - %20 = torch.aten.le.int %9, %19 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %20 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %21 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.list) { - %24 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - torch.prim.If.yield %24 : !torch.list - } else { - %24 = torch.aten.eq.int %9, %19 : !torch.int, !torch.int -> !torch.bool - %25 = torch.prim.If %24 -> (!torch.list) { - %26 = torch.prim.ListConstruct : () -> !torch.list - %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %27, %true, init() { - ^bb0(%arg3: !torch.int): - %28 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %29 = torch.aten.append.t %26, %28 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %26 : !torch.list - } else { - %26 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int - %27 = torch.aten.__range_length %9, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %28 = torch.prim.Loop %27, %true, init(%int1) { - ^bb0(%arg3: !torch.int, %arg4: !torch.int): - %34 = torch.aten.__derive_index %arg3, %9, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.mul.int %arg4, %35 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%36 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %29 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %9, %true, init() { - ^bb0(%arg3: !torch.int): - %34 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.append.t %29, %34 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %30 = torch.aten.append.t %29, %28 : !torch.list, !torch.int -> !torch.list - %31 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int - %32 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %33 = torch.aten.__range_length %31, %32, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %33, %true, init() { - ^bb0(%arg3: !torch.int): - %34 = torch.aten.__derive_index %arg3, %31, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.append.t %29, %35 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %29 : !torch.list - } - torch.prim.If.yield %25 : !torch.list - } - return %23 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.cat(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list { - %str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension" - %str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions" - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int - torch.prim.Loop %0, %true, init() { - ^bb0(%arg2: !torch.int): - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list - %14 = torch.aten.len.t %13 : !torch.list -> !torch.int - %15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int - %2 = torch.derefine %none : !torch.none to !torch.optional - %3 = torch.prim.Loop %1, %true, init(%2) { - ^bb0(%arg2: !torch.int, %arg3: !torch.optional): - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list - %14 = torch.aten.len.t %13 : !torch.list -> !torch.int - %15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.bool) { - %19 = torch.aten.__getitem__.t %13, %int0 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.optional) { - %19 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool - %20 = torch.prim.If %19 -> (!torch.int) { - %22 = torch.aten.len.t %13 : !torch.list -> !torch.int - %23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool - %24 = torch.prim.If %23 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %22 : !torch.int - } - %25 = torch.aten.neg.int %24 : !torch.int -> !torch.int - %26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int - %27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool - %28 = torch.prim.If %27 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %32 : !torch.bool - } - %29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool - torch.prim.If %29 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool - %31 = torch.prim.If %30 -> (!torch.int) { - %32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %32 : !torch.int - } else { - torch.prim.If.yield %arg1 : !torch.int - } - torch.prim.If.yield %31 : !torch.int - } else { - %22 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int - torch.prim.If.yield %22 : !torch.int - } - %21 = torch.derefine %20 : !torch.int to !torch.optional - torch.prim.If.yield %21 : !torch.optional - } else { - torch.prim.If.yield %arg3 : !torch.optional - } - torch.prim.Loop.condition %true, iter(%18 : !torch.optional) - } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional - %4 = torch.aten.__is__ %3, %none : !torch.optional, !torch.none -> !torch.bool - %5 = torch.prim.If %4 -> (!torch.int) { - torch.prim.If.yield %arg1 : !torch.int - } else { - %13 = torch.prim.unchecked_cast %3 : !torch.optional -> !torch.int - torch.prim.If.yield %13 : !torch.int - } - %6 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int - %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %7 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %8 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int - %9 = torch.derefine %none : !torch.none to !torch.optional> - %10 = torch.prim.Loop %8, %true, init(%9) { - ^bb0(%arg2: !torch.int, %arg3: !torch.optional>): - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list - %14 = torch.aten.len.t %13 : !torch.list -> !torch.int - %15 = torch.prim.Loop %14, %true, init(%int1) { - ^bb0(%arg4: !torch.int, %arg5: !torch.int): - %20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%21 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.bool) { - %20 = torch.aten.len.t %13 : !torch.list -> !torch.int - %21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %21 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.optional>) { - %20 = torch.derefine %13 : !torch.list to !torch.optional> - torch.prim.If.yield %20 : !torch.optional> - } else { - torch.prim.If.yield %arg3 : !torch.optional> - } - torch.prim.Loop.condition %true, iter(%19 : !torch.optional>) - } : (!torch.int, !torch.bool, !torch.optional>) -> !torch.optional> - %11 = torch.aten.__is__ %10, %none : !torch.optional>, !torch.none -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.list) { - %13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list - torch.prim.If.yield %13 : !torch.list - } else { - %13 = torch.prim.unchecked_cast %10 : !torch.optional> -> !torch.list - %14 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int - %15 = torch.prim.Loop %14, %true, init(%int0) { - ^bb0(%arg2: !torch.int, %arg3: !torch.int): - %19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list - %20 = torch.aten.len.t %19 : !torch.list -> !torch.int - %21 = torch.prim.Loop %20, %true, init(%int1) { - ^bb0(%arg4: !torch.int, %arg5: !torch.int): - %26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list, !torch.int -> !torch.int - %27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%27 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - %26 = torch.aten.len.t %19 : !torch.list -> !torch.int - %27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %27 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool - %25 = torch.prim.If %24 -> (!torch.int) { - %26 = torch.aten.len.t %13 : !torch.list -> !torch.int - %27 = torch.aten.len.t %19 : !torch.list -> !torch.int - %28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %28 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %29, %true, init() { - ^bb0(%arg4: !torch.int): - %32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %33 -> () { - %34 = torch.aten.__getitem__.t %13, %32 : !torch.list, !torch.int -> !torch.int - %35 = torch.aten.__getitem__.t %19, %32 : !torch.list, !torch.int -> !torch.int - %36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %36 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %30 = torch.aten.__getitem__.t %19, %5 : !torch.list, !torch.int -> !torch.int - %31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %31 : !torch.int - } else { - torch.prim.If.yield %arg3 : !torch.int - } - torch.prim.Loop.condition %true, iter(%25 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %16 = torch.prim.ListConstruct : () -> !torch.list - %17 = torch.aten.len.t %13 : !torch.list -> !torch.int - torch.prim.Loop %17, %true, init() { - ^bb0(%arg2: !torch.int): - %19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.append.t %16, %19 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %18 = torch.aten._set_item.t %16, %5, %15 : !torch.list, !torch.int, !torch.int -> !torch.list - torch.prim.If.yield %16 : !torch.list - } - return %12 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.check_cat_no_zero_dim(%arg0: !torch.list>) -> !torch.none { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int - torch.prim.Loop %0, %true, init() { - ^bb0(%arg1: !torch.int): - %1 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list - %2 = torch.aten.len.t %1 : !torch.list -> !torch.int - %3 = torch.aten.gt.int %2, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %none : !torch.none - } - func.func @__torch__.torch.jit._shape_functions.legacy_cat_wrap_dim(%arg0: !torch.int, %arg1: !torch.list>) -> !torch.int { - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int - %1 = torch.derefine %none : !torch.none to !torch.optional - %2 = torch.prim.Loop %0, %true, init(%1) { - ^bb0(%arg2: !torch.int, %arg3: !torch.optional): - %5 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>, !torch.int -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool - %8 = torch.prim.If %7 -> (!torch.bool) { - %11 = torch.aten.__getitem__.t %5, %int0 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %12 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.optional) { - %11 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.int) { - %14 = torch.aten.len.t %5 : !torch.list -> !torch.int - %15 = func.call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg0, %14, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int - torch.prim.If.yield %15 : !torch.int - } else { - %14 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int - torch.prim.If.yield %14 : !torch.int - } - %13 = torch.derefine %12 : !torch.int to !torch.optional - torch.prim.If.yield %13 : !torch.optional - } else { - torch.prim.If.yield %arg3 : !torch.optional - } - torch.prim.Loop.condition %true, iter(%10 : !torch.optional) - } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional - %3 = torch.aten.__is__ %2, %none : !torch.optional, !torch.none -> !torch.bool - %4 = torch.prim.If %3 -> (!torch.int) { - torch.prim.If.yield %arg0 : !torch.int - } else { - %5 = torch.prim.unchecked_cast %2 : !torch.optional -> !torch.int - torch.prim.If.yield %5 : !torch.int - } - return %4 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.should_skip(%arg0: !torch.list) -> !torch.bool { - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = call @__torch__.torch.jit._shape_functions.numel(%arg0) : (!torch.list) -> !torch.int - %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %4 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - return %2 : !torch.bool - } - func.func @__torch__.torch.jit._shape_functions.numel(%arg0: !torch.list) -> !torch.int { - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.Loop %0, %true, init(%int1) { - ^bb0(%arg1: !torch.int, %arg2: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%3 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - return %1 : !torch.int - } - func.func @__torch__.torch.jit._shape_functions.check_cat_shape_except_dim(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.none { - %str = torch.constant.str "AssertionError: Sizes of tensors must match except in dimension" - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %none = torch.constant.none - %str_0 = torch.constant.str "AssertionError: Tensors must have same number of dimensions" - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.__range_length %int0, %0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %3, %true, init() { - ^bb0(%arg4: !torch.int): - %4 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %5 = torch.aten.ne.int %4, %arg2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %5 -> () { - %6 = torch.aten.__getitem__.t %arg0, %4 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.eq.int %6, %7 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %none : !torch.none - } - func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %4 = torch.prim.ListConstruct : () -> !torch.list - %5 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %3, %true, init() { - ^bb0(%arg2: !torch.int): - %7 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.le.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %3 : !torch.int - } - %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int - %11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.lt.int %7, %10 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %20 = torch.aten.gt.int %7, %11 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } - %14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool - torch.prim.If %14 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %15 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - %20 = torch.aten.add.int %7, %9 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %7 : !torch.int - } - %17 = torch.aten.append.t %4, %16 : !torch.list, !torch.int -> !torch.list - %18 = torch.aten.__getitem__.t %arg0, %16 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.append.t %5, %18 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %6 = torch.aten.__range_length %int1, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %6, %true, init() { - ^bb0(%arg2: !torch.int): - %7 = torch.aten.__derive_index %arg2, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %7, %true, init() { - ^bb0(%arg3: !torch.int): - %8 = torch.aten.__getitem__.t %4, %7 : !torch.list, !torch.int -> !torch.int - %9 = torch.aten.__getitem__.t %4, %arg3 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.ne.int %8, %9 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %5 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: invalid shape" - %false = torch.constant.bool false - %str_0 = torch.constant.str "AssertionError: invalid shape dimensions" - %str_1 = torch.constant.str "AssertionError: only one dimension can be inferred" - %int-1 = torch.constant.int -1 - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.Loop %0, %true, init(%int1) { - ^bb0(%arg2: !torch.int, %arg3: !torch.int): - %12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%13 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %2 = torch.prim.Uninitialized : !torch.int - %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %4 = torch.derefine %none : !torch.none to !torch.optional - %5:2 = torch.prim.Loop %3, %true, init(%int1, %4) { - ^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional): - %12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool - %14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional) { - %15 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool - torch.prim.If %15 -> () { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %16 = torch.derefine %arg2 : !torch.int to !torch.optional - torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional - } else { - %15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.int) { - %18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %19 : !torch.int - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield %2 : !torch.int - } - torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional - } - torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional) - } : (!torch.int, !torch.bool, !torch.int, !torch.optional) -> (!torch.int, !torch.optional) - %6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %12 = torch.aten.__isnot__ %5#1, %none : !torch.optional, !torch.none -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.bool) { - %15 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int - %16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %16 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %14 = torch.prim.If %13 -> (!torch.bool) { - %15 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int - %16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %14 : !torch.bool - } - %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %9 = torch.prim.ListConstruct : () -> !torch.list - %10 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - torch.prim.Loop %10, %true, init() { - ^bb0(%arg2: !torch.int): - %12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional, !torch.none -> !torch.bool - torch.prim.If %11 -> () { - %12 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int - %13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten._set_item.t %9, %12, %13 : !torch.list, !torch.int, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - return %9 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.infer_size_impl(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %str = torch.constant.str "AssertionError: invalid shape" - %false = torch.constant.bool false - %str_0 = torch.constant.str "AssertionError: invalid shape dimensions" - %str_1 = torch.constant.str "AssertionError: only one dimension can be inferred" - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.prim.Uninitialized : !torch.int - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.derefine %none : !torch.none to !torch.optional - %3:2 = torch.prim.Loop %1, %true, init(%int1, %2) { - ^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional): - %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.eq.int %9, %int-1 : !torch.int, !torch.int -> !torch.bool - %11:2 = torch.prim.If %10 -> (!torch.int, !torch.optional) { - %12 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %13 = torch.derefine %arg2 : !torch.int to !torch.optional - torch.prim.If.yield %arg3, %13 : !torch.int, !torch.optional - } else { - %12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.ge.int %12, %int0 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.int) { - %15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.mul.int %arg3, %15 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %16 : !torch.int - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield %0 : !torch.int - } - torch.prim.If.yield %14, %arg4 : !torch.int, !torch.optional - } - torch.prim.Loop.condition %true, iter(%11#0, %11#1 : !torch.int, !torch.optional) - } : (!torch.int, !torch.bool, !torch.int, !torch.optional) -> (!torch.int, !torch.optional) - %4 = torch.aten.eq.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.bool - %5 = torch.prim.If %4 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %9 = torch.aten.__isnot__ %3#1, %none : !torch.optional, !torch.none -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.bool) { - %12 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int - %13 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %13 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %11 = torch.prim.If %10 -> (!torch.bool) { - %12 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int - %13 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %14 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %11 : !torch.bool - } - %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %7 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list - %8 = torch.aten.__isnot__ %3#1, %none : !torch.optional, !torch.none -> !torch.bool - torch.prim.If %8 -> () { - %9 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int - %10 = torch.aten.floordiv.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten._set_item.t %7, %9, %10 : !torch.list, !torch.int, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - return %7 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.list) { - %7 = torch.prim.ListConstruct : () -> !torch.list - %8 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - torch.prim.Loop %8, %true, init() { - ^bb0(%arg2: !torch.int): - %9 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.append.t %7, %9 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %3, %true, init() { - ^bb0(%arg2: !torch.int): - %8 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %9 = torch.aten.sub.int %8, %arg2 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - %20 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %14 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %15 = torch.aten.eq.int %14, %int-1 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - %20 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %20 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %13 : !torch.int - } else { - torch.prim.If.yield %14 : !torch.int - } - %17 = torch.aten.ne.int %13, %16 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.int) { - %20 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %20 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %16 : !torch.int - } else { - torch.prim.If.yield %13 : !torch.int - } - %19 = torch.aten.append.t %7, %18 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } - return %6 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.expand_one_unused(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %str = torch.constant.str "AssertionError: " - %none = torch.constant.none - %true = torch.constant.bool true - %int-1 = torch.constant.int -1 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.list) { - %7 = torch.prim.ListConstruct : () -> !torch.list - %8 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - torch.prim.Loop %8, %true, init() { - ^bb0(%arg3: !torch.int): - %9 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.append.t %7, %9 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } else { - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %3, %true, init() { - ^bb0(%arg3: !torch.int): - %8 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int - %9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - %20 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %14 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int - %15 = torch.aten.eq.int %14, %int-1 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - %20 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %20 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %13 : !torch.int - } else { - torch.prim.If.yield %14 : !torch.int - } - %17 = torch.aten.ne.int %13, %16 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.int) { - %20 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %20 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.If.yield %16 : !torch.int - } else { - torch.prim.If.yield %13 : !torch.int - } - %19 = torch.aten.append.t %7, %18 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } - return %6 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list { - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %6 = torch.aten.len.t %5 : !torch.list -> !torch.int - %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %7 : !torch.bool - } - %3 = torch.prim.If %2 -> (!torch.list) { - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %5, %true, init() { - ^bb0(%arg4: !torch.int): - %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %6 : !torch.list - } else { - %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - torch.prim.If.yield %5 : !torch.list - } - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %4, %true, init() { - ^bb0(%arg4: !torch.int): - %5 = torch.aten.len.t %3 : !torch.list -> !torch.int - %6 = torch.prim.Loop %5, %true, init(%false) { - ^bb0(%arg5: !torch.int, %arg6: !torch.bool): - %7 = torch.aten.__getitem__.t %3, %arg5 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %9 = torch.aten.le.int %8, %int0 : !torch.int, !torch.int -> !torch.bool - %10 = torch.prim.If %9 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %8 : !torch.int - } - %11 = torch.aten.neg.int %10 : !torch.int -> !torch.int - %12 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.lt.int %7, %11 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %20 = torch.aten.gt.int %7, %12 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } - %15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool - torch.prim.If %15 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %16 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.int) { - %20 = torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %7 : !torch.int - } - %18 = torch.aten.eq.int %arg4, %17 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg6 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%19 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If %arg2 -> () { - %7 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %7 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.append.t %0, %7 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> { - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %str = torch.constant.str "AssertionError: " - %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct : () -> !torch.list - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %2, %true, init() { - ^bb0(%arg3: !torch.int): - %4 = torch.prim.Loop %int1, %true, init(%false) { - ^bb0(%arg4: !torch.int, %arg5: !torch.bool): - %5 = torch.aten.__getitem__.t %0, %arg4 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %7 = torch.aten.le.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %8 = torch.prim.If %7 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %6 : !torch.int - } - %9 = torch.aten.neg.int %8 : !torch.int -> !torch.int - %10 = torch.aten.sub.int %8, %int1 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.lt.int %5, %9 : !torch.int, !torch.int -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %18 = torch.aten.gt.int %5, %10 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %18 : !torch.bool - } - %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool - torch.prim.If %13 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %14 = torch.aten.lt.int %5, %int0 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.int) { - %18 = torch.aten.add.int %5, %8 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %18 : !torch.int - } else { - torch.prim.If.yield %5 : !torch.int - } - %16 = torch.aten.eq.int %arg3, %15 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - torch.prim.If.yield %arg5 : !torch.bool - } - torch.prim.Loop.condition %true, iter(%17 : !torch.bool) - } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool - torch.prim.If %4 -> () { - torch.prim.If %arg2 -> () { - %5 = torch.aten.append.t %1, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.append.t %1, %5 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %3 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list> - return %3 : !torch.tuple, list> - } - func.func @__torch__.torch.jit._shape_functions.addmm(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.any, %arg4: !torch.any) -> !torch.list { - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %false = torch.constant.bool false - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %str_0 = torch.constant.str "AssertionError: self must be a matrix" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: mat2 must be a matrix" - %str_2 = torch.constant.str "AssertionError: " - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int - %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list - %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %11 = torch.prim.max.int %10, %int2 : !torch.int, !torch.int -> !torch.int - %12 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %11, %true, init() { - ^bb0(%arg5: !torch.int): - %13 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.sub.int %13, %arg5 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int - %16 = torch.aten.sub.int %15, %14 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.sub.int %int1, %14 : !torch.int, !torch.int -> !torch.int - %18 = torch.aten.ge.int %16, %int0 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.int) { - %28 = torch.aten.__getitem__.t %arg0, %16 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %28 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %20 = torch.aten.ge.int %17, %int0 : !torch.int, !torch.int -> !torch.bool - %21 = torch.prim.If %20 -> (!torch.int) { - %28 = torch.aten.__getitem__.t %9, %17 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %28 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %22 = torch.aten.ne.int %19, %21 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - %28 = torch.aten.ne.int %19, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %28 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %24 = torch.prim.If %23 -> (!torch.bool) { - %28 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %28 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %24 -> () { - %28 = torch.aten.format(%str, %19, %21, %arg5) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %29 = torch.aten.add.str %str_2, %28 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %29, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %25 = torch.aten.eq.int %19, %int1 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.int) { - torch.prim.If.yield %21 : !torch.int - } else { - torch.prim.If.yield %19 : !torch.int - } - %27 = torch.aten.append.t %12, %26 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %12 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.optional> { - %str = torch.constant.str "AssertionError: Either output_size or scale_factors must be presented" - %str_0 = torch.constant.str "AssertionError: " - %str_1 = torch.constant.str "AssertionError: Must specify exactly one of output_size and scale_factors" - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %2 = torch.aten.append.t %0, %1 : !torch.list, !torch.int -> !torch.list - %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %4 = torch.aten.append.t %0, %3 : !torch.list, !torch.int -> !torch.list - %5 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.optional>) { - %7 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %8 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - torch.prim.If %8 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %9 = torch.aten.len.t %7 : !torch.list -> !torch.int - %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %11 = torch.aten.__getitem__.t %7, %int0 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.append.t %0, %11 : !torch.list, !torch.int -> !torch.list - %13 = torch.aten.__getitem__.t %7, %int1 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.append.t %0, %13 : !torch.list, !torch.int -> !torch.list - %15 = torch.derefine %0 : !torch.list to !torch.optional> - torch.prim.If.yield %15 : !torch.optional> - } else { - %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %8 = torch.prim.If %7 -> (!torch.optional>) { - %9 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %10 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %11 = torch.aten.len.t %9 : !torch.list -> !torch.int - %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %12 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %13 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float - %15 = torch.operator "aten.mul.int_float"(%13, %14) : (!torch.int, !torch.float) -> !torch.float - %16 = torch.aten.Int.float %15 : !torch.float -> !torch.int - %17 = torch.aten.append.t %0, %16 : !torch.list, !torch.int -> !torch.list - %18 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int - %19 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float - %20 = torch.operator "aten.mul.int_float"(%18, %19) : (!torch.int, !torch.float) -> !torch.float - %21 = torch.aten.Int.float %20 : !torch.float -> !torch.int - %22 = torch.aten.append.t %0, %21 : !torch.list, !torch.int -> !torch.list - %23 = torch.derefine %0 : !torch.list to !torch.optional> - torch.prim.If.yield %23 : !torch.optional> - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - %9 = torch.derefine %none : !torch.none to !torch.optional> - torch.prim.If.yield %9 : !torch.optional> - } - torch.prim.If.yield %8 : !torch.optional> - } - return %6 : !torch.optional> - } - func.func @__torch__.torch.jit._shape_functions.argmax(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %str = torch.constant.str "AssertionError: " - %none = torch.constant.none - %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %2 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.If.yield %2 : !torch.list - } else { - %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int - %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %4 = torch.aten.le.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - %5 = torch.prim.If %4 -> (!torch.int) { - torch.prim.If.yield %int1 : !torch.int - } else { - torch.prim.If.yield %3 : !torch.int - } - %6 = torch.aten.neg.int %5 : !torch.int -> !torch.int - %7 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int - %8 = torch.aten.lt.int %2, %6 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %17 = torch.aten.gt.int %2, %7 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool - } - %10 = torch.aten.__not__ %9 : !torch.bool -> !torch.bool - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %11 = torch.aten.lt.int %2, %int0 : !torch.int, !torch.int -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.int) { - %17 = torch.aten.add.int %2, %5 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %17 : !torch.int - } else { - torch.prim.If.yield %2 : !torch.int - } - %13 = torch.prim.ListConstruct : () -> !torch.list - %14 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %15 = torch.prim.ListConstruct %int9223372036854775807, %14 : (!torch.int, !torch.int) -> !torch.list - %16 = torch.prim.min.self_int %15 : !torch.list -> !torch.int - torch.prim.Loop %16, %true, init() { - ^bb0(%arg3: !torch.int): - %17 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.eq.int %arg3, %12 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %18 -> () { - torch.prim.If %arg2 -> () { - %19 = torch.aten.append.t %13, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %19 = torch.aten.append.t %13, %17 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %13 : !torch.list - } - return %1 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int - %2 = torch.prim.ListConstruct : () -> !torch.list - %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list - %5 = torch.prim.min.self_int %4 : !torch.list -> !torch.int - torch.prim.Loop %5, %true, init() { - ^bb0(%arg3: !torch.int): - %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %7 -> () { - torch.prim.If %arg2 -> () { - %8 = torch.aten.append.t %2, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %8 = torch.aten.append.t %2, %6 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %2 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.bmm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: mismatching contracting dimension" - %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: bmm only supports 3D tensors" - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %9 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int - %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list - return %13 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions._shape_as_tensor(%arg0: !torch.list) -> !torch.list { - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list - return %1 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.topk(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.tuple, list> { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "k ({}) is too big for dimension {} of size {}" - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.list) { - %4 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.If.yield %4 : !torch.list - } else { - %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.le.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %5 -> () { - torch.prim.If.yield - } else { - %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.format(%str_0, %arg1, %arg2, %9) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %11 = torch.aten.add.str %str, %10 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %11, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %6 = torch.prim.ListConstruct : () -> !torch.list - %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %7, %true, init() { - ^bb0(%arg3: !torch.int): - %9 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %8 = torch.aten._set_item.t %6, %arg2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list - torch.prim.If.yield %6 : !torch.list - } - %3 = torch.prim.TupleConstruct %2, %2 : !torch.list, !torch.list -> !torch.tuple, list> - return %3 : !torch.tuple, list> - } - func.func @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int) -> !torch.tuple, list> { - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.bool) { - %16 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %16 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %4 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.bool) { - %16 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %16 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.eq.int %16, %17 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %18 : !torch.bool - } - torch.prim.If %7 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int - %9 = torch.prim.ListConstruct : () -> !torch.list - %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %16 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %17 = torch.aten.len.t %16 : !torch.list -> !torch.int - %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - %20 = torch.aten.__getitem__.t %16, %int0 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.eq.int %20, %8 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %21 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %19 : !torch.bool - } - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.bool) { - %16 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %16 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %14 = torch.prim.If %13 -> (!torch.list) { - %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %17 = torch.prim.ListConstruct %16 : (!torch.int) -> !torch.list - torch.prim.If.yield %17 : !torch.list - } else { - torch.prim.If.yield %9 : !torch.list - } - %15 = torch.prim.TupleConstruct %14, %9 : !torch.list, !torch.list -> !torch.tuple, list> - return %15 : !torch.tuple, list> - } - func.func @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list, list> { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int - %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %4 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop %3, %true, init() { - ^bb0(%arg2: !torch.int): - %10 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.append.t %0, %10 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %6, %true, init() { - ^bb0(%arg2: !torch.int): - %10 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %7 = torch.prim.ListConstruct : () -> !torch.list - %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %8, %true, init() { - ^bb0(%arg2: !torch.int): - %10 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.append.t %7, %10 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %9 = torch.prim.TupleConstruct %7, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> - return %9 : !torch.tuple, list, list> - } - func.func @__torch__.torch.jit._shape_functions.native_batch_norm(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool) -> !torch.tuple, list, list> { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.prim.If %arg5 -> (!torch.list) { - %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list - torch.prim.If.yield %5 : !torch.list - } else { - %4 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list - torch.prim.If.yield %4 : !torch.list - } - %1 = torch.prim.ListConstruct : () -> !torch.list - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %2, %true, init() { - ^bb0(%arg6: !torch.int): - %4 = torch.aten.__getitem__.t %arg0, %arg6 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.append.t %1, %4 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %3 = torch.prim.TupleConstruct %1, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> - return %3 : !torch.tuple, list, list> - } - func.func @__torch__.torch.jit._shape_functions.broadcast_three(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %false = torch.constant.bool false - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_0 = torch.constant.str "AssertionError: " - %none = torch.constant.none - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int - %3 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %2, %true, init() { - ^bb0(%arg3: !torch.int): - %8 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int - %9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.int) { - %24 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %24 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.int) { - %24 = torch.aten.__getitem__.t %arg1, %13 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %24 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - %24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %24 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %20 = torch.prim.If %19 -> (!torch.bool) { - %24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %24 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %20 -> () { - %24 = torch.aten.format(%str, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %25 = torch.aten.add.str %str_0, %24 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %25, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool - %22 = torch.prim.If %21 -> (!torch.int) { - torch.prim.If.yield %17 : !torch.int - } else { - torch.prim.If.yield %15 : !torch.int - } - %23 = torch.aten.append.t %3, %22 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %4 = torch.aten.len.t %3 : !torch.list -> !torch.int - %5 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %6 = torch.prim.max.int %4, %5 : !torch.int, !torch.int -> !torch.int - %7 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %6, %true, init() { - ^bb0(%arg3: !torch.int): - %8 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int - %9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.int) { - %24 = torch.aten.__getitem__.t %3, %11 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %24 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.int) { - %24 = torch.aten.__getitem__.t %arg2, %13 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %24 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - %24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %24 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %20 = torch.prim.If %19 -> (!torch.bool) { - %24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %24 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %20 -> () { - %24 = torch.aten.format(%str, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %25 = torch.aten.add.str %str_0, %24 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %25, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool - %22 = torch.prim.If %21 -> (!torch.int) { - torch.prim.If.yield %17 : !torch.int - } else { - torch.prim.If.yield %15 : !torch.int - } - %23 = torch.aten.append.t %7, %22 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %7 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.broadcast_one_three(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %false = torch.constant.bool false - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %str_0 = torch.constant.str "AssertionError: " - %none = torch.constant.none - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int - %3 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %2, %true, init() { - ^bb0(%arg3: !torch.int): - %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.sub.int %4, %arg3 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int - %7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int - %8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int - %9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.int) { - %20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - %20 = torch.aten.__getitem__.t %arg2, %9 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %20 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.bool) { - %20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %16 = torch.prim.If %15 -> (!torch.bool) { - %20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %16 -> () { - %20 = torch.aten.format(%str, %11, %13, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %21 = torch.aten.add.str %str_0, %20 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %21, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - %17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.int) { - torch.prim.If.yield %13 : !torch.int - } else { - torch.prim.If.yield %11 : !torch.int - } - %19 = torch.aten.append.t %3, %18 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %3 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.broadcast_inplace(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" - %false = torch.constant.bool false - %true = torch.constant.bool true - %none = torch.constant.none - %str_0 = torch.constant.str "AssertionError: " - %str_1 = torch.constant.str "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.aten.gt.int %1, %0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - %5 = torch.aten.format(%str_1, %1, %0) : !torch.str, !torch.int, !torch.int -> !torch.str - %6 = torch.aten.add.str %str_0, %5 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %6, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.Loop %0, %true, init() { - ^bb0(%arg2: !torch.int): - %5 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int - %6 = torch.aten.add.int %5, %arg2 : !torch.int, !torch.int -> !torch.int - %7 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.ge.int %6, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.int) { - %12 = torch.aten.__getitem__.t %arg1, %6 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %12 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %10 = torch.aten.ne.int %7, %9 : !torch.int, !torch.int -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { - %12 = torch.aten.ne.int %9, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %12 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %11 -> () { - %12 = torch.aten.format(%str, %7, %9, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %13 = torch.aten.add.str %str_0, %12 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %13, %none : !torch.str, !torch.none - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %3 = torch.prim.ListConstruct : () -> !torch.list - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %4, %true, init() { - ^bb0(%arg2: !torch.int): - %5 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.append.t %3, %5 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %3 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.nonzero_lower_bound(%arg0: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.ListConstruct %int0, %0 : (!torch.int, !torch.int) -> !torch.list - return %1 : !torch.list - } - func.func @__torch__.torch.jit._shape_functions.nonzero_upper_bound(%arg0: !torch.list) -> !torch.list { - %int1 = torch.constant.int 1 - %true = torch.constant.bool true - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.Loop %0, %true, init(%int1) { - ^bb0(%arg1: !torch.int, %arg2: !torch.int): - %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.mul.int %arg2, %4 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop.condition %true, iter(%5 : !torch.int) - } : (!torch.int, !torch.bool, !torch.int) -> !torch.int - %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list - return %3 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.triu"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.tanh"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.erf"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sigmoid"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.hardsigmoid"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.softplus"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.square"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.hardswish"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.silu"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.exp"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.expm1"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.cos"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.hardtanh"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sqrt"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.neg"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.floor"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.detach"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.log2"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.log1p"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.rsqrt"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.abs"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.reciprocal"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.tanh_backward"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.gelu_backward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.ceil"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.log"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.relu"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._softmax"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.softmax.int"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._log_softmax"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.log_softmax.int"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.clamp"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.clamp_min"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.clamp_max"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.rsub.Scalar"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.to.dtype_layout"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.to.device"(%arg0: !torch.list, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.to.other"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.type_as"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.dropout"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.gelu"(%arg0: !torch.list, %arg1: !torch.str) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.contiguous"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.clone"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._log_softmax_backward_data"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.eq.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.ne.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.gt.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.ge.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.le.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.lt.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.add.Scalar"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sub.Scalar"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.mul.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.div.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.floor_divide.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.gather"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.layer_norm"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._softmax_backward_data"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.any"(%arg0: !torch.list) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.all"(%arg0: !torch.list) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.max"(%arg0: !torch.list) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sum"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.mean"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.var"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { - %none = torch.constant.none - %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { - %none = torch.constant.none - %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { - %none = torch.constant.none - %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list { - %none = torch.constant.none - %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %2 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.If.yield %2 : !torch.list - } else { - %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int - %3 = func.call @__torch__._reduce_along_dim(%arg0, %2, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list - torch.prim.If.yield %3 : !torch.list - } - return %1 : !torch.list - } - func.func @__torch__._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %true = torch.constant.bool true - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int - %2 = torch.prim.ListConstruct : () -> !torch.list - %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list - %5 = torch.prim.min.self_int %4 : !torch.list -> !torch.int - torch.prim.Loop %5, %true, init() { - ^bb0(%arg3: !torch.int): - %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %7 -> () { - torch.prim.If %arg2 -> () { - %8 = torch.aten.append.t %2, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } else { - torch.prim.If.yield - } - torch.prim.If.yield - } else { - %8 = torch.aten.append.t %2, %6 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield - } - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %2 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.any.dim"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list { - %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.max.dim"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> { - %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list - %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list> - return %1 : !torch.tuple, list> - } - func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg3 : !torch.optional to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg3 : !torch.optional to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.transpose.int"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.t"(%arg0: !torch.list) -> !torch.list { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list, !torch.int, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.numpy_T"(%arg0: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - torch.prim.Loop %1, %true, init() { - ^bb0(%arg1: !torch.int): - %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int - torch.aten.insert.t %0, %int0, %2 : !torch.list, !torch.int, !torch.int - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.matmul"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.mm"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.addmm"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list { - %0 = torch.derefine %arg3 : !torch.float to !torch.any - %1 = torch.derefine %arg4 : !torch.float to !torch.any - %2 = call @__torch__.torch.jit._shape_functions.addmm(%arg0, %arg1, %arg2, %0, %1) : (!torch.list, !torch.list, !torch.list, !torch.any, !torch.any) -> !torch.list - return %2 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.bmm"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %str = torch.constant.str "AssertionError: mismatching contracting dimension" - %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: bmm only supports 3D tensors" - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %9 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int - %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list - return %13 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.baddbmm"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list { - %str = torch.constant.str "AssertionError: mismatching contracting dimension" - %str_0 = torch.constant.str "AssertionError: mismatching batch dimension" - %none = torch.constant.none - %str_1 = torch.constant.str "AssertionError: baddbmm only supports 3D tensors" - %int3 = torch.constant.int 3 - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int - %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int - %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int - %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %9 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %11 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int - %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list - return %13 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.embedding"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.embedding(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.int, !torch.bool, !torch.bool) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.repeat"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.list) { - %7 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list - torch.prim.If.yield %7 : !torch.list - } else { - %7 = torch.prim.ListConstruct : () -> !torch.list - %8 = torch.aten.sub.int %3, %4 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop %8, %true, init() { - ^bb0(%arg2: !torch.int): - %9 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.append.t %7, %9 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.Loop %4, %true, init() { - ^bb0(%arg2: !torch.int): - %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.add.int %arg2, %8 : !torch.int, !torch.int -> !torch.int - %11 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.mul.int %9, %11 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.append.t %7, %12 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %7 : !torch.list - } - return %6 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.roll"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.expand"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.expand_as"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.broadcast_to"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.view"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.reshape"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._reshape_alias"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._unsafe_view"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - return %arg1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.resize_"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.list { - return %arg1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.max_pool2d"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> { - %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list - %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list> - return %1 : !torch.tuple, list> - } - func.func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list) -> !torch.list { - return %arg1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.avg_pool2d"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list { - %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list - return %0 : !torch.list - } - func.func @__torch__.avg_pool2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list { - %int-1 = torch.constant.int -1 - %int-2 = torch.constant.int -2 - %int-3 = torch.constant.int -3 - %int-4 = torch.constant.int -4 - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "AssertionError: avg_pool2d: padding must be either be a single int, or a tuple of two ints" - %str_1 = torch.constant.str "AssertionError: avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints" - %none = torch.constant.none - %str_2 = torch.constant.str "AssertionError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints" - %true = torch.constant.bool true - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool - %2 = torch.prim.If %1 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %39 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.int) { - torch.prim.If.yield %3 : !torch.int - } else { - %39 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %39 : !torch.int - } - %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %39 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } - %10 = torch.prim.If %9 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %39 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } - torch.prim.If %10 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.int) { - torch.prim.If.yield %3 : !torch.int - } else { - %39 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %39 : !torch.int - } - %14 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool - %16 = torch.prim.If %15 -> (!torch.int) { - torch.prim.If.yield %6 : !torch.int - } else { - %39 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool - %41 = torch.prim.If %40 -> (!torch.int) { - torch.prim.If.yield %13 : !torch.int - } else { - %42 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %42 : !torch.int - } - torch.prim.If.yield %41 : !torch.int - } - %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool - %19 = torch.prim.If %18 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %39 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } - torch.prim.If %19 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.len.t %arg3 : !torch.list -> !torch.int - %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.int) { - torch.prim.If.yield %20 : !torch.int - } else { - %39 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %39 : !torch.int - } - %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool - %26 = torch.prim.If %25 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %39 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %40 = torch.aten.eq.int %39, %int4 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %40 : !torch.bool - } - torch.prim.If %26 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %28 = torch.aten.eq.int %27, %int4 : !torch.int, !torch.int -> !torch.bool - %29 = torch.prim.If %28 -> (!torch.int) { - %39 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %39 : !torch.int - } else { - torch.prim.If.yield %int1 : !torch.int - } - %30 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int - %31 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int - %32 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int - %33 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%31, %3, %20, %13, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int - %34 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%32, %6, %23, %16, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int - %35 = call @__torch__.torch.jit._shape_functions.pool2d_shape_check(%arg0, %3, %6, %13, %16, %20, %23, %int1, %int1, %30, %31, %32, %33, %34) : (!torch.list, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none - %36 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %37 = torch.aten.eq.int %36, %int3 : !torch.int, !torch.int -> !torch.bool - %38 = torch.prim.If %37 -> (!torch.list) { - %39 = torch.prim.ListConstruct %30, %33, %34 : (!torch.int, !torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %39 : !torch.list - } else { - %39 = torch.prim.ListConstruct %29, %30, %33, %34 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - torch.prim.If.yield %39 : !torch.list - } - return %38 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.flatten.using_ints"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.linear"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.zeros"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.ones"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.empty.memory_format"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.full"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.full_like"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.zeros_like"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.ones_like"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.empty_like"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.new_zeros"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - return %arg1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.new_ones"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - return %arg1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.new_empty"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - return %arg1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._to_copy"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.masked_fill.Scalar"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.masked_fill.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.fill.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.copy"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.uniform"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.bernoulli.float"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.bernoulli.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.index_put_impl"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list, %arg1: !torch.any) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.rand_like"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.arange.start_step"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg0 : !torch.float to !torch.union - %1 = torch.derefine %arg1 : !torch.float to !torch.union - %2 = torch.derefine %arg2 : !torch.float to !torch.union - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = torch.derefine %arg4 : !torch.optional to !torch.any - %5 = torch.derefine %arg5 : !torch.optional to !torch.any - %6 = torch.derefine %arg6 : !torch.optional to !torch.any - %7 = call @__torch__.torch.jit._shape_functions.arange_start_step(%0, %1, %2, %3, %4, %5, %6) : (!torch.union, !torch.union, !torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list - return %7 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.arange.start"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg0 : !torch.float to !torch.union - %1 = torch.derefine %arg1 : !torch.float to !torch.union - %2 = torch.derefine %arg2 : !torch.optional to !torch.any - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = torch.derefine %arg4 : !torch.optional to !torch.any - %5 = torch.derefine %arg5 : !torch.optional to !torch.any - %6 = call @__torch__.torch.jit._shape_functions.arange_start(%0, %1, %2, %3, %4, %5) : (!torch.union, !torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list - return %6 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.arange"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg0 : !torch.float to !torch.union - %1 = torch.derefine %arg1 : !torch.optional to !torch.any - %2 = torch.derefine %arg2 : !torch.optional to !torch.any - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = torch.derefine %arg4 : !torch.optional to !torch.any - %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list - return %5 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.add.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.sub.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.mul.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.div.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.div.Tensor_mode"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.floor_divide"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.minimum"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.maximum"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.bitwise_and.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.threshold"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.threshold_backward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.eq.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.gt.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.lt.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.unsqueeze"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.squeeze"(%arg0: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.squeeze_nodim(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.squeeze.dim"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.prim.NumToTensor.Scalar"(%arg0: !torch.float) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.tensor.float"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.tensor.int"(%arg0: !torch.int, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.tensor.bool"(%arg0: !torch.bool, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { - %0 = torch.prim.ListConstruct : () -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._shape_as_tensor"(%arg0: !torch.list) -> !torch.list { - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.where.self"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list - %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.where.Scalar"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.where.ScalarOther"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.where.ScalarSelf"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.lerp.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list - %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.addcmul"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list - %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.addcdiv"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list - %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list - return %1 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.topk"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %str_0 = torch.constant.str "k ({}) is too big for dimension {} of size {}" - %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str - %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str - torch.prim.RaiseException %6, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten._set_item.t %arg0, %arg2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list - %3 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list> - return %3 : !torch.tuple, list> - } - func.func @"__torch_mlir_shape_fn.aten.conv2d"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.conv_transpose2d.input"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list { - %0 = torch.derefine %arg3 : !torch.list to !torch.optional> - %1 = torch.derefine %arg4 : !torch.list to !torch.optional> - %2 = torch.derefine %arg5 : !torch.list to !torch.optional> - %3 = torch.derefine %arg7 : !torch.list to !torch.optional> - %4 = call @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0, %arg1, %arg2, %0, %1, %2, %arg6, %3) : (!torch.list, !torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.int, !torch.optional>) -> !torch.list - return %4 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._convolution"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list { - %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list { - %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.batch_norm"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.slice.Tensor"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.narrow"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list { - %int1 = torch.constant.int 1 - %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int - %1 = torch.derefine %arg2 : !torch.int to !torch.optional - %2 = torch.derefine %0 : !torch.int to !torch.optional - %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list - return %3 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.slice_scatter"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.select.int"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.select_scatter"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list { - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.index_select"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.index_put"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.index_put.hacked_twin"(%arg0: !torch.list, %arg1: !torch.list>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple, list, list, list> { - %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list> - return %0 : !torch.tuple, list, list, list> - } - func.func @__torch__._embedding_bag_helper(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int) -> !torch.tuple, list, list, list> { - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %1 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.len.t %arg2 : !torch.list -> !torch.int - %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %5 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %6 = torch.prim.ListConstruct : () -> !torch.list - %7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int - %8 = torch.prim.If %arg3 -> (!torch.int) { - %19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int - torch.prim.If.yield %19 : !torch.int - } else { - torch.prim.If.yield %7 : !torch.int - } - %9 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %10 = torch.aten.append.t %6, %8 : !torch.list, !torch.int -> !torch.list - %11 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list - %12 = torch.prim.ListConstruct : () -> !torch.list - %13 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool - %14 = torch.prim.If %13 -> (!torch.list) { - %19 = torch.aten.append.t %12, %int0 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield %12 : !torch.list - } else { - %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list - torch.prim.If.yield %19 : !torch.list - } - %15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list - %16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool - %17 = torch.prim.If %16 -> (!torch.list) { - %19 = func.call @__torch__.torch.jit._shape_functions._copy(%6) : (!torch.list) -> !torch.list - torch.prim.If.yield %19 : !torch.list - } else { - %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list - torch.prim.If.yield %19 : !torch.list - } - %18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list, !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list, list> - return %18 : !torch.tuple, list, list, list> - } - func.func @"__torch_mlir_shape_fn.aten._embedding_bag"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple, list, list, list> { - %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list> - return %0 : !torch.tuple, list, list, list> - } - func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> { - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %false = torch.constant.bool false - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool - %3 = torch.prim.If %2 -> (!torch.bool) { - %15 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %15 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If %3 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %4 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool - %6 = torch.prim.If %5 -> (!torch.bool) { - %15 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %15 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %7 = torch.prim.If %6 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %17 : !torch.bool - } - torch.prim.If %7 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int - %9 = torch.prim.ListConstruct : () -> !torch.list - %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool - %11 = torch.prim.If %10 -> (!torch.bool) { - torch.prim.If.yield %true : !torch.bool - } else { - %15 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list - %16 = torch.aten.len.t %15 : !torch.list -> !torch.int - %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool - %18 = torch.prim.If %17 -> (!torch.bool) { - %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list, !torch.int -> !torch.int - %20 = torch.aten.eq.int %19, %8 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %20 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - torch.prim.If.yield %18 : !torch.bool - } - torch.prim.If %11 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool - %13 = torch.prim.If %12 -> (!torch.bool) { - %15 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool - torch.prim.If.yield %15 : !torch.bool - } else { - torch.prim.If.yield %false : !torch.bool - } - %14 = torch.prim.If %13 -> (!torch.tuple, list>) { - %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list - %17 = torch.prim.TupleConstruct %16, %9 : !torch.list, !torch.list -> !torch.tuple, list> - torch.prim.If.yield %17 : !torch.tuple, list> - } else { - %15 = torch.prim.TupleConstruct %9, %9 : !torch.list, !torch.list -> !torch.tuple, list> - torch.prim.If.yield %15 : !torch.tuple, list> - } - return %14 : !torch.tuple, list> - } - func.func @"__torch_mlir_shape_fn.aten.nll_loss_backward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> { - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: " - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int - %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %4 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - torch.prim.Loop %3, %true, init() { - ^bb0(%arg5: !torch.int): - %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int - %9 = torch.aten.append.t %0, %8 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %6, %true, init() { - ^bb0(%arg5: !torch.int): - %8 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> - return %7 : !torch.tuple, list, list> - } - func.func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.prim.If %arg5 -> (!torch.tuple, list, list>) { - %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list - %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int - %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list - %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> - torch.prim.If.yield %5 : !torch.tuple, list, list> - } else { - %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list - %3 = torch.prim.TupleConstruct %arg0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list> - torch.prim.If.yield %3 : !torch.tuple, list, list> - } - return %0 : !torch.tuple, list, list> - } - func.func @"__torch_mlir_shape_fn.aten.constant_pad_nd"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list { - %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %true = torch.constant.bool true - %str = torch.constant.str "AssertionError: Number of padded dimensions must be less than or equal to the input dimension" - %none = torch.constant.none - %str_0 = torch.constant.str "AssertionError: Must have paired low-high pad amount values" - %int2 = torch.constant.int 2 - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int - %2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %6 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int - torch.prim.Loop %8, %true, init() { - ^bb0(%arg2: !torch.int): - %9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int - %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int - %11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int - %13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list, !torch.int, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - return %arg0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.pad"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list { - %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list { - %false = torch.constant.bool false - %int-1 = torch.constant.int -1 - %true = torch.constant.bool true - %none = torch.constant.none - %str = torch.constant.str "AssertionError: More indices than dimensions to index" - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %0 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int - %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %2 = torch.aten.le.int %0, %1 : !torch.int, !torch.int -> !torch.bool - torch.prim.If %2 -> () { - torch.prim.If.yield - } else { - torch.prim.RaiseException %str, %none : !torch.str, !torch.none - torch.prim.If.yield - } - %3 = torch.prim.ListConstruct : () -> !torch.list - %4 = torch.prim.ListConstruct : () -> !torch.list - %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int - %6 = torch.prim.Loop %5, %true, init(%3) { - ^bb0(%arg2: !torch.int, %arg3: !torch.list): - %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int - %11 = torch.aten.ge.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool - %12 = torch.prim.If %11 -> (!torch.list) { - %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %14 = torch.aten.append.t %4, %13 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield %arg3 : !torch.list - } else { - %13 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional> - %14 = torch.aten.__isnot__ %13, %none : !torch.optional>, !torch.none -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.list) { - %16 = torch.prim.unchecked_cast %13 : !torch.optional> -> !torch.list - %17 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg3, %16) : (!torch.list, !torch.list) -> !torch.list - torch.prim.If.yield %17 : !torch.list - } else { - %16 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int - %17 = torch.aten.append.t %4, %16 : !torch.list, !torch.int -> !torch.list - torch.prim.If.yield %arg3 : !torch.list - } - torch.prim.If.yield %15 : !torch.list - } - torch.prim.Loop.condition %true, iter(%12 : !torch.list) - } : (!torch.int, !torch.bool, !torch.list) -> !torch.list - %7 = torch.aten.len.t %4 : !torch.list -> !torch.int - %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - %9 = torch.prim.If %8 -> (!torch.list) { - torch.prim.If.yield %6 : !torch.list - } else { - %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int - %11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list - %12 = torch.prim.min.self_int %11 : !torch.list -> !torch.int - %13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) { - ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int): - %16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional> - %17 = torch.aten.__isnot__ %16, %none : !torch.optional>, !torch.none -> !torch.bool - %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) { - %19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool - %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) { - torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int - } else { - %21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool - %23 = torch.prim.If %22 -> (!torch.bool) { - torch.prim.If.yield %false : !torch.bool - } else { - torch.prim.If.yield %arg3 : !torch.bool - } - torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int - } - torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int - } else { - torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int - } - torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int) - } : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int) - %14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool - %15 = torch.prim.If %14 -> (!torch.list) { - %16 = torch.aten.add.t %6, %4 : !torch.list, !torch.list -> !torch.list - torch.prim.If.yield %16 : !torch.list - } else { - %16 = torch.prim.ListConstruct : () -> !torch.list - torch.prim.Loop %13#1, %true, init() { - ^bb0(%arg2: !torch.int): - %20 = torch.aten.__getitem__.t %4, %arg2 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.append.t %16, %20 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %17 = torch.aten.len.t %6 : !torch.list -> !torch.int - torch.prim.Loop %17, %true, init() { - ^bb0(%arg2: !torch.int): - %20 = torch.aten.__getitem__.t %6, %arg2 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.append.t %16, %20 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - %18 = torch.aten.len.t %4 : !torch.list -> !torch.int - %19 = torch.aten.__range_length %13#1, %18, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %19, %true, init() { - ^bb0(%arg2: !torch.int): - %20 = torch.aten.__derive_index %arg2, %13#1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %21 = torch.aten.__getitem__.t %4, %20 : !torch.list, !torch.int -> !torch.int - %22 = torch.aten.append.t %16, %21 : !torch.list, !torch.int -> !torch.list - torch.prim.Loop.condition %true, iter() - } : (!torch.int, !torch.bool) -> () - torch.prim.If.yield %16 : !torch.list - } - torch.prim.If.yield %15 : !torch.list - } - return %9 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.cat"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list - return %0 : !torch.list - } - func.func @"__torch_mlir_shape_fn.aten.bincount"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list { - %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int - %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list - return %1 : !torch.list - } - func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int { - %0 = torch.prim.CreateObject !torch.nn.Module<"__torch__.DummyClassType"> - %1 = torch.prim.CallMethod %0["__init__"] () : !torch.nn.Module<"__torch__.DummyClassType">, () -> !torch.none - %2 = torch.operator "prim.id"(%0) : (!torch.nn.Module<"__torch__.DummyClassType">) -> !torch.int - return %2 : !torch.int - } - func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<"__torch__.DummyClassType">) -> !torch.none { - %none = torch.constant.none - return %none : !torch.none - } - func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg4 : !torch.optional to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list - } -} -)mlir"); +#endif + // clang-format off + return "module {\n" +" func.func @__torch__.torch._decomp.decompositions.nll_loss_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" +" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" +" %str = torch.constant.str \"Expected a single element grad_output tensor, but got: {}\"\n" +" %str_0 = torch.constant.str \"Expected a tensor of dimension 1 and tensor.size[0] == {} but got: dimension {} and tensor.size[0] == {}\"\n" +" %str_1 = torch.constant.str \"AssertionError: weight tensor should be defined either for all or no classes\"\n" +" %int-1 = torch.constant.int -1\n" +" %str_2 = torch.constant.str \"{} ({} elements)\"\n" +" %str_3 = torch.constant.str \"expected total_weight to be a single element tensor, got: \"\n" +" %str_4 = torch.constant.str \"AssertionError: \"\n" +" %str_5 = torch.constant.str \"size mismatch (got input: {}, target: {})\"\n" +" %true = torch.constant.bool true\n" +" %str_6 = torch.constant.str \"AssertionError: 0D or 1D target tensor expected, multi-target not supported\"\n" +" %none = torch.constant.none\n" +" %str_7 = torch.constant.str \"AssertionError: input tensor should be 1D or 2D\"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.Uninitialized : !torch.optional\n" +" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %2 = torch.aten.le.int %int0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %35 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %36 = torch.aten.le.int %35, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %36 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_7, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" +" %5 = torch.aten.le.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_6, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %35 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %36 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %35 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %36 = torch.aten.size.int %arg2, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %37 = torch.aten.eq.int %35, %36 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %37 : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %35 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" +" %36 = torch.aten.size %arg2 : !torch.tensor -> !torch.list\n" +" %37 = torch.aten.format(%str_5, %35, %36) : !torch.str, !torch.list, !torch.list -> !torch.str\n" +" %38 = torch.aten.add.str %str_4, %37 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %38, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %35 = torch.aten.size %arg6 : !torch.tensor -> !torch.list\n" +" %36 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" +" %37 = torch.aten.format(%str_2, %35, %36) : !torch.str, !torch.list, !torch.int -> !torch.str\n" +" %38 = torch.prim.TupleConstruct %str_3, %37 : !torch.str, !torch.str -> !torch.tuple\n" +" %39 = torch.aten.str %38 : !torch.tuple -> !torch.str\n" +" %40 = torch.aten.add.str %str_4, %39 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %40, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %35 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" +" %36 = torch.aten.numel %35 : !torch.tensor -> !torch.int\n" +" %37 = torch.aten.size.int %arg1, %int-1 : !torch.tensor, !torch.int -> !torch.int\n" +" %38 = torch.aten.eq.int %36, %37 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %38 : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.optional) {\n" +" torch.prim.If.yield %arg3 : !torch.optional\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.optional\n" +" }\n" +" %15 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %35 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %36 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" %35 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.bool) {\n" +" %38 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %39 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %37 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %38 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %39 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" +" %40 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %41 = torch.aten.format(%str_0, %38, %39, %40) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %42 = torch.aten.add.str %str_4, %41 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %42, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %35 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" +" %36 = torch.aten.le.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.bool) {\n" +" %38 = torch.aten.numel %arg0 : !torch.tensor -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %37 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %38 = torch.aten.size %arg0 : !torch.tensor -> !torch.list\n" +" %39 = torch.aten.format(%str, %38) : !torch.str, !torch.list -> !torch.str\n" +" %40 = torch.aten.add.str %str_4, %39 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %40, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %18 = torch.aten.lt.int %17, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %20 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %21 = torch.prim.If %20 -> (!torch.tensor) {\n" +" %35 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %35 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.tensor\n" +" }\n" +" %22 = torch.aten.unsqueeze %arg2, %19 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %23 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %24 = torch.operator \"aten.scatter.value\"(%23, %19, %22, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" +" %25 = torch.aten.dim %24 : !torch.tensor -> !torch.int\n" +" %26 = torch.aten.dim %21 : !torch.tensor -> !torch.int\n" +" %27 = torch.aten.gt.int %25, %26 : !torch.int, !torch.int -> !torch.bool\n" +" %28 = torch.prim.If %27 -> (!torch.bool) {\n" +" %35 = torch.aten.dim %21 : !torch.tensor -> !torch.int\n" +" %36 = torch.aten.gt.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %36 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" +" %35 = torch.aten.unsqueeze %21, %19 : !torch.tensor, !torch.int -> !torch.tensor\n" +" torch.prim.If.yield %35 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %21 : !torch.tensor\n" +" }\n" +" %30 = torch.aten.__isnot__ %14, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %31 = torch.prim.If %30 -> (!torch.tensor) {\n" +" %35 = torch.prim.unchecked_cast %14 : !torch.optional -> !torch.tensor\n" +" %36 = torch.prim.ListConstruct : () -> !torch.list\n" +" %37 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" torch.prim.Loop %37, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %42 = torch.aten.append.t %36, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %38 = torch.aten.size.int %35, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %39 = torch.aten._set_item.t %36, %19, %38 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %40 = torch.aten.reshape %35, %36 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %41 = torch.aten.mul.Tensor %29, %40 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %41 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %29 : !torch.tensor\n" +" }\n" +" %32 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %33 = torch.prim.If %32 -> (!torch.tensor) {\n" +" %35 = torch.aten.ne.Scalar %22, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %36 = torch.aten.where.ScalarOther %35, %31, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" torch.prim.If.yield %36 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %31 : !torch.tensor\n" +" }\n" +" %34 = torch.aten.mul.Tensor %24, %33 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" return %34 : !torch.tensor\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions._nll_loss_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" +" %none = torch.constant.none\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %1 = torch.aten.lt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %3 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tensor) {\n" +" %18 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %18 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.tensor\n" +" }\n" +" %5 = torch.aten.unsqueeze %arg2, %2 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %6 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %7 = torch.operator \"aten.scatter.value\"(%6, %2, %5, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" +" %8 = torch.aten.dim %7 : !torch.tensor -> !torch.int\n" +" %9 = torch.aten.dim %4 : !torch.tensor -> !torch.int\n" +" %10 = torch.aten.gt.int %8, %9 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %18 = torch.aten.dim %4 : !torch.tensor -> !torch.int\n" +" %19 = torch.aten.gt.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %12 = torch.prim.If %11 -> (!torch.tensor) {\n" +" %18 = torch.aten.unsqueeze %4, %2 : !torch.tensor, !torch.int -> !torch.tensor\n" +" torch.prim.If.yield %18 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.tensor\n" +" }\n" +" %13 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.tensor) {\n" +" %18 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" +" %19 = torch.prim.ListConstruct : () -> !torch.list\n" +" %20 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" torch.prim.Loop %20, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %25 = torch.aten.append.t %19, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %21 = torch.aten.size.int %18, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %22 = torch.aten._set_item.t %19, %2, %21 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %23 = torch.aten.reshape %18, %19 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %24 = torch.aten.mul.Tensor %12, %23 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %24 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %12 : !torch.tensor\n" +" }\n" +" %15 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.tensor) {\n" +" %18 = torch.aten.ne.Scalar %5, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %19 = torch.aten.where.ScalarOther %18, %14, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" torch.prim.If.yield %19 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.tensor\n" +" }\n" +" %17 = torch.aten.mul.Tensor %7, %16 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" return %17 : !torch.tensor\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions.nll_loss2d_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" +" %true = torch.constant.bool true\n" +" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" +" %str = torch.constant.str \"expected total_weight to be a single element tensor, got: {} ( {}, elements)\"\n" +" %str_0 = torch.constant.str \"size mismatch (got input: {}, target: {}\"\n" +" %false = torch.constant.bool false\n" +" %str_1 = torch.constant.str \"only batches of spatial targets supported (3D tensors) but got targets of dimension: {}\"\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: \"\n" +" %str_3 = torch.constant.str \"only batches of spatial inputs supported (4D tensors), but got input of dimension: {}\"\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %29 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %30 = torch.aten.format(%str_3, %29) : !torch.str, !torch.int -> !torch.str\n" +" %31 = torch.aten.add.str %str_2, %30 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %31, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %29 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" +" %30 = torch.aten.format(%str_1, %29) : !torch.str, !torch.int -> !torch.str\n" +" %31 = torch.aten.add.str %str_2, %30 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %31, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %5 = torch.aten.size.int %arg2, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %29 = torch.aten.size.int %arg1, %int2 : !torch.tensor, !torch.int -> !torch.int\n" +" %30 = torch.aten.size.int %arg2, %int1 : !torch.tensor, !torch.int -> !torch.int\n" +" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %31 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %29 = torch.aten.size.int %arg1, %int3 : !torch.tensor, !torch.int -> !torch.int\n" +" %30 = torch.aten.size.int %arg2, %int2 : !torch.tensor, !torch.int -> !torch.int\n" +" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %31 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %29 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" +" %30 = torch.aten.size %arg2 : !torch.tensor -> !torch.list\n" +" %31 = torch.aten.format(%str_0, %29, %30) : !torch.str, !torch.list, !torch.list -> !torch.str\n" +" %32 = torch.aten.add.str %str_2, %31 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %29 = torch.aten.size %arg6 : !torch.tensor -> !torch.list\n" +" %30 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" +" %31 = torch.aten.format(%str, %29, %30) : !torch.str, !torch.list, !torch.int -> !torch.str\n" +" %32 = torch.aten.add.str %str_2, %31 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %12 = torch.aten.lt.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %14 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.tensor) {\n" +" %29 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %29 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.tensor\n" +" }\n" +" %16 = torch.aten.unsqueeze %arg2, %13 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %17 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %18 = torch.operator \"aten.scatter.value\"(%17, %13, %16, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" +" %19 = torch.aten.dim %18 : !torch.tensor -> !torch.int\n" +" %20 = torch.aten.dim %15 : !torch.tensor -> !torch.int\n" +" %21 = torch.aten.gt.int %19, %20 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.bool) {\n" +" %29 = torch.aten.dim %15 : !torch.tensor -> !torch.int\n" +" %30 = torch.aten.gt.int %29, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %30 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %23 = torch.prim.If %22 -> (!torch.tensor) {\n" +" %29 = torch.aten.unsqueeze %15, %13 : !torch.tensor, !torch.int -> !torch.tensor\n" +" torch.prim.If.yield %29 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %15 : !torch.tensor\n" +" }\n" +" %24 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.tensor) {\n" +" %29 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" +" %30 = torch.prim.ListConstruct : () -> !torch.list\n" +" %31 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" torch.prim.Loop %31, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %36 = torch.aten.append.t %30, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %32 = torch.aten.size.int %29, %int0 : !torch.tensor, !torch.int -> !torch.int\n" +" %33 = torch.aten._set_item.t %30, %13, %32 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %34 = torch.aten.reshape %29, %30 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %35 = torch.aten.mul.Tensor %23, %34 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %35 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %23 : !torch.tensor\n" +" }\n" +" %26 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %27 = torch.prim.If %26 -> (!torch.tensor) {\n" +" %29 = torch.aten.ne.Scalar %16, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %30 = torch.aten.where.ScalarOther %29, %25, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" torch.prim.If.yield %30 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %25 : !torch.tensor\n" +" }\n" +" %28 = torch.aten.mul.Tensor %18, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" return %28 : !torch.tensor\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions._log_softmax_backward_data(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int, %arg3: !torch.int) -> !torch.tensor {\n" +" %int1 = torch.constant.int 1\n" +" %none = torch.constant.none\n" +" %true = torch.constant.bool true\n" +" %0 = torch.aten.exp %arg1 : !torch.tensor -> !torch.tensor\n" +" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" %2 = torch.aten.sum.dim_IntList %arg0, %1, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %3 = torch.aten.mul.Tensor %0, %2 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %4 = torch.aten.sub.Tensor %arg0, %3, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" return %4 : !torch.tensor\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions._softmax_backward_data(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int, %arg3: !torch.int) -> !torch.tensor {\n" +" %int1 = torch.constant.int 1\n" +" %none = torch.constant.none\n" +" %true = torch.constant.bool true\n" +" %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" %2 = torch.aten.sum.dim_IntList %0, %1, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %3 = torch.aten.mul.Tensor %arg1, %2 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %4 = torch.aten.sub.Tensor %0, %3, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" return %4 : !torch.tensor\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions.log_sigmoid_forward(%arg0: !torch.tensor) -> !torch.tuple {\n" +" %int1 = torch.constant.int 1\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.new_zeros %arg0, %0, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %2 = torch.aten.minimum %1, %arg0 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %3 = torch.aten.abs %arg0 : !torch.tensor -> !torch.tensor\n" +" %4 = torch.aten.neg %3 : !torch.tensor -> !torch.tensor\n" +" %5 = torch.aten.exp %4 : !torch.tensor -> !torch.tensor\n" +" %6 = torch.operator \"prim.is_cuda\"(%arg0) : (!torch.tensor) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" +" %11 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" %12 = torch.aten.new_zeros %arg0, %11, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %12 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %5 : !torch.tensor\n" +" }\n" +" %8 = torch.aten.log1p %5 : !torch.tensor -> !torch.tensor\n" +" %9 = torch.aten.sub.Tensor %2, %8, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %10 = torch.prim.TupleConstruct %9, %7 : !torch.tensor, !torch.tensor -> !torch.tuple\n" +" return %10 : !torch.tuple\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions_for_jvp.native_layer_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.list, %arg3: !torch.tensor, %arg4: !torch.tensor, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.list) -> !torch.tuple, optional, optional> {\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" +" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %5 = torch.aten.slice.t %0, %none, %3, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.__range_length %3, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %7, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %17 = torch.aten.__derive_index %arg8, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.append.t %6, %17 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %8 = torch.prim.ListConstruct : () -> !torch.list\n" +" %9 = torch.aten.__range_length %int0, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %17 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.append.t %8, %17 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %11 = torch.prim.Loop %10, %true, init(%int1) {\n" +" ^bb0(%arg8: !torch.int, %arg9: !torch.int):\n" +" %17 = torch.aten.__getitem__.t %4, %arg8 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.mul.int %arg9, %17 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%18 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %12 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %13 = torch.prim.Loop %12, %true, init(%int1) {\n" +" ^bb0(%arg8: !torch.int, %arg9: !torch.int):\n" +" %17 = torch.aten.__getitem__.t %5, %arg8 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.mul.int %arg9, %17 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%18 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %14 = torch.aten.le.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %17 = torch.aten.le.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" %16 = torch.prim.If %15 -> (!torch.tuple, optional, optional>) {\n" +" %17 = torch.aten.new_zeros %arg1, %0, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %18 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %19 = torch.aten.new_zeros %arg1, %18, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %20 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %21 = torch.aten.new_zeros %arg1, %20, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" %22 = torch.prim.TupleConstruct %17, %19, %21 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" +" %23 = torch.derefine %22 : !torch.tuple to !torch.tuple, optional, optional>\n" +" torch.prim.If.yield %23 : !torch.tuple, optional, optional>\n" +" } else {\n" +" %17 = torch.aten.mean.dim %arg1, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %18 = torch.aten.var.dim %arg1, %6, %false, %true : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" +" %19 = torch.aten.reciprocal %arg4 : !torch.tensor -> !torch.tensor\n" +" %20 = torch.aten.pow.Tensor_Scalar %19, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %21 = torch.aten.sub.Tensor %20, %18, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %22 = torch.aten.detach %21 : !torch.tensor -> !torch.tensor\n" +" %23 = torch.aten.add.Tensor %18, %22, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %24 = torch.aten.sqrt %23 : !torch.tensor -> !torch.tensor\n" +" %25 = torch.aten.reciprocal %24 : !torch.tensor -> !torch.tensor\n" +" %26 = torch.aten.sub.Tensor %arg1, %17, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %27 = torch.aten.mul.Tensor %26, %25 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %28 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" +" %46 = torch.prim.unchecked_cast %arg5 : !torch.optional -> !torch.tensor\n" +" %47 = torch.aten.mul.Tensor %arg0, %46 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %47 : !torch.tensor\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.tensor\n" +" }\n" +" %30 = torch.aten.mul.Scalar %29, %11 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %31 = torch.aten.sum.dim_IntList %29, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %32 = torch.aten.mul.Tensor %29, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %33 = torch.aten.sum.dim_IntList %32, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %34 = torch.aten.mul.Tensor %27, %33 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %35 = torch.aten.sub.Tensor %30, %31, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %36 = torch.aten.sub.Tensor %35, %34, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %37 = torch.aten.__getitem__.t %arg7, %int0 : !torch.list, !torch.int -> !torch.bool\n" +" %38 = torch.prim.If %37 -> (!torch.tensor) {\n" +" %46 = torch.aten.div.Scalar %25, %11 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %47 = torch.aten.mul.Tensor %46, %36 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %47 : !torch.tensor\n" +" } else {\n" +" %46 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %46 : !torch.tensor\n" +" }\n" +" %39 = torch.aten.__getitem__.t %arg7, %int1 : !torch.list, !torch.int -> !torch.bool\n" +" %40 = torch.prim.If %39 -> (!torch.bool) {\n" +" %46 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %41 = torch.prim.If %40 -> (!torch.tensor) {\n" +" %46 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" +" %47 = torch.aten.gt.int %46, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %48 = torch.prim.If %47 -> (!torch.tensor) {\n" +" %49 = torch.aten.mul.Tensor %arg0, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %50 = torch.aten.sum.dim_IntList %49, %8, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %50 : !torch.tensor\n" +" } else {\n" +" %49 = torch.aten.mul.Tensor %arg0, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" }\n" +" torch.prim.If.yield %48 : !torch.tensor\n" +" } else {\n" +" %46 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %47 = torch.prim.If %46 -> (!torch.tensor) {\n" +" %48 = torch.prim.unchecked_cast %arg5 : !torch.optional -> !torch.tensor\n" +" %49 = torch.aten.zeros_like %48, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" } else {\n" +" %48 = torch.prim.ListConstruct : () -> !torch.list\n" +" %49 = torch.aten.zeros %48, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" }\n" +" torch.prim.If.yield %47 : !torch.tensor\n" +" }\n" +" %42 = torch.aten.__getitem__.t %arg7, %int2 : !torch.list, !torch.int -> !torch.bool\n" +" %43 = torch.prim.If %42 -> (!torch.bool) {\n" +" %46 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %44 = torch.prim.If %43 -> (!torch.tensor) {\n" +" %46 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" +" %47 = torch.aten.gt.int %46, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %48 = torch.prim.If %47 -> (!torch.tensor) {\n" +" %49 = torch.aten.sum.dim_IntList %arg0, %8, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" } else {\n" +" %49 = torch.aten.clone %arg0, %none : !torch.tensor, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" }\n" +" torch.prim.If.yield %48 : !torch.tensor\n" +" } else {\n" +" %46 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %47 = torch.prim.If %46 -> (!torch.tensor) {\n" +" %48 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.tensor\n" +" %49 = torch.aten.zeros_like %48, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" } else {\n" +" %48 = torch.prim.ListConstruct : () -> !torch.list\n" +" %49 = torch.aten.zeros %48, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %49 : !torch.tensor\n" +" }\n" +" torch.prim.If.yield %47 : !torch.tensor\n" +" }\n" +" %45 = torch.prim.TupleConstruct %38, %41, %44 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple, optional, optional>\n" +" torch.prim.If.yield %45 : !torch.tuple, optional, optional>\n" +" }\n" +" return %16 : !torch.tuple, optional, optional>\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions_for_jvp.recompute_mean_var(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.tuple {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.mean.dim %arg0, %arg2, %arg3, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %1 = torch.aten.var.dim %arg0, %arg2, %false, %arg3 : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" +" %2 = torch.aten.reciprocal %arg1 : !torch.tensor -> !torch.tensor\n" +" %3 = torch.aten.mul.Scalar %2, %int1 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %4 = torch.aten.pow.Tensor_Scalar %3, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %5 = torch.aten.sub.Tensor %4, %1, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %6 = torch.aten.detach %5 : !torch.tensor -> !torch.tensor\n" +" %7 = torch.aten.add.Tensor %1, %6, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %8 = torch.aten.sqrt %7 : !torch.tensor -> !torch.tensor\n" +" %9 = torch.aten.reciprocal %8 : !torch.tensor -> !torch.tensor\n" +" %10 = torch.aten.mul.Scalar %9, %int1 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %11 = torch.prim.TupleConstruct %0, %10 : !torch.tensor, !torch.tensor -> !torch.tuple\n" +" return %11 : !torch.tuple\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions_for_jvp.native_batch_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.bool, %arg8: !torch.float, %arg9: !torch.list) -> !torch.tuple, optional> {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %str_0 = torch.constant.str \"AssertionError: when train=True, save_mean and save_invstd are required\"\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: rank of the input must be at least 2\"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %float1.000000e00 = torch.constant.float 1.000000e+00\n" +" %0 = torch.prim.Uninitialized : !torch.tensor\n" +" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %2 = torch.aten.ge.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %4 = torch.prim.Loop %3, %true, init(%int1) {\n" +" ^bb0(%arg10: !torch.int, %arg11: !torch.int):\n" +" %34 = torch.aten.size.int %arg1, %arg10 : !torch.tensor, !torch.int -> !torch.int\n" +" %35 = torch.aten.mul.int %arg11, %34 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%35 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %5 = torch.aten.size.int %arg1, %int1 : !torch.tensor, !torch.int -> !torch.int\n" +" %6 = torch.operator \"aten.div.int\"(%4, %5) : (!torch.int, !torch.int) -> !torch.float\n" +" %7:2 = torch.prim.If %arg7 -> (!torch.tensor, !torch.tensor) {\n" +" %34 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %35 = torch.prim.If %34 -> (!torch.bool) {\n" +" %52 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %52 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %35 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %36 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" %37 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" +" %38 = torch.prim.ListConstruct : () -> !torch.list\n" +" %39 = torch.aten.__range_length %int2, %37, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %39, %true, init() {\n" +" ^bb0(%arg10: !torch.int):\n" +" %52 = torch.aten.__derive_index %arg10, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %53 = torch.aten.append.t %38, %52 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %40 = torch.aten.add.t %36, %38 : !torch.list, !torch.list -> !torch.list\n" +" %41 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %42 = torch.prim.If %41 -> (!torch.tensor) {\n" +" %52 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.tensor\n" +" torch.prim.If.yield %52 : !torch.tensor\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.tensor\n" +" }\n" +" %43 = torch.aten.mean.dim %arg1, %40, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %44 = torch.aten.var.dim %arg1, %40, %false, %false : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" +" %45 = torch.aten.reciprocal %42 : !torch.tensor -> !torch.tensor\n" +" %46 = torch.aten.pow.Tensor_Scalar %45, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" +" %47 = torch.aten.sub.Tensor %46, %44, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %48 = torch.aten.detach %47 : !torch.tensor -> !torch.tensor\n" +" %49 = torch.aten.add.Tensor %44, %48, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %50 = torch.aten.sqrt %49 : !torch.tensor -> !torch.tensor\n" +" %51 = torch.aten.reciprocal %50 : !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %43, %51 : !torch.tensor, !torch.tensor\n" +" } else {\n" +" %34 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %35 = torch.prim.If %34 -> (!torch.bool) {\n" +" %39 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" +" %40 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %36:2 = torch.prim.If %35 -> (!torch.tensor, !torch.tensor) {\n" +" %39 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" +" %40 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.tensor\n" +" torch.prim.If.yield %40, %39 : !torch.tensor, !torch.tensor\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0, %0 : !torch.tensor, !torch.tensor\n" +" }\n" +" %37 = torch.aten.add.Scalar %36#0, %arg8, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor\n" +" %38 = torch.aten.rsqrt %37 : !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %36#1, %38 : !torch.tensor, !torch.tensor\n" +" }\n" +" %8 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" +" %9 = torch.operator \"aten.mul.left_t\"(%8, %1) : (!torch.list, !torch.int) -> !torch.list\n" +" %10 = torch.aten.size.int %arg1, %int1 : !torch.tensor, !torch.int -> !torch.int\n" +" %11 = torch.aten._set_item.t %9, %int1, %10 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %12 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg10: !torch.int):\n" +" %34 = torch.aten.ne.int %arg10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %34 -> () {\n" +" %35 = torch.aten.append.t %12, %arg10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %13 = torch.aten.reshape %7#0, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %14 = torch.aten.div.float %float1.000000e00, %6 : !torch.float, !torch.float -> !torch.float\n" +" %15 = torch.aten.sum.dim_IntList %arg0, %12, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %16 = torch.aten.sub.Tensor %arg1, %13, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %17 = torch.aten.mul.Tensor %arg0, %16 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %18 = torch.aten.sum.dim_IntList %17, %12, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" +" %19 = torch.aten.mul.Scalar %15, %14 : !torch.tensor, !torch.float -> !torch.tensor\n" +" %20 = torch.aten.reshape %19, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %21 = torch.aten.mul.Scalar %18, %14 : !torch.tensor, !torch.float -> !torch.tensor\n" +" %22 = torch.aten.mul.Tensor %7#1, %7#1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %23 = torch.aten.mul.Tensor %21, %22 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %24 = torch.aten.reshape %23, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %25 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.tensor) {\n" +" %34 = torch.aten.reshape %7#1, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" +" %35 = torch.aten.mul.Scalar %34, %float1.000000e00 : !torch.tensor, !torch.float -> !torch.tensor\n" +" torch.prim.If.yield %35 : !torch.tensor\n" +" } else {\n" +" %34 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.tensor\n" +" %35 = torch.aten.mul.Tensor %7#1, %34 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %36 = torch.aten.reshape %35, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" +" torch.prim.If.yield %36 : !torch.tensor\n" +" }\n" +" %27 = torch.prim.If %arg7 -> (!torch.tensor) {\n" +" %34 = torch.aten.sub.Tensor %arg1, %13, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %35 = torch.aten.mul.Tensor %34, %24 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" %36 = torch.aten.sub.Tensor %arg0, %35, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %37 = torch.aten.sub.Tensor %36, %20, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" +" %38 = torch.aten.mul.Tensor %37, %26 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %38 : !torch.tensor\n" +" } else {\n" +" %34 = torch.aten.mul.Tensor %arg0, %26 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %34 : !torch.tensor\n" +" }\n" +" %28 = torch.aten.__getitem__.t %arg9, %int1 : !torch.list, !torch.int -> !torch.bool\n" +" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" +" %34 = torch.aten.mul.Tensor %18, %7#1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" +" torch.prim.If.yield %34 : !torch.tensor\n" +" } else {\n" +" %34 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %35 = torch.prim.If %34 -> (!torch.tensor) {\n" +" %36 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.tensor\n" +" %37 = torch.aten.zeros_like %36, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %37 : !torch.tensor\n" +" } else {\n" +" %36 = torch.prim.ListConstruct : () -> !torch.list\n" +" %37 = torch.aten.zeros %36, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %37 : !torch.tensor\n" +" }\n" +" torch.prim.If.yield %35 : !torch.tensor\n" +" }\n" +" %30 = torch.aten.__getitem__.t %arg9, %int2 : !torch.list, !torch.int -> !torch.bool\n" +" %31 = torch.prim.If %30 -> (!torch.tensor) {\n" +" torch.prim.If.yield %15 : !torch.tensor\n" +" } else {\n" +" %34 = torch.aten.zeros_like %15, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" +" torch.prim.If.yield %34 : !torch.tensor\n" +" }\n" +" %32 = torch.prim.TupleConstruct %27, %29, %31 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" +" %33 = torch.derefine %32 : !torch.tuple to !torch.tuple, optional>\n" +" return %33 : !torch.tuple, optional>\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions_for_jvp.prod(%arg0: !torch.list) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch._decomp.decompositions.cudnn_batch_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.float, %arg8: !torch.tensor) -> !torch.tuple {\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list\n" +" %result0, %result1, %result2 = torch.aten.native_batch_norm_backward %arg1, %arg0, %arg2, %arg3, %arg4, %arg5, %arg6, %true, %arg7, %0 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.optional, !torch.optional, !torch.optional, !torch.optional, !torch.bool, !torch.float, !torch.list -> !torch.tensor, !torch.tensor, !torch.tensor\n" +" %1 = torch.prim.TupleConstruct %result0, %result1, %result2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.append.t %0, %2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions._copy(%arg0: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.append.t %0, %2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %13 = torch.aten.eq.int %12, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.__range_length %int1, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %12 = torch.aten.__derive_index %arg2, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg0, %12 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.ne.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.__range_length %int0, %9, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %12 = torch.aten.__derive_index %arg2, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg0, %12 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.append.t %7, %13 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %11 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %11, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %7, %12 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.zero_dim_tensor(%arg0: !torch.any) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.arange_end(%arg0: !torch.union, %arg1: !torch.any, %arg2: !torch.any, %arg3: !torch.any, %arg4: !torch.any) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.aten.ceil.Scalar %arg0 : !torch.union -> !torch.number\n" +" %2 = torch.aten.Int.Scalar %1 : !torch.number -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.arange_start(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.any, %arg3: !torch.any, %arg4: !torch.any, %arg5: !torch.any) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.sub %arg1, %arg0 : !torch.union, !torch.union -> !torch.number\n" +" %3 = torch.aten.ceil.Scalar %2 : !torch.number -> !torch.number\n" +" %4 = torch.aten.Int.Scalar %3 : !torch.number -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" return %5 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.arange_start_step(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.any, %arg4: !torch.any, %arg5: !torch.any, %arg6: !torch.any) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union, !torch.union) -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.sub %arg1, %arg0 : !torch.union, !torch.union -> !torch.number\n" +" %3 = torch.aten.div %2, %arg2 : !torch.number, !torch.union -> !torch.float\n" +" %4 = torch.aten.ceil.float %3 : !torch.float -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" return %5 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.squeeze_nodim(%arg0: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.append.t %0, %4 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.squeeze(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1 : !torch.int\n" +" }\n" +" %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %12 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %12 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %11, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %12 = torch.aten.eq.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" %15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.append.t %0, %15 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.append.t %0, %13 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.le.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If %arg2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %2 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n" +" %3 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.lt.int %arg0, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.gt.int %arg0, %3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" }\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.lt.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" %9 = torch.aten.add.int %arg0, %1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.int\n" +" }\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.le.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1 : !torch.int\n" +" }\n" +" %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %13 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %13 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %11 = torch.prim.ListConstruct : () -> !torch.list\n" +" %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %12, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.append.t %11, %13 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.aten.insert.t %11, %10, %int1 : !torch.list, !torch.int, !torch.int\n" +" return %11 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.slice(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %33 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %33 : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %33 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %33 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %11 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %33 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %33 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %13 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %33 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %33 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int9223372036854775807 : !torch.int\n" +" }\n" +" %15 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.eq.int %12, %int9223372036854775807 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %12 : !torch.int\n" +" }\n" +" %18 = torch.aten.lt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.int) {\n" +" %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %34 = torch.aten.add.int %17, %33 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %34 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %17 : !torch.int\n" +" }\n" +" %20 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %21 = torch.prim.If %20 -> (!torch.int) {\n" +" %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %34 = torch.aten.add.int %14, %33 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %34 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %22 = torch.aten.lt.int %19, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %34 = torch.aten.gt.int %19, %33 : !torch.int, !torch.int -> !torch.bool\n" +" %35 = torch.prim.If %34 -> (!torch.int) {\n" +" %36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %36 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %19 : !torch.int\n" +" }\n" +" torch.prim.If.yield %35 : !torch.int\n" +" }\n" +" %24 = torch.aten.lt.int %21, %23 : !torch.int, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.int) {\n" +" torch.prim.If.yield %23 : !torch.int\n" +" } else {\n" +" %33 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %34 = torch.aten.ge.int %21, %33 : !torch.int, !torch.int -> !torch.bool\n" +" %35 = torch.prim.If %34 -> (!torch.int) {\n" +" %36 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %36 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %21 : !torch.int\n" +" }\n" +" torch.prim.If.yield %35 : !torch.int\n" +" }\n" +" %26 = torch.aten.sub.int %25, %23 : !torch.int, !torch.int -> !torch.int\n" +" %27 = torch.prim.ListConstruct : () -> !torch.list\n" +" %28 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %28, %true, init() {\n" +" ^bb0(%arg5: !torch.int):\n" +" %33 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %34 = torch.aten.append.t %27, %33 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %29 = torch.aten.add.int %26, %arg4 : !torch.int, !torch.int -> !torch.int\n" +" %30 = torch.aten.sub.int %29, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.floordiv.int %30, %arg4 : !torch.int, !torch.int -> !torch.int\n" +" %32 = torch.aten._set_item.t %27, %10, %31 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" return %27 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.max_int() -> !torch.int {\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" return %int9223372036854775807 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.select(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.lt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %17 = torch.aten.gt.int %arg1, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %17 = torch.aten.add.int %arg1, %3 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %17 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %11 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.neg.int %11 : !torch.int -> !torch.int\n" +" %13 = torch.aten.lt.int %arg2, %12 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %17 = torch.aten.ge.int %arg2, %11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" %15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %17 = torch.aten.ne.int %arg3, %10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %17 -> () {\n" +" %18 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %16, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %16 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.index_select(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %3 = torch.aten.neg.int %2 : !torch.int -> !torch.int\n" +" %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %18 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" }\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %18 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %18 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %10 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %11 = torch.prim.Loop %10, %true, init(%int1) {\n" +" ^bb0(%arg3: !torch.int, %arg4: !torch.int):\n" +" %18 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.mul.int %arg4, %18 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%19 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %12 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %13 = torch.aten.le.int %12, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.aten.eq.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %18 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %19 = torch.aten.lt.int %9, %18 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" }\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.prim.ListConstruct : () -> !torch.list\n" +" %17 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %17, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %18 = torch.aten.eq.int %9, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %18 -> () {\n" +" %19 = torch.aten.append.t %16, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %19 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.append.t %16, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %16 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.multiply_integers(%arg0: !torch.list) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.embedding(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.list) {\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.le.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" %8 = torch.aten.neg.int %7 : !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.lt.int %int0, %8 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %19 = torch.aten.gt.int %int0, %9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" }\n" +" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %14 = torch.prim.Loop %13, %true, init(%int1) {\n" +" ^bb0(%arg5: !torch.int, %arg6: !torch.int):\n" +" %19 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.mul.int %arg6, %19 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%20 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %15 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %16 = torch.aten.le.int %15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.prim.ListConstruct : () -> !torch.list\n" +" %18 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %18, %true, init() {\n" +" ^bb0(%arg5: !torch.int):\n" +" %19 = torch.aten.eq.int %int0, %arg5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %19 -> () {\n" +" %20 = torch.aten.append.t %17, %14 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %17, %20 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %17 : !torch.list\n" +" } else {\n" +" %5 = torch.prim.ListConstruct : () -> !torch.list\n" +" %6 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg5: !torch.int):\n" +" %9 = torch.aten.__getitem__.t %arg1, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %5, %9 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.append.t %5, %7 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.mm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: mat2 must be a matrix\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: self must be a matrix\"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list\n" +" return %9 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.dot(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.mv(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %8 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.matmul(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %str_0 = torch.constant.str \"AssertionError: mat2 must be a matrix\"\n" +" %str_1 = torch.constant.str \"AssertionError: self must be a matrix\"\n" +" %str_2 = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %str_3 = torch.constant.str \"AssertionError: both arguments to matmul need to be at least 1D\"\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %int-2 = torch.constant.int -2\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %6 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %6 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %13 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %14 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %14 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %9 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %8 = torch.prim.If %7 -> (!torch.list) {\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %17 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.eq.int %12, %13 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" %9 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %12 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %11 = torch.prim.If %10 -> (!torch.list) {\n" +" %12 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %13 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.le.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" }\n" +" %16 = torch.aten.neg.int %15 : !torch.int -> !torch.int\n" +" %17 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.lt.int %int0, %16 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %34 = torch.aten.gt.int %int0, %17 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %34 : !torch.bool\n" +" }\n" +" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %21 = torch.prim.ListConstruct : () -> !torch.list\n" +" %22 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %22, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %34 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.append.t %21, %34 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.aten.insert.t %21, %int0, %int1 : !torch.list, !torch.int, !torch.int\n" +" %23 = torch.aten.len.t %21 : !torch.list -> !torch.int\n" +" %24 = torch.aten.eq.int %23, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %24 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %25 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %26 = torch.aten.eq.int %25, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %27 = torch.aten.__getitem__.t %21, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %29 = torch.aten.eq.int %27, %28 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %29 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %30 = torch.aten.__getitem__.t %21, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %31 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.prim.ListConstruct %30, %31 : (!torch.int, !torch.int) -> !torch.list\n" +" %33 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %int2, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %34 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %34 -> () {\n" +" %35 = torch.aten.__getitem__.t %32, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.ne.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %36 -> () {\n" +" %37 = torch.aten.__getitem__.t %32, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.append.t %33, %37 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %35 = torch.aten.__getitem__.t %32, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.append.t %33, %35 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %33 : !torch.list\n" +" } else {\n" +" %12 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.list) {\n" +" %15 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %16 = torch.aten.eq.int %15, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.eq.int %19, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %22 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %23 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.prim.ListConstruct %22, %23 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %24 : !torch.list\n" +" } else {\n" +" %15 = torch.aten.ge.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %18 = torch.aten.ge.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %17 = torch.prim.If %16 -> (!torch.list) {\n" +" %18 = torch.aten.gt.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.int) {\n" +" %31 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %31 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %20 = torch.prim.ListConstruct : () -> !torch.list\n" +" %21 = torch.aten.sub.int %1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %21, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %31 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.append.t %20, %31 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %22 = torch.aten.__getitem__.t %arg1, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %23 = torch.prim.ListConstruct : () -> !torch.list\n" +" %24 = torch.aten.sub.int %2, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %24, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %31 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.append.t %23, %31 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %25 = torch.aten.len.t %20 : !torch.list -> !torch.int\n" +" %26 = torch.aten.len.t %23 : !torch.list -> !torch.int\n" +" %27 = torch.prim.max.int %25, %26 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %27, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %31 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %32 = torch.aten.sub.int %31, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %33 = torch.aten.sub.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %34 = torch.aten.sub.int %33, %32 : !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.sub.int %26, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %36 = torch.aten.sub.int %35, %32 : !torch.int, !torch.int -> !torch.int\n" +" %37 = torch.aten.ge.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %38 = torch.prim.If %37 -> (!torch.int) {\n" +" %47 = torch.aten.__getitem__.t %20, %34 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %47 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %39 = torch.aten.ge.int %36, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %40 = torch.prim.If %39 -> (!torch.int) {\n" +" %47 = torch.aten.__getitem__.t %23, %36 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %47 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %41 = torch.aten.ne.int %38, %40 : !torch.int, !torch.int -> !torch.bool\n" +" %42 = torch.prim.If %41 -> (!torch.bool) {\n" +" %47 = torch.aten.ne.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %47 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %43 = torch.prim.If %42 -> (!torch.bool) {\n" +" %47 = torch.aten.ne.int %40, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %47 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %43 -> () {\n" +" %47 = torch.aten.format(%str, %38, %40, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %48 = torch.aten.add.str %str_2, %47 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %48, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %44 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %45 = torch.prim.If %44 -> (!torch.int) {\n" +" torch.prim.If.yield %40 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %38 : !torch.int\n" +" }\n" +" %46 = torch.aten.append.t %28, %45 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %29 = torch.aten.gt.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %29 -> () {\n" +" %31 = torch.aten.append.t %28, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %30 = torch.aten.gt.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %30 -> () {\n" +" %31 = torch.aten.append.t %28, %22 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %28 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" torch.prim.If.yield %17 : !torch.list\n" +" }\n" +" torch.prim.If.yield %14 : !torch.list\n" +" }\n" +" torch.prim.If.yield %11 : !torch.list\n" +" }\n" +" torch.prim.If.yield %8 : !torch.list\n" +" }\n" +" return %5 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.broadcast(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %2, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %4, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" %20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" %20 = torch.aten.format(%str_0, %11, %13, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %21 = torch.aten.add.str %str, %20 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %21, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" %19 = torch.aten.append.t %3, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.linear(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: both arguments to matmul need to be at least 1D\"\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %int-2 = torch.constant.int -2\n" +" %false = torch.constant.bool false\n" +" %str_0 = torch.constant.str \"AssertionError: self must be a matrix\"\n" +" %str_1 = torch.constant.str \"AssertionError: mat2 must be a matrix\"\n" +" %str_2 = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %str_3 = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.list) {\n" +" %13 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %13 : !torch.list\n" +" } else {\n" +" %13 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.list) {\n" +" %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" %15 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.prim.ListConstruct %15, %16 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %17 : !torch.list\n" +" }\n" +" torch.prim.If.yield %14 : !torch.list\n" +" }\n" +" %5 = torch.prim.ListConstruct : () -> !torch.list\n" +" %6 = torch.prim.Uninitialized : !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %13 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %11 = torch.prim.If %10 -> (!torch.list) {\n" +" %13 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %14 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" %19 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %20 = torch.aten.eq.int %19, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.eq.int %16, %17 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.list\n" +" } else {\n" +" %13 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" %16 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %15 = torch.prim.If %14 -> (!torch.list) {\n" +" %16 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.bool) {\n" +" %24 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %19 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.eq.int %19, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %22 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %23 = torch.prim.ListConstruct %22 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %23 : !torch.list\n" +" } else {\n" +" %16 = torch.aten.eq.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %19 = torch.aten.eq.int %8, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %18 = torch.prim.If %17 -> (!torch.list) {\n" +" %19 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %20 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.le.int %20, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %20 : !torch.int\n" +" }\n" +" %23 = torch.aten.neg.int %22 : !torch.int -> !torch.int\n" +" %24 = torch.aten.sub.int %22, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.lt.int %int0, %23 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %41 = torch.aten.gt.int %int0, %24 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %41 : !torch.bool\n" +" }\n" +" %27 = torch.aten.__not__ %26 : !torch.bool -> !torch.bool\n" +" torch.prim.If %27 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %28 = torch.prim.ListConstruct : () -> !torch.list\n" +" %29 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %29, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %41 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.append.t %28, %41 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.aten.insert.t %28, %int0, %int1 : !torch.list, !torch.int, !torch.int\n" +" %30 = torch.aten.len.t %28 : !torch.list -> !torch.int\n" +" %31 = torch.aten.eq.int %30, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %31 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %32 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %33 = torch.aten.eq.int %32, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %33 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %34 = torch.aten.__getitem__.t %28, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %36 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %37 = torch.aten.__getitem__.t %28, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.__getitem__.t %4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.prim.ListConstruct %37, %38 : (!torch.int, !torch.int) -> !torch.list\n" +" %40 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %int2, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %41 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %41 -> () {\n" +" %42 = torch.aten.__getitem__.t %39, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.ne.int %42, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %43 -> () {\n" +" %44 = torch.aten.__getitem__.t %39, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %45 = torch.aten.append.t %40, %44 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %42 = torch.aten.__getitem__.t %39, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.append.t %40, %42 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %40 : !torch.list\n" +" } else {\n" +" %19 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.bool) {\n" +" %22 = torch.aten.eq.int %8, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %22 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %21 = torch.prim.If %20 -> (!torch.list) {\n" +" %22 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %23 = torch.aten.eq.int %22, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %25 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %26 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.__getitem__.t %4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %28 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.__getitem__.t %4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %31 = torch.prim.ListConstruct %29, %30 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %31 : !torch.list\n" +" } else {\n" +" %22 = torch.aten.ge.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %25 = torch.aten.ge.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.prim.If %23 -> (!torch.list) {\n" +" %25 = torch.aten.gt.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.int) {\n" +" %38 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %38 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %27 = torch.prim.ListConstruct : () -> !torch.list\n" +" %28 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %28, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %38 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.append.t %27, %38 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %29 = torch.aten.__getitem__.t %4, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.prim.ListConstruct : () -> !torch.list\n" +" %31 = torch.aten.sub.int %8, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %31, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %38 = torch.aten.__getitem__.t %4, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.append.t %30, %38 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %32 = torch.aten.len.t %27 : !torch.list -> !torch.int\n" +" %33 = torch.aten.len.t %30 : !torch.list -> !torch.int\n" +" %34 = torch.prim.max.int %32, %33 : !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %34, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %38 = torch.aten.sub.int %34, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.sub.int %38, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.sub.int %32, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %41 = torch.aten.sub.int %40, %39 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.sub.int %33, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %43 = torch.aten.sub.int %42, %39 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.ge.int %41, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %45 = torch.prim.If %44 -> (!torch.int) {\n" +" %54 = torch.aten.__getitem__.t %27, %41 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %54 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %46 = torch.aten.ge.int %43, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %47 = torch.prim.If %46 -> (!torch.int) {\n" +" %54 = torch.aten.__getitem__.t %30, %43 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %54 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %48 = torch.aten.ne.int %45, %47 : !torch.int, !torch.int -> !torch.bool\n" +" %49 = torch.prim.If %48 -> (!torch.bool) {\n" +" %54 = torch.aten.ne.int %45, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %54 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %50 = torch.prim.If %49 -> (!torch.bool) {\n" +" %54 = torch.aten.ne.int %47, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %54 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %50 -> () {\n" +" %54 = torch.aten.format(%str_2, %45, %47, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %55 = torch.aten.add.str %str_3, %54 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %55, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %51 = torch.aten.eq.int %45, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %52 = torch.prim.If %51 -> (!torch.int) {\n" +" torch.prim.If.yield %47 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %45 : !torch.int\n" +" }\n" +" %53 = torch.aten.append.t %35, %52 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %36 = torch.aten.gt.int %7, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %36 -> () {\n" +" %38 = torch.aten.append.t %35, %26 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %37 = torch.aten.gt.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %37 -> () {\n" +" %38 = torch.aten.append.t %35, %29 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %35 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" torch.prim.If.yield %24 : !torch.list\n" +" }\n" +" torch.prim.If.yield %21 : !torch.list\n" +" }\n" +" torch.prim.If.yield %18 : !torch.list\n" +" }\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" %12 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" %13 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %14 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %15 = torch.aten.len.t %11 : !torch.list -> !torch.int\n" +" %16 = torch.prim.max.int %14, %15 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %16, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %19 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.sub.int %19, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.sub.int %21, %20 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.sub.int %23, %20 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.ge.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.int) {\n" +" %35 = torch.aten.__getitem__.t %13, %22 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %35 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %27 = torch.aten.ge.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %28 = torch.prim.If %27 -> (!torch.int) {\n" +" %35 = torch.aten.__getitem__.t %11, %24 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %35 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %29 = torch.aten.ne.int %26, %28 : !torch.int, !torch.int -> !torch.bool\n" +" %30 = torch.prim.If %29 -> (!torch.bool) {\n" +" %35 = torch.aten.ne.int %26, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %35 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %31 = torch.prim.If %30 -> (!torch.bool) {\n" +" %35 = torch.aten.ne.int %28, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %35 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %31 -> () {\n" +" %35 = torch.aten.format(%str_2, %26, %28, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %36 = torch.aten.add.str %str_3, %35 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %36, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %32 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %33 = torch.prim.If %32 -> (!torch.int) {\n" +" torch.prim.If.yield %28 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %26 : !torch.int\n" +" }\n" +" %34 = torch.aten.append.t %17, %33 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %18 = torch.aten.eq.int_list %17, %11 : !torch.list, !torch.list -> !torch.bool\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %11 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.t(%arg0: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.list) {\n" +" %5 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" } else {\n" +" %5 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.list) {\n" +" %7 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %8 : !torch.list\n" +" } else {\n" +" %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %9 : !torch.list\n" +" }\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.max_pool2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: stride should not be zeero\"\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %int-4 = torch.constant.int -4\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %str_1 = torch.constant.str \"AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints\"\n" +" %str_2 = torch.constant.str \"AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints\"\n" +" %str_3 = torch.constant.str \"AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints\"\n" +" %none = torch.constant.none\n" +" %str_4 = torch.constant.str \"AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" %86 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" }\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" %86 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %6 : !torch.int\n" +" } else {\n" +" %86 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %88 = torch.prim.If %87 -> (!torch.int) {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" %89 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %89 : !torch.int\n" +" }\n" +" torch.prim.If.yield %88 : !torch.int\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" }\n" +" torch.prim.If %19 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.int) {\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" %86 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" }\n" +" %24 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %27 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %29 = torch.aten.eq.int %28, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %30 = torch.prim.If %29 -> (!torch.int) {\n" +" torch.prim.If.yield %27 : !torch.int\n" +" } else {\n" +" %86 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" }\n" +" %31 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %32 = torch.aten.eq.int %31, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %33 = torch.prim.If %32 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %87 = torch.aten.eq.int %86, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" }\n" +" torch.prim.If %33 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %34 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %35 = torch.aten.eq.int %34, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.int) {\n" +" %86 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %37 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.ne.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %40 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %41 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.add.int %41, %20 : !torch.int, !torch.int -> !torch.int\n" +" %43 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.mul.int %27, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.sub.int %42, %44 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.sub.int %45, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %86 = torch.aten.sub.int %13, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %48 = torch.aten.add.int %46, %47 : !torch.int, !torch.int -> !torch.int\n" +" %49 = torch.aten.floordiv.int %48, %13 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %51 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %86 = torch.aten.mul.int %49, %13 : !torch.int, !torch.int -> !torch.int\n" +" %87 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int\n" +" %88 = torch.aten.ge.int %86, %87 : !torch.int, !torch.int -> !torch.bool\n" +" %89 = torch.prim.If %88 -> (!torch.int) {\n" +" torch.prim.If.yield %49 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %50 : !torch.int\n" +" }\n" +" torch.prim.If.yield %89 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %50 : !torch.int\n" +" }\n" +" %52 = torch.aten.ne.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %52 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %53 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int\n" +" %54 = torch.aten.add.int %53, %23 : !torch.int, !torch.int -> !torch.int\n" +" %55 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %56 = torch.aten.mul.int %30, %55 : !torch.int, !torch.int -> !torch.int\n" +" %57 = torch.aten.sub.int %54, %56 : !torch.int, !torch.int -> !torch.int\n" +" %58 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %59 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %86 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %86 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %60 = torch.aten.add.int %58, %59 : !torch.int, !torch.int -> !torch.int\n" +" %61 = torch.aten.floordiv.int %60, %16 : !torch.int, !torch.int -> !torch.int\n" +" %62 = torch.aten.add.int %61, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %63 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %86 = torch.aten.mul.int %61, %16 : !torch.int, !torch.int -> !torch.int\n" +" %87 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int\n" +" %88 = torch.aten.ge.int %86, %87 : !torch.int, !torch.int -> !torch.bool\n" +" %89 = torch.prim.If %88 -> (!torch.int) {\n" +" torch.prim.If.yield %61 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %62 : !torch.int\n" +" }\n" +" torch.prim.If.yield %89 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %62 : !torch.int\n" +" }\n" +" %64 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %65 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %66 = torch.prim.If %65 -> (!torch.bool) {\n" +" %86 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %86 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %66 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %67 = torch.aten.gt.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %68 = torch.prim.If %67 -> (!torch.bool) {\n" +" %86 = torch.aten.gt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %86 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %68 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %69 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %70 = torch.prim.If %69 -> (!torch.bool) {\n" +" %86 = torch.aten.gt.int %30, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %86 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %70 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %71 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %72 = torch.aten.ne.int %71, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %73 = torch.prim.If %72 -> (!torch.bool) {\n" +" %86 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %87 = torch.aten.ne.int %86, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %74 = torch.aten.eq.int %64, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %75 = torch.prim.If %74 -> (!torch.bool) {\n" +" %86 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %87 = torch.aten.ne.int %86, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %76 = torch.prim.If %75 -> (!torch.bool) {\n" +" torch.prim.If.yield %73 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %77 = torch.prim.If %76 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %86 = torch.aten.eq.int %64, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %87 = torch.prim.If %86 -> (!torch.bool) {\n" +" torch.prim.If.yield %73 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %88 = torch.prim.If %87 -> (!torch.bool) {\n" +" %89 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %90 = torch.aten.ne.int %89, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %90 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" torch.prim.If %77 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %78 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %79 = torch.aten.ge.int %78, %23 : !torch.int, !torch.int -> !torch.bool\n" +" %80 = torch.prim.If %79 -> (!torch.bool) {\n" +" %86 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %87 = torch.aten.ge.int %86, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %80 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %81 = torch.aten.ge.int %63, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %82 = torch.prim.If %81 -> (!torch.bool) {\n" +" %86 = torch.aten.ge.int %51, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %86 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %82 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %83 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %84 = torch.aten.eq.int %83, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %85 = torch.prim.If %84 -> (!torch.list) {\n" +" %86 = torch.prim.ListConstruct %37, %51, %63 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %86 : !torch.list\n" +" } else {\n" +" %86 = torch.prim.ListConstruct %36, %37, %51, %63 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %86 : !torch.list\n" +" }\n" +" return %85 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.pooling_output_shape(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: stride should not be zeero\"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.ne.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__.torch.jit._shape_functions.pooling_output_shape_pad_lr(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.pooling_output_shape_pad_lr(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %1 = torch.aten.add.int %0, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.sub.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.mul.int %arg5, %2 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.sub.int %1, %3 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.If %arg6 -> (!torch.int) {\n" +" %11 = torch.aten.sub.int %arg4, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %11 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %7 = torch.aten.add.int %5, %6 : !torch.int, !torch.int -> !torch.int\n" +" %8 = call @__torch__.torch.jit._shape_functions.div_rtn(%7, %arg4) : (!torch.int, !torch.int) -> !torch.int\n" +" %9 = torch.aten.add.int %8, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.prim.If %arg6 -> (!torch.int) {\n" +" %11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.mul.int %11, %arg4 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %12, %13 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" +" %16 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %16 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.div_rtn(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = torch.aten.floordiv.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.pool2d_shape_check(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.int) -> !torch.none {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %19 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %19 = torch.aten.gt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.gt.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %19 = torch.aten.gt.int %arg8, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.ne.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %19 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.ne.int %19, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %10 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %19 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.ne.int %19, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %19 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.bool) {\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %21 = torch.prim.If %20 -> (!torch.bool) {\n" +" %22 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %23 = torch.aten.ne.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %23 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.aten.floordiv.int %arg2, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.ge.int %14, %arg6 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %19 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.ge.int %19, %arg5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.ge.int %arg13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.bool) {\n" +" %19 = torch.aten.ge.int %arg12, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %none : !torch.none\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.max_pool2d_with_indices(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: stride should not be zeero\"\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %str_0 = torch.constant.str \"AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints\"\n" +" %str_2 = torch.constant.str \"AssertionError: max_pool2d: padding must be either be a single int, or a tuple of two ints\"\n" +" %str_3 = torch.constant.str \"AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints\"\n" +" %str_4 = torch.constant.str \"AssertionError: \"\n" +" %int-4 = torch.constant.int -4\n" +" %int-3 = torch.constant.int -3\n" +" %int-2 = torch.constant.int -2\n" +" %int-1 = torch.constant.int -1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" %87 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" }\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" %87 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %6 : !torch.int\n" +" } else {\n" +" %87 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %89 = torch.prim.If %88 -> (!torch.int) {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" %90 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %90 : !torch.int\n" +" }\n" +" torch.prim.If.yield %89 : !torch.int\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" torch.prim.If %19 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.int) {\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" %87 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" }\n" +" %24 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %27 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %29 = torch.aten.eq.int %28, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %30 = torch.prim.If %29 -> (!torch.int) {\n" +" torch.prim.If.yield %27 : !torch.int\n" +" } else {\n" +" %87 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" }\n" +" %31 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %32 = torch.aten.eq.int %31, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %33 = torch.prim.If %32 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %88 = torch.aten.eq.int %87, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" }\n" +" torch.prim.If %33 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %34 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %35 = torch.aten.eq.int %34, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.int) {\n" +" %87 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %37 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.ne.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %40 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %41 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.add.int %41, %20 : !torch.int, !torch.int -> !torch.int\n" +" %43 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.mul.int %27, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.sub.int %42, %44 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.sub.int %45, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %87 = torch.aten.sub.int %13, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %48 = torch.aten.add.int %46, %47 : !torch.int, !torch.int -> !torch.int\n" +" %49 = torch.aten.floordiv.int %48, %13 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %51 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %87 = torch.aten.mul.int %49, %13 : !torch.int, !torch.int -> !torch.int\n" +" %88 = torch.aten.add.int %38, %20 : !torch.int, !torch.int -> !torch.int\n" +" %89 = torch.aten.ge.int %87, %88 : !torch.int, !torch.int -> !torch.bool\n" +" %90 = torch.prim.If %89 -> (!torch.int) {\n" +" torch.prim.If.yield %49 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %50 : !torch.int\n" +" }\n" +" torch.prim.If.yield %90 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %50 : !torch.int\n" +" }\n" +" %52 = torch.aten.ne.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %52 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %53 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int\n" +" %54 = torch.aten.add.int %53, %23 : !torch.int, !torch.int -> !torch.int\n" +" %55 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %56 = torch.aten.mul.int %30, %55 : !torch.int, !torch.int -> !torch.int\n" +" %57 = torch.aten.sub.int %54, %56 : !torch.int, !torch.int -> !torch.int\n" +" %58 = torch.aten.sub.int %57, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %59 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %87 = torch.aten.sub.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %87 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %60 = torch.aten.add.int %58, %59 : !torch.int, !torch.int -> !torch.int\n" +" %61 = torch.aten.floordiv.int %60, %16 : !torch.int, !torch.int -> !torch.int\n" +" %62 = torch.aten.add.int %61, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %63 = torch.prim.If %arg5 -> (!torch.int) {\n" +" %87 = torch.aten.mul.int %61, %16 : !torch.int, !torch.int -> !torch.int\n" +" %88 = torch.aten.add.int %39, %23 : !torch.int, !torch.int -> !torch.int\n" +" %89 = torch.aten.ge.int %87, %88 : !torch.int, !torch.int -> !torch.bool\n" +" %90 = torch.prim.If %89 -> (!torch.int) {\n" +" torch.prim.If.yield %61 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %62 : !torch.int\n" +" }\n" +" torch.prim.If.yield %90 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %62 : !torch.int\n" +" }\n" +" %64 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %65 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %66 = torch.prim.If %65 -> (!torch.bool) {\n" +" %87 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %66 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %67 = torch.aten.gt.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %68 = torch.prim.If %67 -> (!torch.bool) {\n" +" %87 = torch.aten.gt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %68 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %69 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %70 = torch.prim.If %69 -> (!torch.bool) {\n" +" %87 = torch.aten.gt.int %30, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %70 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %71 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %72 = torch.aten.ne.int %71, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %73 = torch.prim.If %72 -> (!torch.bool) {\n" +" %87 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %88 = torch.aten.ne.int %87, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %74 = torch.aten.eq.int %64, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %75 = torch.prim.If %74 -> (!torch.bool) {\n" +" %87 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %88 = torch.aten.ne.int %87, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %76 = torch.prim.If %75 -> (!torch.bool) {\n" +" torch.prim.If.yield %73 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %77 = torch.prim.If %76 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %87 = torch.aten.eq.int %64, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %88 = torch.prim.If %87 -> (!torch.bool) {\n" +" torch.prim.If.yield %73 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %89 = torch.prim.If %88 -> (!torch.bool) {\n" +" %90 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %91 = torch.aten.ne.int %90, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %91 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %89 : !torch.bool\n" +" }\n" +" torch.prim.If %77 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %78 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %79 = torch.aten.ge.int %78, %23 : !torch.int, !torch.int -> !torch.bool\n" +" %80 = torch.prim.If %79 -> (!torch.bool) {\n" +" %87 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %88 = torch.aten.ge.int %87, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %88 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %80 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %81 = torch.aten.ge.int %63, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %82 = torch.prim.If %81 -> (!torch.bool) {\n" +" %87 = torch.aten.ge.int %51, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %87 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %82 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %83 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %84 = torch.aten.eq.int %83, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %85 = torch.prim.If %84 -> (!torch.list) {\n" +" %87 = torch.prim.ListConstruct %37, %51, %63 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %87 : !torch.list\n" +" } else {\n" +" %87 = torch.prim.ListConstruct %36, %37, %51, %63 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %87 : !torch.list\n" +" }\n" +" %86 = torch.prim.TupleConstruct %85, %85 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %86 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.transpose(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %3 = torch.aten.neg.int %2 : !torch.int -> !torch.int\n" +" %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %21 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" }\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %21 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %10 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %12 = torch.aten.neg.int %11 : !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.lt.int %arg2, %12 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %21 = torch.aten.gt.int %arg2, %13 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" }\n" +" %16 = torch.aten.__not__ %15 : !torch.bool -> !torch.bool\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" %21 = torch.aten.add.int %arg2, %11 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %19 = torch.aten.eq.int %9, %18 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.list) {\n" +" %21 = torch.prim.ListConstruct : () -> !torch.list\n" +" %22 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %22, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.append.t %21, %23 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %21 : !torch.list\n" +" } else {\n" +" %21 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %22 = torch.aten.eq.int %arg3, %9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %22 -> () {\n" +" %23 = torch.aten.__getitem__.t %arg0, %18 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.append.t %21, %23 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %23 = torch.aten.eq.int %arg3, %18 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %23 -> () {\n" +" %24 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.append.t %21, %24 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.append.t %21, %24 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %21 : !torch.list\n" +" }\n" +" return %20 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %6 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %7 = torch.prim.Loop %6, %true, init(%false) {\n" +" ^bb0(%arg7: !torch.int, %arg8: !torch.bool):\n" +" %34 = torch.aten.__getitem__.t %arg4, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg8 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %10 = torch.prim.Loop %9, %true, init(%false) {\n" +" ^bb0(%arg7: !torch.int, %arg8: !torch.bool):\n" +" %34 = torch.aten.__getitem__.t %arg3, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg8 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %5, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %13, %arg6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.remainder.int %15, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %17 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.mul.int %19, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.eq.int %18, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %22 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %34 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %35 = torch.aten.len.t %34 : !torch.list -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.bool) {\n" +" %38 = torch.aten.__getitem__.t %34, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %37 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.__range_length %int2, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %24, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %37 = torch.aten.__getitem__.t %arg4, %36 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.mul.int %37, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.add.int %35, %38 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %41 = torch.aten.__getitem__.t %arg5, %40 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.sub.int %42, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.mul.int %41, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.ge.int %39, %45 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %46 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %25 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %26 = torch.aten.gt.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %28 = torch.prim.ListConstruct : () -> !torch.list\n" +" %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.append.t %28, %29 : !torch.list, !torch.int -> !torch.list\n" +" %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list\n" +" %33 = torch.aten.__range_length %int2, %27, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %33, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.prim.If %26 -> (!torch.int) {\n" +" %51 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %52 = torch.aten.__getitem__.t %arg5, %51 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %52 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %36 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int\n" +" %37 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %38 = torch.aten.mul.int %35, %37 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.add.int %38, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg4, %41 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.mul.int %42, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.add.int %40, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.sub.int %44, %39 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %46 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.aten.floordiv.int %45, %47 : !torch.int, !torch.int -> !torch.int\n" +" %49 = torch.aten.add.int %48, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.append.t %28, %49 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %28 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv_output_size(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__.torch.jit._shape_functions.check_shape_forward(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.none\n" +" %1 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %2 = torch.aten.gt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.append.t %4, %5 : !torch.list, !torch.int -> !torch.list\n" +" %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.append.t %4, %7 : !torch.list, !torch.int -> !torch.list\n" +" %9 = torch.aten.__range_length %int2, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %10 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.prim.If %2 -> (!torch.int) {\n" +" %27 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.__getitem__.t %arg5, %27 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %28 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.mul.int %int2, %18 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %4, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.check_shape_forward(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.none {\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = call @__torch__.torch.jit._shape_functions.check_non_negative(%arg4) : (!torch.list) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch.jit._shape_functions.check_non_negative(%arg3) : (!torch.list) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %1, %0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.ge.int %7, %arg6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.remainder.int %9, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %13, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.eq.int %12, %14 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %19 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %20 = torch.aten.len.t %19 : !torch.list -> !torch.int\n" +" %21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.bool) {\n" +" %23 = torch.aten.__getitem__.t %19, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %22 : !torch.bool\n" +" }\n" +" torch.prim.If %17 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.__range_length %int2, %0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %18, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %19 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.__getitem__.t %arg0, %19 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.sub.int %19, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.__getitem__.t %arg4, %21 : !torch.list, !torch.int -> !torch.int\n" +" %23 = torch.aten.mul.int %int2, %22 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.add.int %20, %23 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.sub.int %19, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.__getitem__.t %arg5, %25 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.__getitem__.t %arg1, %19 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %29 = torch.aten.mul.int %26, %28 : !torch.int, !torch.int -> !torch.int\n" +" %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.ge.int %24, %30 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %31 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %none : !torch.none\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.check_non_negative(%arg0: !torch.list) -> !torch.bool {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%false) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.bool):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.lt.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%4 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %6 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %7 = torch.prim.Loop %6, %true, init(%false) {\n" +" ^bb0(%arg7: !torch.int, %arg8: !torch.bool):\n" +" %34 = torch.aten.__getitem__.t %arg4, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg8 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %10 = torch.prim.Loop %9, %true, init(%false) {\n" +" ^bb0(%arg7: !torch.int, %arg8: !torch.bool):\n" +" %34 = torch.aten.__getitem__.t %arg3, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg8 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %5, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %13, %arg6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.remainder.int %15, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %17 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.mul.int %19, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.eq.int %18, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %22 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %34 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %35 = torch.aten.len.t %34 : !torch.list -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.bool) {\n" +" %38 = torch.aten.__getitem__.t %34, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %37 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.__range_length %int2, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %24, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %37 = torch.aten.__getitem__.t %arg4, %36 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.mul.int %37, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.add.int %35, %38 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %41 = torch.aten.__getitem__.t %arg5, %40 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.sub.int %42, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.mul.int %41, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.ge.int %39, %45 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %46 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %25 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %26 = torch.aten.gt.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %28 = torch.prim.ListConstruct : () -> !torch.list\n" +" %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.append.t %28, %29 : !torch.list, !torch.int -> !torch.list\n" +" %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list\n" +" %33 = torch.aten.__range_length %int2, %27, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %33, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.prim.If %26 -> (!torch.int) {\n" +" %51 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %52 = torch.aten.__getitem__.t %arg5, %51 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %52 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %36 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int\n" +" %37 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %38 = torch.aten.mul.int %35, %37 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.add.int %38, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg4, %41 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.mul.int %42, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.add.int %40, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.sub.int %44, %39 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %46 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.aten.floordiv.int %45, %47 : !torch.int, !torch.int -> !torch.int\n" +" %49 = torch.aten.add.int %48, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.append.t %28, %49 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %28 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.batch_norm(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg9: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg9 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.append.t %0, %2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv3d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %6 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %7 = torch.prim.Loop %6, %true, init(%false) {\n" +" ^bb0(%arg7: !torch.int, %arg8: !torch.bool):\n" +" %34 = torch.aten.__getitem__.t %arg4, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg8 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %10 = torch.prim.Loop %9, %true, init(%false) {\n" +" ^bb0(%arg7: !torch.int, %arg8: !torch.bool):\n" +" %34 = torch.aten.__getitem__.t %arg3, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg8 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %5, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %13, %arg6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.remainder.int %15, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %17 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.mul.int %19, %arg6 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.eq.int %18, %20 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %22 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %34 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %35 = torch.aten.len.t %34 : !torch.list -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.bool) {\n" +" %38 = torch.aten.__getitem__.t %34, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %37 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.__range_length %int2, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %24, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %37 = torch.aten.__getitem__.t %arg4, %36 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.mul.int %37, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.add.int %35, %38 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %41 = torch.aten.__getitem__.t %arg5, %40 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.sub.int %42, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.mul.int %41, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.ge.int %39, %45 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %46 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %25 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %26 = torch.aten.gt.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %28 = torch.prim.ListConstruct : () -> !torch.list\n" +" %29 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.append.t %28, %29 : !torch.list, !torch.int -> !torch.list\n" +" %31 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list\n" +" %33 = torch.aten.__range_length %int2, %27, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %33, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg7, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.prim.If %26 -> (!torch.int) {\n" +" %51 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %52 = torch.aten.__getitem__.t %arg5, %51 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %52 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %36 = torch.aten.__getitem__.t %arg1, %34 : !torch.list, !torch.int -> !torch.int\n" +" %37 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %38 = torch.aten.mul.int %35, %37 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.add.int %38, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg4, %41 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.mul.int %42, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %44 = torch.aten.add.int %40, %43 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.sub.int %44, %39 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.sub.int %34, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %46 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.aten.floordiv.int %45, %47 : !torch.int, !torch.int -> !torch.int\n" +" %49 = torch.aten.add.int %48, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.append.t %28, %49 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %28 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv_backwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>) -> !torch.tuple, list, list> {\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %7 = torch.aten.__getitem__.t %arg1, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.append.t %0, %7 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %7 = torch.aten.__getitem__.t %arg2, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.append.t %2, %7 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" %6 = torch.prim.TupleConstruct %0, %2, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %6 : !torch.tuple, list, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.prim.If %arg6 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.append.t %3, %5 : !torch.list, !torch.int -> !torch.list\n" +" %7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.append.t %3, %7 : !torch.list, !torch.int -> !torch.list\n" +" %9 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg9: !torch.int):\n" +" %10 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.prim.If %1 -> (!torch.int) {\n" +" %12 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg5, %12 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" torch.prim.If %arg6 -> () {\n" +" %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.__getitem__.t %arg3, %17 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg4, %20 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.mul.int %21, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.sub.int %19, %22 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.add.int %23, %14 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.mul.int %18, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.optional>, %arg6: !torch.int, %arg7: !torch.optional>) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" } else {\n" +" %15 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" %2 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.list) {\n" +" %15 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" } else {\n" +" %15 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" %4 = torch.aten.__is__ %arg7, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %15 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" } else {\n" +" %15 = torch.prim.unchecked_cast %arg7 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" %6 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.prim.ListConstruct : () -> !torch.list\n" +" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list\n" +" %12 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list\n" +" %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %14, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %15 = torch.aten.__derive_index %arg8, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.prim.If %7 -> (!torch.int) {\n" +" %32 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %33 = torch.aten.__getitem__.t %5, %32 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %33 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %17 = torch.aten.__getitem__.t %arg1, %15 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.sub.int %17, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.__getitem__.t %arg0, %15 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.__getitem__.t %1, %22 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.mul.int %21, %23 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.sub.int %15, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.__getitem__.t %3, %25 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.mul.int %26, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.sub.int %24, %27 : !torch.int, !torch.int -> !torch.int\n" +" %29 = torch.aten.add.int %28, %19 : !torch.int, !torch.int -> !torch.int\n" +" %30 = torch.aten.add.int %29, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.append.t %9, %30 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %9 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.flatten(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %3 = torch.aten.neg.int %2 : !torch.int -> !torch.int\n" +" %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.lt.int %arg1, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.gt.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" }\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %24 = torch.aten.add.int %arg1, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.le.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" %13 = torch.aten.neg.int %12 : !torch.int -> !torch.int\n" +" %14 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.lt.int %arg2, %13 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.gt.int %arg2, %14 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" }\n" +" %17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool\n" +" torch.prim.If %17 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.int) {\n" +" %24 = torch.aten.add.int %arg2, %12 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %20 = torch.aten.le.int %9, %19 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %21 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.list) {\n" +" %24 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %24 : !torch.list\n" +" } else {\n" +" %24 = torch.aten.eq.int %9, %19 : !torch.int, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.list) {\n" +" %26 = torch.prim.ListConstruct : () -> !torch.list\n" +" %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %27, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %28 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %29 = torch.aten.append.t %26, %28 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %26 : !torch.list\n" +" } else {\n" +" %26 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %27 = torch.aten.__range_length %9, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.prim.Loop %27, %true, init(%int1) {\n" +" ^bb0(%arg3: !torch.int, %arg4: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg3, %9, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.mul.int %arg4, %35 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%36 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %29 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %34 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.append.t %29, %34 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %30 = torch.aten.append.t %29, %28 : !torch.list, !torch.int -> !torch.list\n" +" %31 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %32 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %33 = torch.aten.__range_length %31, %32, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %33, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %34 = torch.aten.__derive_index %arg3, %31, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %34 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.append.t %29, %35 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %29 : !torch.list\n" +" }\n" +" torch.prim.If.yield %25 : !torch.list\n" +" }\n" +" return %23 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.cat(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Sizes of tensors must match except in dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: Tensors must have same number of dimensions\"\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" +" %14 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %15 = torch.aten.gt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %2 = torch.derefine %none : !torch.none to !torch.optional\n" +" %3 = torch.prim.Loop %1, %true, init(%2) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.optional):\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" +" %14 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %19 = torch.aten.__getitem__.t %13, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %17 = torch.aten.__not__ %16 : !torch.bool -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.optional) {\n" +" %19 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.int) {\n" +" %22 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %23 = torch.aten.le.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %24 = torch.prim.If %23 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %22 : !torch.int\n" +" }\n" +" %25 = torch.aten.neg.int %24 : !torch.int -> !torch.int\n" +" %26 = torch.aten.sub.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %27 = torch.aten.lt.int %arg1, %25 : !torch.int, !torch.int -> !torch.bool\n" +" %28 = torch.prim.If %27 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %32 = torch.aten.gt.int %arg1, %26 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %32 : !torch.bool\n" +" }\n" +" %29 = torch.aten.__not__ %28 : !torch.bool -> !torch.bool\n" +" torch.prim.If %29 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %30 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %31 = torch.prim.If %30 -> (!torch.int) {\n" +" %32 = torch.aten.add.int %arg1, %24 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %32 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %31 : !torch.int\n" +" } else {\n" +" %22 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %22 : !torch.int\n" +" }\n" +" %21 = torch.derefine %20 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %21 : !torch.optional\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.optional\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%18 : !torch.optional)\n" +" } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional\n" +" %4 = torch.aten.__is__ %3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" } else {\n" +" %13 = torch.prim.unchecked_cast %3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %13 : !torch.int\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %7 = torch.aten.gt.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %9 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %10 = torch.prim.Loop %8, %true, init(%9) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.optional>):\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" +" %14 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %15 = torch.prim.Loop %14, %true, init(%int1) {\n" +" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n" +" %20 = torch.aten.__getitem__.t %13, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.mul.int %arg5, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%21 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %20 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %21 = torch.aten.eq.int %20, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %18 = torch.aten.__not__ %17 : !torch.bool -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.optional>) {\n" +" %20 = torch.derefine %13 : !torch.list to !torch.optional>\n" +" torch.prim.If.yield %20 : !torch.optional>\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.optional>\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%19 : !torch.optional>)\n" +" } : (!torch.int, !torch.bool, !torch.optional>) -> !torch.optional>\n" +" %11 = torch.aten.__is__ %10, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.list) {\n" +" %13 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %13 : !torch.list\n" +" } else {\n" +" %13 = torch.prim.unchecked_cast %10 : !torch.optional> -> !torch.list\n" +" %14 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %15 = torch.prim.Loop %14, %true, init(%int0) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.int):\n" +" %19 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" +" %20 = torch.aten.len.t %19 : !torch.list -> !torch.int\n" +" %21 = torch.prim.Loop %20, %true, init(%int1) {\n" +" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n" +" %26 = torch.aten.__getitem__.t %19, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.mul.int %arg5, %26 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%27 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %26 = torch.aten.len.t %19 : !torch.list -> !torch.int\n" +" %27 = torch.aten.eq.int %26, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.aten.__not__ %23 : !torch.bool -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.int) {\n" +" %26 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" %27 = torch.aten.len.t %19 : !torch.list -> !torch.int\n" +" %28 = torch.aten.eq.int %26, %27 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %28 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %29 = torch.aten.__range_length %int0, %26, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %29, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %32 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %33 = torch.aten.ne.int %32, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %33 -> () {\n" +" %34 = torch.aten.__getitem__.t %13, %32 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %19, %32 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.eq.int %34, %35 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %36 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %30 = torch.aten.__getitem__.t %19, %5 : !torch.list, !torch.int -> !torch.int\n" +" %31 = torch.aten.add.int %arg3, %30 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %31 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.int\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%25 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %16 = torch.prim.ListConstruct : () -> !torch.list\n" +" %17 = torch.aten.len.t %13 : !torch.list -> !torch.int\n" +" torch.prim.Loop %17, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %19 = torch.aten.__getitem__.t %13, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.append.t %16, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %18 = torch.aten._set_item.t %16, %5, %15 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" }\n" +" return %12 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.check_cat_no_zero_dim(%arg0: !torch.list>) -> !torch.none {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %1 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list\n" +" %2 = torch.aten.len.t %1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.gt.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %none : !torch.none\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.legacy_cat_wrap_dim(%arg0: !torch.int, %arg1: !torch.list>) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int\n" +" %1 = torch.derefine %none : !torch.none to !torch.optional\n" +" %2 = torch.prim.Loop %0, %true, init(%1) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.optional):\n" +" %5 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>, !torch.int -> !torch.list\n" +" %6 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %11 = torch.aten.__getitem__.t %5, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.optional) {\n" +" %11 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %14 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %15 = func.call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg0, %14, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" %14 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %13 = torch.derefine %12 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %13 : !torch.optional\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.optional\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%10 : !torch.optional)\n" +" } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional\n" +" %3 = torch.aten.__is__ %2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %arg0 : !torch.int\n" +" } else {\n" +" %5 = torch.prim.unchecked_cast %2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.should_skip(%arg0: !torch.list) -> !torch.bool {\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__.torch.jit._shape_functions.numel(%arg0) : (!torch.list) -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" return %2 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.numel(%arg0: !torch.list) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.check_cat_shape_except_dim(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.none {\n" +" %str = torch.constant.str \"AssertionError: Sizes of tensors must match except in dimension\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: Tensors must have same number of dimensions\"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__range_length %int0, %0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %4 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.ne.int %4, %arg2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" %6 = torch.aten.__getitem__.t %arg0, %4 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %6, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %none : !torch.none\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %7 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.le.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %9, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.lt.int %7, %10 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %20 = torch.aten.gt.int %7, %11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" }\n" +" %14 = torch.aten.__not__ %13 : !torch.bool -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" %20 = torch.aten.add.int %7, %9 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" %17 = torch.aten.append.t %4, %16 : !torch.list, !torch.int -> !torch.list\n" +" %18 = torch.aten.__getitem__.t %arg0, %16 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %5, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.aten.__range_length %int1, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %7 = torch.aten.__derive_index %arg2, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %7, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %8 = torch.aten.__getitem__.t %4, %7 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.__getitem__.t %4, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.ne.int %8, %9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %5 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: invalid shape\"\n" +" %false = torch.constant.bool false\n" +" %str_0 = torch.constant.str \"AssertionError: invalid shape dimensions\"\n" +" %str_1 = torch.constant.str \"AssertionError: only one dimension can be inferred\"\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.int):\n" +" %12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.mul.int %arg3, %12 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%13 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %2 = torch.prim.Uninitialized : !torch.int\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.derefine %none : !torch.none to !torch.optional\n" +" %5:2 = torch.prim.Loop %3, %true, init(%int1, %4) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional):\n" +" %12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.eq.int %12, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" %14:2 = torch.prim.If %13 -> (!torch.int, !torch.optional) {\n" +" %15 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.derefine %arg2 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %arg3, %16 : !torch.int, !torch.optional\n" +" } else {\n" +" %15 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.ge.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %18 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.mul.int %arg3, %18 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %19 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" torch.prim.If.yield %17, %arg4 : !torch.int, !torch.optional\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%14#0, %14#1 : !torch.int, !torch.optional)\n" +" } : (!torch.int, !torch.bool, !torch.int, !torch.optional) -> (!torch.int, !torch.optional)\n" +" %6 = torch.aten.eq.int %1, %5#0 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %12 = torch.aten.__isnot__ %5#1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int\n" +" %16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int\n" +" %16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %14 : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct : () -> !torch.list\n" +" %10 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %12 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" %12 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int\n" +" %13 = torch.aten.floordiv.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten._set_item.t %9, %12, %13 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %9 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.infer_size_impl(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: invalid shape\"\n" +" %false = torch.constant.bool false\n" +" %str_0 = torch.constant.str \"AssertionError: invalid shape dimensions\"\n" +" %str_1 = torch.constant.str \"AssertionError: only one dimension can be inferred\"\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.derefine %none : !torch.none to !torch.optional\n" +" %3:2 = torch.prim.Loop %1, %true, init(%int1, %2) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional):\n" +" %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" %11:2 = torch.prim.If %10 -> (!torch.int, !torch.optional) {\n" +" %12 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.derefine %arg2 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %arg3, %13 : !torch.int, !torch.optional\n" +" } else {\n" +" %12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.ge.int %12, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %15 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.mul.int %arg3, %15 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %16 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %14, %arg4 : !torch.int, !torch.optional\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%11#0, %11#1 : !torch.int, !torch.optional)\n" +" } : (!torch.int, !torch.bool, !torch.int, !torch.optional) -> (!torch.int, !torch.optional)\n" +" %4 = torch.aten.eq.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.__isnot__ %3#1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int\n" +" %13 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int\n" +" %13 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %14 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" %8 = torch.aten.__isnot__ %3#1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" %9 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int\n" +" %10 = torch.aten.floordiv.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten._set_item.t %7, %9, %10 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.list) {\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %9 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %7, %9 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list\n" +" } else {\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %8 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %14 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" %20 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %17 = torch.aten.ne.int %13, %16 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" %20 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %16 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" }\n" +" %19 = torch.aten.append.t %7, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list\n" +" }\n" +" return %6 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.expand_one_unused(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %true = torch.constant.bool true\n" +" %int-1 = torch.constant.int -1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.list) {\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %9 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %7, %9 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list\n" +" } else {\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %8 = torch.aten.sub.int %3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %14 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" %20 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %17 = torch.aten.ne.int %13, %16 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" %20 = torch.aten.eq.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %16 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" }\n" +" %19 = torch.aten.append.t %7, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list\n" +" }\n" +" return %6 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %6 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.list) {\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %5 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" +" %6 = torch.prim.Loop %5, %true, init(%false) {\n" +" ^bb0(%arg5: !torch.int, %arg6: !torch.bool):\n" +" %7 = torch.aten.__getitem__.t %3, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.aten.le.int %8, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %8 : !torch.int\n" +" }\n" +" %11 = torch.aten.neg.int %10 : !torch.int -> !torch.int\n" +" %12 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.lt.int %7, %11 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %20 = torch.aten.gt.int %7, %12 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" }\n" +" %15 = torch.aten.__not__ %14 : !torch.bool -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %20 = torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" %18 = torch.aten.eq.int %arg4, %17 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg6 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%19 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If %arg2 -> () {\n" +" %7 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %7 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.append.t %0, %7 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.max_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %2, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %4 = torch.prim.Loop %int1, %true, init(%false) {\n" +" ^bb0(%arg4: !torch.int, %arg5: !torch.bool):\n" +" %5 = torch.aten.__getitem__.t %0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.le.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" %9 = torch.aten.neg.int %8 : !torch.int -> !torch.int\n" +" %10 = torch.aten.sub.int %8, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.lt.int %5, %9 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %18 = torch.aten.gt.int %5, %10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" }\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.aten.lt.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" +" %18 = torch.aten.add.int %5, %8 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %18 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" %16 = torch.aten.eq.int %arg3, %15 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg5 : !torch.bool\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%17 : !torch.bool)\n" +" } : (!torch.int, !torch.bool, !torch.bool) -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If %arg2 -> () {\n" +" %5 = torch.aten.append.t %1, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.append.t %1, %5 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %3 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %3 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.addmm(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.any, %arg4: !torch.any) -> !torch.list {\n" +" %str = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %str_0 = torch.constant.str \"AssertionError: self must be a matrix\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: mat2 must be a matrix\"\n" +" %str_2 = torch.constant.str \"AssertionError: \"\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.prim.max.int %10, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %11, %true, init() {\n" +" ^bb0(%arg5: !torch.int):\n" +" %13 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.sub.int %13, %arg5 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.sub.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.sub.int %15, %14 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.sub.int %int1, %14 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.ge.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.int) {\n" +" %28 = torch.aten.__getitem__.t %arg0, %16 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %28 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %20 = torch.aten.ge.int %17, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %21 = torch.prim.If %20 -> (!torch.int) {\n" +" %28 = torch.aten.__getitem__.t %9, %17 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %28 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %22 = torch.aten.ne.int %19, %21 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %28 = torch.aten.ne.int %19, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %28 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.prim.If %23 -> (!torch.bool) {\n" +" %28 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %28 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %24 -> () {\n" +" %28 = torch.aten.format(%str, %19, %21, %arg5) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %29 = torch.aten.add.str %str_2, %28 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %29, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %25 = torch.aten.eq.int %19, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.int) {\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %19 : !torch.int\n" +" }\n" +" %27 = torch.aten.append.t %12, %26 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %12 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.optional> {\n" +" %str = torch.constant.str \"AssertionError: Either output_size or scale_factors must be presented\"\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %str_1 = torch.constant.str \"AssertionError: Must specify exactly one of output_size and scale_factors\"\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.append.t %0, %1 : !torch.list, !torch.int -> !torch.list\n" +" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.append.t %0, %3 : !torch.list, !torch.int -> !torch.list\n" +" %5 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.optional>) {\n" +" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %8 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.len.t %7 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.__getitem__.t %7, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %0, %11 : !torch.list, !torch.int -> !torch.list\n" +" %13 = torch.aten.__getitem__.t %7, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.append.t %0, %13 : !torch.list, !torch.int -> !torch.list\n" +" %15 = torch.derefine %0 : !torch.list to !torch.optional>\n" +" torch.prim.If.yield %15 : !torch.optional>\n" +" } else {\n" +" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.optional>) {\n" +" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %10 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %9 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %15 = torch.operator \"aten.mul.int_float\"(%13, %14) : (!torch.int, !torch.float) -> !torch.float\n" +" %16 = torch.aten.Int.float %15 : !torch.float -> !torch.int\n" +" %17 = torch.aten.append.t %0, %16 : !torch.list, !torch.int -> !torch.list\n" +" %18 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" +" %20 = torch.operator \"aten.mul.int_float\"(%18, %19) : (!torch.int, !torch.float) -> !torch.float\n" +" %21 = torch.aten.Int.float %20 : !torch.float -> !torch.int\n" +" %22 = torch.aten.append.t %0, %21 : !torch.list, !torch.int -> !torch.list\n" +" %23 = torch.derefine %0 : !torch.list to !torch.optional>\n" +" torch.prim.If.yield %23 : !torch.optional>\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" %9 = torch.derefine %none : !torch.none to !torch.optional>\n" +" torch.prim.If.yield %9 : !torch.optional>\n" +" }\n" +" torch.prim.If.yield %8 : !torch.optional>\n" +" }\n" +" return %6 : !torch.optional>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.argmax(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.le.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" %6 = torch.aten.neg.int %5 : !torch.int -> !torch.int\n" +" %7 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.lt.int %2, %6 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %17 = torch.aten.gt.int %2, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" %10 = torch.aten.__not__ %9 : !torch.bool -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.lt.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %17 = torch.aten.add.int %2, %5 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %17 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" %13 = torch.prim.ListConstruct : () -> !torch.list\n" +" %14 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %15 = torch.prim.ListConstruct %int9223372036854775807, %14 : (!torch.int, !torch.int) -> !torch.list\n" +" %16 = torch.prim.min.self_int %15 : !torch.list -> !torch.int\n" +" torch.prim.Loop %16, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %17 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.eq.int %arg3, %12 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If %arg2 -> () {\n" +" %19 = torch.aten.append.t %13, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %19 = torch.aten.append.t %13, %17 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %13 : !torch.list\n" +" }\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %5 = torch.prim.min.self_int %4 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If %arg2 -> () {\n" +" %8 = torch.aten.append.t %2, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %8 = torch.aten.append.t %2, %6 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.bmm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: bmm only supports 3D tensors\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %13 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions._shape_as_tensor(%arg0: !torch.list) -> !torch.list {\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.topk(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.tuple, list> {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"k ({}) is too big for dimension {} of size {}\"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %4 : !torch.list\n" +" } else {\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.le.int %arg1, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.format(%str_0, %arg1, %arg2, %9) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %11 = torch.aten.add.str %str, %10 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %11, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %7, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %9 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %8 = torch.aten._set_item.t %6, %arg2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %3 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int) -> !torch.tuple, list> {\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %16 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %16 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.eq.int %16, %17 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct : () -> !torch.list\n" +" %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %16 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %17 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" %20 = torch.aten.__getitem__.t %16, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.eq.int %20, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %16 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.list) {\n" +" %16 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.prim.ListConstruct %16 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %17 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %9 : !torch.list\n" +" }\n" +" %15 = torch.prim.TupleConstruct %14, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %15 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list, list> {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %0, %10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %10 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %7, %10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.prim.TupleConstruct %7, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %9 : !torch.tuple, list, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.native_batch_norm(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool) -> !torch.tuple, list, list> {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.If %arg5 -> (!torch.list) {\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" } else {\n" +" %4 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %4 : !torch.list\n" +" }\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %2, %true, init() {\n" +" ^bb0(%arg6: !torch.int):\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg6 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.append.t %1, %4 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %3 = torch.prim.TupleConstruct %1, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %3 : !torch.tuple, list, list>\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.broadcast_three(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %2, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %8 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg1, %13 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" %24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %20 = torch.prim.If %19 -> (!torch.bool) {\n" +" %24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %20 -> () {\n" +" %24 = torch.aten.format(%str, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %25 = torch.aten.add.str %str_0, %24 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.int) {\n" +" torch.prim.If.yield %17 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %15 : !torch.int\n" +" }\n" +" %23 = torch.aten.append.t %3, %22 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %4 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" +" %5 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %6 = torch.prim.max.int %4, %5 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %8 = torch.aten.sub.int %6, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.sub.int %4, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %10, %9 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.sub.int %5, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %9 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %3, %11 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %16 = torch.aten.ge.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg2, %13 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %18 = torch.aten.ne.int %15, %17 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" %24 = torch.aten.ne.int %15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %20 = torch.prim.If %19 -> (!torch.bool) {\n" +" %24 = torch.aten.ne.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %20 -> () {\n" +" %24 = torch.aten.format(%str, %15, %17, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %25 = torch.aten.add.str %str_0, %24 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %21 = torch.aten.eq.int %15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.int) {\n" +" torch.prim.If.yield %17 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %15 : !torch.int\n" +" }\n" +" %23 = torch.aten.append.t %7, %22 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.broadcast_one_three(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %2 = torch.prim.max.int %0, %1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %2, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %4 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.sub.int %4, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.sub.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.sub.int %6, %5 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.sub.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %8, %5 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.ge.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %7 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %12 = torch.aten.ge.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg2, %9 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %14 = torch.aten.ne.int %11, %13 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" %20 = torch.aten.ne.int %11, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %20 = torch.aten.ne.int %13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" %20 = torch.aten.format(%str, %11, %13, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %21 = torch.aten.add.str %str_0, %20 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %21, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.eq.int %11, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" %19 = torch.aten.append.t %3, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.broadcast_inplace(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}\"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %str_1 = torch.constant.str \"The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.gt.int %1, %0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" %5 = torch.aten.format(%str_1, %1, %0) : !torch.str, !torch.int, !torch.int -> !torch.str\n" +" %6 = torch.aten.add.str %str_0, %5 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %6, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %5 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.add.int %5, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.ge.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %12 = torch.aten.__getitem__.t %arg1, %6 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %10 = torch.aten.ne.int %7, %9 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %12 = torch.aten.ne.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" %12 = torch.aten.format(%str, %7, %9, %arg2) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %13 = torch.aten.add.str %str_0, %12 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %13, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %5 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.append.t %3, %5 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.nonzero_lower_bound(%arg0: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.ListConstruct %int0, %0 : (!torch.int, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__.torch.jit._shape_functions.nonzero_upper_bound(%arg0: !torch.list) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %true = torch.constant.bool true\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.mul.int %arg2, %4 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop.condition %true, iter(%5 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.triu\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sigmoid\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hardsigmoid\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.softplus\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.square\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hardswish\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.silu\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exp\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hardtanh\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sqrt\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.neg\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.floor\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.log2\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.log1p\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rsqrt\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.abs\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reciprocal\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tanh_backward\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gelu_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ceil\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.relu\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.relu6\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.softmax.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._log_softmax\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.log_softmax.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.clamp\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.clamp_min\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.clamp_max\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rsub.Scalar\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.to.dtype_layout\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.to.device\"(%arg0: !torch.list, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.to.other\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.type_as\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dropout\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gelu\"(%arg0: !torch.list, %arg1: !torch.str) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.clone\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.lift_fresh_copy\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.eq.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ne.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gt.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ge.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.le.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.lt.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.add.Scalar\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sub.Scalar\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mul.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.div.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.remainder.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.floor_divide.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.any\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.all\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sum\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mean\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.var\"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.var.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.std\"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.std.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__._reduce_along_dim(%arg0, %2, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %5 = torch.prim.min.self_int %4 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If %arg2 -> () {\n" +" %8 = torch.aten.append.t %2, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %8 = torch.aten.append.t %2, %6 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sum.dim_IntList\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.transpose.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.t\"(%arg0: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.numpy_T\"(%arg0: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" torch.aten.insert.t %0, %int0, %2 : !torch.list, !torch.int, !torch.int\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.matmul\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.addmm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list {\n" +" %0 = torch.derefine %arg3 : !torch.float to !torch.any\n" +" %1 = torch.derefine %arg4 : !torch.float to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.addmm(%arg0, %arg1, %arg2, %0, %1) : (!torch.list, !torch.list, !torch.list, !torch.any, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bmm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: bmm only supports 3D tensors\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %13 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.baddbmm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: baddbmm only supports 3D tensors\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %13 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.embedding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.embedding(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.int, !torch.bool, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.repeat\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.ge.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.list) {\n" +" %7 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %7 : !torch.list\n" +" } else {\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.sub.int %3, %4 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %9 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %7, %9 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %9 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.add.int %arg2, %8 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.mul.int %9, %11 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %7, %12 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list\n" +" }\n" +" return %6 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.roll\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.expand\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.expand_as\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.broadcast_to\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reshape\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._reshape_alias\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._unsafe_view\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.resize_\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.avg_pool2d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %int-4 = torch.constant.int -4\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: avg_pool2d: padding must be either be a single int, or a tuple of two ints\"\n" +" %str_1 = torch.constant.str \"AssertionError: avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints\"\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: avg_pool2d: kernel_size must either be a single int, or a tuple of two ints\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %39 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" %39 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %39 : !torch.int\n" +" }\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %39 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %39 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.int) {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" } else {\n" +" %39 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %39 : !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %6 : !torch.int\n" +" } else {\n" +" %39 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %40 = torch.aten.eq.int %39, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %41 = torch.prim.If %40 -> (!torch.int) {\n" +" torch.prim.If.yield %13 : !torch.int\n" +" } else {\n" +" %42 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %42 : !torch.int\n" +" }\n" +" torch.prim.If.yield %41 : !torch.int\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %39 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %40 = torch.aten.eq.int %39, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" }\n" +" torch.prim.If %19 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %20 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.int) {\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" %39 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %39 : !torch.int\n" +" }\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %39 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %40 = torch.aten.eq.int %39, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %40 : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %27 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %28 = torch.aten.eq.int %27, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %29 = torch.prim.If %28 -> (!torch.int) {\n" +" %39 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %39 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %30 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %31 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %33 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%31, %3, %20, %13, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %34 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%32, %6, %23, %16, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %35 = call @__torch__.torch.jit._shape_functions.pool2d_shape_check(%arg0, %3, %6, %13, %16, %20, %23, %int1, %int1, %30, %31, %32, %33, %34) : (!torch.list, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n" +" %36 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %37 = torch.aten.eq.int %36, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %38 = torch.prim.If %37 -> (!torch.list) {\n" +" %39 = torch.prim.ListConstruct %30, %33, %34 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %39 : !torch.list\n" +" } else {\n" +" %39 = torch.prim.ListConstruct %29, %30, %33, %34 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %39 : !torch.list\n" +" }\n" +" return %38 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.flatten.using_ints\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ones\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.full_like\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.zeros_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ones_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.empty_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.new_zeros\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.new_ones\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.new_empty\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.masked_fill.Scalar\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.masked_fill.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.zero\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fill.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.copy\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.float\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index_put_impl\"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bernoulli\"(%arg0: !torch.list, %arg1: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" +" %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" +" %2 = torch.derefine %arg2 : !torch.float to !torch.union\n" +" %3 = torch.derefine %arg3 : !torch.optional to !torch.any\n" +" %4 = torch.derefine %arg4 : !torch.optional to !torch.any\n" +" %5 = torch.derefine %arg5 : !torch.optional to !torch.any\n" +" %6 = torch.derefine %arg6 : !torch.optional to !torch.any\n" +" %7 = call @__torch__.torch.jit._shape_functions.arange_start_step(%0, %1, %2, %3, %4, %5, %6) : (!torch.union, !torch.union, !torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.arange.start\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" +" %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" +" %2 = torch.derefine %arg2 : !torch.optional to !torch.any\n" +" %3 = torch.derefine %arg3 : !torch.optional to !torch.any\n" +" %4 = torch.derefine %arg4 : !torch.optional to !torch.any\n" +" %5 = torch.derefine %arg5 : !torch.optional to !torch.any\n" +" %6 = call @__torch__.torch.jit._shape_functions.arange_start(%0, %1, %2, %3, %4, %5) : (!torch.union, !torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.arange\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" +" %1 = torch.derefine %arg1 : !torch.optional to !torch.any\n" +" %2 = torch.derefine %arg2 : !torch.optional to !torch.any\n" +" %3 = torch.derefine %arg3 : !torch.optional to !torch.any\n" +" %4 = torch.derefine %arg4 : !torch.optional to !torch.any\n" +" %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list\n" +" return %5 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sub.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mul.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.div.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.div.Tensor_mode\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.floor_divide\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.__and__.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.minimum\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.maximum\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.logical_or\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.threshold\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.threshold_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.eq.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gt.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.lt.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.squeeze\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.squeeze_nodim(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.squeeze.dim\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tensor.int\"(%arg0: !torch.int, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tensor.bool\"(%arg0: !torch.bool, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._shape_as_tensor\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.where.self\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.where.Scalar\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.where.ScalarOther\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.where.ScalarSelf\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.addcmul\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.addcdiv\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"k ({}) is too big for dimension {} of size {}\"\n" +" %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %6, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten._set_item.t %arg0, %arg2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %3 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %0 = torch.derefine %arg3 : !torch.list to !torch.optional>\n" +" %1 = torch.derefine %arg4 : !torch.list to !torch.optional>\n" +" %2 = torch.derefine %arg5 : !torch.list to !torch.optional>\n" +" %3 = torch.derefine %arg7 : !torch.list to !torch.optional>\n" +" %4 = call @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0, %arg1, %arg2, %0, %1, %2, %arg6, %3) : (!torch.list, !torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.int, !torch.optional>) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list {\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._convolution.deprecated\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list {\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.narrow\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %1 = torch.derefine %arg2 : !torch.int to !torch.optional\n" +" %2 = torch.derefine %0 : !torch.int to !torch.optional\n" +" %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.slice_scatter\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.select.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.select_scatter\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index_put.hacked_twin\"(%arg0: !torch.list, %arg1: !torch.list>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple, list, list, list> {\n" +" %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list>\n" +" return %0 : !torch.tuple, list, list, list>\n" +" }\n" +" func.func @__torch__._embedding_bag_helper(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int) -> !torch.tuple, list, list, list> {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.prim.If %arg3 -> (!torch.int) {\n" +" %19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %19 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" %9 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.append.t %6, %8 : !torch.list, !torch.int -> !torch.list\n" +" %11 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list\n" +" %12 = torch.prim.ListConstruct : () -> !torch.list\n" +" %13 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.list) {\n" +" %19 = torch.aten.append.t %12, %int0 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %12 : !torch.list\n" +" } else {\n" +" %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %19 : !torch.list\n" +" }\n" +" %15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list\n" +" %16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.list) {\n" +" %19 = func.call @__torch__.torch.jit._shape_functions._copy(%6) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %19 : !torch.list\n" +" } else {\n" +" %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %19 : !torch.list\n" +" }\n" +" %18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list, !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list, list>\n" +" return %18 : !torch.tuple, list, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._embedding_bag\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple, list, list, list> {\n" +" %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list>\n" +" return %0 : !torch.tuple, list, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %15 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %15 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct : () -> !torch.list\n" +" %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %15 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %16 = torch.aten.len.t %15 : !torch.list -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.bool) {\n" +" %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.eq.int %19, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.tuple, list>) {\n" +" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list\n" +" %17 = torch.prim.TupleConstruct %16, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %17 : !torch.tuple, list>\n" +" } else {\n" +" %15 = torch.prim.TupleConstruct %9, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %15 : !torch.tuple, list>\n" +" }\n" +" return %14 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg5: !torch.int):\n" +" %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.append.t %0, %8 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %6, %true, init() {\n" +" ^bb0(%arg5: !torch.int):\n" +" %8 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %7 : !torch.tuple, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.If %arg5 -> (!torch.tuple, list, list>) {\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list\n" +" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list\n" +" %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %5 : !torch.tuple, list, list>\n" +" } else {\n" +" %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %arg0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %3 : !torch.tuple, list, list>\n" +" }\n" +" return %0 : !torch.tuple, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.remainder.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %2 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.floordiv.int %3, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.le.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int\n" +" %18 = torch.aten.add.int %17, %16 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten._set_item.t %arg0, %10, %18 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" +" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.index_tensor_like(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: More indices than dimensions to index\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %0 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.le.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.prim.Loop %5, %true, init(%3) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.list):\n" +" %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int\n" +" %11 = torch.aten.ge.int %arg2, %10 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.list) {\n" +" %13 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.append.t %4, %13 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %arg3 : !torch.list\n" +" } else {\n" +" %13 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional>\n" +" %14 = torch.aten.__isnot__ %13, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.list) {\n" +" %16 = torch.prim.unchecked_cast %13 : !torch.optional> -> !torch.list\n" +" %17 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg3, %16) : (!torch.list, !torch.list) -> !torch.list\n" +" torch.prim.If.yield %17 : !torch.list\n" +" } else {\n" +" %16 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.append.t %4, %16 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %arg3 : !torch.list\n" +" }\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%12 : !torch.list)\n" +" } : (!torch.int, !torch.bool, !torch.list) -> !torch.list\n" +" %7 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.list) {\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %10 = torch.aten.len.t %arg1 : !torch.list>> -> !torch.int\n" +" %11 = torch.prim.ListConstruct %int9223372036854775807, %10 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.prim.min.self_int %11 : !torch.list -> !torch.int\n" +" %13:2 = torch.prim.Loop %12, %true, init(%true, %int-1) {\n" +" ^bb0(%arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.int):\n" +" %16 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>>, !torch.int -> !torch.optional>\n" +" %17 = torch.aten.__isnot__ %16, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" %19 = torch.aten.eq.int %arg4, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %arg3, %arg2 : !torch.bool, !torch.int\n" +" } else {\n" +" %21 = torch.aten.sub.int %arg2, %arg4 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.ne.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %arg3 : !torch.bool\n" +" }\n" +" torch.prim.If.yield %23, %arg4 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg3, %arg4 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%18#0, %18#1 : !torch.bool, !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.bool, !torch.int) -> (!torch.bool, !torch.int)\n" +" %14 = torch.aten.__not__ %13#0 : !torch.bool -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.list) {\n" +" %16 = torch.aten.add.t %6, %4 : !torch.list, !torch.list -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" %16 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %13#1, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %20 = torch.aten.__getitem__.t %4, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %16, %20 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %17 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" torch.prim.Loop %17, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %20 = torch.aten.__getitem__.t %6, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %16, %20 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %18 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" +" %19 = torch.aten.__range_length %13#1, %18, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %19, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %20 = torch.aten.__derive_index %arg2, %13#1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %4, %20 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.append.t %16, %21 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %16 : !torch.list\n" +" }\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" return %9 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor_hacked_twin\"(%arg0: !torch.list, %arg1: !torch.list>) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>>\n" +" %1 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %3 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>, !torch.int -> !torch.list\n" +" %4 = torch.aten.append.t %0, %3 : !torch.list>>, !torch.list -> !torch.list>>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %2 = call @__torch__.index_tensor_like(%arg0, %0) : (!torch.list, !torch.list>>) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n" +" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n" +" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n" +" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n" +" %none = torch.constant.none\n" +" return %none : !torch.none\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg4 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.frobenius_norm.dim\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" +" %1 = torch.derefine %int0 : !torch.int to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +"}\n" +""; + // clang-format on +#ifndef _MSC_VER #pragma clang diagnostic pop - return shapeLib; +#endif } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 906243e14668a..f89be8d916b52 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -60,6 +60,16 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } +Type Torch::getTypeForTorchType( + MLIRContext *context, Type type, + mlir::IntegerType::SignednessSemantics signedness) { + if (type.isa()) + return IntegerType::get(context, 64, signedness); + if (type.isa()) + return Float64Type::get(context); + llvm::report_fatal_error("unhandled type for getTypeForTorchType"); +} + Type Torch::getTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness) { diff --git a/lib/Dialect/TorchConversion/IR/CMakeLists.txt b/lib/Dialect/TorchConversion/IR/CMakeLists.txt index d932650c5a381..38f6f32686656 100644 --- a/lib/Dialect/TorchConversion/IR/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/IR/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(TorchMLIRTorchConversionDialect DEPENDS MLIRTorchConversionOpsIncGen + MLIRTorchTypesIncGen LINK_COMPONENTS Core diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index f7a9141640c3c..f7eb50aa6d9a2 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -41,19 +41,19 @@ namespace { void mlir::torch::registerTorchConversionPasses() { ::registerPasses(); - mlir::PassPipelineRegistration( + mlir::PassPipelineRegistration<>( "torch-backend-to-linalg-on-tensors-backend-pipeline", "Pipeline lowering torch backend contract to linalg-on-tensors backend " "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - mlir::PassPipelineRegistration( + mlir::PassPipelineRegistration<>( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", TorchConversion::createTorchBackendToTosaBackendPipeline); #ifdef TORCH_MLIR_ENABLE_MHLO - mlir::PassPipelineRegistration( + mlir::PassPipelineRegistration( "torch-backend-to-mhlo-backend-pipeline", "Pipeline lowering torch backend contract to MHLO backend " "contract.", @@ -62,7 +62,7 @@ void mlir::torch::registerTorchConversionPasses() { } void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( - OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { + OpPassManager &pm) { // Lower to linalg + guards which is the input to codegen backends. // We do this first as it tends to involve pattern-matching against constants, // (e.g. dimensions which must be constant in a ranked programming model) @@ -96,7 +96,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( } void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { + OpPassManager &pm) { pm.addNestedPass(createConvertTorchToTosaPass()); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); @@ -121,8 +121,10 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( #ifdef TORCH_MLIR_ENABLE_MHLO void TorchConversion::createTorchBackendToMhloBackendPipeline( - OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { - pm.addNestedPass(createConvertTorchToMhloPass()); + OpPassManager &pm, + const TorchConversion::MhloBackendPipelineOptions &options) { + pm.addNestedPass(createConvertTorchToMhloPass( + options.enableStaticShape, options.enableI32Index)); // Clean up any non-canonical code introduced above.. pm.addNestedPass(createCanonicalizerPass()); diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp index 8bc19645dd979..c28ac45eb06a2 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp @@ -12,6 +12,7 @@ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" @@ -45,9 +46,10 @@ class VerifyMhloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp( - opHasLegalTypes); - // Basic scalar operations. + target.addDynamicallyLegalOp(opHasLegalTypes); + // Shape operations. + target.addDynamicallyLegalOp(opHasLegalTypes); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/lib/RefBackend/CMakeLists.txt b/lib/RefBackend/CMakeLists.txt index 733dcc3deff9d..2ef5dab3ae8d0 100644 --- a/lib/RefBackend/CMakeLists.txt +++ b/lib/RefBackend/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(TorchMLIRRefBackend ${PROJECT_SRC_DIR}/include/torch-mlir/RefBackend DEPENDS + MLIRTorchTypesIncGen TorchMLIRRefBackendPassIncGen LINK_COMPONENTS diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 046c2fd44ae05..099e50234896b 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -126,7 +126,7 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): # ops in the backend contract, and move these lists somewhere deeper in the # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { - OutputType.TOSA: ['torch.aten.flatten.using_ints',], + OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',], OutputType.MHLO: [], } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index 32cba4fdf63f9..192b4b2838167 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -301,9 +301,11 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::empty( - at::SymIntArrayRef sym_size, c10::optional dtype, - c10::optional layout, c10::optional device, +at::Tensor LazyNativeFunctions::empty_symint( + at::SymIntArrayRef sym_size, + c10::optional dtype, + c10::optional layout, + c10::optional device, c10::optional pin_memory, c10::optional memory_format) { // TODO: support this directly @@ -333,8 +335,8 @@ at::Tensor LazyNativeFunctions::empty_strided( c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); - at::Tensor t = empty( - c10::SymIntArrayRef::fromIntArrayRef(size), + at::Tensor t = empty_symint( + c10::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt); return t.as_strided(size, stride, /*storage_offset=*/0); } @@ -354,7 +356,7 @@ LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { at::Tensor LazyNativeFunctions::_unsafe_view( const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy(self, c10::SymIntArrayRef::fromIntArrayRef(size)); + return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRef(size)); } // This is needed by the torch.tensor constructor. @@ -380,15 +382,27 @@ at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { return at::functionalization::functionalize_aten_op::call(tensors); } -at::Tensor LazyNativeFunctions::new_empty_strided( - const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, - c10::optional dtype, c10::optional layout, - c10::optional device, c10::optional pin_memory) { +at::Tensor LazyNativeFunctions::new_empty_strided_symint( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return at::functionalization:: functionalize_aten_op::call( self, size, stride, dtype, layout, device, pin_memory); } +at::Tensor LazyNativeFunctions::narrow_copy_symint( + const at::Tensor& self, + int64_t dim, + c10::SymInt start, + c10::SymInt length) { + return at::functionalization::functionalize_aten_op::call(self, dim, start, length); +} at::Tensor LazyNativeFunctions::pixel_shuffle( const at::Tensor& self, int64_t upscale_factor) { return at::functionalization::functionalize_aten_op List[int]: def aten〇relu(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇relu6(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇_softmax(self: List[int], dim: int, half_to_float: bool) -> List[int]: return upstream_shape_functions.unary(self) @@ -448,6 +451,9 @@ def aten〇contiguous(self: List[int], memory_format: int = 0) -> List[int]: def aten〇clone(self: List[int], memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇lift_fresh_copy(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇_log_softmax_backward_data(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -490,6 +496,9 @@ def aten〇floor_divide〇Scalar(self: List[int], other: float) -> List[int]: def aten〇pow〇Tensor_Scalar(self: List[int], exponent: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇pow〇Tensor_Tensor(self: List[int], exponent: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, exponent) + def aten〇rsub〇Scalar(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) @@ -803,6 +812,9 @@ def aten〇index_put_impl(self: List[int], indices: List[Optional[List[int]]], v def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]: return self +def aten〇cumsum(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return self + def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self @@ -854,6 +866,9 @@ def aten〇maximum(self: List[int], other: List[int]) -> List[int]: def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_not(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇logical_or(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1069,20 +1084,7 @@ def aten〇constant_pad_nd(self: List[int], pad: List[int], value: float = 0) -> def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) -# See https://numpy.org/doc/stable/user/basics.indexing.html -@check_shape_function([ - Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. - Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. - Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. - Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. - ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. -]) -def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: +def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" broadcasted_shape: List[int] = [] unused_dim_sizes: List[int] = [] @@ -1122,6 +1124,26 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) - result_shape.append(unused_dim_sizes[i]) return result_shape +# See https://numpy.org/doc/stable/user/basics.indexing.html +@check_shape_function([ + Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. + Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. + Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. + Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. + ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. +]) +def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: + return index_tensor_like(self, indices) + +def aten〇index〇Tensor_hacked_twin(self: List[int], indices: List[List[int]]) -> List[int]: + optional_indices: List[Optional[List[int]]] = [x for x in indices] + return index_tensor_like(self, optional_indices) + def aten〇cat(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.cat(tensors, dim) @@ -1170,6 +1192,9 @@ def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlen def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) +def aten〇frobenius_norm〇dim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) + # ============================================================================== # Shape library generator main(). # ============================================================================== @@ -1229,6 +1254,17 @@ def main(args): # Put the `〇` back to a regular `.`. asm = asm.replace("\\E3\\80\\87", ".") + # We're about to put quotes around the string, so escape the `"` characters. + asm = asm.replace("\"", "\\\"") + + # Instead of dumping one big chunk of text that is several thousand lines + # long (and which causes MSVC to error out), split it into multiple lines. + # See MSVC Compiler Error C2026 + # [https://docs.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026?view=msvc-170] + # for details. + multiple_lines = asm.replace("\n", "\\n\"\n\"") + asm = f"\"{multiple_lines}\"" + # Write out the shape library .cpp file. shape_lib_cpp_file = os.path.join( args.torch_transforms_cpp_dir, "ShapeLibrary.cpp") @@ -1254,19 +1290,16 @@ def main(args): using namespace mlir; StringRef mlir::torch::Torch::getShapeLibrary() {{ -// TODO: Find a way to embed this string nicely. -// It is currently too long, and will probably break MSVC builds if anyone -// attempts that. -// We want to preserve the legibility of the shape library as a checked in file, -// since that is sometimes useful for debugging / diffing. -// Probably the ideal outcome is to have the shape library be a .mlir file -// that is checked in, and then we embed it as part of the build process. +#ifndef _MSC_VER #pragma clang diagnostic push #pragma clang diagnostic ignored "-Woverlength-strings" - constexpr StringLiteral shapeLib(R"mlir( -{asm})mlir"); +#endif + // clang-format off + return {asm}; + // clang-format on +#ifndef _MSC_VER #pragma clang diagnostic pop - return shapeLib; +#endif }}""") def _create_argparse() -> argparse.ArgumentParser: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e2ff4146bc161..7192b5540859c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -241,6 +241,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::tanh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", + "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", @@ -307,6 +308,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") @@ -397,6 +399,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") + emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") @@ -427,6 +430,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") + emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") @@ -443,6 +447,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") + emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") emit("aten::item : (Tensor) -> (Scalar)") @@ -464,7 +469,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") - emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") + emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True) emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") @@ -586,8 +591,10 @@ def emit_with_mutating_variants(key, **kwargs): has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") - emit("aten::div : (Scalar, Scalar) -> (float)") + emit("aten::div : (Scalar, Scalar) -> (float)", has_folder=True) emit("aten::add : (Scalar, Scalar) -> (Scalar)") + emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True) + emit("aten::ceil.Scalar : (Scalar) -> (Scalar)", has_folder=True) emit("aten::sqrt.int : (int) -> (float)", has_folder=True) emit("aten::Bool.float : (float) -> (bool)", has_folder=True) emit("aten::Bool.int : (int) -> (bool)", has_folder=True) diff --git a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py index 25896e0a0043d..0a467ef5f7b6c 100644 --- a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py @@ -36,7 +36,7 @@ def compile(self, imported_module: Module): """ run_pipeline_with_repro_report( imported_module, - "func.func(hlo-legalize-to-linalg)", + "func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize)", "Lowering MLIR-HLO to Linalg-on-Tensors") return self.refbackend.compile(imported_module) diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 70981ead2e2e0..68c18456f232e 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1047,6 +1047,28 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastToIdentityCaseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 1], torch.float32, True), + ]) + def forward(self, x): + return torch.broadcast_to(x, [3, 1, 1]) + + +@register_test_case(module_factory=lambda: BroadcastToIdentityCaseStaticModule()) +def BroadcastToIdentityCaseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 1)) + + +# ============================================================================== + + class RollModule(torch.nn.Module): def __init__(self): @@ -1963,6 +1985,83 @@ def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): # ============================================================================== +class IndexTensorHackedTwinModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, [index]) + + +@register_test_case(module_factory=lambda: IndexTensorHackedTwinModule()) +def IndexTensorHackedTwinModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5), tu.randint(2, 3, high=4)) + + +# ============================================================================== + + +class IndexTensorHackedTwinModule3dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, [index]) + + +@register_test_case( + module_factory=lambda: IndexTensorHackedTwinModule3dInput()) +def IndexTensorHackedTwinModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) + + +# ============================================================================== + + +class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims( + torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 1], torch.int64, True), + ([1, 3], torch.int64, True), + ([-1, 3], torch.int64, True), + ]) + def forward(self, x, index1, index2, index3): + return torch.ops.aten.index(x, [index1, index2, index3]) + + +@register_test_case( + module_factory=lambda: + IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims()) +def IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3), + tu.randint(1, 3, high=1), tu.randint(4, 3, high=1)) + + +# ============================================================================== + + class SquareModule(torch.nn.Module): def __init__(self): @@ -2769,6 +2868,42 @@ def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): # ============================================================================== +class CumsumModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, val): + return torch.ops.aten.cumsum(val, 1) + +@register_test_case(module_factory=lambda: CumsumModule()) +def CumsumModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + +class CumsumStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 7, 4], torch.float32, True), + ]) + def forward(self, val): + return torch.ops.aten.cumsum(val, 1) + +@register_test_case(module_factory=lambda: CumsumStaticModule()) +def CumsumStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + +# ============================================================================== + class AtenToDeviceModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 98f0a94fd913d..6770d7237ff69 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -345,6 +345,28 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRelu6Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.relu6(x) + + +@register_test_case(module_factory=lambda: ElementwiseRelu6Module()) +def ElementwiseRelu6Module_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 2) - 0.5) + + +# ============================================================================== + + class ElementwiseLeakyReluModule(torch.nn.Module): def __init__(self): @@ -1094,6 +1116,52 @@ def ElementwisePowModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePowTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.pow(a, b) + + +@register_test_case(module_factory=lambda: ElementwisePowTensorModule()) +def ElementwisePowTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwisePowTensorBroadcastModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.pow(a, b) + + +@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule()) +def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1), tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): def __init__(self): @@ -1485,6 +1553,50 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNotIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.bitwise_not(x) + + +@register_test_case(module_factory=lambda: ElementwiseNotIntegerModule()) +def ElementwiseNotIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10)) + + +# ============================================================================== + + +class ElementwiseNotInt32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.bitwise_not(x) + + +@register_test_case(module_factory=lambda: ElementwiseNotInt32Module()) +def ElementwiseNotInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseSubScalarIntModule(torch.nn.Module): def __init__(self): @@ -1639,6 +1751,28 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): # ============================================================================== +class LiftFreshCopyModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.lift_fresh_copy(x) + + +@register_test_case(module_factory=lambda: LiftFreshCopyModule()) +def LiftFreshCopyModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +# ============================================================================== + + class ElementwiseExpModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index 36ad293605a26..dfe9d1484f709 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -604,12 +604,10 @@ def __init__(self): def forward(self, x): return self.ap2d(x) - @register_test_case(module_factory=lambda: AvgPool2dFloatModule()) def AvgPool2dFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20) - 0.5) - class AvgPool2dIntModule(torch.nn.Module): def __init__(self): @@ -704,7 +702,6 @@ def __init__(self): def forward(self, x): return self.ap2d(x) - @register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule()) def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 1eecb5186fd70..b28d78a126ed6 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -49,6 +49,25 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumElementTypeBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.sum(a) + + +@register_test_case(module_factory=lambda: ReduceSumElementTypeBoolModule()) +def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -106,6 +125,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 12, 7, 7], torch.float32, True), + ]) + def forward(self, a): + return torch.sum(a, dim=(-1), keepdim=True) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule()) +def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 12, 7, 7)) + +# ============================================================================== + class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -125,6 +163,25 @@ def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.sum(a, dim=(-1), keepdim=False) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule()) +def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool)) + +# ============================================================================== + class ReduceSumUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -510,3 +567,37 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule()) def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== +class ReduceFrobeniusNormModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False) + +@register_test_case(module_factory=lambda: ReduceFrobeniusNormModule()) +def ReduceFrobeniusNormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== +class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=True) + +@register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule()) +def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index f10a9a051a99b..0321e0f48d06c 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -266,7 +266,7 @@ def __init__(self): @export @annotate_args([ None, - ([3,2], torch.float32, True), + ([3, 2], torch.float32, True), ]) def forward(self, a): @@ -297,25 +297,6 @@ def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils): # ============================================================================== -class ViewFiveTestStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 3, 4, 5, 6], torch.float32, True), - ]) - - def forward(self, a): - return a.view(2, 3, 4, 6, 5) - -@register_test_case(module_factory=lambda: ViewFiveTestStaticModule()) -def ViewFiveTestStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5, 6)) - -# ============================================================================== - class ViewOffsetTestStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -354,63 +335,6 @@ def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils): # ============================================================================== -class ViewUnknown1TestStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 3, -1], torch.float32, True), - ]) - - def forward(self, a): - return a.view(3, 2, a.size(2)) - -@register_test_case(module_factory=lambda: ViewUnknown1TestStaticModule()) -def ViewUnknown1TestStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 5)) - -# ============================================================================== - -class ViewUnknown2TestStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, 2, 3], torch.float32, True), - ]) - - def forward(self, a): - return a.view(a.size(0), 3, 2) - -@register_test_case(module_factory=lambda: ViewUnknown2TestStaticModule()) -def ViewUnknown2TestStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 2, 3)) - -# ============================================================================== - -class ViewDoubleMergeStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 2, 4, 4], torch.float32, True), - ]) - - def forward(self, a): - return a.view(4, 16) - -@register_test_case(module_factory=lambda: ViewDoubleMergeStaticModule()) -def ViewDoubleMergeStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 2, 4, 4)) - -# ============================================================================== - class View1DFoldModule(torch.nn.Module): def __init__(self): super().__init__() @@ -475,7 +399,7 @@ def __init__(self): @export @annotate_args([ None, - ([-1, 16, 128], torch.float32, True), + ([1, -1, 128], torch.float32, True), ]) def forward(self, a): diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index cf41631f56b89..2df66184edeb6 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -214,3 +214,23 @@ def forward(self, x): @register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule()) def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + + +class TypeAsSameModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.ops.aten.type_as(x, y) + + +@register_test_case(module_factory=lambda: TypeAsSameModule()) +def TypeAsSameModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5), tu.rand(3, 5)) diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index eefb492f0a459..302bebd4f5c5b 100644 --- a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -52,10 +52,13 @@ def compile(self, imported_module: Module): "func.func(tosa-to-linalg-named)", "Lowering TOSA to Linalg-on-Tensors for Named Ops") + # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them + # to arith.constants here before proceeding further. run_pipeline_with_repro_report( imported_module, - "func.func(tosa-to-linalg)", + "func.func(tosa-to-linalg),func.func(tosa-to-arith)", "Lowering TOSA to Linalg-on-Tensors") + return self.refbackend.compile(imported_module) def load(self, module): diff --git a/setup.py b/setup.py index 929c27e8ac728..e627107e80d5c 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,9 @@ PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1" +# If true, enable LTC build by default +TORCH_MLIR_ENABLE_LTC_DEFAULT = False + # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): @@ -68,12 +71,16 @@ def run(self): src_dir = os.path.abspath(os.path.dirname(__file__)) llvm_dir = os.path.join( src_dir, "externals", "llvm-project", "llvm") + + enable_ltc = int(os.environ.get('TORCH_MLIR_ENABLE_LTC', TORCH_MLIR_ENABLE_LTC_DEFAULT)) + cmake_args = [ f"-DCMAKE_BUILD_TYPE=Release", f"-DPython3_EXECUTABLE={sys.executable}", f"-DLLVM_TARGETS_TO_BUILD=host", f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", f"-DLLVM_ENABLE_PROJECTS=mlir", + f"-DLLVM_ENABLE_ZSTD=OFF", f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects", f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}", f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/externals/llvm-external-projects/torch-mlir-dialects", @@ -81,7 +88,7 @@ def run(self): f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", - f"-DTORCH_MLIR_ENABLE_LTC={'OFF' if int(os.environ.get('TORCH_MLIR_ENABLE_LTC', 1)) else 'OFF'}", + f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", ] os.makedirs(cmake_build_dir, exist_ok=True) diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 96c52da4a6b62..1ffa11a0a4256 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -3,81 +3,74 @@ // ----- // CHECK-LABEL: func.func @torch.aten.view$twotothree( -// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { -// CHECK: %[[ZERO:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> -// CHECK: %[[int3:.*]] = torch.constant.int 3 -// CHECK: %[[int2:.*]] = torch.constant.int 2 -// CHECK: %[[ONE:.*]] = torch.prim.ListConstruct %[[int2]], %[[int3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[TWO:.*]] = torch_c.to_i64 %[[int2]] -// CHECK: %[[THREE:.*]] = torch_c.to_i64 %[[int3]] -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[c3:.*]] = arith.constant 3 : index -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[c2:.*]] = arith.constant 2 : index -// CHECK: %[[FOUR:.*]] = tensor.cast %[[ZERO]] : tensor<3x2xf32> to tensor<3x2xf32> -// CHECK: %[[FIVE:.*]] = tensor.collapse_shape %[[FOUR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> -// CHECK: %[[SIX:.*]] = tensor.expand_shape %[[FIVE]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> -// CHECK: %[[SEVEN:.*]] = tensor.cast %[[SIX]] : tensor<2x3xf32> to tensor<2x3xf32> -// CHECK: %[[EIGHT:.*]] = torch_c.from_builtin_tensor %[[SEVEN]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> -// CHECK: return %[[EIGHT]] : !torch.vtensor<[2,3],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> +// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<3x2xf32> to tensor<3x2xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> +// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x3xf32> to tensor<2x3xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32> -func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { - %int3 = torch.constant.int 3 - %int2 = torch.constant.int 2 - %ONE = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list - %EIGHT = torch.aten.view %arg0, %ONE : !torch.vtensor<[3,2],f32>, !torch.list -> !torch.vtensor<[2,3],f32> - return %EIGHT : !torch.vtensor<[2,3],f32> +func.func @torch.aten.view$twotothree(%ARG: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { + %0 = torch.constant.int 3 + %1 = torch.constant.int 2 + %LIST = torch.prim.ListConstruct %1, %0 : (!torch.int, !torch.int) -> !torch.list + %VIEW = torch.aten.view %ARG, %LIST : !torch.vtensor<[3,2],f32>, !torch.list -> !torch.vtensor<[2,3],f32> + return %VIEW : !torch.vtensor<[2,3],f32> } // CHECK-LABEL: func.func @torch.aten.view$dynamictest( -// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[v0:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[int1:.*]] = torch.constant.int 1 -// CHECK: %[[v1:.*]] = torch_c.to_i64 %[[int1]] -// CHECK: %[[int0:.*]] = torch.constant.int 0 -// CHECK: %[[v2:.*]] = torch_c.to_i64 %[[int0]] -// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 -// CHECK: %[[v3:.*]] = arith.addi %[[v2]], %[[c2_i64]] : i64 -// CHECK: %[[c0_i64:.*]] = arith.constant 0 : i64 -// CHECK: %[[v4:.*]] = arith.cmpi sge, %[[v2]], %[[c0_i64]] : i64 -// CHECK: %[[v5:.*]] = arith.select %[[v4]], %[[v2]], %[[v3]] : i64 -// CHECK: %[[c0_i64_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[v6:.*]] = arith.cmpi sge, %[[v5]], %[[c0_i64_0]] : i64 -// CHECK: %[[v7:.*]] = arith.cmpi slt, %[[v5]], %[[c2_i64]] : i64 -// CHECK: %[[v8:.*]] = arith.index_cast %[[v5]] : i64 to index -// CHECK: %[[v9:.*]] = tensor.dim %[[v0]], %[[v8]] : tensor -// CHECK: %[[v10:.*]] = arith.index_cast %[[v9]] : index to i64 -// CHECK: %[[v11:.*]] = torch_c.from_i64 %[[v10]] -// CHECK: %[[c2_i64_1:.*]] = arith.constant 2 : i64 -// CHECK: %[[v12:.*]] = arith.addi %[[v1]], %[[c2_i64_1]] : i64 -// CHECK: %[[c0_i64_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[v13:.*]] = arith.cmpi sge, %[[v1]], %[[c0_i64_2]] : i64 -// CHECK: %[[v14:.*]] = arith.select %[[v13]], %[[v1]], %[[v12]] : i64 -// CHECK: %[[c0_i64_3:.*]] = arith.constant 0 : i64 -// CHECK: %[[v15:.*]] = arith.cmpi sge, %[[v14]], %[[c0_i64_3]] : i64 -// CHECK: %[[v16:.*]] = arith.cmpi slt, %[[v14]], %[[c2_i64_1]] : i64 -// CHECK: %[[v17:.*]] = arith.index_cast %[[v14]] : i64 to index -// CHECK: %[[v18:.*]] = tensor.dim %[[v0]], %[[v17]] : tensor -// CHECK: %[[v19:.*]] = arith.index_cast %[[v18]] : index to i64 -// CHECK: %[[v20:.*]] = torch_c.from_i64 %[[v19]] -// CHECK: %[[v21:.*]] = torch.prim.ListConstruct %[[v11]], %[[v20]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[v22:.*]] = torch_c.to_i64 %[[v11]] -// CHECK: %[[v23:.*]] = torch_c.to_i64 %[[v20]] -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[v24:.*]] = tensor.dim %[[v0]], %[[c0]] : tensor -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[v25:.*]] = tensor.dim %[[v0]], %[[c1]] : tensor -// CHECK: %[[v26:.*]] = tensor.cast %[[v0]] : tensor to tensor -// CHECK: %[[v27:.*]] = torch_c.from_builtin_tensor %[[v26]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[v27]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor to tensor +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %11 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int - %20 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int - %21 = torch.prim.ListConstruct %11, %20 : (!torch.int, !torch.int) -> !torch.list - %27 = torch.aten.view %arg0, %21 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> - return %27 : !torch.vtensor<[?,?],f32> +func.func @torch.aten.view$dynamictest(%ARG: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.constant.int 1 + %1 = torch.constant.int 0 + %LISTVALUE1 = torch.aten.size.int %ARG, %1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %LISTVALUE2 = torch.aten.size.int %ARG, %0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %LIST = torch.prim.ListConstruct %LISTVALUE1, %LISTVALUE2 : (!torch.int, !torch.int) -> !torch.list + %VIEW = torch.aten.view %ARG, %LIST : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %VIEW : !torch.vtensor<[?,?],f32> } +// CHECK-LABEL: func.func @torch.aten.view$dynamicVal( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32> +// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0], [1, 2]] : tensor<16x128xf32> into tensor<16x1x128xf32> +// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<16x1x128xf32> to tensor<16x1x128xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32> + +func.func @torch.aten.view$dynamicVal(%ARG: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> { + %0 = torch.constant.int 128 + %1 = torch.constant.int 1 + %2 = torch.constant.int 16 + %LIST = torch.prim.ListConstruct %2, %1, %0 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %VIEW = torch.aten.view %ARG, %LIST : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[16,1,128],f32> + return %VIEW : !torch.vtensor<[16,1,128],f32> + } + +// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32> +// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<2x6xf32> to tensor<2x6xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] : tensor<12xf32> into tensor<3x2x2xf32> +// CHECK: %[[EXPAND_CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<3x2x2xf32> to tensor<3x2x2xf32> +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND_CAST]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32> + +func.func @torch.aten.view$expandInferredDim(%ARG: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { + %0 = torch.constant.int 2 + %1 = torch.constant.int 3 + %2 = torch.constant.int -1 + %LIST = torch.prim.ListConstruct %1, %2, %0 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %VIEW = torch.aten.view %ARG, %LIST : !torch.vtensor<[2,6],f32>, !torch.list -> !torch.vtensor<[3,2,2],f32> + return %VIEW : !torch.vtensor<[3,2,2],f32> + } \ No newline at end of file diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index ae505146d5b7e..6480c57da7523 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -159,22 +159,24 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // ----- -// CHECK-LABEL: func.func @torch.aten.batch_norm$training( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> -// CHECK: %true = torch.constant.bool true -// CHECK: %false = torch.constant.bool false -// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 -// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 -// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor -// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]] = "mhlo.batch_norm_inference"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> -func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK-LABEL: func.func @torch.aten.batch_norm$inference( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %[[FLOAT1:.*]].000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor +// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> +// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor to tensor +// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor to tensor +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> %true = torch.constant.bool true diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index bad66a84dbd99..165c874ea0617 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -5,7 +5,7 @@ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> // CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32> +// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { @@ -20,7 +20,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> // CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<3x?xf32>) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -46,7 +46,7 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> // CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> -// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<10x3x5xf32> +// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> // CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { @@ -72,7 +72,7 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor -// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -98,7 +98,7 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> // CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> -// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x256x256xf32> +// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { @@ -124,7 +124,7 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> // CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> -// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x?x?xf32> +// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { @@ -147,7 +147,7 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> // CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> // CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> -// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<1x?xf32> +// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { @@ -170,7 +170,7 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> // CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor // CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor -// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor +// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor to tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -185,7 +185,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256xf32>) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32> func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { @@ -200,7 +200,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> // CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32> func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { @@ -215,7 +215,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[T4]] : !torch.vtensor<[],f32> func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { @@ -241,7 +241,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor // CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor -// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { @@ -257,7 +257,7 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256x256xf32>) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 1c878a5a555dd..346faaa344ff3 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -349,39 +349,6 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> ! return %1 : !torch.vtensor<[?,120,4,64],f32> } -// CHECK-LABEL: func.func @torch.aten.view$minus1( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> -// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] -// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]] -// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INTneg1]] -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[T7:.*]] = arith.muli %[[C1_I64]], %[[T4]] : i64 -// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T5]] : i64 -// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T6]] : i64 -// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index -// CHECK: %[[T11:.*]] = tensor.from_elements %[[T4]], %[[T5]], %[[T6]] : tensor<3xi64> -// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<3xi64> -> tensor<3xi64> -// CHECK: %[[T13:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T12]] : (tensor<2x3x?x?xf32>, tensor<3xi64>) -> tensor<2x3x?xf32> -// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32> -// CHECK: return %[[T14]] : !torch.vtensor<[2,3,?],f32> -func.func @torch.aten.view$minus1(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { - %int-1 = torch.constant.int -1 - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int - %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int - %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[2,3,?,?],f32>, !torch.list -> !torch.vtensor<[2,3,?],f32> - return %3 : !torch.vtensor<[2,3,?],f32> -} - // CHECK-LABEL: func.func @torch.aten.view$to_rank1( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 8534c694e7b6d..56c9e9b811171 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -794,3 +794,26 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> return %0 : !torch.vtensor<[1,512,1,1],f32> } + +// ----- + +// CHECK-LABEL: @torch.aten.max.dim$basic( +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>) +// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true +// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> +// CEHCK: return %[[VAL_6]] : tensor<3x2x1xf32> +func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> + %true = torch.constant.bool true + %int2 = torch.constant.int 2 + %values, %indices = torch.aten.max.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64> + %1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> + return %1 : tensor<3x2x1xf32> +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 36aa26f1058ab..2694318191bb9 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1185,6 +1185,14 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to return %0 : !torch.tensor<[],f32> } +// CHECK-LABEL: func.func @torch.aten.type_as$same( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32> +func.func @torch.aten.type_as$same(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { + %0 = torch.aten.type_as %arg0, %arg0 : !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32> + return %0 : !torch.tensor<[?,?],f32> +} + // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> @@ -1620,4 +1628,4 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch. %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> return %2 : !torch.vtensor<[],si64> -} \ No newline at end of file +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 9cd2d18538bcf..b47b6ecb5ca46 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -479,149 +479,6 @@ func.func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknow return %0 : !torch.vtensor<[?,?,?],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.bernoulli -// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT7:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[NONE_2:.*]] = torch.constant.none -// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> - -// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> -// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false -// CHECK: %[[NONE_3:.*]] = torch.constant.none -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { - %none = torch.constant.none - %0 = torch.aten.bernoulli %arg0, %none : !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> - %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.valsem.aten.bernoulli.float -// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[PROB:.*]] = torch.constant.float 4.000000e-01 -// CHECK: %[[PROB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[PROB]] : !torch.float -> !torch.vtensor<[],f64> -// CHECK: %[[INT7:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[NONE_2:.*]] = torch.constant.none -// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1> -// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false -// CHECK: %[[NONE_3:.*]] = torch.constant.none -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { - %none = torch.constant.none - %prob = torch.constant.float 4.000000e-01 - %0 = torch.valsem.aten.bernoulli.float %arg0, %prob, %none : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> - %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.valsem.aten.bernoulli.Tensor( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f64>, -// CHECK-SAME: %[[PROB:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT7:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[NONE_2:.*]] = torch.constant.none -// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1> -// CHECK: %[[INT7_2:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false -// CHECK: %[[NONE_3:.*]] = torch.constant.none -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT7_2]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : -// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %arg1: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { - %none = torch.constant.none - %0 = torch.valsem.aten.bernoulli.Tensor %arg0, %arg1, %none : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64>, !torch.none -> !torch.vtensor<[?,?,?],f64> - %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f64> to !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.rand_like( -// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[INT6]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_1]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[UNIFORM]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor { - %int6 = torch.constant.int 6 - %none = torch.constant.none - %0 = torch.aten.rand_like %arg0, %int6, %none, %none, %none, %none : !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> - %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor - return %1 : !torch.vtensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.select.int( // CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> { @@ -639,79 +496,6 @@ func.func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vt return %0 : !torch.vtensor<[?],si64> } -// ----- -// CHECK-LABEL: func.func @torch.aten.hardsigmoid( -// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[CST2:.*]] = torch.constant.int 3 -// CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INPUT]], %[[CST2]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[ADD]], %[[CST6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[CST1_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : -// CHECK-SAME: !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[CST0_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32> -// CHECK: } -func.func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.hardswish( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INP]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: %[[RELU:.*]] = torch.aten.relu %[[ADD]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[INT6_:.*]] = torch.constant.int 6 -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.hardswish %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.hardtanh( -// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?],f32>, -// CHECK-SAME: %[[MIN_VAL:.*]]: !torch.float, -// CHECK-SAME: %[[MAX_VAL:.*]]: !torch.float) -> !torch.vtensor<[?],f32> { -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[MIN_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> -// CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32> -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[MAX_TENSOR:.*]] = torch.valsem.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> -// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32> -// CHECK: return %[[RET]] : !torch.vtensor<[?],f32> -func.func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> { - %0 = torch.aten.hardtanh %arg0, %min, %max : !torch.vtensor<[?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?],f32> - return %0 : !torch.vtensor<[?],f32> -} // ----- // CHECK-LABEL: func.func @torch.aten.new_zeros @@ -764,48 +548,6 @@ func.func @torch.aten.silu(%arg0: !torch.vtensor<[?,?],f32> loc(unknown)) -> !to return %0 : !torch.vtensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.full -// CHECK-SAME: () -> !torch.vtensor<[2,3],f32> { -// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[MEM_FORMAT:.*]] = torch.constant.none -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[MEM_FORMAT]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> -// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[FLOAT5]] : !torch.vtensor<[2,3],f32>, !torch.float -> !torch.vtensor<[2,3],f32> -// CHECK: return %[[RES]] : !torch.vtensor<[2,3],f32> -func.func @torch.aten.full() -> !torch.vtensor<[2,3],f32> { - %float5.000000e00 = torch.constant.float 5.000000e+00 - %int3 = torch.constant.int 3 - %int2 = torch.constant.int 2 - %none = torch.constant.none - %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.full %0, %float5.000000e00, %none, %none, %none, %none : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> - return %1 : !torch.vtensor<[2,3],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.full_like( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[INT5:.*]] = torch.constant.int 5 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INP]], %[[INT0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> -// CHECK: %[[RES:.*]] = torch.valsem.aten.fill.Scalar %[[EMPTY]], %[[INT5]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int5 = torch.constant.int 5 - %none = torch.constant.none - %0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - // ----- // CHECK-LABEL: func.func @torch.aten.index_put( // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>, @@ -821,45 +563,6 @@ func.func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch. return %0 : !torch.vtensor<[?],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.expand_as( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,1,1],f32>, %[[OTHER:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[OTHER]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[OTHER]], %[[INT1]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[OTHER]], %[[INT2]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[RES:.*]] = torch.aten.broadcast_to %[[INP]], %[[SIZE]] : !torch.vtensor<[?,1,1],f32>, !torch.list -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32> -func.func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { - %0 = torch.aten.expand_as %arg0, %arg1 : !torch.vtensor<[?,1,1],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._to_copy( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INP]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INP]], %[[INT1]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INP]], %[[INT2]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[RES:.*]] = torch.valsem.aten.copy %[[EMPTY]], %[[INP]], %[[FALSE]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>, !torch.bool -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32> -func.func @torch.aten._to_copy(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { - %false = torch.constant.bool false - %none = torch.constant.none - %0 = torch.aten._to_copy %arg0, %none, %none, %none, %none, %false, %none : !torch.vtensor<[?,?,?],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> -} - // ----- // CHECK-LABEL: func.func @torch.aten.dropout$eval( // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -873,47 +576,6 @@ func.func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v return %0 : !torch.vtensor<[?,?],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.dropout$train( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[PROB:.*]] = torch.constant.float 3.000000e-01 -// CHECK: %[[TRAIN:.*]] = torch.constant.bool true -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[ONEMINUSP:.*]] = torch.aten.sub.float %[[CST1]], %[[PROB]] : !torch.float, !torch.float -> !torch.float -// CHECK: %[[PROB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[ONEMINUSP]] : !torch.float -> !torch.vtensor<[],f64> -// CHECK: %[[INT7:.*]] = torch.constant.int 7 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE_0:.*]] = torch.constant.none -// CHECK: %[[CON2FLOAT:.*]] = torch.aten.to.dtype %[[INP]], %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : -// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> -// CHECK: %[[NONE_1:.*]] = torch.constant.none -// CHECK: %[[NONE_2:.*]] = torch.constant.none -// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int -// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f64> -// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?],f64> -// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?],i1> -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false -// CHECK: %[[NONE_3:.*]] = torch.constant.none -// CHECK: %[[BOOL_MASK:.*]] = torch.aten.to.dtype %[[CMP]], %[[INT6]], %[[FALSE_2]], %[[FALSE_2]], %[[NONE_3]] : -// CHECK-SAME: !torch.vtensor<[?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32> -// CHECK: %[[MASK_INP:.*]] = torch.aten.mul.Tensor %[[BOOL_MASK]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[MASK_INP]], %[[ONEMINUSP]] : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %float3.000000e-01 = torch.constant.float 3.000000e-01 - %true = torch.constant.bool true - %0 = torch.aten.dropout %arg0, %float3.000000e-01, %true : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - // ----- // CHECK-LABEL: func.func @torch.aten.zero( // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -945,59 +607,6 @@ func.func @torch.aten.new_empty(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten return %1 : !torch.vtensor<[2,3],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.where.Scalar( -// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[CST8:.*]] = torch.constant.float 8.000000e+00 -// CHECK: %[[CST4:.*]] = torch.constant.float 4.000000e+00 -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[FILL_SELF:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST4]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> -// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE2:.*]] = torch.constant.none -// CHECK: %[[ALLOC2:.*]] = torch.aten.empty.memory_format %[[LIST2]], %none_0, %none_0, %none_0, %none_0, %none_0 : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> -// CHECK: %[[FILL_OTHER:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC2]], %[[CST8]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> -// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL_SELF]], %[[FILL_OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f32> -func.func @torch.aten.where.Scalar(%arg0: !torch.vtensor<[?,?,?],i1>) -> !torch.vtensor<[?,?,?],f32> { - %cst8 = torch.constant.float 8.000000e+00 - %cst4 = torch.constant.float 4.000000e+00 - %0 = torch.aten.where.Scalar %arg0, %cst4, %cst8 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.float -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.where.ScalarSelf( -// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[OTHER:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { -// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00 -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64> -// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64> -// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[FILL]], %[[OTHER]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64> -// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64> -func.func @torch.aten.where.ScalarSelf(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { - %cst = torch.constant.float 4.000000e+00 - %0 = torch.aten.where.ScalarSelf %arg0, %cst, %arg1 : !torch.vtensor<[?,?,?],i1>, !torch.float, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?,?],f64> - return %0 : !torch.vtensor<[?,?,?],f64> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.where.ScalarOther( -// CHECK-SAME: %[[COND:.*]]: !torch.vtensor<[?,?,?],i1>, %[[SELF:.*]]: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { -// CHECK: %[[CST:.*]] = torch.constant.float 4.000000e+00 -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[LIST]], %none, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64> -// CHECK: %[[FILL:.*]] = torch.valsem.aten.fill.Scalar %[[ALLOC]], %[[CST]] : !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[],f64> -// CHECK: %[[OUT:.*]] = torch.aten.where.self %[[COND]], %[[SELF]], %[[FILL]] : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],f64> -// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?],f64> -func.func @torch.aten.where.ScalarOther(%arg0: !torch.vtensor<[?,?,?],i1>, %arg1: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?,?],f64> { - %cst = torch.constant.float 4.000000e+00 - %0 = torch.aten.where.ScalarOther %arg0, %arg1, %cst : !torch.vtensor<[?,?,?],i1>, !torch.vtensor<[?,?],f64>, !torch.float -> !torch.vtensor<[?,?,?],f64> - return %0 : !torch.vtensor<[?,?,?],f64> -} // ----- // CHECK-LABEL: func.func @torch.aten.pad diff --git a/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py b/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py index cde686c74d16a..25d65101486b8 100644 --- a/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py +++ b/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py @@ -18,6 +18,7 @@ # `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so # naively duplicating a Tensor retains the identity of the TensorImpl. +# CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/python/importer/jit_ir/ivalue_import/quantization.py b/test/python/importer/jit_ir/ivalue_import/quantization.py index 422e6bb70526c..f05cf434f8373 100644 --- a/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -12,6 +12,7 @@ mb = ModuleBuilder() +# CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/python/importer/jit_ir/ivalue_import/submodules.py b/test/python/importer/jit_ir/ivalue_import/submodules.py index d4a49928d1f22..92333d20e1db3 100644 --- a/test/python/importer/jit_ir/ivalue_import/submodules.py +++ b/test/python/importer/jit_ir/ivalue_import/submodules.py @@ -22,6 +22,7 @@ def __init__(self): self.s0 = Submodule(0) self.s1 = Submodule(1) +# CHECK-LABEL: torch.class_type @__torch__.TestModule { # CHECK: %[[T:.*]] = torch.constant.bool true # CHECK: %[[N0:.*]] = torch.constant.int 0 diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 6a8cc1e38c4c5..1bd831223dcf8 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -314,10 +314,7 @@ gentbl_cc_library( strip_include_prefix = "include", tbl_outs = [ ( - [ - "-gen-pass-decls", - "-DTORCH_MLIR_ENABLE_MHLO", - ], + ["-gen-pass-decls"], "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc", ), ],