-
Notifications
You must be signed in to change notification settings - Fork 74
Updates to LM PR #891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: lm_workload
Are you sure you want to change the base?
Updates to LM PR #891
Conversation
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
loss = -jnp.sum(targets * jax.nn.log_softmax(logits, axis=-1)) | ||
return loss | ||
# TODO(kasimbeg): add weights? | ||
metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this workload I don't think we need weights for the cross-entropy calculation. Maybe we should explicitly del any weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weights are used to identify padded elements in the validation split and correctly calculate the number of tokens returned in the eval dict.
ds, | ||
) | ||
|
||
return iter(it) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use itertools.cycle here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I don't think so. The input_pipeline already calls .repeat() on the train split and we don't want cycle on the validation split.
…efficiency into lm_workload_priya
No description provided.