Skip to content

Conversation

@thijs-vanweezel
Copy link

@thijs-vanweezel thijs-vanweezel commented Nov 17, 2025

What does this PR do?

As per issue #5075, this PR slightly modifies nnx.PathContains to allow for substring matching. This is important in order to match all layers which have been given a similar name (e.g., "conv1", "conv2", etc.). As suggested, the modification implements an exact parameter, which, if set to False, ensures that key is matched against each element of path. Note; since a Key instance is not guaranteed to implement .__contains__, the PathParts are first converted to string.

P.S.: Although the tests were eventually passed, it appeared that JAX had to be built from source to satisfy recently added imports (an instruction perhaps worth adding to this guide), and that pre-existing typing errors in mypy required setting RUN_MYPY=false.

Fixes #5075

Example

import jax
from flax import nnx
from jax import numpy as jnp

# Modified LeNet-5 for 36X60 images
class LeNet(nnx.Module):
    def __init__(self, key):
        super().__init__()
        self.conv1 = nnx.Conv(1, 8, (4,4), rngs=key, padding="VALID")
        self.conv2 = nnx.Conv(8, 16, (4,4), rngs=key, padding="VALID")
        self.fc1 = nnx.Linear(6*12*16, 128, rngs=key)
        self.fc2 = nnx.Linear(128, 64, rngs=key)
        self.fc3 = nnx.Linear(64, 16, rngs=key)
    
model = LeNet(nnx.Rngs(0))

# Split model into fc and conv layers
struct, fc, conv, rest = nnx.split(
    model, 
    nnx.All(nnx.PathContains("fc", exact=False), nnx.PathContains("kernel")), 
    nnx.All(nnx.PathContains("conv", exact=False), nnx.PathContains("kernel")),
    ...
)

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@google-cla
Copy link

google-cla bot commented Nov 17, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @thijs-vanweezel !
Can you please add a test for the exact=False case?

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@thijs-vanweezel
Copy link
Author

@vfdev-5 good suggestion. It has been added now.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2025

Please rebase your PR on the latest main branch to keep only your updates

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2025

The failure of the CI is related to the broken main which is now fixed (please rebase)

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @thijs-vanweezel !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.PathContains should support regex

2 participants