@@ -103,8 +103,12 @@ __global__ void train(Memory<Vector, Index> head_embeddings, Memory<Vector, Inde
103103 sample_loss += weight * -log (prob + kEpsilon );
104104 } else {
105105 gradient = prob;
106- if (adversarial_temperature > kEpsilon )
106+ if (adversarial_temperature > kEpsilon ) {
107107 weight = safe_exp ((logit - bias) / adversarial_temperature) / normalizer;
108+ // the normalizer may be out of date in ASGD
109+ // so we need to clip the weight
110+ weight = min (weight, Float (1 ));
111+ }
108112 else
109113 weight = 1.0 / num_negative;
110114 sample_loss += weight * -log (1 - prob + kEpsilon );
@@ -198,8 +202,12 @@ __global__ void train_1_moment(Memory<Vector, Index> head_embeddings, Memory<Vec
198202 sample_loss += weight * -log (prob + kEpsilon );
199203 } else {
200204 gradient = prob;
201- if (adversarial_temperature > kEpsilon )
205+ if (adversarial_temperature > kEpsilon ) {
202206 weight = safe_exp ((logit - bias) / adversarial_temperature) / normalizer;
207+ // the normalizer may be out of date in ASGD
208+ // so we need to clip the weight
209+ weight = min (weight, Float (1 ));
210+ }
203211 else
204212 weight = 1.0 / num_negative;
205213 sample_loss += weight * -log (1 - prob + kEpsilon );
@@ -298,8 +306,12 @@ __global__ void train_2_moment(Memory<Vector, Index> head_embeddings, Memory<Vec
298306 sample_loss += weight * -log (prob + kEpsilon );
299307 } else {
300308 gradient = prob;
301- if (adversarial_temperature > kEpsilon )
309+ if (adversarial_temperature > kEpsilon ) {
302310 weight = safe_exp ((logit - bias) / adversarial_temperature) / normalizer;
311+ // the normalizer may be out of date in ASGD
312+ // so we need to clip the weight
313+ weight = min (weight, Float (1 ));
314+ }
303315 else
304316 weight = 1.0 / num_negative;
305317 sample_loss += weight * -log (1 - prob + kEpsilon );
0 commit comments