Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce amount of modules, clean-up #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lericson
Copy link

@lericson lericson commented Mar 6, 2023

Use existing primitives such as nn.GELU
Remove no-op modules
Update super()
Removed trailing whitespaces in code

Use existing primitives such as `nn.GELU`
Remove no-op modules
Update `super()`
Removed trailing whitespaces in code
@lericson
Copy link
Author

lericson commented Mar 6, 2023

I should note that this fixes a minor discrepancy in the code compared to the JAX reference, their code says (with some trivial rewriting):

x = self.encoder(x)

x = x[:, 0]

if repr_dim is not None:
  x = nn.Dense(repr_dim)(x)
  x = nn.tanh(x)

... whereas this is what this repository says:

        x = self.encoder(x)
        x = self.pre_logits(x)

        # only support cls token now
        x = x[:, 0]

        return self.head(x)

It doesn't actually matter, but the placement of the x = x[:, 0] is earlier in the reference code. This PR does the same. It just saves some cycles I guess.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant