Skip to content

[BUG] Run 111B+ Teacher distributed inference and 8B Student distributed training on multi-node H200 GPUs using the Transformers Trainer without encountering OOM errors? #39637

@seona21

Description

@seona21

Hello, first off, apologies if this information is already available elsewhere. I've searched through the documentation and existing issues but haven't found a clear answer to my question.

I have access to 2 to 4 nodes (16 to 32 GPUs in total), each equipped with 8x140GB H200 GPUs. My objective is to perform large-scale distributed inference using a massive 111B-parameter Teacher model (CohereLabs/c4ai-command-a-03-2025) and simultaneously conduct online knowledge distillation (soft-logit based) from this 111B Teacher model to a smaller 8B Student model (CohereLabs/c4ai-command-r7b-12-2024).

Is there a way to simultaneously run distributed inference for Teacher models larger than 111B and distributed training for Student models in a multi-node setup, utilizing Hugging Face Transformers' Trainer?

The Transformers version I'm using is v4.51.3. I've observed the use of model = deepspeed.tp_model_init within the def deepspeed_init function in src/transformers/integrations/deepspeed.py. I attempted to apply this code, but it resulted in a torch.distributed.DistBackendError.

I would be very grateful if someone could explain what would be most suitable for my use case. A minimal working example would be the icing on the cake. Surely, if the Open LLM Leaderboard shows that online knowledge distillation (soft-logit) is possible with large models exceeding 111B, there must be a straightforward way to achieve what I want, but I'm unsure how everyone else does it.

For reference, below is the script I'm currently working with:

deepspeed --num_nodes 2 --num_gpus 8 \ --hostfile $HOSTFILE \ --master_addr $MASTER_ADDR \ --master_port=62535 \ train.py \ --teacher CohereLabs/c4ai-command-a-03-2025 \ --student CohereLabs/c4ai-command-r7b-12-2024 \ --epochs 1 --batch_size 1 --seq_len 4096 --temperature 1.0 --max_samples 150 --lr 1e-6 2>&1 | tee -a "./train.log"

import torch.distributed as dist
import os, math, argparse, warnings, torch, random, multiprocessing as mp
from datasets import load_dataset, concatenate_datasets
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          PreTrainedTokenizerBase)
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from datetime import timedelta
from deepspeed.runtime.utils import see_memory_usage


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1")
warnings.filterwarnings("ignore", category=UserWarning)
mp.set_start_method("spawn", force=True)

def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--teacher", default="")
    p.add_argument("--student", default="")
    p.add_argument("--dataset", default="")
    p.add_argument("--split", default="train")
    p.add_argument("--epochs", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=1,
                   help="per-GPU micro-batch")
    p.add_argument("--seq_len", type=int, default=4096)
    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--lr", type=float, default=1e-6)
    p.add_argument("--max_samples", type=int, default=0,
                   help="0=1000 ")
    p.add_argument("--local_rank", type=int, default=-1,
               help="deepspeed/torch launcher GPU index")
    p.add_argument("--cache_path", default="")
    p.add_argument("--hf_token", default="")
    p = deepspeed.add_config_arguments(p)
    return p.parse_args()


