Skip to content

Commit

Permalink
Merge pull request #22923 from Rifur13:faster
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660736990
  • Loading branch information
jax authors committed Aug 8, 2024
2 parents 42fe45f + e6425a2 commit 9fbc51b
Showing 1 changed file with 39 additions and 50 deletions.
89 changes: 39 additions & 50 deletions jax/experimental/pallas/ops/gpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def body(start_k, carry):
) # Use m_next instead of m_curr to avoid a correction on l_curr
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
l_next_rcp = 1. / l_next
s_curr = s_curr * l_next_rcp[:, None]
o_prev_corr = (l_prev_corr * l_next_rcp)[:, None] * o_prev
o_prev_corr = correction[:, None] * o_prev
v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d)))
o_curr = pl.dot(s_curr.astype(v.dtype), v)

Expand All @@ -118,10 +116,15 @@ def body(start_k, carry):
upper_bound = pl.cdiv(seq_len, block_k)
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))

# We keep an unscaled version of o during the scan over seq_len. Scaling it
# by the last l_i gives us the correct final output. See section 3.1.1 in the
# FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691.
o /= l_i[:, None]

if residual_refs:
l_ref, m_ref = residual_refs
pl.store(l_ref, (curr_q_slice,), l_i)
pl.store(m_ref, (curr_q_slice,), m_i)
lse_ref = residual_refs[0]
lse_i = m_i + jnp.log(l_i)
pl.store(lse_ref, (curr_q_slice,), lse_i)
# Write output to dram.
o = o.astype(o_ref.dtype)
pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o)
Expand Down Expand Up @@ -258,11 +261,10 @@ def _mha_forward(
sm_scale=sm_scale, causal=causal, block_q=block_q,
block_k=block_k, block_d=head_dim)
out_shape = [
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l
dtype=jnp.float32),
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m
dtype=jnp.float32)
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
jax.ShapeDtypeStruct(
shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse
),
]
in_specs = [
pl.BlockSpec(
Expand All @@ -280,7 +282,7 @@ def _mha_forward(
if segment_ids is None
else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0))
)
out, l, m = pl.pallas_call(
out, lse = pl.pallas_call(
kernel,
grid=grid_,
in_specs=in_specs,
Expand All @@ -289,7 +291,6 @@ def _mha_forward(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
],
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
Expand All @@ -299,55 +300,45 @@ def _mha_forward(
interpret=interpret,
name="mha_forward",
)(q, k, v, segment_ids)
return out, (q, k, v, segment_ids, out, l, m)
return out, (q, k, v, segment_ids, out, lse)


def _preprocess_backward_kernel(out_ref, dout_ref, l_ref,
new_dout_ref, delta_ref, *,
block_q: int):
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, *, block_q: int):
pid_m = pl.program_id(0)

off_m = pl.ds(pid_m * block_q, block_q)
# load
o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32)
do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32)
denom = pl.load(l_ref, (off_m,)).astype(jnp.float32)
# compute
do = do / denom[:, None]
delta = jnp.sum(o * do, axis=1)
# write-back
pl.store(new_dout_ref, (off_m, slice(None)),
do.astype(new_dout_ref.dtype))
pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype))

@jax.named_scope("preprocess_backward")
def _preprocess_backward(out, do, l, block_q: int,
def _preprocess_backward(out, do, lse, block_q: int,
debug: bool, interpret: bool):
batch_size, seq_len, num_heads, head_dim = out.shape
out_shape = [
jax.ShapeDtypeStruct(do.shape, do.dtype),
jax.ShapeDtypeStruct(l.shape, l.dtype),
]
do_scaled, delta = pl.pallas_call(
out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype)
delta = pl.pallas_call(
functools.partial(_preprocess_backward_kernel, block_q=block_q),
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
in_specs=[
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
],
out_specs=[
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
],
compiler_params=dict(
triton=dict(num_warps=4, num_stages=3)
),
out_specs=pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)),
compiler_params=dict(triton=dict(num_warps=4, num_stages=3)),
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_preprocess_backward")(out, do, l)
return do_scaled, delta
name="mha_preprocess_backward",
)(out, do)
return delta


# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence
Expand All @@ -361,8 +352,7 @@ def mha_backward_kernel(
segment_ids_ref: jax.Array | None,
out_ref,
do_scaled_ref,
l_ref,
m_ref,
lse_ref,
delta_ref,
# Outputs
dq_ref,
Expand All @@ -377,7 +367,7 @@ def mha_backward_kernel(
block_k2: int,
block_d: int,
):
del out_ref, l_ref # Not needed
del out_ref # Not needed
seq_len = q_ref.shape[0]

# Scan #1: dK and dV
Expand Down Expand Up @@ -422,11 +412,11 @@ def inner_loop_dkdv(start_q, carry):
)
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)

m = pl.load(m_ref, (curr_q_slice,))
lse = pl.load(lse_ref, (curr_q_slice,))
di = pl.load(delta_ref, (curr_q_slice,))
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))

p = jnp.exp(qk - m[:, None])
p = jnp.exp(qk - lse[:, None])
dv = dv + pl.dot(p.astype(do.dtype).T, do)
dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
Expand Down Expand Up @@ -461,7 +451,7 @@ def inner_loop_dkdv(start_q, carry):
if segment_ids_ref is None
else pl.load(segment_ids_ref, (curr_q_slice,))
)
m = pl.load(m_ref, (curr_q_slice,))
lse = pl.load(lse_ref, (curr_q_slice,))
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))
di = pl.load(delta_ref, (curr_q_slice,))

Expand All @@ -488,7 +478,7 @@ def inner_loop_dq(start_k, dq):
)
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)

p = jnp.exp(qk - m[:, None])
p = jnp.exp(qk - lse[:, None])
dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
Expand All @@ -513,7 +503,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
num_stages: int, grid: Any, interpret: bool,
debug: bool, res, do):
del num_warps, num_stages, grid
q, k, v, segment_ids, out, l, m = res
q, k, v, segment_ids, out, lse = res

if backward_pass_impl == "xla":
return jax.vjp(
Expand All @@ -527,7 +517,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
delta = _preprocess_backward(out, do, lse, block_q, debug, interpret)
out_shapes = [
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
Expand All @@ -552,7 +542,6 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
]
if segment_ids is None:
in_specs.insert(3, None) # type: ignore[arg-type]
Expand Down Expand Up @@ -593,7 +582,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
debug=debug,
interpret=interpret,
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)),
)(q, k, v, segment_ids, out, do_scaled, l, m, delta)
)(q, k, v, segment_ids, out, do, lse, delta)
else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv, None
Expand Down

0 comments on commit 9fbc51b

Please sign in to comment.