Skip to content

Hesso with yolov8 #13

@sumeshp99

Description

@sumeshp99

I tried hesso with yolov8 on VOC dataset. Followed sanity check and resnet tutorial as reference and this is the final code which I had.

import torch
import torch.nn as nn
from pathlib import Path
from ultralytics import YOLO
from ultralytics.data.build import build_yolo_dataset, build_dataloader
from ultralytics.data.utils import check_det_dataset
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.cfg import get_cfg
from tqdm import tqdm

# Import OTO components
from only_train_once import OTO
from only_train_once.quantization.quant_model import model_to_quantize_model
from only_train_once.quantization.quant_layers import QuantizationMode

# --- Configuration ---
class Config:
    dataset_yaml = 'VOC.yaml'
    imgsz = 640
    batch_size = 16
    stride = 32
    workers = 8
    epochs = 3
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    learning_rate = 1e-3
    weight_decay = 1e-4
    target_sparsity = 0.5
    multi_scale = False

config = Config()

def create_train_val_data(yaml_file, batch_size=16, imgsz=640, stride=32, workers=0):
    """Create train/val dataloaders for YOLOv8"""
    try:
        LOGGER.info(f"Checking dataset: {yaml_file}...")
        data_config = check_det_dataset(yaml_file)
        if 'path' not in data_config:
            data_config['path'] = Path(yaml_file).parent.parent.as_posix()
    except Exception as e:
        LOGGER.error(f"Failed to load/download '{yaml_file}': {e}")
        LOGGER.error("Please ensure your YAML file is correct and accessible.")
        return None, None, None, None, None

    LOGGER.info("Dataset check complete.")

    # Training data
    args = DEFAULT_CFG
    args.imgsz = imgsz
    
    train_dataset, train_dataloader = None, None
    val_dataset, val_dataloader = None, None

    if 'train' in data_config:
        train_path = data_config['train']
        LOGGER.info(f"Building training dataset from: {train_path}")
        train_dataset = build_yolo_dataset(
            cfg=args,
            img_path=train_path,
            batch=batch_size,
            data=data_config,
            mode='train',
            stride=stride
        )
        
        LOGGER.info("Building training dataloader...")
        train_dataloader = build_dataloader(
            dataset=train_dataset,
            batch=batch_size,
            workers=workers
        )

    # Validation data
    if 'val' in data_config:
        val_path = data_config['val']
        LOGGER.info(f"Building validation dataset from: {val_path}")
        val_dataset = build_yolo_dataset(
            cfg=args,
            img_path=val_path,
            batch=batch_size,
            data=data_config,
            mode='val',
            stride=stride
        )
        
        LOGGER.info("Building validation dataloader...")
        val_dataloader = build_dataloader(
            dataset=val_dataset,
            batch=batch_size,
            workers=workers
        )

    return train_dataset, train_dataloader, val_dataset, val_dataloader, data_config

