Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Parameter equivalent in mlx #127

Closed
henrhie opened this issue Dec 11, 2023 · 2 comments
Closed

nn.Parameter equivalent in mlx #127

henrhie opened this issue Dec 11, 2023 · 2 comments

Comments

@henrhie
Copy link
Contributor

henrhie commented Dec 11, 2023

How do you achieve nn.Parameter (which is used to create a parameter for a module in PyTorch) in mlx

@angeloskath
Copy link
Member

Any public member of a nn.Module (ie its name not starting with an underscore) that is an array is a trainable parameter.

You can have "frozen" parameters which will not be returned by trainable_parameters() by calling freeze. Or you can have what is effectively constants by setting private member variables starting with an underscore. For instance:

import mlx.core as mx
import mlx.nn as nn

class Foo(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = mx.array(0.)
        self.b = mx.array(1.)
        self._c = mx.array(2.)

f = Foo()

print(f.trainable_parameters())
# {'a': array(0, dtype=float32), 'b': array(1, dtype=float32)}
print(f.parameters())
# {'a': array(0, dtype=float32), 'b': array(1, dtype=float32)}

f.freeze(keys=["b"])
print(f.trainable_parameters())
# {'a': array(0, dtype=float32)}
print(f.parameters())
# {'a': array(0, dtype=float32), 'b': array(1, dtype=float32)}

print(f._c)
array(2, dtype=float32)

@henrhie
Copy link
Contributor Author

henrhie commented Dec 11, 2023

makes sense. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants