Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch size calculation with vmap and static_argnums #1150

Merged
merged 5 commits into from
Sep 24, 2024
Merged

Conversation

dime10
Copy link
Collaborator

@dime10 dime10 commented Sep 23, 2024

Resolves a bug accessing argument shapes in the vmap function when those arguments don't have a shape. This might happen when using vmap together with static_argnums for example:

@qjit(static_argnums=0)
@vmap(in_axes=(None, 0))
def test(n, array):
    for _ in range(n):
        array = array * 2
    return array

test(10, np.random.rand(3,2))

Also improves error message for out of bounds axis specifier.

[sc-74291]

@dime10 dime10 added bug Something isn't working frontend Pull requests that update the frontend labels Sep 23, 2024
@dime10 dime10 requested a review from a team September 23, 2024 22:23
Copy link

codecov bot commented Sep 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.87%. Comparing base (c13cd9b) to head (8b23dda).
Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1150      +/-   ##
==========================================
- Coverage   97.88%   97.87%   -0.02%     
==========================================
  Files          76       76              
  Lines       10863    10850      -13     
  Branches     1283     1283              
==========================================
- Hits        10633    10619      -14     
  Misses        179      179              
- Partials       51       52       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
frontend/test/pytest/test_vmap.py Outdated Show resolved Hide resolved
Co-authored-by: Joey Carter <joseph.carter@xanadu.ai>
@dime10 dime10 merged commit 5e14033 into main Sep 24, 2024
41 of 42 checks passed
@dime10 dime10 deleted the bug-vmap-shape branch September 24, 2024 13:27
@dime10 dime10 added this to the v0.8.2 milestone Oct 9, 2024
dime10 added a commit that referenced this pull request Oct 11, 2024
Resolves a bug accessing argument shapes in the vmap function when those
arguments don't have a shape. This might happen when using vmap together
with `static_argnums` for example:

```py
@qjit(static_argnums=0)
@vmap(in_axes=(None, 0))
def test(n, array):
    for _ in range(n):
        array = array * 2
    return array

test(10, np.random.rand(3,2))
```

Also improves error message for out of bounds axis specifier.

[sc-74291]

---------

Co-authored-by: Joey Carter <joseph.carter@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants