Skip to content

Conversation

@Pranaykarvi
Copy link
Contributor

Summary

This PR fixes excessive memory allocation for Transformer attention ops when the
sequence length is statically known at compile time (e.g. seq_len=128).

For static-shape attention, Metal may eagerly allocate large intermediate buffers
(e.g. QKᵀ matrices), which can lead to multi-GB allocations and OOM on iOS devices.
The existing attention slicing pass was gated behind a high sequence-length
threshold and did not trigger for smaller static shapes.

This change enables memory-efficient attention slicing for static sequence lengths
while preserving the existing behavior for dynamic-shape models.


Problem

When exporting Transformer models with a statically known sequence length,
scaled_dot_product_attention may materialize large intermediate tensors during
lowering. On iOS, this can result in excessive Metal buffer allocation (observed
~10GB) and OOM during inference or benchmarking, even for relatively small models
(e.g. Llama-style models with seq_len=128).


Solution

  • Detect whether the attention sequence length is statically known.
  • For static shapes with sequence length ≥ 64, apply the existing
    scaled_dot_product_attention_sliced_q pass to break the computation into
    smaller chunks and reduce peak memory usage.
  • Preserve the original slicing threshold (1280) and behavior for
    dynamic-shape models to avoid unnecessary overhead.

This approach limits the change to the pathological static-shape case and avoids
global behavior changes.


Testing

  • Added a regression test that constructs a Transformer-style attention block with
    a static sequence length (seq_len=128).
  • The test verifies:
    • The sequence length is statically known.
    • The attention op is replaced by sliced logic.
    • Intermediate tensor sizes remain below a conservative safety bound, preventing
      pathological buffer materialization.
  • Tests run at conversion time only and do not require iOS or Metal execution.

Notes

  • This fix targets mobile inference scenarios where static shapes can trigger
    eager buffer allocation.
  • Dynamic-shape models are intentionally unaffected.

Fixes #2590.

@Pranaykarvi
Copy link
Contributor Author

Hi @TobyRoseman , just a gentle follow-up in case this slipped through.
Happy to make any adjustments if needed, thanks!

Copy link
Collaborator

@TobyRoseman TobyRoseman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your new unit tests don't pass.

from coremltools.converters.mil._deployment_compatibility import AvailableTarget as target
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.types.symbolic import is_symbolic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this line.

# allocation issues as static shapes, so the higher threshold is appropriate.
logger.debug(
f"skipping SDPA op, Q seq_length is {q_seq_length} (minimum seq length needed: {self._min_seq_length}"
f"skipping SDPA op, Q seq_length is dynamic (symbolic), "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be a f-string since there is no variable being used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is also an issue in several other places of this PR.

"common::remove_symbolic_reshape",
"common::noop_elimination",
# Apply attention slicing early to reduce memory allocation for static sequence lengths.
# This pass replaces scaled_dot_product_attention with a memory-efficient sliced implementation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line of the comment. It doesn't really add much and can easily become outdated/inaccurate.

Defines the size of the chunks of Q being processed in SDPA (chunk_size = seq_length / seq_length_divider)
"""

# Default threshold for dynamic-shape models. Dynamic shapes use runtime allocation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The added comments in this file are far too long. They need to be much concise.

3. The model can be converted successfully
"""
# Create a minimal transformer attention block with static seq_len=128
batch_size = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are constants? If so the variable name should be all in caps.

return output

# Apply the default pass pipeline which includes the slicing pass
from coremltools.converters.mil.mil.passes.pass_pipeline import PassPipeline
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statement should be done at the top of the file. If for some reason, they can't be done at the top of the file, do it at the top of the function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is also an issue elsewhere.

f"This indicates the memory allocation fix is not working correctly."
)

# Verify the program structure is correct
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to delete this comment?

"""
Regression test for memory allocation bug with static sequence length transformers.
This test verifies that exporting a Llama-style Transformer with a static sequence
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but I don't think any of this is specific to a Llama-style Transformer.

# The key verification is that attention ops are sliced and tensor sizes are reasonable
# which we've already checked above

def test_static_seq_len_128_with_quantization(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of duplicated code here with the previous method. Please create a helper function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Exported Llama 1B transformer with static 128 sequence length tries to allocate 10Gb on iOS18 causing OOM

2 participants