Skip to content

Logging outlier batch #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
from datetime import datetime
from functools import partial
from math import ceil
from typing import Any, Dict, Iterable, List, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler
import torch.distributed as dist
import torch.nn.functional as F
from packaging import version

from nanotron import distributed as dist
from nanotron import logging
Expand Down Expand Up @@ -45,6 +49,7 @@
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize import DataStageMetadata
from nanotron.serialize.metadata import TrainingMetadata
from transformers import AutoTokenizer

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -849,3 +854,159 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp(
)

return consumed_train_samples

def log_outlier_batch(
loss: float,
ema_loss: float,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
label_ids: torch.Tensor,
label_mask: torch.Tensor,
config: Config,
parallel_context: ParallelContext,
):
"""Logs detailed information about a batch that caused a loss spike."""
log_dir = Path("/fsx/elie_bakouch/nanotron-dev/batch_outlier_logs/")
try:
log_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
log_rank(f"Error creating log directory {log_dir}: {e}", logger=logger, level=logging.ERROR)
return # Don't proceed if directory creation fails

# --- Load Tokenizer ---
tokenizer = None
tokenizer_path = "meta-llama/Llama-3.2-1B"
if tokenizer_path:
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# Add padding token if it doesn't exist for cleaner decoding
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
log_rank(f"Tokenizer missing pad token, setting to EOS token: {tokenizer.eos_token}", logger=logger, level=logging.WARNING)
except Exception as e:
log_rank(f"Error loading tokenizer from {tokenizer_path}: {e}", logger=logger, level=logging.ERROR)
else:
log_rank("Tokenizer path not found in config (config.tokenizer.tokenizer_name_or_path). Cannot detokenize.", logger=logger, level=logging.WARNING)
# ----------------------

# Get rank information
dp_rank = parallel_context.dp_pg.rank()
pp_rank = parallel_context.pp_pg.rank()
tp_rank = parallel_context.tp_pg.rank()
# Assuming expert parallel rank is tied to cp_pg if it exists, otherwise default
ep_rank = parallel_context.cp_pg.rank() if parallel_context.cp_pg is not None else 0

# Get timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")

# Construct filename
run_name = "blablablabla"
# Sanitize run_name for filename
sanitized_run_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in run_name)
filename = log_dir / f"outlier_{sanitized_run_name}_dp{dp_rank}_pp{pp_rank}_tp{tp_rank}_ep{ep_rank}_{timestamp}.txt"

# Helper to convert tensor to list (handling potential TensorPointer)
def tensor_to_list(tensor):
# TODO: Figure out how to handle TensorPointer if necessary, might require communication
if isinstance(tensor, torch.Tensor):
try:
return tensor.detach().cpu().tolist()
except Exception as e:
return f"Error converting tensor to list: {e}"
# isinstance check already done before calling
# return "Data not available (potentially TensorPointer or other type)" # Placeholder
# Let's assume if not Tensor, it might be TensorPointer or None
if hasattr(tensor, 'is_tensor_pointer') and tensor.is_tensor_pointer():
return f"TensorPointer(pp_rank={tensor.group_rank})"
else:
return f"Non-Tensor data: {type(tensor)}"

try:
with open(filename, "w") as f:
f.write("=" * 50 + "\n")
f.write("Loss Spike Detected - Batch Information\n")
f.write("=" * 50 + "\n\n")

f.write(f"Timestamp: {timestamp}\n")
f.write(f"Run Name: {run_name}\n")
# TODO: Add step and microbatch index if available later
# f.write(f"Step: {step}\n")
# f.write(f"Microbatch Index: {microbatch_idx}\n")
f.write(f"Detected Loss: {loss:.4f}\n")
f.write(f"EMA Loss: {ema_loss:.4f}\n\n")

f.write("Distributed Ranks:\n")
f.write(f" Data Parallel Rank (DP): {dp_rank}\n")
f.write(f" Pipeline Parallel Rank (PP): {pp_rank}\n")
f.write(f" Tensor Parallel Rank (TP): {tp_rank}\n")
f.write(f" Expert Parallel Rank (EP/CP): {ep_rank}\n\n")

f.write("Model Configuration Snippet:\n")
# Check if model config exists and get its class name safely
model_config = getattr(config, "model_config", None)
model_type = model_config.__class__.__name__ if model_config else "N/A"
f.write(f" Model Type: {model_type}\n")
f.write(f" Tokenizer Path: {tokenizer_path if tokenizer_path else 'N/A'}\n")
# Add other relevant config snippets if needed
f.write("\n")

f.write("Batch Tensor Information:\n")
tensors = {
"input_ids": input_ids,
"position_ids": position_ids,
"label_ids": label_ids,
"label_mask": label_mask,
}

