Skip to content

[QEff. Finetune]: Added support to sync gradients across devices during backward step only. #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions QEfficient/finetune/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,48 @@
#
# -----------------------------------------------------------------------------

from contextlib import nullcontext

import torch

try:
import torch_qaic.debug as qaic_debug # noqa: F401
except ImportError as e:
print(f"Warning: {e}. Moving ahead without these qaic modules.")


TASK_TYPE = ["generation", "seq_classification"]
PEFT_METHOD = ["lora"]
DEVICE = ["qaic", "cpu", "cuda"]
BATCHING_STRATEGY = ["padding", "packing"]


def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16):
return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext()


def get_op_verifier_ctx(
use_op_by_op_verifier,
train_device,
dump_dir,
step,
ref_device="cpu",
ref_dtype=torch.float32,
atol=1e-1,
rtol=1e-5,
use_ref_output_on_mismatch=True,
):
if not use_op_by_op_verifier:
return nullcontext()

filter_config = qaic_debug.DispatchFilterConfig.default(train_device)
dump_dir = dump_dir + "_" + str(step)
return qaic_debug.OpByOpVerifierMode(
ref_device=ref_device,
ref_dtype=ref_dtype,
atol=atol,
rtol=rtol,
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
filter_config=filter_config,
dump_root_dir=dump_dir,
)
54 changes: 23 additions & 31 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import json
import os
import time
from contextlib import nullcontext
from datetime import datetime
from functools import partial
from typing import Dict, List, Tuple

import torch
Expand All @@ -19,6 +19,7 @@
from tqdm import tqdm

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx

try:
import torch_qaic # noqa: F401
Expand Down Expand Up @@ -110,6 +111,9 @@ def train(
num_classes = model.classifier.out_features
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)

autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.dump_root_dir)

# Start the training loop
for epoch in range(train_config.num_epochs):
if loss_0_counter.item() == train_config.convergence_counter:
Expand Down Expand Up @@ -174,38 +178,29 @@ def train(
break
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device

with (
torch.autocast(device_type=device_type, dtype=torch.float16)
if train_config.use_autocast
else nullcontext()
):
is_backward_step = (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(
train_dataloader
) - 1
if train_config.enable_ddp:
# Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
# using too many context managers may bloat the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = is_backward_step

with autocast_ctx:
# an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
if train_config.opByOpVerifier:
with qaic_debug.OpByOpVerifierMode(
ref_device="cpu",
ref_dtype=torch.float32,
# adjust atol & rtol this as required
atol=1e-1,
use_ref_output_on_mismatch=True,
filter_config=qaic_debug.DispatchFilterConfig.default(device),
dump_root_dir=train_config.dump_root_dir + str(step),
) as verifier:
model_outputs = model(**batch)
loss = model_outputs.loss # Forward call
if train_config.task_type == "seq_classification":
logits = model_outputs.logits
labels = batch["labels"][:, 0]
preds = torch.nn.functional.softmax(logits, dim=-1)
acc_helper.forward(preds, labels)
print("Mismatches detected:", verifier.get_perop_mismatch_count())
else:
with op_verifier_ctx(step) as verifier:
model_outputs = model(**batch)
loss = model_outputs.loss # Forward call
if train_config.task_type == "seq_classification":
logits = model_outputs.logits
labels = batch["labels"][:, 0]
preds = torch.nn.functional.softmax(logits, dim=-1)
acc_helper.forward(preds, labels)
if train_config.opByOpVerifier:
print("Mismatches detected:", verifier.get_perop_mismatch_count())

total_loss += loss.detach().float()
# Accumalate gradients
Expand Down Expand Up @@ -242,7 +237,7 @@ def train(
else:
loss.backward() # backward pass

if (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if is_backward_step:
if train_config.grad_scaler:
scaler.step(optimizer)
scaler.update()
Expand Down Expand Up @@ -421,6 +416,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
eval_loss = 0.0 # Initialize evaluation loss
device_type = torch.device(device).type

autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
# stop when the maximum number of eval steps is reached
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
Expand All @@ -431,11 +427,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
# Ensure no gradients are computed for this scope to save memory
with torch.no_grad():
# Forward pass and compute loss
with (
torch.autocast(device_type=device_type, dtype=torch.float16)
if train_config.use_autocast
else nullcontext()
):
with autocast_ctx:
outputs = model(**batch)
loss = outputs.loss

Expand Down
Loading