Skip to content

Commit

Permalink
Dockerfile and update train_gpt2.py to most recent record
Browse files Browse the repository at this point in the history
  • Loading branch information
bluecoconut committed Nov 12, 2024
1 parent 49b70ad commit ff93d20
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 50 deletions.
31 changes: 31 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04

ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHON_VERSION=3.12.7
ENV PATH=/usr/local/bin:$PATH

RUN apt update && apt install -y --no-install-recommends build-essential libssl-dev zlib1g-dev \
libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev \
libxmlsec1-dev libffi-dev liblzma-dev \
&& apt clean && rm -rf /var/lib/apt/lists/*

RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \
tar -xzf Python-${PYTHON_VERSION}.tgz && \
cd Python-${PYTHON_VERSION} && \
./configure --enable-optimizations && \
make -j$(nproc) && \
make altinstall && \
cd .. && \
rm -rf Python-${PYTHON_VERSION} Python-${PYTHON_VERSION}.tgz

RUN ln -s /usr/local/bin/python3.12 /usr/local/bin/python && \
ln -s /usr/local/bin/pip3.12 /usr/local/bin/pip

COPY requirements.txt /modded-nanogpt/requirements.txt
WORKDIR /modded-nanogpt

RUN python -m pip install --upgrade pip && \
pip install -r requirements.txt

CMD ["bash"]
ENTRYPOINT []
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ For comparison, the default llm.c PyTorch trainer yields [>3.28 validation loss
Both of these changes will have no effect on the training - you should get the exact same loss curve as the most recent record, because the training code
will automatically adjust the gradient accumulation in order to have the same total batch size.

## Running with Docker

For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative.
This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup.
Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available).

```bash
sudo docker build -t modded-nanogpt .
sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 18
sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh
```
---

## World record history
Expand Down
116 changes: 66 additions & 50 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def zeropower_via_svd(G, steps=None):

@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
"""
r"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
Expand Down Expand Up @@ -125,20 +125,17 @@ class Rotary(torch.nn.Module):

def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.inv_freq = None
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None

def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq).to(x.device)
self.cos_cached = freqs.cos().bfloat16()
self.sin_cached = freqs.sin().bfloat16()
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
Expand Down Expand Up @@ -238,6 +235,13 @@ def __init__(self, config):
wte = nn.Embedding(config.vocab_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
))

# U-net design by @brendanh0gan
self.encoder_layers = config.n_layer // 2 # Half of the layers for encoder
self.decoder_layers = config.n_layer - self.encoder_layers # Remaining for decoder
# Add learnable skip connection weights for decoder layers
self.skip_weights = nn.Parameter(torch.ones(self.decoder_layers))

self.lm_head = CastedLinear(config.n_embd, config.vocab_size, bias=False)
self.lm_head.weight.data.zero_() # @Grad62304977

Expand All @@ -248,10 +252,23 @@ def forward(self, idx, target):
x = F.rms_norm(x, (x.size(-1),)) # @Grad62304977
x0 = x
v1 = None
for block in self.transformer.h:
x, v1 = block(x, v1, x0)
x = F.rms_norm(x, (x.size(-1),))

# Store outputs for U-Net skip connections
skip_connections = []

# Encoder pass - process only the first half of the blocks
for i in range(self.encoder_layers):
x, v1 = self.transformer.h[i](x, v1, x0)
skip_connections.append(x) # Store the output for skip connections

# Decoder pass - process the remaining blocks with weighted skip connections
for i in range(self.decoder_layers):
skip_connection = skip_connections.pop() # Get the corresponding encoder output
# Apply learnable weight to skip connection
weighted_skip = self.skip_weights[i] * skip_connection
x, v1 = self.transformer.h[self.encoder_layers + i](x + weighted_skip, v1, x0)

x = F.rms_norm(x, (x.size(-1),))
logits = self.lm_head(x)
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
logits = logits.float()
Expand Down Expand Up @@ -345,9 +362,9 @@ class Hyperparameters:
batch_size : int = 8*64 # batch size, in sequences, across all devices
device_batch_size : int = 64 # batch size, in sequences, per device
sequence_length : int = 1024 # sequence length, in tokens
num_iterations : int = 3242 # number of iterations to run
num_iterations : int = 3000 # number of iterations to run
warmup_iters : int = 0
warmdown_iters : int = 926 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
warmdown_iters : int = 900 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
weight_decay : float = 0
# evaluation and logging hyperparams
val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
Expand All @@ -366,33 +383,6 @@ class Hyperparameters:
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.

# begin logging
logfile = None
if master_process:
run_id = str(uuid.uuid4())
logdir = 'logs/%s/' % run_id
os.makedirs(logdir, exist_ok=True)
logfile = 'logs/%s.txt' % run_id
# create the log file
with open(logfile, "w") as f:
# begin the log by printing this file (the Python code)
f.write('='*100 + '\n')
f.write(code)
f.write('='*100 + '\n')
def print0(s, logonly=False):
if master_process:
with open(logfile, "a") as f:
if not logonly:
print(s)
f.write(s+'\n')
# log information about the hardware/software environment this is running on
# and print the full `nvidia-smi` to file
print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:")
import subprocess
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print0(f'{result.stdout}', logonly=True)
print0('='*100, logonly=True)

# convenience variables
B, T = args.device_batch_size, args.sequence_length
# calculate the number of steps to take in the val loop.
Expand All @@ -405,9 +395,9 @@ def print0(s, logonly=False):
# load tokens
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
print0('='*100, logonly=True)
if master_process:
print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
x, y = train_loader.next_batch()

# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
Expand All @@ -418,6 +408,7 @@ def print0(s, logonly=False):
for m in model.modules():
if isinstance(m, CastedLinear):
m.float()

if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
Expand All @@ -433,13 +424,13 @@ def print0(s, logonly=False):
enable_math_sdp(False)

# init the optimizer(s)
optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight], lr=0.3, betas=(0.9, 0.95), fused=True)
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.002, betas=(0.9, 0.95), fused=True)
optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight], lr=0.6, betas=(0.9, 0.95), fused=True)
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.9, 0.95), fused=True)
params = list(raw_model.transformer.h.parameters())
matrix_params = [p for p in params if p.ndim == 2]
scalar_params = [p for p in params if p.ndim < 2]
optimizer3 = Muon(matrix_params, lr=0.02, momentum=0.95)
optimizer4 = torch.optim.Adam(scalar_params, lr=0.02, betas=(0.9, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned
scalar_params = [p for p in params if p.ndim < 2]+[raw_model.skip_weights]
optimizer3 = Muon(matrix_params, lr=0.04, momentum=0.95)
optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.9, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned
optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
Expand All @@ -456,7 +447,26 @@ def get_lr(it):
return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# Start training loop
# begin logging
if master_process:
run_id = str(uuid.uuid4())
logdir = 'logs/%s/' % run_id
os.makedirs(logdir, exist_ok=True)
logfile = 'logs/%s.txt' % run_id
# create the log file
with open(logfile, "w") as f:
# begin the log by printing this file (the Python code)
f.write('='*100 + '\n')
f.write(code)
f.write('='*100 + '\n')
# log information about the hardware/software environment this is running on
# and print the full `nvidia-smi` to file
f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
import subprocess
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
f.write(f'{result.stdout}\n')
f.write('='*100 + '\n')

training_time_ms = 0
# start the clock
torch.cuda.synchronize()
Expand Down Expand Up @@ -489,7 +499,10 @@ def get_lr(it):
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss /= val_steps
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
if master_process:
print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
with open(logfile, "a") as f:
f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
# start the clock again
torch.cuda.synchronize()
t0 = time.time()
Expand Down Expand Up @@ -541,8 +554,11 @@ def get_lr(it):
# everything that follows now is just diagnostics, prints, logging, etc.

#dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
approx_time = training_time_ms + 1000 * (time.time() - t0)
print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
if master_process:
approx_time = training_time_ms + 1000 * (time.time() - t0)
print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
with open(logfile, "a") as f:
f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")

if master_process:
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
Expand Down

0 comments on commit ff93d20

Please sign in to comment.