-
Notifications
You must be signed in to change notification settings - Fork 64
Open
Description
Hello!
If you try to use a Linear layer with complex number'ed weights, and an input with complex number'ed weights, the code errors.
StructureMismatchError: (tester) Mismatch while checking structures: At root: Value has the wrong dtype: expected a sub-dtype of <class 'numpy.floating'> but got dtype complex64.
penzai/penzai/nn/linear_and_affine.py
Line 589 in bf38ed0
dtype=jnp.floating, |
On line 589 of linear_and_affine.py, it defaults the dtype to be floating.
To fix this issue and allow for complex input, it might be better to check to see if the input is complex, and to make sure the weights are complex if so.
Please let me know if you need any more information!
Metadata
Metadata
Assignees
Labels
No labels