Skip to content

Commit

Permalink
Trainer working end to end
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-toulme committed Mar 20, 2024
1 parent b8d05c4 commit ccbe0a6
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 36 deletions.
6 changes: 4 additions & 2 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2968,8 +2968,10 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names)
ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names)
# Encourage the right activation sharding.
ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
#ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
#ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear1.output_partition_spec = (batch_axis_names, None, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, None, None)

if not isinstance(cfg, Sequence):
cfg = [cfg]
Expand Down
1 change: 0 additions & 1 deletion axlearn/common/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,5 +732,4 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
A float Tensor of shape [batch_size, seq_len, vocab_size].
"""
x = x.astype(self.parameters["weight"].dtype) # ptoulme - this is to fix CC op coalescing fp32 bug
return jnp.einsum("bsh,vh->bsv", x, self.parameters["weight"])
2 changes: 1 addition & 1 deletion axlearn/common/neuron_tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_model(self):
stacked_layer = StackedTransformerLayer.default_config()
decoder_cfg = llama_decoder_config(
stack_cfg=stacked_layer,
num_layers=1,
num_layers=4,
hidden_dim=model_dim,
num_heads=num_heads,
vocab_size=vocab_size,
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/neuron_tests/run.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
#SBATCH --exclusive
#SBATCH --nodes=1

srun --kill-on-bad-exit=1 run_test.sh
srun --kill-on-bad-exit=1 run_trainer.sh
10 changes: 10 additions & 0 deletions axlearn/common/neuron_tests/run_trainer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#! /bin/bash
source /shared_new/ptoulme/axlearn/venv/bin/activate
source ./setup.sh
source ./train_setup.sh

OUTPUT_DIR=./c4_test_dump
DATA_DIR=gs://axlearn-public/tensorflow_datasets
python3 -m axlearn.common.launch_trainer_main \
--module=text.gpt.c4_trainer --config=fuji-test \
--trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR --jax_backend=neuron
9 changes: 5 additions & 4 deletions axlearn/common/neuron_tests/train_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ echo $DISTRIBUTED_ARGS
#export TF_CPP_MIN_LOG_LEVEL=0 # Enable SPMD verbose logging - 0 means most verbose
#export TF_CPP_MAX_VLOG_LEVEL=2 # Needs above flag for logging but goes in reverse. 0 means no log
#export TF_CPP_MIN_LOG_LEVEL=0
#export TF_CPP_MAX_VLOG_LEVEL=5
#export TF_CPP_MAX_VLOG_LEVEL=1

export PJRT_DEVICE="NEURON"
export NEURON_RT_NUM_CORES=32
Expand All @@ -46,8 +46,9 @@ export NEURON_INTERNAL_USE_VANILLA_TORCH_XLA=1
export NEURON_USE_VANILLA_TORCH_XLA=1
export NEURON_TRANSFER_WITH_STATIC_RING_OPS=""
export NEURON_TRANSFER_ALL_PARAMETERS_WITH_STATIC_RING=0
export XLA_FLAGS="--xla_force_host_platform_device_count=32 --xla_dump_hlo_as_text --xla_dump_hlo_as_proto --xla_dump_to=./jax_dump_new --xla_dump_hlo_pass_re='.*'"

export XLA_FLAGS="--xla_force_host_platform_device_count=32 --xla_dump_hlo_as_text --xla_dump_hlo_as_proto --xla_dump_to=./trainer_dump --xla_dump_hlo_pass_re='.*'"
# To run on CPU set below
#export JAX_PLATFORMS='cpu'
#Snapshotting
#export XLA_FLAGS=" --xla_dump_hlo_snapshots --xla_dump_to=/shared/ptoulme/GSPMD/NeuronGSPMDTests/src/NeuronGSPMDTests/transformers/snapshots"
export XLA_IR_DEBUG=1
Expand All @@ -59,7 +60,7 @@ export XLA_HLO_DEBUG=1
export XLA_USE_BF16=1
#export NEURON_CC_FLAGS="--dump=./compiler_dump --framework=XLA --model-type=transformer --distribution-strategy=llm-training -O1 --no-internal-hlo-remat"
export NEURON_CC_FLAGS="--dump=./compiler_dump --framework=XLA --model-type transformer --internal-io-to-internal-dmacopy-insertion --enable-mixed-precision-accumulation -O1"

#export NEURON_CC_FLAGS="--dump=./compiler_dump --framework=XLA --model-type transformer --no-internal-hlo-remat --distribution-strategy=llm-training --enable-mixed-precision-accumulation -O1"
export NEURON_RT_STOCHASTIC_ROUNDING_EN=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=5

Expand Down
73 changes: 58 additions & 15 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,19 @@
prune_tree,
thread_stack_traces,
)

import jax
import optax
from jax import numpy as jnp
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from axlearn.common.base_layer import ParameterSpec
from axlearn.common.learner import Learner
from axlearn.common.module import functional as F, InvocationContext
from axlearn.common.optimizer_base import NestedOptParam, OptParam
from axlearn.common.optimizers import AddDecayedWeightsState
from typing import Any,Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from axlearn.common.utils import Tensor, VDict, NestedTensor, TensorSpec
import os

def _prune_empty(in_tree: NestedTensor) -> NestedTensor:
"""Returns a shallow copy of the input tree with empty subtrees pruned.
Expand Down Expand Up @@ -578,39 +590,46 @@ def _init_state_cpu(prng_key: Tensor, prebuilt_model_state: NestedTensor):

cpu_device = jax.devices("cpu")[0]
with jax.default_device(cpu_device):
logging.info("prebuilt_model_state: %s", utils.shapes(prebuilt_model_state))
model_params = self.model.initialize_parameters_recursively(
init_key,
prebuilt=prebuilt_model_state,
)
learner_params = self.learner.init(self._opt_params(model_params))
logging.info("CPU initialization completed.")
return prng_key, model_params, learner_params

return prng_key, model_params

def _move_state_to_neuron(prng_key: Tensor, model_params):
def _move_state_to_neuron(prng_key: Tensor, model_params, learner_params):
model_params = jax.device_put(model_params)
learner_params = jax.device_put(learner_params)
self.vlog(
1, "tree_structure(model_params)=%s", jax.tree_util.tree_structure(model_params)
)
learner_params = self.learner.init(self._opt_params(model_params))
return TrainerState(
prng_key=prng_key,
model=model_params,
learner=learner_params,
)

logging.info("prebuilt_model_state_partition_spec: %s", prebuilt_model_state_partition_spec)
logging.info("trainer_state_partition_specs: %s", self._trainer_state_partition_specs)
init_computation = pjit(
model_specs = jax.tree_util.tree_map(
lambda value: create_named_sharding(value, self.mesh()) if isinstance(value, PartitionSpec) else None,
self._trainer_state_partition_specs[1],
)
learner_specs = jax.tree_util.tree_map(
lambda value: create_named_sharding_optimizer(value, self.mesh()) if isinstance(value, TensorSpec) else None,
self._learner_state_partition_specs,
)
init_computation = jax.jit(
#_init_state,
_move_state_to_neuron,
in_shardings=(None, prebuilt_model_state_partition_spec),
out_shardings=self._trainer_state_partition_specs,
in_shardings=(None, model_specs, learner_specs),
)
self._step_log("Initializing trainer state.")
cpu_device = jax.devices("cpu")[0]
with jax.default_device(cpu_device):
prng_key, model_params, learner_params = _init_state_cpu(prng_key, prebuilt_model_state)
with self.mesh():
#self._trainer_state = init_computation(prng_key, prebuilt_model_state)
prng_key, model_params = _init_state_cpu(prng_key, prebuilt_model_state)
self._trainer_state = init_computation(prng_key, model_params)
self._trainer_state = init_computation(prng_key, model_params, learner_params)
logging.info("Transfer to device completed.")

def _log_trainer_state_stats(self):
total_num_params = count_model_params(self._trainer_state.model)
Expand Down Expand Up @@ -867,7 +886,7 @@ def _pjit_train_step(self) -> jax.stages.Wrapped:
aux=None,
),
),
donate_argnums=(0,), # donate the state
#donate_argnums=(0,), # donate the state - neuron doesnt support this
)

def compile_train_step(self) -> jax.stages.Compiled:
Expand Down Expand Up @@ -1033,3 +1052,27 @@ def select_mesh_config(trainer_config: SpmdTrainer.Config, *, mesh_selector: str
logging.info("Mesh selector %s matches mesh rule %s", mesh_selector, mesh)
if mesh is not REQUIRED:
trainer_config.mesh_shape = mesh

def create_named_sharding_optimizer(tensor_spec, mesh):
zero1=True
if isinstance(tensor_spec, TensorSpec):
if tensor_spec.mesh_axes == (None,):
return NamedSharding(mesh, PartitionSpec(None))
else:
if len(tensor_spec.mesh_axes) > len(tensor_spec.shape):
adjusted_mesh_axes = tensor_spec.mesh_axes[1:]
else:
adjusted_mesh_axes = tensor_spec.mesh_axes
if zero1:
adjusted_mesh_axes = tuple('data' if axis == 'fsdp' else axis for axis in adjusted_mesh_axes)
partition_spec = PartitionSpec(*adjusted_mesh_axes)
return NamedSharding(mesh, partition_spec)
return tensor_spec

def create_named_sharding(param_spec, mesh):
if isinstance(param_spec, PartitionSpec):
return NamedSharding(
mesh,
param_spec
)
return param_spec
17 changes: 12 additions & 5 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TransformerLayer,
build_remat_spec,
set_double_shard_weights_config,
StackedTransformerLayer
)
from axlearn.common.checkpointer import every_n_steps_policy
from axlearn.common.config import (
Expand All @@ -45,7 +46,7 @@
maybe_instantiate,
maybe_set_config,
)
from axlearn.common.decoder import Decoder
from axlearn.common.decoder import Decoder, LmHead
from axlearn.common.embedding import TransformerTextEmbeddings
from axlearn.common.evaler import BaseMetricCalculator, ModelSummaryAccumulator, SpmdEvaler
from axlearn.common.evaler import every_n_steps_policy as eval_every_n_steps_policy
Expand Down Expand Up @@ -214,7 +215,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
if stack_cfg.klass is RepeatedTransformerLayer or stack_cfg.klass is StackedTransformerLayer:
# Enable remat to reduce memory usage for larger models.
layer_cfg.remat_spec = build_remat_spec(stack_cfg)
# Stack.
Expand All @@ -226,6 +227,7 @@ def model_config(
vocab_size=vocab_size,
emb=emb_cfg,
dropout_rate=dropout_rate,
lm_head=LmHead.default_config().set(dtype=jnp.bfloat16) #bfloat16
)
# Model.
model_param_init = DefaultInitializer.default_config().set(
Expand All @@ -242,16 +244,20 @@ def model_config(
batch_axis_names=batch_axis_names,
seq_axis_names="seq",
)
cfg.dtype = jnp.float32
cfg.dtype = jnp.bfloat16
# Shard some FFN and attention weights over multiple axes.
set_double_shard_weights_config(
cfg.decoder.transformer.layer,
batch_axis_names=batch_axis_names,
fsdp_axis_names=("expert", "fsdp", "seq"),
fsdp_axis_names="data",
tp_axis_names="model",
seq_axis_names=("seq",),
)
cfg.decoder.logits_partition_spec = (batch_axis_names, "seq", "model")
tp_axis_names='model'
fsdp_axis_names='data'
cfg.decoder.emb.token_emb.param_partition_spec = (tp_axis_names, fsdp_axis_names) # shard vocab
cfg.decoder.lm_head.param_partition_spec = (tp_axis_names, fsdp_axis_names) # shard vocab
#cfg.decoder.logits_partition_spec = (batch_axis_names, "seq", "model")
set_bias_recursively(cfg, False)
set_norm_recursively(cfg, normalization)
cfg.z_loss_scale = z_loss_scale
Expand Down Expand Up @@ -290,6 +296,7 @@ def learner_config(
weight_decay=weight_decay,
weight_decay_per_param_scale=None,
adam_update_transformation=None,
mu_dtype=jnp.bfloat16
),
]
)
Expand Down
15 changes: 8 additions & 7 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FusedQKVLinear,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer
)
from axlearn.common.utils import DataPartitionType
from axlearn.common.embedding import TransformerTextEmbeddings
Expand All @@ -35,21 +36,21 @@ def get_trainer_kwargs(model_size: str, *, vocab_size: int) -> Dict[str, Any]:
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=1,
hidden_dim=32,
hidden_dim=1024,
#ffn_dim=scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=16),
ffn_dim=scaled_hidden_dim(scale=4, round_up_to_multiples_of=16),
num_heads=8,
vocab_size=32,
num_heads=32,
vocab_size=8000,
),
learner_kwargs=dict(
peak_lr=6e-4,
weight_decay=0.01,
),
input_partition_type=DataPartitionType.DATA,
max_sequence_length=64,
train_batch_size=8,
max_sequence_length=2048,
train_batch_size=4,
max_step=5000,
mesh_shape=mesh_shape_from_axes(data=4, model=8), # gpu
mesh_shape=mesh_shape_from_axes(data=4, model=8),
)
elif model_size == "7B":
trainer_kwargs = dict(
Expand Down Expand Up @@ -121,7 +122,7 @@ def model_config(
hidden_dim=hidden_dim,
num_heads=num_heads,
vocab_size=vocab_size,
stack_cfg=RepeatedTransformerLayer.default_config(),
stack_cfg=StackedTransformerLayer.default_config(), # Repeated transformer layer breaks Neuron
activation_fn=activation_fn,
ffn_dim=ffn_dim,
normalization=RMSNorm.default_config().set(eps=1e-5, forward_dtype=None),
Expand Down

0 comments on commit ccbe0a6

Please sign in to comment.