Skip to content

Commit 6310690

Browse files
author
Flax Authors
committed
Merge pull request google#4705 from IvyZX:einsum
PiperOrigin-RevId: 747923709
2 parents 32ad6c5 + 000ee89 commit 6310690

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

flax/nnx/nn/linear.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PaddingLike,
3737
LaxPadding,
3838
PromoteDtypeFn,
39+
EinsumT,
3940
)
4041

4142
Array = jax.Array
@@ -426,6 +427,8 @@ class Einsum(Module):
426427
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
427428
and a ``dtype`` keyword argument, and return a tuple of arrays with the
428429
promoted dtype.
430+
einsum_op: An injectable alternative of `jnp.einsum` to do the computation.
431+
Should support same signature as `jnp.einsum`.
429432
rngs: rng key.
430433
"""
431434

@@ -441,6 +444,7 @@ def __init__(
441444
kernel_init: Initializer = default_kernel_init,
442445
bias_init: Initializer = default_bias_init,
443446
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
447+
einsum_op: EinsumT = jnp.einsum,
444448
rngs: rnglib.Rngs,
445449
):
446450
einsum_str = einsum_str.replace(' ', '')
@@ -465,6 +469,7 @@ def __init__(
465469
self.kernel_init = kernel_init
466470
self.bias_init = bias_init
467471
self.promote_dtype = promote_dtype
472+
self.einsum_op = einsum_op
468473

469474
def __call__(
470475
self, inputs: Array, einsum_str: tp.Optional[str] = None
@@ -500,7 +505,7 @@ def __call__(
500505
dtype=self.dtype,
501506
)
502507

503-
y = jnp.einsum(einsum_str, inputs, kernel, precision=self.precision)
508+
y = self.einsum_op(einsum_str, inputs, kernel, precision=self.precision)
504509

505510
if bias is not None:
506511
broadcasted_bias_shape = self._infer_broadcasted_bias_shape(

flax/typing.py

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def is_key_like(x: Any) -> TypeGuard[Key]:
6868
]
6969
DotGeneralT = Callable[..., Array]
7070
ConvGeneralDilatedT = Callable[..., Array]
71+
EinsumT = Callable[..., Array]
7172

7273
PaddingLike = Union[str, int, Sequence[Union[int, tuple[int, int]]]]
7374
LaxPadding = Union[str, Sequence[tuple[int, int]]]

tests/nnx/nn/linear_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def test_nnx_einsum_equivalence(
154154
assert isinstance(out, jax.Array)
155155
np.testing.assert_array_equal(out, out_nnx)
156156

157+
def test_einsum_op(self):
158+
def custom_einsum(*args, **kwargs):
159+
out = jnp.einsum(*args, **kwargs)
160+
return out.reshape((1, *out.shape))
161+
model = nnx.Einsum('ab,bc->ac', (3, 4), einsum_op=custom_einsum,
162+
rngs=nnx.Rngs(42))
163+
y = model(jnp.ones((2, 3)))
164+
assert y.shape == (1, 2, 4)
165+
157166

158167
if __name__ == '__main__':
159168
absltest.main()

0 commit comments

Comments
 (0)