Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
launch.json
__pycache__
voxcpm.egg-info
.DS_Store
.DS_Store
lora/
models/
132 changes: 131 additions & 1 deletion scripts/train_voxcpm_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from tensorboardX import SummaryWriter
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
import signal
import os

try:
from safetensors.torch import save_file
Expand Down Expand Up @@ -171,6 +173,39 @@ def tokenize(batch):
num_training_steps=total_training_steps,
)

# Try to load checkpoint and resume training
start_step = 0
if accelerator.rank == 0:
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
# Broadcast start_step to all processes
if hasattr(accelerator, 'all_reduce'):
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
accelerator.all_reduce(start_step_tensor)
start_step = int(start_step_tensor.item())

if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")

# Resume tracker for signal handler to read current step
resume = {"step": start_step}

# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _resume=resume):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained)
print("Checkpoint saved. Exiting.")
except Exception as e:
print(f"Error saving checkpoint on signal: {e}")
os._exit(0)

signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)

# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
grad_accum_steps = max(int(grad_accum_steps), 1)
data_epoch = 0
Expand All @@ -191,7 +226,9 @@ def get_next_batch():
return next(train_iter)

with tracker.live():
for step in range(num_iters):
for step in range(start_step, num_iters):
# update resume step so signal handler can save current progress
resume["step"] = step
tracker.step = step
optimizer.zero_grad(set_to_none=True)

Expand Down Expand Up @@ -301,6 +338,76 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
model.train()


def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
"""
Load the latest checkpoint if it exists.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0

unwrapped = model.module if hasattr(model, "module") else model
lora_cfg = unwrapped.lora_config

# Load model weights
if lora_cfg is not None:
# LoRA: load lora_weights
lora_weights_path = latest_folder / "lora_weights.safetensors"
if not lora_weights_path.exists():
lora_weights_path = latest_folder / "lora_weights.ckpt"

if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)

# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}")
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
if not model_path.exists():
model_path = latest_folder / "pytorch_model.bin"

if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)

unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}")

# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}")

# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}")

# Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}")
return resume_step

return 0


def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
Expand Down Expand Up @@ -345,6 +452,29 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")

# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
latest_link = save_dir / "latest"
try:
if latest_link.exists() or latest_link.is_symlink():
# remove existing link or directory
if latest_link.is_dir() and not latest_link.is_symlink():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
# Create a symlink pointing to the new folder
os.symlink(str(folder), str(latest_link))
except Exception:
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
try:
if latest_link.exists():
if latest_link.is_dir():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}")


if __name__ == "__main__":
from voxcpm.training.config import load_yaml_config
Expand Down
Loading