Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 139 additions & 1 deletion bergson/distributed.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,152 @@
import os
import socket
from collections import defaultdict
from typing import Any, Callable

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
from torch.distributed.tensor import (
DTensor,
Partial,
Replicate,
Shard,
distribute_tensor,
)
from torch.nn.utils.parametrize import register_parametrization
from torch.utils.checkpoint import (
CheckpointPolicy,
checkpoint,
create_selective_checkpoint_contexts,
)

from .config import DistributedConfig


def grad_tree(
outputs: torch.Tensor,
inputs: dict[str, torch.Tensor],
grad_outputs: dict[str, torch.Tensor] | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Compute grads of loss wrt inputs dict, returning a dict with the same keys.

Args:
outputs: The output tensor to compute gradients for.
inputs: A dict of input tensors to compute gradients with respect to.
grad_outputs: Optional dict of gradient outputs for each output tensor.
**kwargs: Additional keyword arguments to pass to torch.autograd.grad.
"""
if grad_outputs is not None:
kwargs["grad_outputs"] = list(grad_outputs.values())

grads = torch.autograd.grad(
outputs,
list(inputs.values()),
**kwargs,
allow_unused=True,
)
return dict(zip(inputs, grads))


def fsdp_policy():
def _fsdp_recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
to_recompute = func in {
torch.ops._c10d_functional.all_gather_into_tensor.default, # type: ignore[attr-defined]
torch.ops._c10d_functional.wait_tensor.default, # type: ignore[attr-defined]
}
return (
CheckpointPolicy.MUST_RECOMPUTE
if to_recompute
else CheckpointPolicy.MUST_SAVE
)

return _custom_policy

return create_selective_checkpoint_contexts(_fsdp_recomp_policy())


class ReplicateComputation(torch.nn.Module):
def replicate_compute(self, x):
return x.redistribute(
placements=(Replicate(),),
).to_local(grad_placements=(Partial(reduce_op="avg"),))

def forward(self, x):
return checkpoint(
self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy
)


def shallow_copy(tensor_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Create a shallow copy of a dict of tensors, handling tied weights."""
# For each unique tensor, construct a list of the places in the model where it
# appears. This is a bit wonky, but it is the best way to handle tied weights.
tensor_to_paths = defaultdict(list)
for path, param in tensor_dict.items():
tensor_to_paths[param].append(path)

# Use a while loop to avoid modifying the dict while iterating over it. We don't
# want to hold onto both the original and copied versions of each parameter.
tensor_dict = {}
while tensor_to_paths:
t, paths = tensor_to_paths.popitem()

if isinstance(t, DTensor):
t2 = DTensor.from_local(t.to_local(), t.device_mesh, t.placements)
else:
t2 = torch.Tensor(t.data)

# Update all occurrences of this parameter in the model
t2.requires_grad_(t.requires_grad)
# for path in paths:
tensor_dict[paths[0]] = t2

return tensor_dict


def simple_fsdp(model: torch.nn.Module) -> torch.nn.Module:
"""SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile"""
# For each unique parameter, construct a list of the places in the model where it
# appears. This is a bit wonky, but it is the best way to handle tied weights.
param_to_paths = defaultdict(list)
for path, param in model.named_parameters(remove_duplicate=False):
param_to_paths[param].append(path)

# Use a while loop to avoid modifying the dict while iterating over it. We don't
# want to hold onto both the original and distributed versions of each parameter.
while param_to_paths:
param, paths = param_to_paths.popitem()

# Create a new distributed version of this param
dist_param = torch.nn.Parameter(
distribute_tensor(param, placements=(Shard(0),))
)

# Update all occurrences of this parameter in the model
for path in paths:
# Find the module that has a reference to this parameter
mod_name, _, p_name = path.rpartition(".")
mod = model.get_submodule(mod_name)

# Re-register the parameter with sharding and replication
mod.register_parameter(p_name, dist_param)
register_parametrization(
mod,
p_name,
ReplicateComputation(),
unsafe=True,
)

return model


Worker = Callable[[int, int, int, object], None]
"""A worker function for distributed training."""


def dist_worker(
worker: Callable,
*worker_args,
Expand All @@ -28,7 +166,7 @@ def dist_worker(

def launch_distributed_run(
process_name: str,
worker,
worker: Worker,
const_worker_args: list[Any],
dist_config: DistributedConfig | None = None,
):
Expand Down
6 changes: 0 additions & 6 deletions bergson/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,6 @@ def on_step_end(
# Read normalizers off of the optimizer state. We need to figure out
# what type of optimizer this is first.
for group in optimizer.param_groups:
lr_sqrt = group["lr"] ** 0.5

for param in group["params"]:
name = param_to_name[param].removesuffix(".weight")
if name not in self.collector.target_info:
Expand All @@ -299,10 +297,6 @@ def on_step_end(
else:
continue

# Scale the gradient by the current learning rate. It's factorized
# so we multiply each factor by the square root of the LR.
norm.row *= lr_sqrt
norm.col *= lr_sqrt
normalizers[name] = norm

proc.normalizers = normalizers
Expand Down
106 changes: 106 additions & 0 deletions bergson/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import NamedTuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from bergson.utils.math import weighted_ce


class Output(NamedTuple):
loss: torch.Tensor | None
logits: torch.Tensor


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_channels, out_channels, stride=1):
super().__init__()

self.conv1 = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels)

self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)

self.shortcut = nn.Identity()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=stride, bias=False
),
nn.BatchNorm2d(out_channels),
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)


class ResNetCIFAR(nn.Module):
def __init__(self, num_classes=10):
super().__init__()

self.in_channels = 64

# CIFAR stem
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)

# Residual stages (2,2,2,2)
self.layer1 = self._make_layer(64, 2, stride=1)
self.layer2 = self._make_layer(128, 2, stride=2)
self.layer3 = self._make_layer(256, 2, stride=2)
self.layer4 = self._make_layer(512, 2, stride=2)

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)

self._init_weights()

def _make_layer(self, out_channels, blocks, stride):
layers = []
layers.append(BasicBlock(self.in_channels, out_channels, stride))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(BasicBlock(out_channels, out_channels))
return nn.Sequential(*layers)

def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

def forward(self, pixel_values, labels, *, example_weight=None):
x = F.relu(self.bn1(self.conv1(pixel_values)))

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
logits = self.fc(x)

if labels is not None:
loss = weighted_ce(
labels=labels,
logits=logits,
example_weight=example_weight,
)
return Output(loss, logits)

return Output(None, logits)
Loading
Loading