-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
I tried hesso with yolov8 on VOC dataset. Followed sanity check and resnet tutorial as reference and this is the final code which I had.
import torch
import torch.nn as nn
from pathlib import Path
from ultralytics import YOLO
from ultralytics.data.build import build_yolo_dataset, build_dataloader
from ultralytics.data.utils import check_det_dataset
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.cfg import get_cfg
from tqdm import tqdm
# Import OTO components
from only_train_once import OTO
from only_train_once.quantization.quant_model import model_to_quantize_model
from only_train_once.quantization.quant_layers import QuantizationMode
# --- Configuration ---
class Config:
dataset_yaml = 'VOC.yaml'
imgsz = 640
batch_size = 16
stride = 32
workers = 8
epochs = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-3
weight_decay = 1e-4
target_sparsity = 0.5
multi_scale = False
config = Config()
def create_train_val_data(yaml_file, batch_size=16, imgsz=640, stride=32, workers=0):
"""Create train/val dataloaders for YOLOv8"""
try:
LOGGER.info(f"Checking dataset: {yaml_file}...")
data_config = check_det_dataset(yaml_file)
if 'path' not in data_config:
data_config['path'] = Path(yaml_file).parent.parent.as_posix()
except Exception as e:
LOGGER.error(f"Failed to load/download '{yaml_file}': {e}")
LOGGER.error("Please ensure your YAML file is correct and accessible.")
return None, None, None, None, None
LOGGER.info("Dataset check complete.")
# Training data
args = DEFAULT_CFG
args.imgsz = imgsz
train_dataset, train_dataloader = None, None
val_dataset, val_dataloader = None, None
if 'train' in data_config:
train_path = data_config['train']
LOGGER.info(f"Building training dataset from: {train_path}")
train_dataset = build_yolo_dataset(
cfg=args,
img_path=train_path,
batch=batch_size,
data=data_config,
mode='train',
stride=stride
)
LOGGER.info("Building training dataloader...")
train_dataloader = build_dataloader(
dataset=train_dataset,
batch=batch_size,
workers=workers
)
# Validation data
if 'val' in data_config:
val_path = data_config['val']
LOGGER.info(f"Building validation dataset from: {val_path}")
val_dataset = build_yolo_dataset(
cfg=args,
img_path=val_path,
batch=batch_size,
data=data_config,
mode='val',
stride=stride
)
LOGGER.info("Building validation dataloader...")
val_dataloader = build_dataloader(
dataset=val_dataset,
batch=batch_size,
workers=workers
)
return train_dataset, train_dataloader, val_dataset, val_dataloader, data_config
def preprocess_batch(batch: dict, multi_scale: bool = False, stride: int = 32) -> dict:
"""Preprocess a batch of images for YOLOv8 training"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move all tensors to device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device, non_blocking=device.type == "cuda")
# Normalize images
batch["img"] = batch["img"].float() / 255.0
# Multi-scale training (optional)
if multi_scale:
import random
import math
imgs = batch["img"]
sz = (
random.randrange(int(640 * 0.5), int(640 * 1.5 + 16))
// stride * stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [
math.ceil(x * sf / 32) * 32 for x in imgs.shape[2:]
]
imgs = nn.functional.interpolate(
imgs, size=ns, mode="bilinear", align_corners=False
)
batch["img"] = imgs
return batch
def setup_model_and_oto():
"""Setup YOLOv8 model with OTO and HESSO optimizer"""
# Load YOLOv8 model
model = YOLO('yolov8n.pt')
model = model.model
# Freeze specific layers
freeze_list = []
always_freeze_names = [".dfl"]
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
for k, v in model.named_parameters():
if any(x in k for x in freeze_layer_names):
v.requires_grad = False
elif not v.requires_grad and v.dtype.is_floating_point:
v.requires_grad = True
# Convert to quantized model
# model = model_to_quantize_model(
# model,
# quant_mode=QuantizationMode.WEIGHT_AND_ACTIVATION
# )
# Setup OTO
dummy_input = torch.randn(1, 3, 640, 640)
oto = OTO(model=model, dummy_input=dummy_input)
# Mark unprunable parameters (YOLOv8 detection head)
oto.mark_unprunable_by_param_names([
'model.22.dfl.conv.weight',
'model.22.cv3.2.2.weight',
'model.22.cv2.2.2.weight',
'model.22.cv2.1.2.weight',
'model.22.cv3.1.2.weight',
'model.22.cv3.0.2.weight',
'model.22.cv2.0.2.weight'
])
# Mark slice operations as unprunable
for node_group in oto._graph.node_groups.values():
for node in node_group:
if node.op_name == 'slice':
node_group.is_prunable = False
# Setup model configuration
model.args = get_cfg(DEFAULT_CFG, None)
model.to(config.device)
model.train()
return model, oto
"""Main training function"""
LOGGER.info("Initializing YOLOv8 with HESSO optimizer...")
# Setup model and OTO
model, oto = setup_model_and_oto()
# Get dataloaders
train_dataset, train_dataloader, val_dataset, val_dataloader, data_config = create_train_val_data(
config.dataset_yaml,
batch_size=config.batch_size,
imgsz=config.imgsz,
stride=config.stride,
workers=config.workers
)
if train_dataloader is None:
LOGGER.error("Failed to create training dataloader. Exiting.")
# Create HESSO optimizer
optimizer = oto.hesso(
variant="sgd",
lr=0.01, # Main model LR
first_momentum=0.9,
weight_decay=1e-4,
target_group_sparsity=0.5,
start_pruning_step = 0 * len(train_dataloader), # Start pruning at epoch 5
pruning_periods = 5, # Spread the pruning over 5 periods
pruning_steps = config.epochs * len(train_dataloader)
)
LOGGER.info("Starting training with HESSO optimizer...")
lf = lambda x: max(1 - x / config.epochs, 0) * (1.0 - 0.01) + 0.01
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# Training loop
for epoch in range(config.epochs):
model.train()
progress_bar = tqdm(
enumerate(train_dataloader),
total=len(train_dataloader),
desc=f"Epoch {epoch+1}/{config.epochs}"
)
epoch_loss = 0.0
lr_scheduler.step()
for batch_idx, batch in progress_bar:
# Preprocess batch
batch = preprocess_batch(batch, config.multi_scale, config.stride)
# Forward pass and loss computation
loss, loss_items = model(batch)
loss = loss.sum()
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update progress
epoch_loss += loss.item()
progress_bar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Avg Loss': f'{epoch_loss/(batch_idx+1):.4f}'
})
# End of epoch logging
LOGGER.info(
f"Epoch {epoch+1} completed. "
f"Average Loss: {epoch_loss/len(train_dataloader):.4f}"
)
LOGGER.info("Training completed!")
oto.construct_subnet(out_dir='./cache')
import os
full_model_size = os.stat(oto.full_group_sparse_model_path)
compressed_model_size = os.stat(oto.compressed_model_path)
print("Size of full model : ", full_model_size.st_size / (1024 ** 3), "GBs")
print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs")
full_model = torch.load(oto.full_group_sparse_model_path)
compressed_model = torch.load(oto.compressed_model_path)But after I run
compressed_model(torch.rand(1,3,640,640).cuda())
I am getting this error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[23], [line 1](vscode-notebook-cell:?execution_count=23&line=1)
----> [1](vscode-notebook-cell:?execution_count=23&line=1) compressed_model(torch.rand(1,3,640,640).cuda())
File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530]/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539]/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540]/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543]/torch/nn/modules/module.py:1543) try:
[1544]/torch/nn/modules/module.py:1544) result = None
File /ultralytics/nn/tasks.py:139, in BaseModel.forward(self, x, *args, **kwargs)
[137]/ultralytics/nn/tasks.py:137) if isinstance(x, dict): # for cases of training and validating while training.
[138]/ultralytics/nn/tasks.py:138) return self.loss(x, *args, **kwargs)
--> [139]/ultralytics/nn/tasks.py:139) return self.predict(x, *args, **kwargs)
File /ultralytics/nn/tasks.py:157, in BaseModel.predict(self, x, profile, visualize, augment, embed)
[155]/ultralytics/nn/tasks.py:155) if augment:
[156]/ultralytics/nn/tasks.py:156) return self._predict_augment(x)
--> [157]/ultralytics/nn/tasks.py:157) return self._predict_once(x, profile, visualize, embed)
File /ultralytics/nn/tasks.py:180, in BaseModel._predict_once(self, x, profile, visualize, embed)
[178]/ultralytics/nn/tasks.py:178) if profile:
[179]/ultralytics/nn/tasks.py:179) self._profile_one_layer(m, x, dt)
--> [180]/ultralytics/nn/tasks.py:180) x = m(x) # run
[181]/ultralytics/nn/tasks.py:181) y.append(x if m.i in self.save else None) # save output
[182]/ultralytics/nn/tasks.py:182) if visualize:
File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530]/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539]/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540]/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543]/torch/nn/modules/module.py:1543) try:
[1544]/torch/nn/modules/module.py:1544) result = None
File /ultralytics/nn/modules/block.py:319, in C2f.forward(self, x)
[317]/ultralytics/nn/modules/block.py:317) y = list(self.cv1(x).chunk(2, 1))
[318]/ultralytics/nn/modules/block.py:318) y.extend(m(y[-1]) for m in self.m)
--> [319]/ultralytics/nn/modules/block.py:319) return self.cv2(torch.cat(y, 1))
File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530]/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539]/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540]/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543]/torch/nn/modules/module.py:1543) try:
[1544]/torch/nn/modules/module.py:1544) result = None
File /ultralytics/nn/modules/conv.py:93, in Conv.forward_fuse(self, x)
[83]/ultralytics/nn/modules/conv.py:83) def forward_fuse(self, x):
[84]/ultralytics/nn/modules/conv.py:84) """
[85]/ultralytics/nn/modules/conv.py:85) Apply convolution and activation without batch normalization.
[86]/ultralytics/nn/modules/conv.py:86)
(...)
[91]/ultralytics/nn/modules/conv.py:91) (torch.Tensor): Output tensor.
[92]/ultralytics/nn/modules/conv.py:92) """
---> [93]/ultralytics/nn/modules/conv.py:93) return self.act(self.conv(x))
File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
[1530]/torch/nn/modules/module.py:1530) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532) return self._call_impl(*args, **kwargs)
File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
[1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
[1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
[1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1539]/torch/nn/modules/module.py:1539) or _global_backward_pre_hooks or _global_backward_hooks
[1540]/torch/nn/modules/module.py:1540) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541) return forward_call(*args, **kwargs)
[1543]/torch/nn/modules/module.py:1543) try:
[1544]/torch/nn/modules/module.py:1544) result = None
File /torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
[459]/torch/nn/modules/conv.py:459) def forward(self, input: Tensor) -> Tensor:
--> [460]/torch/nn/modules/conv.py:460) return self._conv_forward(input, self.weight, self.bias)
File /torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
[452]/torch/nn/modules/conv.py:452) if self.padding_mode != 'zeros':
[453]/torch/nn/modules/conv.py:453) return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
[454]/torch/nn/modules/conv.py:454) weight, bias, self.stride,
[455]/torch/nn/modules/conv.py:455) _pair(0), self.dilation, self.groups)
--> [456]/torch/nn/modules/conv.py:456) return F.conv2d(input, weight, bias, self.stride,
[457]/torch/nn/modules/conv.py:457) self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [112, 192, 1, 1], expected input[1, 177, 40, 40] to have 192 channels, but got 177 channels instead
I am assuming that after channel pruning number of channels have changed causing this issue. Can you please help in resolving this issue or maybe tell me if I have done something wrong in my code?
Metadata
Metadata
Assignees
Labels
No labels