Skip to content

in shard_map docs, bidirectional collective matmul, use pad-and-add pattern before dot #20243

Open
@mattjj

Description

@mattjj

code from @sharadmv

  def chunked_einsum(x_split, idx_lo, idx_hi, w):
    x_lo, x_hi = x_split
    x_lo = jnp.pad(x_lo, [(0, 0), (0, chunk_size)])
    x_hi = jnp.pad(x_hi, [(0, 0), (chunk_size, 0)])
    x = x_lo + x_hi
    w_block_lo = jnp.pad(
        jax.lax.dynamic_slice_in_dim(w, 2 * idx_lo * chunk_size, chunk_size, 0),
        [(0, chunk_size), (0, 0)],
    )
    w_block_hi = jnp.pad(
        jax.lax.dynamic_slice_in_dim(
            w, (2 * idx_hi + 1) * chunk_size, chunk_size, 0
        ),
        [(chunk_size, 0), (0, 0)],
    )
    w_block = w_block_lo + w_block_hi
    return x @ w_block

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions