File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -98,7 +98,7 @@ def step(self, closure=None):
98
98
and returns the loss.
99
99
"""
100
100
device = self .param_groups [0 ]["params" ][0 ].device
101
- one_tensor = torch .tensor (1.0 , device = device )
101
+ one_tensor = torch .tensor (1.0 , device = device ) # because torch.where doesn't handle scalars correctly
102
102
103
103
loss = None
104
104
if closure is not None :
@@ -115,7 +115,9 @@ def step(self, closure=None):
115
115
global_grad_norm .add_ (grad .pow (2 ).sum ())
116
116
117
117
global_grad_norm = torch .sqrt (global_grad_norm )
118
- max_grad_norm = self .defaults ['max_grad_norm' ]
118
+ # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
119
+ # scalar types properly https://github.com/pytorch/pytorch/issues/9190
120
+ max_grad_norm = torch .tensor (self .defaults ['max_grad_norm' ], device = device )
119
121
clip_global_grad_norm = torch .where (
120
122
global_grad_norm > max_grad_norm ,
121
123
global_grad_norm / max_grad_norm ,
You can’t perform that action at this time.
0 commit comments