Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine symmetric linear op fro CG from abstract output dtype #23486

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Joshuaalbert
Copy link
Contributor

Address #23403

@Joshuaalbert
Copy link
Contributor Author

@jakevdp want to review this?

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I lost track of this while I was out of office.

Thanks for putting this together – a couple comments below

jax/_src/scipy/sparse/linalg.py Outdated Show resolved Hide resolved
jax/_src/scipy/sparse/linalg.py Outdated Show resolved Hide resolved
@jakevdp jakevdp self-assigned this Sep 11, 2024
Copy link
Contributor Author

@Joshuaalbert Joshuaalbert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 16, 2024

Lint is failing due to an unused import

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 16, 2024

The CI failures look to be related to this change

return not issubclass(x.dtype.type, np.complexfloating)
if callable(A) and x0 is not None:
# we use output dtype as the proxy for dtype.
symmetric = all(map(real_valued, tree_leaves(jax.eval_shape(A, x0))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is problematic: it assumes that all tree leaves are arrays, which is not always the case.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval to allow running test

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 16, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 16, 2024

There's some issue with github actions currently, so for some reason I can't trigger the tests. I'd suggest running them locally via pytest -n auto tests to see if there's any remaining issues

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

Can you please squash the changes into a single commit? See https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests

Thanks!

@Joshuaalbert
Copy link
Contributor Author

I don't see an easy way to squash since I've already created the PR. @jakevdp

@Joshuaalbert
Copy link
Contributor Author

I squashed on another PR. https://github.com/Joshuaalbert/jax/tree/update-cg

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 24, 2024

You can squash the commits and force-push to the branch this PR was made from, and it will update here. Please let me know if you'd like me to walk you through it – thanks!

@Joshuaalbert
Copy link
Contributor Author

@jakevdp figured it out. There was some issue with it dropping my branch after interactive rebase, which made force push fail, but found that git rebase --continue solved it.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 25, 2024

Thanks - but it looks like there are now conflicts with respect to the main branch. Can you rebase against the updated main branch, and then make sure your branch only contains one commit on top of main? Thanks!

@Joshuaalbert
Copy link
Contributor Author

@jakevdp There you go

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 30, 2024

Something went wrong here: you now have almost 100 commits on your branch. It looks like you probably merged an updated main branch, and then rebased against an older version of the main branch.

Can you please rebase against the most recent main branch?

@Joshuaalbert
Copy link
Contributor Author

Probably because I forked and then modified the main branch. I'll see what I can do with this mess.

…rator is symmetric.

* Added SciPy API extension to bicgstab, so user can specify if linear operator is symmetric

(cherry picked from commit 618e777)
@Joshuaalbert Joshuaalbert reopened this Oct 1, 2024
@Joshuaalbert
Copy link
Contributor Author

Mess resolved?

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! The last thing we'll need here is tests for the new keywords. We can probably modify existing cases in lax_scipy_sparse_test.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants