Skip to content

Commit

Permalink
reduce the model size, sequence length, vocab_size for FAKE dataset t…
Browse files Browse the repository at this point in the history
…esting
  • Loading branch information
penxujun committed Mar 14, 2024
1 parent 8ac9323 commit 368130a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ def get_trainer_kwargs(model_size: str, *, vocab_size: int) -> Dict[str, Any]:
if model_size == "test":
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=4,
hidden_dim=128*32,
num_layers=1,
hidden_dim=32,
ffn_dim=scaled_hidden_dim(scale=8 / 3, round_up_to_multiples_of=16),
num_heads=32,
#vocab_size=32,
num_heads=8,
vocab_size=32,
),
learner_kwargs=dict(
peak_lr=6e-4,
weight_decay=0.01,
),
input_partition_type=DataPartitionType.DATA,
#max_sequence_length=64,
max_sequence_length=64,
train_batch_size=8,
max_step=5000,
mesh_shape=mesh_shape_from_axes(data=2, model=4), # gpu
Expand Down

0 comments on commit 368130a

Please sign in to comment.