for name, tensor in tensors.items():
f.write(f" {name}:\n")
if isinstance(tensor, torch.Tensor):
f.write(f" Shape: {tensor.shape}\n")
f.write(f" Device: {tensor.device}\n")
f.write(f" Dtype: {tensor.dtype}\n")
f.write(f" Requires Grad: {tensor.requires_grad}\n")
# Log values with better formatting
values_list = tensor_to_list(tensor)
if isinstance(values_list, list):
formatted_values = "\n ".join(str(row) for row in values_list) # Indent each row
f.write(f" Values:\n {formatted_values}\n\n")
else: # Handle error string from tensor_to_list
f.write(f" Values: {values_list}\n\n")
else:
f.write(f" Data: {tensor_to_list(tensor)}\n\n") # Use updated helper

# --- Detokenize input_ids ---
if tokenizer is not None and isinstance(input_ids, torch.Tensor):
try:
f.write("Detokenized Input IDs (Batch):\n")
# Assuming input_ids might be flattened [b*s] or [b, s]
input_ids_cpu = input_ids.detach().cpu()
if input_ids_cpu.dim() == 1:
# Try to infer batch size if possible, otherwise decode as single sequence
# This part is tricky without knowing the original batch structure
# For now, decode as one long sequence
log_rank("Input IDs are flattened, attempting to decode as a single sequence.", logger=logger, level=logging.WARNING)
decoded_text = tokenizer.decode(input_ids_cpu, skip_special_tokens=False)
f.write(f"[Flattened Sequence]: {decoded_text}\n")
else:
# Assume shape [batch_size, seq_length]
decoded_batch = tokenizer.batch_decode(input_ids_cpu, skip_special_tokens=False)
for i, text in enumerate(decoded_batch):
f.write(f"[Example {i}]: {text}\n")
f.write("\n")
except Exception as e:
f.write(f" Error detokenizing input_ids: {e}\n\n")
elif isinstance(input_ids, torch.Tensor):
f.write("Detokenization skipped: Tokenizer not available or failed to load.\n\n")
# ---------------------------

f.write("=" * 50 + "\n")
f.write("End of Report\n")
f.write("=" * 50 + "\n")

log_rank(f"Loss spike detected. Batch info saved to: {filename}", logger=logger, level=logging.INFO)

except Exception as e:
log_rank(f"Error writing outlier batch log to {filename}: {e}", logger=logger, level=logging.ERROR)


48 changes: 40 additions & 8 deletions src/nanotron/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from nanotron.random import RandomStates
from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator

from nanotron.helpers import log_outlier_batch
logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -759,35 +759,39 @@ def forward(
class Qwen2ForTraining(NanotronModel):
def __init__(
self,
config: Qwen2Config,
model_config: Qwen2Config,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
main_config: Config,
random_states: Optional[RandomStates] = None,
):
super().__init__()
self.model = Qwen2Model(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.model = Qwen2Model(config=model_config, parallel_context=parallel_context, parallel_config=parallel_config)
self.model_config = model_config
self.main_config = main_config

# Choose the appropriate loss class based on config
loss_kwargs = {
"tp_pg": parallel_context.tp_pg,
}
if config.z_loss_enabled:
loss_kwargs["z_loss_coefficient"] = config.z_loss_coefficient
if self.model_config.z_loss_enabled:
loss_kwargs["z_loss_coefficient"] = self.model_config.z_loss_coefficient

self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=LossWithZLoss if config.z_loss_enabled else Loss,
module_builder=LossWithZLoss if self.model_config.z_loss_enabled else Loss,
module_kwargs=loss_kwargs,
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss", "z_loss"} if config.z_loss_enabled else {"loss"},
module_output_keys={"loss", "z_loss"} if self.model_config.z_loss_enabled else {"loss"},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
self.register_buffer("ema_loss_buffer", torch.zeros(()))
self.ema_beta = 0.95

def forward(
self,
Expand All @@ -805,6 +809,34 @@ def forward(
label_ids=label_ids,
label_mask=label_mask,
)
current_loss = loss["loss"].detach()

# Check for loss spike only if outlier logging is configured
if self.config.outlier_logging is not None:
threshold = self.config.outlier_logging.loss_spike_threshold
# Use numel() > 0 check cautiously, might trigger on first step if EMA starts at 0
if abs(self.ema_loss_buffer - current_loss) > threshold and self.ema_loss_buffer.numel() > 0:
log_outlier_batch(
loss=current_loss,
ema_loss=self.ema_loss_buffer,
input_ids=input_ids,
position_ids=position_ids,
label_ids=label_ids,
label_mask=label_mask,
config=self.config,
parallel_context=self.parallel_context
)
else:
# Update EMA only if not logging an outlier for this step
self.ema_loss_buffer.copy_(
self.ema_beta * self.ema_loss_buffer + (1.0 - self.ema_beta) * current_loss
)
else:
# Always update EMA if outlier logging is disabled
self.ema_loss_buffer.copy_(
self.ema_beta * self.ema_loss_buffer + (1.0 - self.ema_beta) * current_loss
)

if self.config.z_loss_enabled:
return {"loss": loss["loss"], "z_loss": loss["z_loss"]}
else:
Expand Down
Loading