Skip to content

SABlock parameters when using more heads #7661

@NabJa

Description

@NabJa

Describe the bug
The number of parameters in the SABlock should be increased when increasing the number of heads (num_heads). However, this is not the case and limits comparability to famous scaling like ViT-S or ViT-B.

To Reproduce
Steps to reproduce the behavior:

from monai.networks.nets import ViT

def count_trainable_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Create ViT models with different numbers of heads
vit_b = ViT(1, 224, 16, num_heads=12)
vit_s = ViT(1, 224, 16, num_heads=6)

print("ViT with 12 heads parameters:", count_trainable_parameters(vit_b))
print("ViT with 6 heads parameters:", count_trainable_parameters(vit_s))

>>> ViT with 12 heads parameters: 90282240
>>> ViT with 6 heads parameters: 90282240

Expected behavior
The number of trainable parameters should be increased with increasing number of heads.

Environment

================================
Printing MONAI config...
================================
MONAI version: 0.8.1rc4+1384.g139182ea
Numpy version: 1.26.4
Pytorch version: 2.2.2+cpu
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 139182ea52725aa3c9214dc18082b9837e32f9a2
MONAI __file__: C:\Users\<username>\MONAI\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: 5.3.0
Nibabel version: 5.2.1
scikit-image version: 0.23.1
scipy version: 1.13.0
Pillow version: 10.3.0
Tensorboard version: 2.16.2
gdown version: 4.7.3
TorchVision version: 0.17.2+cpu
tqdm version: 4.66.2
lmdb version: 1.4.1
psutil version: 5.9.8
pandas version: 2.2.2
einops version: 0.7.0
transformers version: 4.39.3
mlflow version: 2.12.1
pynrrd version: 1.0.0
clearml version: 1.15.1

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Windows
Win32 version: ('10', '10.0.22621', 'SP0', 'Multiprocessor Free')
Win32 edition: Professional
Platform: Windows-10-10.0.22621-SP0
Processor: Intel64 Family 6 Model 142 Stepping 12, GenuineIntel
Machine: AMD64
Python version: 3.11.8
Process name: python.exe
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='C:\\Windows\\System32\\de-DE\\KernelBase.dll.mui', fd=-1), popenfile(path='C:\\Windows\\System32\\de-DE\\kernel32.dll.mui', fd=-1), popenfile(path='C:\\Windows\\System32\\de-DE\\tzres.dll.mui', fd=-1)]
Num physical CPUs: 4
Num logical CPUs: 8
Num usable CPUs: 8
CPU usage (%): [3.9, 0.2, 3.7, 0.9, 3.9, 3.9, 2.8, 32.2]
CPU freq. (MHz): 1803
Load avg. in last 1, 5, 15 mins (%): [0.0, 0.0, 0.0]
Disk usage (%): 83.1
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 15.8
Available memory (GB): 5.5
Used memory (GB): 10.2

================================
Printing GPU config...
================================
Num GPUs: 0
Has CUDA: False
cuDNN enabled: False
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions