Skip to content

Commit

Permalink
Fix __name__ on a reconstructed NestedUserFunctionVariable (pytorch…
Browse files Browse the repository at this point in the history
…#118768)

```
def f():
    def g():
        return ()

    print(g.__name__)

f()
```

The following script should print `g` (with or without torch.compile),
but prints `f.<locals>.g` with torch.compile.

The problem looks like we use the co_qualname when reconstructing the
NestedUserFunctionVariable. I switched this over to use the co_name.

Pull Request resolved: pytorch#118768
Approved by: https://github.com/yanboliang, https://github.com/jansel
  • Loading branch information
zou3519 authored and pytorchmergebot committed Feb 1, 2024
1 parent b0e65dd commit 318e6ff
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 33 deletions.
5 changes: 4 additions & 1 deletion scripts/compile_tests/update_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def get_intersection_and_outside(a_dict, b_dict):
def build_dict(keys):
result = {}
for k in keys:
result[k] = a_dict.get(k, b_dict[k])
if k in a_dict:
result[k] = a_dict[k]
else:
result[k] = b_dict[k]
return result

return build_dict(intersection), build_dict(outside)
Expand Down
20 changes: 20 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,26 @@ def fn(param, param2):
self.assertEqual(opt_fn(param, param), fn(param, param))
self.assertEqual(cnts.frame_count, 2) # Recompiles

def test_reconstructed_name(self):
lst = []

@torch._dynamo.disable
def disallowed(g):
lst.append(g.__name__)

def f():
def g():
return ()

disallowed(g)

f_opt = torch._dynamo
opt_f = torch._dynamo.optimize(backend="eager")(f)
opt_f()
f()
self.assertEqual(len(lst), 2)
self.assertEqual(lst[0], lst[1])

@unittest.skipIf(
sys.version_info < (3, 10),
"zip strict kwargs not implemented for Python < 3.10",
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import get_first_attr, make_cell
from .base import typestr, VariableTracker
from .constant import ConstantVariable

if TYPE_CHECKING:
from torch._guards import Source
Expand Down Expand Up @@ -508,7 +509,7 @@ def reconstruct(self, codegen):
codegen.load_import_from(__name__, "_create_nested_fn")
codegen(self.code)
codegen.extend_output([codegen._create_load_const(self.f_globals)])
codegen(self.fn_name)
codegen(ConstantVariable.create(self.code.value.co_name))

if self.defaults:
codegen(self.defaults)
Expand Down
32 changes: 1 addition & 31 deletions torch/testing/_internal/dynamo_test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,8 +1533,6 @@
"TestNamedTensor.test_masked_fill", # test_namedtensor
"TestNamedTensor.test_addmv", # test_namedtensor
"TestNamedTensor.test_cummax_cummin", # test_namedtensor
"TestNamedTensor.test_no_jit_script_support", # test_namedtensor
"TestNamedTensor.test_no_jit_tracer_support", # test_namedtensor
"TestNamedTensor.test_rename_rename_map", # test_namedtensor
"TestNamedTensor.test_mm", # test_namedtensor
"TestNamedTensor.test_no_save_support", # test_namedtensor
Expand Down Expand Up @@ -1602,11 +1600,9 @@
"TestFX.test_immutable_list_pytree_ops", # test_fx
"TestCommonPass.test_correctness_CSEPass_TakeList_cpu", # test_fx
"TestPassManager.test_pass_manager", # test_fx
"TestFX.test_user_friendly_call_provenance_with_function", # test_fx
"TestCommonPass.test_correctness_CSEPass_MutationMetadata_cpu", # test_fx
"TestCommonPass.test_correctness_CSEPass_MutationTorchTensorCall_cpu", # test_fx
"TestCommonPass.test_correctness_CSEPass_MutationInput_cpu", # test_fx
"TestFX.test_fn_type_annotation_empty", # test_fx
"TestFX.test_immutable_dict_pytree_ops", # test_fx
"TestCommonPass.test_correctness_factory_CSEPass_MutationFactory_cpu", # test_fx
"TestCommonPass.test_correctness_factory_CSEPass_FactoryFunctionCall_cpu", # test_fx
Expand Down Expand Up @@ -1922,7 +1918,6 @@
"TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_True", # test_autograd
"TestAutograd.test_gradcheck_forward_ad", # test_autograd
"TestAutograd.test_access_saved_tensor_twice_without_recomputation_works", # test_autograd
"TestMultithreadAutograd.test_fork_join_in_middle", # test_autograd
"TestAutograd.test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # test_autograd
"TestAutograd.test_accumulate_grad_tensor_reference", # test_autograd
"TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op", # test_autograd
Expand Down Expand Up @@ -2368,43 +2363,17 @@
"TestQuantizePT2EQATModels.test_qat_resnet18", # test_quantization.py
"TestQuantizePT2EQATModels.test_qat_mobilenet_v2", # test_quantization.py
"TestObserver.test_per_channel_observers", # test_quantization.py
"TestCustomOp.test_impl_cpu", # test_custom_ops
"TestCustomOp.test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # test_custom_ops
"TestCustomOp.test_define_with_tags_single", # test_custom_ops
"TestCustomOp.test_autogen_aten_ops_are_pt2_compliant", # test_custom_ops
"TestCustomOp.test_backward_output_differentiability_tensorlist", # test_custom_ops
"TestCustomOp.test_backward_output_differentiability_type", # test_custom_ops
"TestCustomOp.test_impl_meta", # test_custom_ops
"TestCustomOp.test_impl_invalid_devices", # test_custom_ops
"TestCustomOp.test_new_data_dependent_symint", # test_custom_ops
"TestCustomOp.test_define_with_tags_list", # test_custom_ops
"TestCustomOp.test_backward_tensorlist_input_requires_list_grads", # test_custom_ops
"TestCustomOp.test_not_implemented_error", # test_custom_ops
"TestCustomOp.test_impl_device_cpu", # test_custom_ops
"TestCustomOp.test_backward_returns_dict", # test_custom_ops
"TestCustomOp.test_autograd_notimplemented", # test_custom_ops
"TestCustomOp.test_backward_grads_are_tensor_or_none", # test_custom_ops
"TestCustomOp.test_backward_dict_requires_keys_for_input_optional_tensors", # test_custom_ops
"TestCustomOp.test_backward_output_differentiability_non_tensor", # test_custom_ops
"TestCustomOp.test_lifetime", # test_custom_ops
"TestCustomOp.test_impl_device_function", # test_custom_ops
"TestCustomOp.test_builtin_torchscript_ops", # test_custom_ops
"TestCustomOpTestingCPU.test_missing_functionalization_cpu", # test_custom_ops
"TestCustomOp.test_define_with_tags_tuple", # test_custom_ops
"TestCustomOp.test_builtin_aten_ops_are_pt2_compliant", # test_custom_ops
"TestCustomOp.test_save_for_backward_inputs_are_namedtuple", # test_custom_ops
"TestCustomOp.test_autograd_notimplemented_gradmode", # test_custom_ops
"TestGenerateOpcheckTests.test_opcheck_bad_op", # test_custom_ops
"TestCustomOp.test_backward_dict_invalid_keys", # test_custom_ops
"TestCustomOp.test_backward_tensorlist_input_requires_list_grads_with_same_numel", # test_custom_ops
"TestCustomOp.test_duplicate_impl", # test_custom_ops
"TestCustomOp.test_backward_output_differentiability_numel", # test_custom_ops
"TestCustomOp.test_backward_dict_requires_keys_for_input_tensors", # test_custom_ops
"TestCustomOp.test_legacy_define", # test_custom_ops
"TestCustomOpTestingCPU.test_opcheck_fails_basic_cpu", # test_custom_ops
"TestCustomOp.test_backward_dict_grad_for_nontensor", # test_custom_ops
"TestCustomOp.test_backward_partially_registered", # test_custom_ops
"TestCustomOp.test_basic_make_fx", # test_custom_ops
"TestPythonRegistration.test_alias_analysis", # test_python_dispatch
"TestPythonDispatch.test_torch_dispatch_mode_subclass_priority", # test_python_dispatch
"TestPythonDispatch.test_strides_slow_path", # test_python_dispatch
Expand Down Expand Up @@ -7325,6 +7294,7 @@
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape2_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape1_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape1_cpu", # test_transformers.py
"TestCustomOpTestingCPU.test_opcheck_fails_basic_cpu", # test_custom_ops.py
}


Expand Down

0 comments on commit 318e6ff

Please sign in to comment.