-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove 'model_.' prefix from onnx model initializers in training (#3881)
* Remove 'model_.' prefix for onnx model initializers in training * fix test case remove redundant device test * rename * Fix state_dict/load_state_dict with frozen_weight * nit * Add monkey patch for pt opset 10 * remove pt patch in CI * nit: newline
- Loading branch information
Showing
6 changed files
with
179 additions
and
229 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
|
||
from torch.onnx import symbolic_opset10 | ||
from torch.onnx.symbolic_helper import parse_args | ||
|
||
@parse_args('v', 'v', 'v', 'v', 'i', 'none') | ||
def nll_loss(g, self, target, weight=None, reduction='mean', ignore_index=-100): | ||
if not weight and not ignore_index: | ||
return g.op("nll_loss", self, target) | ||
elif ignore_index: | ||
ignore_index_ = g.op("Constant", value_t=torch.tensor(ignore_index, dtype=torch.int64)) | ||
eq_ = g.op("Equal", target, ignore_index_) | ||
not_eq_ = g.op("Not", eq_) | ||
weight_ = g.op("Cast", not_eq_, to_i=1) # FLOAT = 1; // float | ||
not_eq_int64_ = g.op("Cast", not_eq_, to_i=7) #INT64 = 7; // int64_t | ||
target_ = g.op("Mul", target, not_eq_int64_) | ||
# if weight: | ||
# weight_ = g.op("Mul", weight_, weight) | ||
return g.op("nll_loss", self, target_, weight_) | ||
|
||
symbolic_opset10.nll_loss = nll_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.