Skip to content

Commit 9541f49

Browse files
committed
One more scalar -> tensor fix for lamb optimizer
1 parent 8f68193 commit 9541f49

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

timm/optim/lamb.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def step(self, closure=None):
9898
and returns the loss.
9999
"""
100100
device = self.param_groups[0]["params"][0].device
101-
one_tensor = torch.tensor(1.0, device=device)
101+
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
102102

103103
loss = None
104104
if closure is not None:
@@ -115,7 +115,9 @@ def step(self, closure=None):
115115
global_grad_norm.add_(grad.pow(2).sum())
116116

117117
global_grad_norm = torch.sqrt(global_grad_norm)
118-
max_grad_norm = self.defaults['max_grad_norm']
118+
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
119+
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
120+
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
119121
clip_global_grad_norm = torch.where(
120122
global_grad_norm > max_grad_norm,
121123
global_grad_norm / max_grad_norm,

0 commit comments

Comments
 (0)