Skip to content

ANE compile OOMs on certain input shapes #8439

@metascroy

Description

@metascroy

🐛 Describe the bug

On a static llama (https://github.com/pytorch/executorch/pull/8436/files), ANE OOMs on certain input sizes during compilation. The OOMing is not monotonic in input size and appears to be related to certain special input shapes.

For example, on a static Llama1B (https://github.com/pytorch/executorch/pull/8436/files) running on iPhone 15 Pro, if I see the following behavior in CoreML based on seq_length/max_seq_length:


seq_len| max_seq_len   | time (ms)
---------------------------------
     1 |          1024 |     35.55
   512 |          1024 |    263.06
   256 |          1024 |    121.98
   130 |          1024 |     83.19
   128 |          1024 |      OOM
    64 |          1024 |     46.22
    32 |          1024 |     39.31
     1 |           512 |     32.99
    32 |           512 |      OOM
    64 |           512 |      OOM
   128 |           512 |      OOM
   256 |           512 |     88.71
     1 |          2048 |      OOM
    64 |          2048 |     64.92

Note that when seq_len = 128, we OOM on compilation, but 130, 256, and 512 succeeds and runs. I wonder if for certain shapes the compiler is doing something different.

To reproduce the models, check out https://github.com/pytorch/executorch/pull/8436/files and run the following from executorch/examples/apple/coreml/llama (seq_len and max_seq_length can be adjusted)

python export.py -n /path/to/output.pte -p /path/to/params.json -c /path/to/model.pth --static_seq_length 128 --max_seq_length 1024 -E"4,32" --coreml-quantize "c4w"

After the pte file is generated, extract the mlpackage using executorch/examples/apple/coreml/scripts/extract_coreml_models.py and profile it in the CoreML profile tool (there is only one CoreML model in the package).

Versions

PyTorch version: 2.7.0.dev20250131
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.2 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.3)
CMake version: version 3.31.4
Libc version: N/A

Python version: 3.10.16 (main, Dec 11 2024, 10:22:29) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+ee7d388
[pip3] executorchcoreml==0.0.1
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==24.4.26
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy==1.14.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.0.0
[pip3] torch==2.7.0.dev20250131
[pip3] torchao==0.8.0+git11333ba2
[pip3] torchaudio==2.6.0.dev20250131
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250131
[conda] executorch 0.6.0a0+ee7d388 pypi_0 pypi
[conda] executorchcoreml 0.0.1 pypi_0 pypi
[conda] numpy 2.0.0 pypi_0 pypi
[conda] torch 2.7.0.dev20250131 pypi_0 pypi
[conda] torchao 0.9.0+git04f3f03d pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250131 pypi_0 pypi
[conda] torchfix 0.6.0 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250131 pypi_0 pypi

cc @kimishpatel @YifanShenSZ @cymbalrush

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    To Triage

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions