Skip to content

Commit

Permalink
Add test that import torch doesn't modify global logging state (pyt…
Browse files Browse the repository at this point in the history
…orch#87629)

Fixes pytorch#87626

Also adds the same test for `import functorch`. Users have complained at
us when we do modify the global logging state, which has happened in the
past.

Test Plan:
- tested locally; I added `logging.basicConfig` to `torch/__init__.py`
and checked that the test got triggered
Pull Request resolved: pytorch#87629
Approved by: https://github.com/albanD
  • Loading branch information
zou3519 authored and pytorchmergebot committed Oct 26, 2022
1 parent 422f946 commit 642b63e
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions test/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,27 @@ def test_no_warning_on_import(self) -> None:
cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
self.assertEquals(out, "")

@unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
@parametrize('path', ['torch', 'functorch'])
def test_no_mutate_global_logging_on_import(self, path) -> None:
# Calling logging.basicConfig, among other things, modifies the global
# logging state. It is not OK to modify the global logging state on
# `import torch` (or other submodules we own) because users do not expect it.
expected = 'abcdefghijklmnopqrstuvwxyz'
commands = [
'import logging',
f'import {path}',
'_logger = logging.getLogger("torch_test_testing")',
'logging.root.addHandler(logging.StreamHandler())',
'logging.root.setLevel(logging.INFO)',
f'_logger.info("{expected}")'
]
out = subprocess.check_output(
[sys.executable, "-W", "all", "-c", "; ".join(commands)],
stderr=subprocess.STDOUT,
).decode("utf-8")
self.assertEqual(out.strip(), expected)

class TestOpInfos(TestCase):
def test_sample_input(self) -> None:
a, b, c, d, e = [object() for _ in range(5)]
Expand Down Expand Up @@ -1913,6 +1934,7 @@ def test_opinfo_error_generators(self, device, op):


instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
instantiate_parametrized_tests(TestImports)


if __name__ == '__main__':
Expand Down

0 comments on commit 642b63e

Please sign in to comment.