Open
Description
code from @sharadmv
def chunked_einsum(x_split, idx_lo, idx_hi, w):
x_lo, x_hi = x_split
x_lo = jnp.pad(x_lo, [(0, 0), (0, chunk_size)])
x_hi = jnp.pad(x_hi, [(0, 0), (chunk_size, 0)])
x = x_lo + x_hi
w_block_lo = jnp.pad(
jax.lax.dynamic_slice_in_dim(w, 2 * idx_lo * chunk_size, chunk_size, 0),
[(0, chunk_size), (0, 0)],
)
w_block_hi = jnp.pad(
jax.lax.dynamic_slice_in_dim(
w, (2 * idx_hi + 1) * chunk_size, chunk_size, 0
),
[(chunk_size, 0), (0, 0)],
)
w_block = w_block_lo + w_block_hi
return x @ w_block