diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 2ab9ae37af60a..81e6faf81734c 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1642,6 +1642,33 @@ def forward(self, {', '.join(param_names)}): test_out = traced(*param_values) self.assertEqual(test_out, ref_out) + def test_normalize_quantized_eb(self): + target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets + args = ( + torch.empty((2, 3), dtype=torch.uint8), + torch.empty((2,), dtype=torch.int64), + torch.empty((2,), dtype=torch.int64), + ) + norm_args_and_kwargs = normalize_function( + target, args, normalize_to_only_use_kwargs=True + ) + self.assertTrue(norm_args_and_kwargs is not None) + self.assertEqual( + set(norm_args_and_kwargs.kwargs.keys()), + { + "weight", + "indices", + "offsets", + "scale_grad_by_freq", + "mode", + "pruned_weights", + "per_sample_weights", + "compressed_indices_mapping", + "include_last_offset", + }, + ) + self.assertEqual(norm_args_and_kwargs.args, tuple()) + instantiate_device_type_tests(TestNormalizeOperators, globals()) diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 5c2d484623983..0b72afaca2ed6 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -257,6 +257,8 @@ def normalize_function( if kwargs is None: kwargs = {} new_args_and_kwargs = None + if isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload): + target = target.op if not isinstance(target, types.BuiltinFunctionType): target_for_analysis = target if target in boolean_dispatched: