Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fp16 aten_native_batch_norm when bias is None and training is True #1217

Merged
merged 9 commits into from
Jan 23, 2024

Conversation

BowenBao
Copy link
Contributor

Fixes nfnet in https://github.com/microsoft/onnx-converters-private/issues/196

Two changes:

  • Cast locally created weight and bias to input dtype.
  • Upcast input when computing mean and var.

No idea how this was not covered by unittest. The test case seems to be there.

Minimized repro:

import torch

class Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(3, track_running_stats=False)

    def forward(self, x):
        # return self.bn(x)
        return torch.nn.functional.batch_norm(x, None, None, self.bn.weight, training=True)
    
model = Module().cuda().to(dtype=torch.float16).eval()
x = torch.randn(1, 3, 224, 224, dtype=torch.float16).cuda()

op = torch.onnx.dynamo_export(model, x)
op.save("report_bn.onnx")
op.save_diagnostics("report_bn.sarif")


import onnxruntime
sess = onnxruntime.InferenceSession("report_bn.onnx")
sess.run(None, {"l_x_": x.cpu().numpy()})

@BowenBao BowenBao added the topic: torch_lib Related to the torch/aten function lib in development label Dec 12, 2023
Copy link

codecov bot commented Dec 12, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (4de6c80) 78.84% compared to head (e6327ee) 78.84%.
Report is 1 commits behind head on main.

❗ Current head e6327ee differs from pull request most recent head 97ca176. Consider uploading reports for the commit 97ca176 to get more accurate results

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1217   +/-   ##
=======================================
  Coverage   78.84%   78.84%           
=======================================
  Files         119      119           
  Lines       15690    15692    +2     
  Branches     2479     2479           
=======================================
+ Hits        12371    12373    +2     
  Misses       2911     2911           
  Partials      408      408           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link

github-actions bot commented Dec 12, 2023

Test Results

     24 files  ± 0      24 suites  ±0   1h 43m 36s ⏱️ - 2m 40s
 11 397 tests ± 0   8 430 ✅  - 3    2 953 💤 ±0   14 ❌ + 3 
258 108 runs  +38  58 671 ✅ ±0  199 242 💤 ±0  195 ❌ +38 

For more details on these failures, see this check.

Results for commit e6327ee. ± Comparison against base commit 4de6c80.

♻️ This comment has been updated with latest results.

@BowenBao
Copy link
Contributor Author

BowenBao commented Dec 13, 2023

I saw some batch_norm failure in CI, not sure if they are new. Also why only failing on torch-nightly?

E onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:_aten_native_batch_norm_training_onnx, node name: _aten_native_batch_norm_training_onnx_1): [TypeInferenceError] Inferred elem type differs from existing elem type: (1) vs (10)

AssertionError: Output 0 mismatch

Cannot repro this on local with same ort,onnx,torch version.

@justinchuby justinchuby reopened this Dec 20, 2023
@justinchuby
Copy link
Collaborator

Looks like there is now assertion errors

@justinchuby justinchuby merged commit b4c9a8b into main Jan 23, 2024
29 of 33 checks passed
@justinchuby justinchuby deleted the bowbao/bn_fp16 branch January 23, 2024 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
Development

Successfully merging this pull request may close these issues.

2 participants