-
Notifications
You must be signed in to change notification settings - Fork 30k
Description
Hello, first off, apologies if this information is already available elsewhere. I've searched through the documentation and existing issues but haven't found a clear answer to my question.
I have access to 2 to 4 nodes (16 to 32 GPUs in total), each equipped with 8x140GB H200 GPUs. My objective is to perform large-scale distributed inference using a massive 111B-parameter Teacher model (CohereLabs/c4ai-command-a-03-2025) and simultaneously conduct online knowledge distillation (soft-logit based) from this 111B Teacher model to a smaller 8B Student model (CohereLabs/c4ai-command-r7b-12-2024).
Is there a way to simultaneously run distributed inference for Teacher models larger than 111B and distributed training for Student models in a multi-node setup, utilizing Hugging Face Transformers' Trainer?
The Transformers version I'm using is v4.51.3. I've observed the use of model = deepspeed.tp_model_init within the def deepspeed_init function in src/transformers/integrations/deepspeed.py. I attempted to apply this code, but it resulted in a torch.distributed.DistBackendError.
I would be very grateful if someone could explain what would be most suitable for my use case. A minimal working example would be the icing on the cake. Surely, if the Open LLM Leaderboard shows that online knowledge distillation (soft-logit) is possible with large models exceeding 111B, there must be a straightforward way to achieve what I want, but I'm unsure how everyone else does it.
For reference, below is the script I'm currently working with:
deepspeed --num_nodes 2 --num_gpus 8 \ --hostfile $HOSTFILE \ --master_addr $MASTER_ADDR \ --master_port=62535 \ train.py \ --teacher CohereLabs/c4ai-command-a-03-2025 \ --student CohereLabs/c4ai-command-r7b-12-2024 \ --epochs 1 --batch_size 1 --seq_len 4096 --temperature 1.0 --max_samples 150 --lr 1e-6 2>&1 | tee -a "./train.log"
import torch.distributed as dist
import os, math, argparse, warnings, torch, random, multiprocessing as mp
from datasets import load_dataset, concatenate_datasets
from transformers import (AutoTokenizer, AutoModelForCausalLM,
PreTrainedTokenizerBase)
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from datetime import timedelta
from deepspeed.runtime.utils import see_memory_usage
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1")
warnings.filterwarnings("ignore", category=UserWarning)
mp.set_start_method("spawn", force=True)
def get_args():
p = argparse.ArgumentParser()
p.add_argument("--teacher", default="")
p.add_argument("--student", default="")
p.add_argument("--dataset", default="")
p.add_argument("--split", default="train")
p.add_argument("--epochs", type=int, default=1)
p.add_argument("--batch_size", type=int, default=1,
help="per-GPU micro-batch")
p.add_argument("--seq_len", type=int, default=4096)
p.add_argument("--temperature", type=float, default=1.0)
p.add_argument("--lr", type=float, default=1e-6)
p.add_argument("--max_samples", type=int, default=0,
help="0=1000 ")
p.add_argument("--local_rank", type=int, default=-1,
help="deepspeed/torch launcher GPU index")
p.add_argument("--cache_path", default="")
p.add_argument("--hf_token", default="")
p = deepspeed.add_config_arguments(p)
return p.parse_args()
def main():
timeout_seconds = 3600
timeout_duration = timedelta(seconds=timeout_seconds)
dist.init_process_group(
backend="nccl",
timeout=timeout_duration
)
args = get_args()
deepspeed.init_distributed()
rank, world = deepspeed.comm.get_rank(), deepspeed.comm.get_world_size()
device = torch.device("cuda", deepspeed.comm.get_local_rank())
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.student,
use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# tokenizer token_id
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
# Teacher (inference only)
teacher_model = AutoModelForCausalLM.from_pretrained(
args.teacher, torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True, device_map=None,
cache_dir=args.cache_path,token=args.hf_token)
see_memory_usage("After load model", force=True)
teacher_model.config.eos_token_id = tokenizer.eos_token_id
teacher_model.config.pad_token_id = tokenizer.pad_token_id
teacher_engine = deepspeed.init_inference(
teacher_model,
mp_size=world,
dtype=torch.bfloat16,
replace_with_kernel_inject=True,
replace_method="auto")
see_memory_usage("After DS-inference init", force=True)
teacher_engine.module.eval()
teacher_engine.optimizer = None
# Student
student_model = AutoModelForCausalLM.from_pretrained(
args.student, torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True, cache_dir=args.cache_path,token=args.hf_token)
student_model.config.eos_token_id = tokenizer.eos_token_id
student_model.config.pad_token_id = tokenizer.pad_token_id
# Dataset
ds = [
load_dataset("Raphael21/LogicKor_Aug_small_v2", split="train", data_dir="v0.1.1", streaming=False)
]
ds = concatenate_datasets(ds).select_columns(["messages"])
total_samples = args.max_samples or len(ds)
total_steps = args.epochs * math.ceil(total_samples / args.batch_size)
# Deepspeed Config
ds_cfg = {
"train_batch_size": args.batch_size * world,
"gradient_accumulation_steps": 1,
"bf16": {"enabled": True},
"zero_optimization": {
"stage": 3,
"stage3_max_live_parameters": 1e9,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e4,
"overlap_comm": True,
"contiguous_gradients": True,
"allgather_bucket_size": 5e8,
"reduce_bucket_size": 5e8,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
},
},
"activation_checkpointing": {
"partition_activations": True,
"contiguous_memory_optimization": True
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": args.lr,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": args.lr,
"warmup_num_steps": int(0.1 * total_steps)
}
}
}
student_model = deepspeed.tp_model_init(
model=student_model,
tp_size=16, # tp_size
dtype=torch.bfloat16
)
if not hasattr(student_engine, "optimizer"):
student_engine.optimizer = None
# Debug Messages
if rank == 0:
print("Configured with ZeRO-3, total_steps:", total_steps)
# Data Loader
def preprocess_batch(examples):
prompt_key = "prompt"
messages_key = "messages"
ignore_index = -100
max_length = min(tokenizer.model_max_length, 4096)
def get_tokens_from_chat_template(messages_dict_or_str, add_gen_prompt, max_len=max_length):
tokens = tokenizer.apply_chat_template(
messages_dict_or_str,
tokenize=True,
add_generation_prompt=add_gen_prompt,
truncation=True,
max_length=max_len,
padding="do_not_pad",
return_tensors=None,
)
return tokens, [1] * len(tokens)
results = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"prompts": [],
"prompt_attention_mask": [],
}
prompts = examples.get(prompt_key, [None] * len(examples[messages_key]))
for message, prompt in zip(examples[messages_key], prompts):
if prompt is None:
prompt_messages = message[:-1]
prompt_ids, prompt_attn = get_tokens_from_chat_template(prompt_messages, add_gen_prompt=True)
else:
prompt_ids, prompt_attn = get_tokens_from_chat_template(prompt_text, add_gen_prompt=True)
input_ids, attn_mask = get_tokens_from_chat_template(message, add_gen_prompt=False)
label = [ignore_index] * len(input_ids)
start_idx = len(prompt_ids)
if start_idx < len(input_ids):
for i in range(start_idx, len(input_ids)):
label[i] = input_ids[i]
results["input_ids"].append(input_ids)
results["attention_mask"].append(attn_mask)
results["labels"].append(label)
results["prompts"].append(prompt_ids)
results["prompt_attention_mask"].append(prompt_attn)
return results
def chatml_collate_fn(batch, pad_token_id=0, ignore_index=-100, max_length=4096):
def pad_and_truncate(seqs, pad_val, max_length):
return torch.stack([
torch.tensor(seq[:max_length] + [pad_val] * (max_length - len(seq)), dtype=torch.long)
for seq in seqs
])
fields = ["input_ids", "attention_mask", "labels", "prompts", "prompt_attention_mask"]
pad_values = [pad_token_id, 0, ignore_index, pad_token_id, 0]
return {
field: pad_and_truncate([ex[field] for ex in batch], pad_val, max_length)
for field, pad_val in zip(fields, pad_values)
}
if args.max_samples:
ds = ds.select(range(args.max_samples))
ds = ds.map(preprocess_batch, batched=True)
loader = torch.utils.data.DataLoader(
ds,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
collate_fn=lambda x: chatml_collate_fn(
x,
pad_token_id=tokenizer.pad_token_id,
ignore_index=-100,
max_length=args.seq_len
)
)
T = args.temperature
for epoch in range(args.epochs):
for step, batch in enumerate(loader):
prompt_lengths_batch = batch["prompt_attention_mask"].sum(dim=1).cpu().tolist()
prompt_lengths_tensor = torch.tensor(prompt_lengths_batch, device=device, dtype=torch.long)
input_ids = batch["input_ids"].to(device)
attn = batch["attention_mask"].to(device)
labels_batch = batch["labels"].to(device)
with torch.no_grad():
teacher_logits = teacher_engine.module(
input_ids=input_ids,
attention_mask=attn,
use_cache=False
).logits
student_logits = student_engine(
input_ids=input_ids,
attention_mask=attn,
use_cache=False
).logits
if rank == 0:
sample_idx_to_inspect = 0
original_input_ids = batch["input_ids"][sample_idx_to_inspect].cpu().tolist()
original_labels_list = batch["labels"][sample_idx_to_inspect].cpu().tolist()
# Student Model (argmax)
student_predictions_ids = student_logits[sample_idx_to_inspect].argmax(dim=-1).cpu().tolist()
decoded_student_predictions = [tokenizer.decode([t], skip_special_tokens=False) for t in student_predictions_ids]
# Teacher Model (argmax)
teacher_predictions_ids = teacher_logits[sample_idx_to_inspect].argmax(dim=-1).cpu().tolist()
decoded_teacher_predictions = [tokenizer.decode([t], skip_special_tokens=False) for t in teacher_predictions_ids]
print(f"Decoded Student Predictions: {''.join(decoded_student_predictions[:100])} ...")
print(f"Decoded Teacher Predictions: {''.join(decoded_teacher_predictions[:100])} ...")
shifted_student_logits = student_logits[:, :-1, :]
shifted_teacher_logits = teacher_logits[:, :-1, :]
shifted_labels = labels_batch[:, 1:]
shifted_attention_mask = attn[:, 1:]
current_seq_len = shifted_labels.size(1)
response_mask = torch.zeros_like(shifted_labels, dtype=torch.bool)
for i in range(args.batch_size):
start_response_idx_in_shifted = prompt_lengths_tensor[i] - 1
start_response_idx_in_shifted = max(0, start_response_idx_in_shifted)
if start_response_idx_in_shifted < current_seq_len:
response_mask[i, start_response_idx_in_shifted:] = True
shifted_attention_mask = shifted_attention_mask & response_mask
# Apply temperature scaling
student_logits_scaled = shifted_student_logits / args.temperature
teacher_logits_scaled = shifted_teacher_logits / args.temperature
# Compute log probabilities for student and probabilities for teacher
student_log_probs = F.log_softmax(student_logits_scaled, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits_scaled, dim=-1)
kd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
kd_loss_per_token = kd_loss.sum(dim=-1)
valid_labels_mask = (shifted_labels != -100)
combined_mask = shifted_attention_mask & valid_labels_mask
masked_kd_loss = kd_loss_per_token * combined_mask
num_valid_tokens = combined_mask.sum()
if num_valid_tokens > 0:
kd_loss = masked_kd_loss.sum() / num_valid_tokens
else:
kd_loss = torch.tensor(0.0, device=device, requires_grad=True)
# Cross-Entropy Loss
ce_loss = F.cross_entropy(
shifted_student_logits.view(-1, shifted_student_logits.size(-1)), # (B*S, V)
shifted_labels.view(-1),
ignore_index=-100
)
if ce_loss.numel() == 0 or torch.isnan(ce_loss):
ce_loss = torch.tensor(0.0, device=device, requires_grad=True)
alpha = 0.5
total_loss = alpha * ce_loss + (1 - alpha) * kd_loss
student_engine.backward(total_loss)
student_engine.step()
# empty cache
torch.cuda.empty_cache()
if rank == 0 and step % 1 == 0:
print(f"[Epoch {epoch}][{step}/{len(loader)}] total_loss = {total_loss.item():.4f}, ce_loss = {ce_loss.item():.4f}, kd_loss = {kd_loss.item():.4f}")
# Save Checkpoint
student_engine.save_checkpoint("./save_checkpoint")
if __name__ == "__main__":
main()```