36
36
PaddingLike ,
37
37
LaxPadding ,
38
38
PromoteDtypeFn ,
39
+ EinsumT ,
39
40
)
40
41
41
42
Array = jax .Array
@@ -426,6 +427,8 @@ class Einsum(Module):
426
427
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
427
428
and a ``dtype`` keyword argument, and return a tuple of arrays with the
428
429
promoted dtype.
430
+ einsum_op: An injectable alternative of `jnp.einsum` to do the computation.
431
+ Should support same signature as `jnp.einsum`.
429
432
rngs: rng key.
430
433
"""
431
434
@@ -441,6 +444,7 @@ def __init__(
441
444
kernel_init : Initializer = default_kernel_init ,
442
445
bias_init : Initializer = default_bias_init ,
443
446
promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
447
+ einsum_op : EinsumT = jnp .einsum ,
444
448
rngs : rnglib .Rngs ,
445
449
):
446
450
einsum_str = einsum_str .replace (' ' , '' )
@@ -465,6 +469,7 @@ def __init__(
465
469
self .kernel_init = kernel_init
466
470
self .bias_init = bias_init
467
471
self .promote_dtype = promote_dtype
472
+ self .einsum_op = einsum_op
468
473
469
474
def __call__ (
470
475
self , inputs : Array , einsum_str : tp .Optional [str ] = None
@@ -500,7 +505,7 @@ def __call__(
500
505
dtype = self .dtype ,
501
506
)
502
507
503
- y = jnp . einsum (einsum_str , inputs , kernel , precision = self .precision )
508
+ y = self . einsum_op (einsum_str , inputs , kernel , precision = self .precision )
504
509
505
510
if bias is not None :
506
511
broadcasted_bias_shape = self ._infer_broadcasted_bias_shape (
0 commit comments