Skip to content

Add stricter type checking and expected matrix sizes#42

Merged
gpleiss merged 75 commits intocornellius-gp:mainfrom
corwinjoy:jaxtyping
May 2, 2023
Merged

Add stricter type checking and expected matrix sizes#42
gpleiss merged 75 commits intocornellius-gp:mainfrom
corwinjoy:jaxtyping

Conversation

@corwinjoy
Copy link
Contributor

As per the discussion: cornellius-gp/gpytorch#2180
I have upgraded the library to use Jaxtyping.
This allows a couple things:

  1. The signatures can now include explicit matrix sizes, showing what kinds of results are expected / where broadcasting may be happening.
  2. I have upgraded the operator tests to annotate each tested library with typeguard. So, only during testing, the type hints are tested for correctness to see if the functions actually behave the way the type hints say they will. This helps make the type hints much more accurate so I think its a good step forward.

Issues:

  1. The jaxtyping library requires jax. It should be not to hard to get rid of this dependency (it is on the maintainer's todo list) but that has not been done yet.
  2. It would be nice to be able to add the storage type for the operator (e.g. dense or not) but this again requires upgrading jaxtyping.
  3. I would also like to upgrade the jaxtyping math so I could have expressions like min(M, N) where M,N are array dimensions for some operators.
  4. Because typeguard does not inherit type hint signatures, I've had to copy/paste these from the base _linear_operators class to the derived operators. I think it probably makes sense to create a tool to do this automatically to help with signature maintenance.

Corwin Joy added 30 commits November 9, 2022 13:31
- Pull requests run a single type checked test (for speed reasons)
- Pushes to main and releases run all type checked tests (for
completeness)
@gpleiss gpleiss merged commit 32ba847 into cornellius-gp:main May 2, 2023
@gpleiss
Copy link
Member

gpleiss commented May 2, 2023

Finally got this in! Thanks @corwinjoy !

@corwinjoy
Copy link
Contributor Author

Awesome! Thanks so much for reviewing, polishing and merging this @gpleiss. I know it was a lot to review!

Balandat added a commit to Balandat/linear_operator that referenced this pull request May 3, 2023
gpleiss pushed a commit that referenced this pull request May 3, 2023
@Balandat
Copy link
Collaborator

Balandat commented May 3, 2023

Hmm @gpleiss I might have missed this, but it looks like jaxtyping is now a required dep even for installation, rather than just testing? In that case we need to add it to install_requires. Will put up a PR for this

Balandat added a commit to Balandat/linear_operator that referenced this pull request May 3, 2023
Balandat added a commit that referenced this pull request May 3, 2023
* Make jaxtyping and typeguard required dependencies

See #42 (comment)

ALso removes the `packaging` dep, this is transitive through `pytest`: https://github.com/pytest-dev/pytest/blob/main/setup.cfg#L48

* Make typeguard an "test" dep only (I don't think it's needed at runtime, is it?)

* Remove typeguard dep (transitive via jaxtyping)

* Add typeguard restriction to <3.0.0 back in

def _size(self) -> torch.Size:
return self._diag.shape + self._diag.shape[-1:]
return torch.Size([*self._diag.shape, self._diag.shape[-1]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change broke the case where self._diag is a scalar tensor. Will put up a fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants