Skip to content

[Do not review] Fuji inference examples #1204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
import jax
from absl import logging
from jax import numpy as jnp
from jax._src.mesh import thread_resources
from jax.sharding import NamedSharding

from axlearn.common import param_init
from axlearn.common.attention_bias import (
Expand Down Expand Up @@ -1645,6 +1647,18 @@ def _forward_for_mode(
query_positions = query_positions + time_step[:, None] # [batch, steps]
q_proj, k_proj, v_proj = self.i_proj(query, query_positions=query_positions, **kv_kwargs)

q_proj = self._remat_name(q_proj, "q_proj")
k_proj = self._remat_name(k_proj, "k_proj")
v_proj = self._remat_name(v_proj, "v_proj")

# Scale query and key.
q_proj, k_proj = self._scale_qk(
q_proj=q_proj,
k_proj=k_proj,
query_positions=query_positions,
key_positions=None, # ScaleKey doesn't use positions.
)

if mode == ForwardMode.FORWARD:
new_cached_states = dict()
key_positions = jnp.arange(k_proj.shape[1])[None]
Expand Down Expand Up @@ -1685,18 +1699,6 @@ def _forward_for_mode(
else:
raise ValueError(f"Unrecognized mode {mode}.")

q_proj = self._remat_name(q_proj, "q_proj")
k_proj = self._remat_name(k_proj, "k_proj")
v_proj = self._remat_name(v_proj, "v_proj")

# Scale query and key.
q_proj, k_proj = self._scale_qk(
q_proj=q_proj,
k_proj=k_proj,
query_positions=query_positions,
key_positions=key_positions,
)

self.vlog(3, "atten.q_proj=%s", q_proj.sum())
self.vlog(3, "atten.k_proj=%s", k_proj.sum())
self.vlog(3, "atten.v_proj=%s", v_proj.sum())
Expand Down Expand Up @@ -2623,6 +2625,10 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
return dict(attention=atten_state), atten_output

if cfg.structure == "prenorm":
target = jax.lax.with_sharding_constraint(
target,
NamedSharding(thread_resources.env.physical_mesh, PartitionSpec(None, None, None)),
)
skip_input = target # pre-norm: where normalization happens within the residual part.
norm_target = self.norm(target)
atten_state, atten_output = attention_thunk(norm_target)
Expand Down
6 changes: 5 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def get_trainer_kwargs(
vocab_size: int,
version: Version,
flash_attention: bool = False,
use_stacked: bool = False,
) -> dict[str, Any]:
"""Construct default trainer kwargs given a model size."""
tokens_per_batch = TOKENS_PER_BATCH[version]
Expand Down Expand Up @@ -816,7 +817,10 @@ def get_trainer_kwargs(
if version == Version.V3_TIKTOKEN: # tiktoken tokenizer
model_kwargs["pad_token_id"] = 128004
model_kwargs["eos_token_id"] = 128001
trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["model_cfg"] = model_config(
**model_kwargs,
stack_cfg=None if not use_stacked else StackedTransformerLayer.default_config(),
)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
**trainer_kwargs.pop("learner_kwargs"),
Expand Down
221 changes: 221 additions & 0 deletions axlearn/experiments/text/gpt/fuji_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright © 2023 Apple Inc.

"""An example to run offline inference for fuji.

Some optimizations are ported from the not-yet open sourced inference engine so the forward pass
performance is comparable to what we have internally.

Some known issues:
1. KVCache update isn't efficient and potentially requires copying the whole KV cache. This could
be solved by either changing the KV cache update method or using PagedKVCache.
"""

import functools
import zlib
from typing import Optional, Sequence

import jax
import jax.numpy as jnp

from axlearn.common.attention import (
BaseLayer,
FusedGroupedQKVLinear,
GroupedQKVLinear,
Module,
TransformerFeedForwardLayer,
TransformerLayer,
set_attention_partition_specs,
)
from axlearn.common.config import ConfigBase
from axlearn.common.inference import DataPartitionType, InferenceRunner
from axlearn.common.state_builder import Builder, TensorStoreStateStorageBuilder
from axlearn.common.utils import Tensor
from axlearn.experiments.text.gpt.common import MESH_AXIS_NAMES, mesh_shape_from_axes
from axlearn.experiments.text.gpt.fuji import Version, get_trainer_kwargs


def set_inference_partition_spec(cfg: ConfigBase) -> ConfigBase:
"""Set inference-friendly model weight and activation sharding."""
if isinstance(cfg, TransformerLayer.Config):
raise ValueError("TransformerLayer cannot be the root of the input model config.")

batch_axis_names = ("data", "expert", "fsdp")
fsdp_axis_names = "fsdp"
tp_axis_names = "model"
seq_axis_names = "seq"

def enter_fn(_, layer_cfg, default_kv):
if isinstance(layer_cfg, TransformerLayer.Config):
set_attention_partition_specs(
layer_cfg.self_attention.attention,
fsdp_axis_names=fsdp_axis_names,
tp_axis_names=tp_axis_names,
)
if layer_cfg.cross_attention is not None:
set_attention_partition_specs(
layer_cfg.cross_attention.attention,
fsdp_axis_names=fsdp_axis_names,
tp_axis_names=tp_axis_names,
)
if isinstance(layer_cfg.feed_forward, TransformerFeedForwardLayer.Config):
cfg = layer_cfg.feed_forward
# Shard weights.
cfg.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names)
cfg.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names)
cfg.linear1.output_partition_spec = (
batch_axis_names,
seq_axis_names,
tp_axis_names,
)
# Do not shard output. This is to avoid having reduce-scatter + allgather instead of
# a single allreduce. The latter has lower latency for inference.
cfg.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, None)
return default_kv

cfg.visit(visit_fn=lambda k, v: None, enter_fn=enter_fn)
return cfg


class DummyBuilder(TensorStoreStateStorageBuilder):
"""A dummy state builder that returns random weights."""

def input_state_type(self) -> Builder.StateType:
return Builder.StateType.TENSOR_SPECS

def __call__(self, state: Builder.State) -> Builder.State:
cfg = self.config
seed = zlib.adler32(cfg.dir.encode("utf-8"))

key = jax.random.PRNGKey(seed)
out_shardings = jax.tree.map(lambda spec: spec.sharding, state.trainer_state)

@functools.partial(jax.jit, out_shardings=out_shardings)
def jit_init():
return jax.tree.map(
lambda spec: jax.random.normal(key=key, shape=spec.shape, dtype=spec.dtype)
if spec.dtype in [jnp.bfloat16, jnp.float32]
else jnp.zeros(shape=spec.shape, dtype=spec.dtype),
state.trainer_state,
)

return Builder.State(
step=0,
trainer_state=jit_init(),
built_keys=set(),
)


def replace_layer_config_recursively(
cfg: ConfigBase,
*,
target_cls: Module,
source_config: BaseLayer.Config,
exclude_keys: Optional[Sequence[str]] = None,
) -> ConfigBase:
"""Replaces the target_cls's config with the source_config.

This function is useful when one wants to replace a specific layer in the model, e.g.,
replacing MultiheadAttention layer with GroupedQueryAttention. Note that the target layer
should not be the root layer in the given model config.

Example usage -- Replacing MultiheadAttention by GroupedQueryAttention:
model_cfg = ... # The original model config.
replace_layer_config_recursively(
model_cfg,
target_cls=MultiheadAttention,
source_config=GroupedQueryAttention.default_config().set(num_kv_heads=8),
exclude_keys=["num_kv_heads"],
)

Args:
cfg: A ConfigBase, usually a top-level model config or a trainer config.
target_cls: A Module, the target layer class to be replaced.
source_config: A new BaseLayer config to be put into the model.
exclude_keys: A sequence of strings specifying which fields in the source config should
not be copied from the target config. By default, only klass is excluded.

Return:
A ConfigBase with the modified configs. This function also revises the input config in
place. So it is okay to not return anything.

Raises:
ValueError: If the target layer is the root of the input cfg.
"""
if isinstance(cfg, target_cls.Config):
raise ValueError("The target cls cannot be the root of the input model config.")

exclude_kwargs = set(["klass"])
if exclude_keys is not None:
exclude_kwargs.update(exclude_keys)

def enter_fn(_, child, default_kv):
if isinstance(child, ConfigBase):
for key, value in child.items():
if isinstance(value, target_cls.Config):
new_cfg = source_config.set(
**{k: v for k, v in value.items() if k not in exclude_kwargs},
)
setattr(child, key, new_cfg)
return default_kv

cfg.visit(visit_fn=lambda k, v: None, enter_fn=enter_fn)
return cfg


# Stop the generation early for profiling purposes.
class LengthStopingCondition:
def __init__(self, length: int):
self._length = length

def __call__(self, *, index: Tensor, sequences: Tensor, out_of_prompt: Tensor) -> Tensor:
return jnp.broadcast_to((index >= self._length)[:, None], out_of_prompt.shape)


# StackedTransformer is faster when doing inference.
model_cfg = get_trainer_kwargs(
"70B", vocab_size=32000, version=Version.V3, flash_attention=True, use_stacked=True
)["model_cfg"]
# Groupde QKV linear has better sharding support.
model_cfg = replace_layer_config_recursively(
model_cfg,
target_cls=FusedGroupedQKVLinear,
source_config=GroupedQKVLinear.default_config(),
)
model_cfg = set_inference_partition_spec(model_cfg)
model_cfg.decoder.emb.token_emb.param_partition_spec = ("model", ("expert", "fsdp", "seq"))


inference_runner_cfg = InferenceRunner.default_config().set(
mesh_shape=mesh_shape_from_axes(model=8),
mesh_axis_names=MESH_AXIS_NAMES,
model=model_cfg,
inference_dtype=jnp.bfloat16,
input_batch_partition_spec=DataPartitionType.REPLICATED,
init_state_builder=DummyBuilder.default_config().set(dir="dummy"),
name="runner",
)
print(inference_runner_cfg)

inference_runner = inference_runner_cfg.instantiate(parent=None)
model = model_cfg.set(name="prefill_model").instantiate(parent=None)
prng_key = jax.random.PRNGKey(0)
stopping_cond = LengthStopingCondition(3)
input_tokens = jnp.zeros((32, 4096), dtype=jnp.int32)


def input_iter():
# Warm up.
yield {"prefix": input_tokens}
jax.profiler.start_trace("/tmp/gpt_test/summaries")
yield {"prefix": input_tokens}


for r in inference_runner.run(
input_batches=input_iter(),
method="sample_decode",
prng_key=prng_key,
stop_decoding_condition=stopping_cond,
):
jax.block_until_ready(r)

jax.profiler.stop_trace()
Loading