Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions tuned_lens/model_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,16 @@ def get_final_norm(model: Model) -> Norm:
),
):
final_layer_norm = base_model.ln_f
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
elif isinstance(
base_model,
(
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
models.qwen3.modeling_qwen3.Qwen3Model,
models.gpt_oss.modeling_gpt_oss.GptOssModel,
),
):
final_layer_norm = base_model.norm
else:
raise NotImplementedError(f"Unknown model type {type(base_model)}")
Expand Down Expand Up @@ -166,11 +171,16 @@ def get_transformer_layers(model: Model) -> tuple[str, th.nn.ModuleList]:
),
):
path_to_layers += ["h"]
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
path_to_layers += ["layers"]
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
path_to_layers += ["layers"]
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
elif isinstance(
base_model,
(
models.gpt_oss.modeling_gpt_oss.GptOssModel,
models.qwen3.modeling_qwen3.Qwen3Model,
models.llama.modeling_llama.LlamaModel,
models.mistral.modeling_mistral.MistralModel,
models.gemma.modeling_gemma.GemmaModel,
),
):
path_to_layers += ["layers"]
else:
raise NotImplementedError(f"Unknown model type {type(base_model)}")
Expand Down
188 changes: 188 additions & 0 deletions tuned_lens/muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Adapted from https://github.com/KellerJordan/Muon/blob/master/muon.py."""

import torch
import torch.distributed as dist
from torch import Tensor


def quintic_newtonschulz(G: Tensor, steps: int) -> Tensor:
"""Newton-Schulz iteration to compute the orthogonalization of G.

We opt to use a quintic iteration whose coefficients are selected to maximize the
slope at zero. For the purpose of minimizing steps, it turns out to be empirically
effective to keep increasing the slope at zero even beyond the point where the
iteration no longer converges all the way to one everywhere on the interval. This
iteration therefore does not produce UV^T but rather something like US'V^T where S'
is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
# batched implementation by @scottjmaddox, put into practice by @YouJiacheng
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
# quintic strategy adapted from suggestion by @jxbz, @leloykun, @YouJiacheng
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X

if G.size(-2) > G.size(-1):
X = X.mT
return X


class Muon(torch.optim.Optimizer):
"""Muon - MomentUm Orthogonalized by Newton-schulz.

Muon is a generalized steepest descent optimizer using the spectral norm on the
matrix-valued parameters. This means it always updates in the direction which
locally reduces the loss as much as possible, while constraining the update to have
a spectral norm given by the learning rate. It achieves this using a Newton-Schulz
iteration to orthogonalize the stochastic gradient (or momentum buffer) for each
matrix in the model before taking a step.

The spectral norm is an intuitive heuristic because, roughly speaking, it measures
the maximum change to the activations of a layer that can be caused by a change to
its weights. By constraining the worst-case change to the activations, we ensure
that we do not desta

TThis optimizer is unlikely to work well with small batch sizes, since it strongly
magnifies small singular values, which will be noisy given a small minibatch.
"""

def __init__(
self,
params,
lr: float = 1e-3,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.1,
ns_steps: int = 5,
ddp: bool = True,
):
"""Initialize the Muon optimizer.

You will need to set the `ddp` flag to `False` if you are using FSDP or some
similar scheme where parameters are sharded across multiple processes.

