Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 142 additions & 6 deletions benchmarks/imagenet/resnet50/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def warn_with_traceback(message, category, filename, lineno, file=None, line=Non
from argparse import ArgumentParser
from datetime import datetime
from pathlib import Path
from typing import Dict, Sequence, Union
from typing import Dict, Optional, Sequence, Union

import barlowtwins
import byol
Expand Down Expand Up @@ -67,6 +67,9 @@ def warn_with_traceback(message, category, filename, lineno, file=None, line=Non
parser.add_argument("--float32-matmul-precision", type=str, default="high")
parser.add_argument("--strategy", default="ddp_find_unused_parameters_true")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--verbose", action="store_true", help="Print full configuration before training"
)

METHODS = {
"barlowtwins": {
Expand All @@ -85,6 +88,125 @@ def warn_with_traceback(message, category, filename, lineno, file=None, line=Non
}


def create_full_config(
args_dict: Dict,
method: str,
method_dir: Path,
world_size: int,
) -> Dict:
"""Create full resolved configuration with derived values.

Args:
args_dict: Dictionary of parsed arguments.
method: SSL method name.
method_dir: Directory for logging this method.
world_size: Number of devices/processes.

Returns:
Full configuration dictionary with derived values.
"""
config = args_dict.copy()

# Add derived values
config["method"] = method
config["world_size"] = world_size
config["global_batch_size"] = args_dict["batch_size_per_device"] * world_size

# Add effective learning rate (typical SimCLR scaling: lr * sqrt(batch_size))
# Note: Actual LR is computed in each method's configure_optimizers
base_lr = 0.075 # SimCLR base learning rate
config["effective_lr_approx"] = base_lr * (config["global_batch_size"] ** 0.5)

# Add environment info
config["pytorch_version"] = torch.__version__
config["cuda_available"] = torch.cuda.is_available()
config["timestamp"] = datetime.now().isoformat()
config["log_directory"] = str(method_dir)

# Convert Path objects to strings for serialization
for key, value in config.items():
if isinstance(value, Path):
config[key] = str(value)

return config


def save_config(config: Dict, output_dir: Path) -> None:
"""Save configuration to a text file.

Args:
config: Configuration dictionary to save.
output_dir: Directory to save config file.
"""
output_dir.mkdir(parents=True, exist_ok=True)

# Save as plain text
config_path = output_dir / "config.txt"
with open(config_path, "w") as f:
f.write("FULL RESOLVED CONFIGURATION\n")
f.write("=" * 80 + "\n\n")
for key, value in config.items():
f.write(f"{key}: {value}\n")
print_rank_zero(f" 💾 Config saved to: {config_path}")


def print_config(config: Dict) -> None:
"""Pretty print the full configuration.

Args:
config: Configuration dictionary to print.
"""
print_rank_zero("\n" + "=" * 80)
print_rank_zero("🔧 FULL RESOLVED CONFIGURATION")
print_rank_zero("=" * 80)

# Group by category
categories = {
"📂 Paths": ["train_dir", "val_dir", "log_dir", "log_directory", "ckpt_path"],
"🧠 Method": ["method", "methods"],
"📊 Data": [
"num_classes",
"batch_size_per_device",
"global_batch_size",
"num_workers",
],
"🎓 Training": ["epochs", "effective_lr_approx", "float32_matmul_precision"],
"📈 Evaluation": [
"skip_knn_eval",
"skip_linear_eval",
"skip_finetune_eval",
"knn_k",
"knn_t",
],
"💻 Hardware": [
"accelerator",
"devices",
"world_size",
"precision",
"strategy",
"cuda_available",
],
"🔧 Other": ["seed", "compile_model", "pytorch_version", "timestamp"],
}

for category, keys in categories.items():
matching_items = [(k, v) for k, v in config.items() if k in keys]
if matching_items:
print_rank_zero(f"\n{category}:")
for key, value in matching_items:
print_rank_zero(f" {key}: {value}")

# Print any remaining keys not categorized
categorized_keys = set(k for keys in categories.values() for k in keys)
remaining = [(k, v) for k, v in config.items() if k not in categorized_keys]
if remaining:
print_rank_zero("\n📝 Additional:")
for key, value in remaining:
print_rank_zero(f" {key}: {value}")

print_rank_zero("=" * 80 + "\n")


def main(
train_dir: Path,
val_dir: Path,
Expand All @@ -106,11 +228,13 @@ def main(
ckpt_path: Union[Path, None],
float32_matmul_precision: str,
strategy: str,
seed: int | None = None,
seed: Optional[int] = None,
verbose: bool = False,
) -> None:
print_rank_zero(f"Args: {locals()}")
seed_everything(seed, workers=True, verbose=True)
# Store args for config creation
args_dict = locals().copy()

seed_everything(seed, workers=True, verbose=True)
torch.set_float32_matmul_precision(float32_matmul_precision)

method_names = methods or METHODS.keys()
Expand All @@ -120,6 +244,18 @@ def main(
log_dir / method / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
).resolve()
print_rank_zero(f"Logging to {method_dir}")

# Create and save full resolved config
world_size = devices if accelerator != "cpu" else 1
config = create_full_config(args_dict, method, method_dir, world_size)

# Print config if verbose flag is set
if verbose:
print_config(config)

# Save config files
save_config(config, method_dir)

model = METHODS[method]["model"](
batch_size_per_device=batch_size_per_device, num_classes=num_classes
)
Expand Down Expand Up @@ -233,7 +369,7 @@ def pretrain(
shuffle=True,
num_workers=num_workers,
drop_last=True,
persistent_workers=True,
persistent_workers=num_workers > 0,
)

# Setup validation data.
Expand All @@ -251,7 +387,7 @@ def pretrain(
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
persistent_workers=num_workers > 0,
)

# Train model.
Expand Down