diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 647310dcaacc..1cf8349e7da2 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -103,9 +103,7 @@ def body(start_k, carry): ) # Use m_next instead of m_curr to avoid a correction on l_curr l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr - l_next_rcp = 1. / l_next - s_curr = s_curr * l_next_rcp[:, None] - o_prev_corr = (l_prev_corr * l_next_rcp)[:, None] * o_prev + o_prev_corr = correction[:, None] * o_prev v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) o_curr = pl.dot(s_curr.astype(v.dtype), v) @@ -118,10 +116,15 @@ def body(start_k, carry): upper_bound = pl.cdiv(seq_len, block_k) o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + # We keep an unscaled version of o during the scan over seq_len. Scaling it + # by the last l_i gives us the correct final output. See section 3.1.1 in the + # FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691. + o /= l_i[:, None] + if residual_refs: - l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) + lse_ref = residual_refs[0] + lse_i = m_i + jnp.log(l_i) + pl.store(lse_ref, (curr_q_slice,), lse_i) # Write output to dram. o = o.astype(o_ref.dtype) pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) @@ -258,11 +261,10 @@ def _mha_forward( sm_scale=sm_scale, causal=causal, block_q=block_q, block_k=block_k, block_d=head_dim) out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l - dtype=jnp.float32), - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m - dtype=jnp.float32) + jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out + jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse + ), ] in_specs = [ pl.BlockSpec( @@ -280,7 +282,7 @@ def _mha_forward( if segment_ids is None else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) ) - out, l, m = pl.pallas_call( + out, lse = pl.pallas_call( kernel, grid=grid_, in_specs=in_specs, @@ -289,7 +291,6 @@ def _mha_forward( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), ], compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -299,55 +300,45 @@ def _mha_forward( interpret=interpret, name="mha_forward", )(q, k, v, segment_ids) - return out, (q, k, v, segment_ids, out, l, m) + return out, (q, k, v, segment_ids, out, lse) -def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, - new_dout_ref, delta_ref, *, - block_q: int): +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, *, block_q: int): pid_m = pl.program_id(0) off_m = pl.ds(pid_m * block_q, block_q) # load o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) - denom = pl.load(l_ref, (off_m,)).astype(jnp.float32) # compute - do = do / denom[:, None] delta = jnp.sum(o * do, axis=1) # write-back - pl.store(new_dout_ref, (off_m, slice(None)), - do.astype(new_dout_ref.dtype)) pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) @jax.named_scope("preprocess_backward") -def _preprocess_backward(out, do, l, block_q: int, +def _preprocess_backward(out, do, lse, block_q: int, debug: bool, interpret: bool): batch_size, seq_len, num_heads, head_dim = out.shape - out_shape = [ - jax.ShapeDtypeStruct(do.shape, do.dtype), - jax.ShapeDtypeStruct(l.shape, l.dtype), - ] - do_scaled, delta = pl.pallas_call( + out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) + delta = pl.pallas_call( functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), - ], - out_specs=[ - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec( + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + ), + pl.BlockSpec( + (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + ), ], - compiler_params=dict( - triton=dict(num_warps=4, num_stages=3) - ), + out_specs=pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), out_shape=out_shape, debug=debug, interpret=interpret, - name="mha_preprocess_backward")(out, do, l) - return do_scaled, delta + name="mha_preprocess_backward", + )(out, do) + return delta # This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence @@ -361,8 +352,7 @@ def mha_backward_kernel( segment_ids_ref: jax.Array | None, out_ref, do_scaled_ref, - l_ref, - m_ref, + lse_ref, delta_ref, # Outputs dq_ref, @@ -377,7 +367,7 @@ def mha_backward_kernel( block_k2: int, block_d: int, ): - del out_ref, l_ref # Not needed + del out_ref # Not needed seq_len = q_ref.shape[0] # Scan #1: dK and dV @@ -422,11 +412,11 @@ def inner_loop_dkdv(start_q, carry): ) qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - m = pl.load(m_ref, (curr_q_slice,)) + lse = pl.load(lse_ref, (curr_q_slice,)) di = pl.load(delta_ref, (curr_q_slice,)) do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) - p = jnp.exp(qk - m[:, None]) + p = jnp.exp(qk - lse[:, None]) dv = dv + pl.dot(p.astype(do.dtype).T, do) dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None] dp = dp + pl.dot(do, v.T) @@ -461,7 +451,7 @@ def inner_loop_dkdv(start_q, carry): if segment_ids_ref is None else pl.load(segment_ids_ref, (curr_q_slice,)) ) - m = pl.load(m_ref, (curr_q_slice,)) + lse = pl.load(lse_ref, (curr_q_slice,)) do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) di = pl.load(delta_ref, (curr_q_slice,)) @@ -488,7 +478,7 @@ def inner_loop_dq(start_k, dq): ) qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - p = jnp.exp(qk - m[:, None]) + p = jnp.exp(qk - lse[:, None]) dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None] dp = dp + pl.dot(do, v.T) ds = p * dp @@ -513,7 +503,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid - q, k, v, segment_ids, out, l, m = res + q, k, v, segment_ids, out, lse = res if backward_pass_impl == "xla": return jax.vjp( @@ -527,7 +517,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) - do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) + delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ jax.ShapeDtypeStruct(q.shape, q.dtype), jax.ShapeDtypeStruct(k.shape, k.dtype), @@ -552,7 +542,6 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, ), pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] @@ -593,7 +582,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, debug=debug, interpret=interpret, compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)), - )(q, k, v, segment_ids, out, do_scaled, l, m, delta) + )(q, k, v, segment_ids, out, do, lse, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None