Args:
params: Iterable of parameters to optimize.
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD.
ns_steps: The number of Newton-Schulz iteration steps to use.
weight_decay: The decoupled weight decay to apply at each step.
ddp: Whether to distribute the work of Newton-Schulz across multiple
processes. This assumes that every process has all the parameters.
"""
defaults = dict(
lr=lr,
momentum=momentum,
nesterov=nesterov,
ns_steps=ns_steps,
weight_decay=weight_decay,
)
self.rank = dist.get_rank() if dist.is_initialized() and ddp else 0
self.world_size = dist.get_world_size() if dist.is_initialized() and ddp else 1

# Distributed Data Parallel (DDP) setup
if dist.is_initialized() and ddp:
param_groups = []

# Check that the user isn't doing some weird model parallelism
devices = {p.device for p in params}
device = next(iter(devices))
assert device.type == "cuda", "Muon only supports CUDA devices."
assert len(devices) == 1, "Muon does not support model parallelism."

# Group parameters by their device and number of elements. For each group,
# we pre-allocate a buffer to store the updates from all ranks.
for size in {p.numel() for p in params}:
b = torch.empty(
self.world_size, size, dtype=torch.bfloat16, device=device
)
group = dict(
params=[p for p in params if p.numel() == size],
update_buffer=b,
update_buffer_views=[b[i] for i in range(self.world_size)],
)
param_groups.append(group)

super().__init__(param_groups, defaults)
else:
super().__init__(params, defaults)

@torch.no_grad()
def step(self):
"""Performs a single optimization step."""
for group in self.param_groups:
params: list[Tensor] = group["params"]

# Apply decoupled weight decay to all parameters. This doesn't require any
# communication, since it's a simple element-wise operation.
if group["weight_decay"] > 0.0:
for p in params:
p.mul_(1 - group["lr"] * group["weight_decay"])

# These will be None / empty list if we're not using DDP
update_buffer: Tensor | None = group.get("update_buffer", None)
update_buffer_views: list[Tensor] = group.get("update_buffer_views", [])

beta = group["momentum"]
handle = None
params_world = None

def update_prev(): # optimized implementation contributed by @YouJiacheng
assert handle is not None and params_world is not None
handle.wait()

for p_world, g_world in zip(params_world, update_buffer_views):
# Heuristic from <https://arxiv.org/abs/2502.16982>
scale = 0.2 * max(p_world.shape) ** 0.5
p_world.add_(g_world.view_as(p_world), alpha=-group["lr"] * scale)

for i in range(0, len(params), self.world_size):
# Compute Muon update
if i + self.rank < len(params):
p = params[i + self.rank]
state = self.state[p]

g = p.grad
assert g is not None

# Apply momentum
if beta > 0.0:
if "exp_avg" not in state:
state["exp_avg"] = torch.zeros_like(g)

buf: Tensor = state["exp_avg"].lerp_(g, 1 - beta)
g = g.lerp_(buf, beta) if group["nesterov"] else buf

if g.ndim == 4: # for the case of conv filters
g = g.view(len(g), -1)

g = quintic_newtonschulz(g, steps=group["ns_steps"])
else:
g = update_buffer_views[self.rank]

if self.world_size > 1:
# async all_gather instead of sync all_reduce by @YouJiacheng
if i > 0:
update_prev()

handle = dist.all_gather_into_tensor(
update_buffer, g.flatten(), async_op=True
)
params_world = params[i : i + self.world_size]
else:
scale = 0.2 * max(params[i].shape) ** 0.5
params[i].add_(g, alpha=-group["lr"] * scale)

if self.world_size > 1:
update_prev()
11 changes: 9 additions & 2 deletions tuned_lens/nn/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import PreTrainedModel

from tuned_lens import load_artifacts
from tuned_lens.model_surgery import Norm
from tuned_lens.nn.unembed import Unembed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,13 +69,16 @@ def __init__(
def from_model(
cls,
model: PreTrainedModel,
*,
final_norm: Optional[Norm] = None,
) -> "LogitLens":
"""Create a LogitLens from a pretrained model.

Args:
model: A pretrained model from the transformers library you wish to inspect.
final_norm: An optional final layer normalization to apply.
"""
unembed = Unembed(model)
unembed = Unembed(model, final_norm=final_norm)
return cls(unembed)

def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor:
Expand Down Expand Up @@ -182,18 +186,21 @@ def from_model(
model: PreTrainedModel,
model_revision: Optional[str] = None,
bias: bool = True,
*,
final_norm: Optional[Norm] = None,
) -> "TunedLens":
"""Create a lens from a pretrained model.

Args:
model: The model to create the lens from.
model_revision: The git revision of the model to used.
bias: Whether to use a bias in the linear translators.
final_norm: An optional final layer normalization to apply.

Returns:
A TunedLens instance.
"""
unembed = Unembed(model)
unembed = Unembed(model, final_norm=final_norm)
config = TunedLensConfig(
base_model_name_or_path=model.config.name_or_path,
base_model_revision=model_revision,
Expand Down
5 changes: 4 additions & 1 deletion tuned_lens/nn/unembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ class Unembed(th.nn.Module):
def __init__(
self,
model: model_surgery.Model,
*,
final_norm: Optional[model_surgery.Norm] = None,
):
"""Initialize unmebed.

Args:
model: A HuggingFace model from which to extract the unembedding matrix.
final_norm: An optional final layer normalization to apply before the
"""
super().__init__()
final_norm = model_surgery.get_final_norm(model)
final_norm = final_norm or model_surgery.get_final_norm(model)
unembedding_matrix = model_surgery.get_unembedding_matrix(model)

self.final_norm = copy.deepcopy(final_norm)
Expand Down
3 changes: 1 addition & 2 deletions tuned_lens/scripts/eval_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ def execute(self):
# Note since we are not training we can just move the lens to the device.
# No need to use DDP
lenses = {name: lens.to(self.dist.device) for name, lens in lenses.items()}
dl = self.dist.dataloader(data)
dl.seed(self.seed)
dl = self.dist.dataloader(data, self.seed)

for lens in lenses.values():
lens.eval()
Expand Down
Loading
Loading