Skip to content

Conversation

@ajroetker
Copy link

Summary

Fixes ONNX export failure for models using SpanMLP span representation.

The SpanMLP class used an inferred dimension (-1) in the view() call:

span_rep = span_rep.view(B, L, -1, D)

This causes ONNX export to fail with a reshape size mismatch error because PyTorch's ONNX exporter cannot reliably infer the -1 dimension when using dynamic shapes.

Error example:

total requested size 786432 (dimensions=[1 128 12 512]) doesn't match original size 65536 (dimensions [1 128 512])

Fix

  1. Store max_width as an instance variable in __init__
  2. Use self.max_width explicitly in view() instead of -1
# Before
span_rep = span_rep.view(B, L, -1, D)

# After  
span_rep = span_rep.view(B, L, self.max_width, D)

This makes SpanMLP consistent with other span representation classes (SpanMarker, SpanMarkerV0, SpanMarkerV1) which already use explicit max_width in their view() calls.

Testing

  • This fix was discovered while attempting to run exported GLiNER ONNX models in GoMLX
  • The fix allows successful ONNX export and inference with dynamic batch sizes and sequence lengths

The SpanMLP class used an inferred dimension (-1) in the view() call,
which causes ONNX export to fail with a reshape size mismatch error.

When exporting to ONNX with dynamic shapes, PyTorch's ONNX exporter
cannot reliably infer the -1 dimension, resulting in malformed reshape
dimensions like [1, 128, 12, 512] for an input of shape [1, 128, 512].

This fix:
1. Stores max_width as an instance variable
2. Uses self.max_width explicitly in view() instead of -1

This makes SpanMLP consistent with other span representation classes
(SpanMarker, SpanMarkerV0, SpanMarkerV1) which already use explicit
max_width in their view() calls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant