Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaowuhu committed Jan 11, 2024
1 parent 0e9e65e commit 48b994b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
9 changes: 5 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,16 +2461,17 @@ def aten_upsample_nearest2d_backward(
raise NotImplementedError()


@torch_op("aten::upsample_nearest3d", trace_only=True)
def aten_upsample_nearest3d(
self: TensorType,
output_size: INT64,
self: TReal,
size: INT64,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
) -> TReal:
"""upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor"""

raise NotImplementedError()
return op.Identity(self)


def aten_upsample_nearest3d_backward(
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def run_test_output_match(
),
kwargs=repr(cpu_sample.kwargs),
):
if i != 0: continue

Check warning

Code scanning / lintrunner

RUFF/E701 Warning test

test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype)

with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
Expand Down
5 changes: 5 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,11 @@ def _where_input_wrangler(
matcher=lambda sample: "scale_factor" in sample.kwargs,
reason="fixme: the scale_factor tests",
),
TorchLibOpInfo(
"nn.functional.upsample_nearest3d",
nn_ops.aten_upsample_nearest3d,
trace_only=True,
),
TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True),
TorchLibOpInfo(
"roll",
Expand Down

0 comments on commit 48b994b

Please sign in to comment.