Skip to content

Commit

Permalink
Fix FBLearner workflows for Cogwheel tests
Browse files Browse the repository at this point in the history
Summary:
Fixed FBLearner workflows:
- benchmark_flow.py
- example_workflow.py

This should enable Cogwheel tests to stop failing

Reviewed By: lvdmaaten

Differential Revision: D31505436

fbshipit-source-id: 793b6c38fde9bc364ae4271a9d2965cd4215b6bc
  • Loading branch information
knottb authored and facebook-github-bot committed Oct 11, 2021
1 parent fef556f commit 09f6368
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion crypten/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,12 @@ class AutogradCrossEntropy(AutogradFunction):
@staticmethod
def forward(ctx, pred, target, skip_forward=False):
# NOTE: target is assumed to be one-hot vector.
softmax = pred.softmax(1)
assert pred.size() == target.size()

# Ignore batch dimension
dim = 1 if pred.dim() > 1 else 0
softmax = pred.softmax(dim)

ctx.save_multiple_for_backward([softmax, target])
ctx.mark_non_differentiable(target)
if skip_forward:
Expand Down

0 comments on commit 09f6368

Please sign in to comment.