def main():
    timeout_seconds = 3600 
    timeout_duration = timedelta(seconds=timeout_seconds)
    dist.init_process_group(
        backend="nccl",
        timeout=timeout_duration 
    )
    args = get_args()
    deepspeed.init_distributed()
    rank, world = deepspeed.comm.get_rank(), deepspeed.comm.get_world_size()
    device = torch.device("cuda", deepspeed.comm.get_local_rank())
    # Tokenizer 
    tokenizer = AutoTokenizer.from_pretrained(args.student,
                                        use_fast=True, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    # tokenizer token_id 
    tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    
    
    # Teacher (inference only)    
    teacher_model = AutoModelForCausalLM.from_pretrained(
        args.teacher, torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True, device_map=None, 
        cache_dir=args.cache_path,token=args.hf_token) 
    
    see_memory_usage("After load model", force=True)
    
    teacher_model.config.eos_token_id = tokenizer.eos_token_id
    teacher_model.config.pad_token_id = tokenizer.pad_token_id
        
    teacher_engine = deepspeed.init_inference(
        teacher_model,
        mp_size=world,
        dtype=torch.bfloat16,
        replace_with_kernel_inject=True, 
        replace_method="auto")
    
    see_memory_usage("After DS-inference init", force=True)

    teacher_engine.module.eval()
    teacher_engine.optimizer = None

    # Student
    student_model = AutoModelForCausalLM.from_pretrained(
        args.student, torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        trust_remote_code=True, cache_dir=args.cache_path,token=args.hf_token)
    
    student_model.config.eos_token_id = tokenizer.eos_token_id
    student_model.config.pad_token_id = tokenizer.pad_token_id

    # Dataset 
    ds = [
        load_dataset("Raphael21/LogicKor_Aug_small_v2", split="train", data_dir="v0.1.1", streaming=False)
    ]

    ds = concatenate_datasets(ds).select_columns(["messages"])
    total_samples = args.max_samples or len(ds)
    total_steps = args.epochs * math.ceil(total_samples / args.batch_size)
    
    # Deepspeed Config 
    ds_cfg = {
        "train_batch_size": args.batch_size * world,
        "gradient_accumulation_steps": 1,
        "bf16": {"enabled": True},
        "zero_optimization": {
            "stage": 3,  
            "stage3_max_live_parameters": 1e9,
            "stage3_prefetch_bucket_size": 5e8,
            "stage3_param_persistence_threshold": 1e4,
            "overlap_comm": True,
            "contiguous_gradients": True,
            "allgather_bucket_size": 5e8,
            "reduce_bucket_size": 5e8,
            "offload_optimizer": {
                "device": "cpu", 
                "pin_memory": True
            },
            "offload_param": {
                "device": "cpu", 
                "pin_memory": True
            },
        },
        "activation_checkpointing": {
            "partition_activations": True,
            "contiguous_memory_optimization": True
        },
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": args.lr,
                "betas": [0.9, 0.999],
                "eps": 1e-8,
                "weight_decay": 0.01
            }
        },
       
        "scheduler": {
            "type": "WarmupLR",
            "params": {
                "warmup_min_lr": 0,
                "warmup_max_lr": args.lr,
                "warmup_num_steps": int(0.1 * total_steps)
            }
        }
    }


    student_model = deepspeed.tp_model_init(
        model=student_model,
        tp_size=16, # tp_size
        dtype=torch.bfloat16
    )
    
    
    if not hasattr(student_engine, "optimizer"):
        student_engine.optimizer = None
        
    # Debug Messages
    if rank == 0:
        print("Configured with ZeRO-3, total_steps:", total_steps)
        
    # Data Loader
    def preprocess_batch(examples):
        prompt_key = "prompt"
        messages_key = "messages"
        ignore_index = -100
        max_length = min(tokenizer.model_max_length, 4096)

        def get_tokens_from_chat_template(messages_dict_or_str, add_gen_prompt, max_len=max_length):
            
            tokens = tokenizer.apply_chat_template(
                messages_dict_or_str,
                tokenize=True,             
                add_generation_prompt=add_gen_prompt,
                truncation=True,
                max_length=max_len,
                padding="do_not_pad",      
                return_tensors=None,
            )
           
            return tokens, [1] * len(tokens) 
        
        results = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
            "prompts": [],
            "prompt_attention_mask": [],
        }

        prompts = examples.get(prompt_key, [None] * len(examples[messages_key]))
        for message, prompt in zip(examples[messages_key], prompts):
            if prompt is None:
                prompt_messages = message[:-1] 
                prompt_ids, prompt_attn = get_tokens_from_chat_template(prompt_messages, add_gen_prompt=True)
            else:
                prompt_ids, prompt_attn = get_tokens_from_chat_template(prompt_text, add_gen_prompt=True) 
                
            input_ids, attn_mask = get_tokens_from_chat_template(message, add_gen_prompt=False)
            label = [ignore_index] * len(input_ids) 
            start_idx = len(prompt_ids)

            if start_idx < len(input_ids):
                for i in range(start_idx, len(input_ids)):
                    label[i] = input_ids[i]            

            results["input_ids"].append(input_ids)
            results["attention_mask"].append(attn_mask)
            results["labels"].append(label)
            results["prompts"].append(prompt_ids)
            results["prompt_attention_mask"].append(prompt_attn)
            
        return results

    def chatml_collate_fn(batch, pad_token_id=0, ignore_index=-100, max_length=4096):
        def pad_and_truncate(seqs, pad_val, max_length):
            return torch.stack([
                torch.tensor(seq[:max_length] + [pad_val] * (max_length - len(seq)), dtype=torch.long)
                for seq in seqs
            ])

        fields = ["input_ids", "attention_mask", "labels", "prompts", "prompt_attention_mask"]
        pad_values = [pad_token_id, 0, ignore_index, pad_token_id, 0]

        return {
            field: pad_and_truncate([ex[field] for ex in batch], pad_val, max_length)
            for field, pad_val in zip(fields, pad_values)
        }


    if args.max_samples:
        ds = ds.select(range(args.max_samples))

    ds = ds.map(preprocess_batch, batched=True)

    loader = torch.utils.data.DataLoader(
        ds,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        collate_fn=lambda x: chatml_collate_fn(
            x,
            pad_token_id=tokenizer.pad_token_id,
            ignore_index=-100,
            max_length=args.seq_len  
        )
    )
   
    T = args.temperature

    for epoch in range(args.epochs):
        for step, batch in enumerate(loader):
            prompt_lengths_batch = batch["prompt_attention_mask"].sum(dim=1).cpu().tolist()
            prompt_lengths_tensor = torch.tensor(prompt_lengths_batch, device=device, dtype=torch.long)
            
            input_ids = batch["input_ids"].to(device)
            attn = batch["attention_mask"].to(device)
            labels_batch = batch["labels"].to(device)

            with torch.no_grad():
                teacher_logits = teacher_engine.module( 
                    input_ids=input_ids,
                    attention_mask=attn,
                    use_cache=False
                ).logits

            student_logits = student_engine(
                input_ids=input_ids,
                attention_mask=attn,
                use_cache=False
            ).logits
            
            
            if rank == 0:
                sample_idx_to_inspect = 0
                original_input_ids = batch["input_ids"][sample_idx_to_inspect].cpu().tolist()
                original_labels_list = batch["labels"][sample_idx_to_inspect].cpu().tolist() 

                # Student Model (argmax)
                student_predictions_ids = student_logits[sample_idx_to_inspect].argmax(dim=-1).cpu().tolist()
                decoded_student_predictions = [tokenizer.decode([t], skip_special_tokens=False) for t in student_predictions_ids]

                # Teacher Model (argmax)
                teacher_predictions_ids = teacher_logits[sample_idx_to_inspect].argmax(dim=-1).cpu().tolist()
                decoded_teacher_predictions = [tokenizer.decode([t], skip_special_tokens=False) for t in teacher_predictions_ids]

                print(f"Decoded Student Predictions: {''.join(decoded_student_predictions[:100])} ...")
                print(f"Decoded Teacher Predictions: {''.join(decoded_teacher_predictions[:100])} ...")

            shifted_student_logits = student_logits[:, :-1, :] 
            shifted_teacher_logits = teacher_logits[:, :-1, :]
            shifted_labels = labels_batch[:, 1:] 
            shifted_attention_mask = attn[:, 1:] 
            
            current_seq_len = shifted_labels.size(1) 
            response_mask = torch.zeros_like(shifted_labels, dtype=torch.bool)

            for i in range(args.batch_size):
                start_response_idx_in_shifted = prompt_lengths_tensor[i] - 1
                start_response_idx_in_shifted = max(0, start_response_idx_in_shifted)
                if start_response_idx_in_shifted < current_seq_len:
                    response_mask[i, start_response_idx_in_shifted:] = True
                    
            shifted_attention_mask = shifted_attention_mask & response_mask        
                 
            # Apply temperature scaling
            student_logits_scaled = shifted_student_logits / args.temperature
            teacher_logits_scaled = shifted_teacher_logits / args.temperature

            # Compute log probabilities for student and probabilities for teacher
            student_log_probs = F.log_softmax(student_logits_scaled, dim=-1)
            teacher_log_probs = F.log_softmax(teacher_logits_scaled, dim=-1)
            
            kd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
            
            kd_loss_per_token = kd_loss.sum(dim=-1) 
            
            valid_labels_mask = (shifted_labels != -100) 
            combined_mask = shifted_attention_mask & valid_labels_mask 
            masked_kd_loss = kd_loss_per_token * combined_mask
            
            
            num_valid_tokens = combined_mask.sum()
            if num_valid_tokens > 0:
                kd_loss = masked_kd_loss.sum() / num_valid_tokens
            else: 
                kd_loss = torch.tensor(0.0, device=device, requires_grad=True) 
            
            # Cross-Entropy Loss 
            ce_loss = F.cross_entropy(
                shifted_student_logits.view(-1, shifted_student_logits.size(-1)), # (B*S, V)
                shifted_labels.view(-1),
                ignore_index=-100 
            )
            if ce_loss.numel() == 0 or torch.isnan(ce_loss):
                ce_loss = torch.tensor(0.0, device=device, requires_grad=True)
                
            alpha = 0.5 
            total_loss = alpha * ce_loss + (1 - alpha) * kd_loss

            student_engine.backward(total_loss)
            student_engine.step()
            
            # empty cache
            torch.cuda.empty_cache()

            if rank == 0 and step % 1 == 0:
                print(f"[Epoch {epoch}][{step}/{len(loader)}] total_loss = {total_loss.item():.4f}, ce_loss = {ce_loss.item():.4f}, kd_loss = {kd_loss.item():.4f}")
               
    # Save Checkpoint 
    student_engine.save_checkpoint("./save_checkpoint")

if __name__ == "__main__":
    main()```


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions