Skip to content

Commit

Permalink
Backend pytorch: Model.predict operator clear gradient (lululxvi#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchelldaneker authored Nov 15, 2022
1 parent b184122 commit 38f4899
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,8 @@ def op(inputs):
"Model.predict() with auxiliary variable hasn't been implemented "
"for backend pytorch."
)
# Clear cached Jacobians and Hessians.
grad.clear()
y = utils.to_numpy(y)
elif backend_name == "paddle":
self.net.eval()
Expand Down

0 comments on commit 38f4899

Please sign in to comment.