def preprocess_batch(batch: dict, multi_scale: bool = False, stride: int = 32) -> dict:
    """Preprocess a batch of images for YOLOv8 training"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Move all tensors to device
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device, non_blocking=device.type == "cuda")
    
    # Normalize images
    batch["img"] = batch["img"].float() / 255.0
    
    # Multi-scale training (optional)
    if multi_scale:
        import random
        import math
        
        imgs = batch["img"]
        sz = (
            random.randrange(int(640 * 0.5), int(640 * 1.5 + 16))
            // stride * stride
        )
        sf = sz / max(imgs.shape[2:])
        
        if sf != 1:
            ns = [
                math.ceil(x * sf / 32) * 32 for x in imgs.shape[2:]
            ]
            imgs = nn.functional.interpolate(
                imgs, size=ns, mode="bilinear", align_corners=False
            )
        batch["img"] = imgs
    
    return batch

def setup_model_and_oto():
    """Setup YOLOv8 model with OTO and HESSO optimizer"""
    # Load YOLOv8 model
    model = YOLO('yolov8n.pt')
    model = model.model
    
    # Freeze specific layers
    freeze_list = []
    always_freeze_names = [".dfl"]
    freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
    
    for k, v in model.named_parameters():
        if any(x in k for x in freeze_layer_names):
            v.requires_grad = False
        elif not v.requires_grad and v.dtype.is_floating_point:
            v.requires_grad = True
    
    # Convert to quantized model
    # model = model_to_quantize_model(
    #     model, 
    #     quant_mode=QuantizationMode.WEIGHT_AND_ACTIVATION
    # )
    
    # Setup OTO
    dummy_input = torch.randn(1, 3, 640, 640)
    oto = OTO(model=model, dummy_input=dummy_input)
    
    # Mark unprunable parameters (YOLOv8 detection head)
    oto.mark_unprunable_by_param_names([
        'model.22.dfl.conv.weight',
        'model.22.cv3.2.2.weight',
        'model.22.cv2.2.2.weight',
        'model.22.cv2.1.2.weight',
        'model.22.cv3.1.2.weight',
        'model.22.cv3.0.2.weight',
        'model.22.cv2.0.2.weight'
    ])
    
    # Mark slice operations as unprunable
    for node_group in oto._graph.node_groups.values():
        for node in node_group:
            if node.op_name == 'slice':
                node_group.is_prunable = False
    
    # Setup model configuration
    model.args = get_cfg(DEFAULT_CFG, None)
    model.to(config.device)
    model.train()
    
    return model, oto


"""Main training function"""
LOGGER.info("Initializing YOLOv8 with HESSO optimizer...")

# Setup model and OTO
model, oto = setup_model_and_oto()

# Get dataloaders
train_dataset, train_dataloader, val_dataset, val_dataloader, data_config = create_train_val_data(
    config.dataset_yaml, 
    batch_size=config.batch_size, 
    imgsz=config.imgsz, 
    stride=config.stride, 
    workers=config.workers
)

if train_dataloader is None:
    LOGGER.error("Failed to create training dataloader. Exiting.")
    

# Create HESSO optimizer
optimizer = oto.hesso(
        variant="sgd",
        lr=0.01, # Main model LR
        first_momentum=0.9,
        weight_decay=1e-4,
        target_group_sparsity=0.5,         
        start_pruning_step = 0 * len(train_dataloader),  # Start pruning at epoch 5
        pruning_periods = 5,                       # Spread the pruning over 5 periods
        pruning_steps = config.epochs * len(train_dataloader)     
    )

LOGGER.info("Starting training with HESSO optimizer...")

lf = lambda x: max(1 - x / config.epochs, 0) * (1.0 - 0.01) + 0.01
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

# Training loop
for epoch in range(config.epochs):
    model.train()
    progress_bar = tqdm(
        enumerate(train_dataloader), 
        total=len(train_dataloader), 
        desc=f"Epoch {epoch+1}/{config.epochs}"
    )
    epoch_loss = 0.0
    lr_scheduler.step()
    for batch_idx, batch in progress_bar:
        # Preprocess batch
        batch = preprocess_batch(batch, config.multi_scale, config.stride)
        
        # Forward pass and loss computation
        loss, loss_items = model(batch)
        loss = loss.sum()
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress
        epoch_loss += loss.item()
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Avg Loss': f'{epoch_loss/(batch_idx+1):.4f}'
        })
    
    # End of epoch logging
    LOGGER.info(
        f"Epoch {epoch+1} completed. "
        f"Average Loss: {epoch_loss/len(train_dataloader):.4f}"
    )

LOGGER.info("Training completed!")

oto.construct_subnet(out_dir='./cache')

import os

full_model_size = os.stat(oto.full_group_sparse_model_path)
compressed_model_size = os.stat(oto.compressed_model_path)
print("Size of full model     : ", full_model_size.st_size / (1024 ** 3), "GBs")
print("Size of compress model : ", compressed_model_size.st_size / (1024 ** 3), "GBs")

full_model = torch.load(oto.full_group_sparse_model_path)
compressed_model = torch.load(oto.compressed_model_path)

But after I run

compressed_model(torch.rand(1,3,640,640).cuda())

I am getting this error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[23], [line 1](vscode-notebook-cell:?execution_count=23&line=1)
----> [1](vscode-notebook-cell:?execution_count=23&line=1) compressed_model(torch.rand(1,3,640,640).cuda())

File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530]/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539]/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540]/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543]/torch/nn/modules/module.py:1543) try:
   [1544]/torch/nn/modules/module.py:1544)     result = None

File /ultralytics/nn/tasks.py:139, in BaseModel.forward(self, x, *args, **kwargs)
    [137]/ultralytics/nn/tasks.py:137) if isinstance(x, dict):  # for cases of training and validating while training.
    [138]/ultralytics/nn/tasks.py:138)     return self.loss(x, *args, **kwargs)
--> [139]/ultralytics/nn/tasks.py:139) return self.predict(x, *args, **kwargs)

File /ultralytics/nn/tasks.py:157, in BaseModel.predict(self, x, profile, visualize, augment, embed)
    [155]/ultralytics/nn/tasks.py:155) if augment:
    [156]/ultralytics/nn/tasks.py:156)     return self._predict_augment(x)
--> [157]/ultralytics/nn/tasks.py:157) return self._predict_once(x, profile, visualize, embed)

File /ultralytics/nn/tasks.py:180, in BaseModel._predict_once(self, x, profile, visualize, embed)
    [178]/ultralytics/nn/tasks.py:178) if profile:
    [179]/ultralytics/nn/tasks.py:179)     self._profile_one_layer(m, x, dt)
--> [180]/ultralytics/nn/tasks.py:180) x = m(x)  # run
    [181]/ultralytics/nn/tasks.py:181) y.append(x if m.i in self.save else None)  # save output
    [182]/ultralytics/nn/tasks.py:182) if visualize:

File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530]/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539]/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540]/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543]/torch/nn/modules/module.py:1543) try:
   [1544]/torch/nn/modules/module.py:1544)     result = None

File /ultralytics/nn/modules/block.py:319, in C2f.forward(self, x)
    [317]/ultralytics/nn/modules/block.py:317) y = list(self.cv1(x).chunk(2, 1))
    [318]/ultralytics/nn/modules/block.py:318) y.extend(m(y[-1]) for m in self.m)
--> [319]/ultralytics/nn/modules/block.py:319) return self.cv2(torch.cat(y, 1))

File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530]/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539]/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540]/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543]/torch/nn/modules/module.py:1543) try:
   [1544]/torch/nn/modules/module.py:1544)     result = None

File /ultralytics/nn/modules/conv.py:93, in Conv.forward_fuse(self, x)
     [83]/ultralytics/nn/modules/conv.py:83) def forward_fuse(self, x):
     [84]/ultralytics/nn/modules/conv.py:84)     """
     [85]/ultralytics/nn/modules/conv.py:85)     Apply convolution and activation without batch normalization.
     [86]/ultralytics/nn/modules/conv.py:86) 
   (...)
     [91]/ultralytics/nn/modules/conv.py:91)         (torch.Tensor): Output tensor.
     [92]/ultralytics/nn/modules/conv.py:92)     """
---> [93]/ultralytics/nn/modules/conv.py:93)     return self.act(self.conv(x))

File /torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1530]/torch/nn/modules/module.py:1530)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1531]/torch/nn/modules/module.py:1531) else:
-> [1532]/torch/nn/modules/module.py:1532)     return self._call_impl(*args, **kwargs)

File /torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   [1536]/torch/nn/modules/module.py:1536) # If we don't have any hooks, we want to skip the rest of the logic in
   [1537]/torch/nn/modules/module.py:1537) # this function, and just call forward.
   [1538]/torch/nn/modules/module.py:1538) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1539]/torch/nn/modules/module.py:1539)         or _global_backward_pre_hooks or _global_backward_hooks
   [1540]/torch/nn/modules/module.py:1540)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1541]/torch/nn/modules/module.py:1541)     return forward_call(*args, **kwargs)
   [1543]/torch/nn/modules/module.py:1543) try:
   [1544]/torch/nn/modules/module.py:1544)     result = None

File /torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
    [459]/torch/nn/modules/conv.py:459) def forward(self, input: Tensor) -> Tensor:
--> [460]/torch/nn/modules/conv.py:460)     return self._conv_forward(input, self.weight, self.bias)

File /torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    [452]/torch/nn/modules/conv.py:452) if self.padding_mode != 'zeros':
    [453]/torch/nn/modules/conv.py:453)     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    [454]/torch/nn/modules/conv.py:454)                     weight, bias, self.stride,
    [455]/torch/nn/modules/conv.py:455)                     _pair(0), self.dilation, self.groups)
--> [456]/torch/nn/modules/conv.py:456) return F.conv2d(input, weight, bias, self.stride,
    [457]/torch/nn/modules/conv.py:457)                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [112, 192, 1, 1], expected input[1, 177, 40, 40] to have 192 channels, but got 177 channels instead

I am assuming that after channel pruning number of channels have changed causing this issue. Can you please help in resolving this issue or maybe tell me if I have done something wrong in my code?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions