From 2794449c414f02273bf2a3116d0c26b51ce123d3 Mon Sep 17 00:00:00 2001 From: Tony Meyer Date: Wed, 4 Oct 2023 09:20:01 +1300 Subject: [PATCH] test: add type hints to test_framework (#1025) * Change the type of `Handle`'s `key` (in `__init__` and `nest()` and the property) to explicitly allow `None * Add a CHANGES.md file to document changes (backfilled only to 2.7) * Tidy up the type hinting of `BoundStoredState.__getattr__` - multiple overloads need to be provided, and once that's done the ignores can be removed (and a bunch of type issues in the tests get resolved) * Ignore type where creating an object that's only partially complete, but the test only requires that much and fully creating it would be a lot of extra code for no real gain * Where's it's simple to do, pass the required attributes rather than `None` * Ignore types where we're checking handling of invalid types * Add casts around the very dynamic snapshot/restore functionality * For a couple of the "list of test cases" situations, move the documentation of what each element in the test case is to a new type definition that covers both what was there before and the (rough) expected type * Sprinkle type hints where required for pyright to be happy Partially addresses #1007 --- CHANGES.md | 9 ++ ops/framework.py | 14 +- pyproject.toml | 1 + test/test_framework.py | 346 ++++++++++++++++++++++++----------------- 4 files changed, 220 insertions(+), 150 deletions(-) create mode 100644 CHANGES.md diff --git a/CHANGES.md b/CHANGES.md new file mode 100644 index 000000000..cc12ba43c --- /dev/null +++ b/CHANGES.md @@ -0,0 +1,9 @@ +# 2.8.0 + +* The type of a `Handle`'s `key` was expanded from `str` to `str|None` + +# 2.7.0 + +* Added Unit.set_ports() +* Type checks now allow comparing a `JujuVersion` to a `str` +* Renamed `OpenPort` to `Port` (`OpenPort` remains as an alias) diff --git a/ops/framework.py b/ops/framework.py index 757911dcf..e77bc7e2d 100755 --- a/ops/framework.py +++ b/ops/framework.py @@ -101,7 +101,7 @@ class Handle: under the same parent and kind may have the same key. """ - def __init__(self, parent: Optional[Union['Handle', 'Object']], kind: str, key: str): + def __init__(self, parent: Optional[Union['Handle', 'Object']], kind: str, key: Optional[str]): if isinstance(parent, Object): # if it's not an Object, it will be either a Handle (good) or None (no parent) parent = parent.handle @@ -119,7 +119,7 @@ def __init__(self, parent: Optional[Union['Handle', 'Object']], kind: str, key: else: self._path = f"{kind}" # don't need f-string, but consistent with above - def nest(self, kind: str, key: str) -> 'Handle': + def nest(self, kind: str, key: Optional[str]) -> 'Handle': """Create a new handle as child of the current one.""" return Handle(self, kind, key) @@ -143,7 +143,7 @@ def kind(self) -> str: return self._kind @property - def key(self) -> str: + def key(self) -> Optional[str]: """Return the handle's key.""" return self._key @@ -1061,13 +1061,17 @@ def __init__(self, parent: Object, attr_name: str): if TYPE_CHECKING: @typing.overload - def __getattr__(self, key: Literal['on']) -> ObjectEvents: # type: ignore + def __getattr__(self, key: Literal['on']) -> ObjectEvents: + pass + + @typing.overload + def __getattr__(self, key: str) -> Any: pass def __getattr__(self, key: str) -> Any: # "on" is the only reserved key that can't be used in the data map. if key == "on": - return self._data.on # type: ignore # casting won't work for some reason + return self._data.on if key not in self._data: raise AttributeError(f"attribute '{key}' is not stored") return _wrap_stored(self._data, self._data[key]) diff --git a/pyproject.toml b/pyproject.toml index 43d596f95..0d2fd4bd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ include = ["ops/*.py", "ops/_private/*.py", "test/test_testing.py", "test/test_storage.py", "test/test_charm.py", + "test/test_framework.py", ] pythonVersion = "3.8" # check no python > 3.8 features are used pythonPlatform = "All" diff --git a/test/test_framework.py b/test/test_framework.py index 26f2d89f9..704d0bb27 100755 --- a/test/test_framework.py +++ b/test/test_framework.py @@ -21,13 +21,14 @@ import shutil import sys import tempfile +import typing from pathlib import Path from test.test_helpers import BaseTestCase, fake_script from unittest.mock import patch import ops from ops.framework import _BREAKPOINT_WELCOME_MESSAGE, _event_regex -from ops.storage import NoSnapshotError, SQLiteStorage +from ops.storage import JujuStorage, NoSnapshotError, SQLiteStorage class TestFramework(BaseTestCase): @@ -43,7 +44,7 @@ def setUp(self): def test_deprecated_init(self): # For 0.7, this still works, but it is deprecated. with self.assertLogs(level="WARNING") as cm: - framework = ops.Framework(':memory:', None, None, None) + framework = ops.Framework(':memory:', None, None, None) # type: ignore self.assertIn( "WARNING:ops.framework:deprecated: Framework now takes a Storage not a path", cm.output) @@ -63,13 +64,13 @@ def test_handle_path(self): def test_handle_attrs_readonly(self): handle = ops.Handle(None, 'kind', 'key') with self.assertRaises(AttributeError): - handle.parent = 'foo' + handle.parent = 'foo' # type: ignore with self.assertRaises(AttributeError): - handle.kind = 'foo' + handle.kind = 'foo' # type: ignore with self.assertRaises(AttributeError): - handle.key = 'foo' + handle.key = 'foo' # type: ignore with self.assertRaises(AttributeError): - handle.path = 'foo' + handle.path = 'foo' # type: ignore def test_restore_unknown(self): framework = self.create_framework() @@ -79,7 +80,7 @@ class Foo(ops.Object): handle = ops.Handle(None, "a_foo", "some_key") - framework.register_type(Foo, None, handle.kind) + framework.register_type(Foo, None, handle.kind) # type: ignore try: framework.load_snapshot(handle) @@ -91,14 +92,16 @@ class Foo(ops.Object): def test_snapshot_roundtrip(self): class Foo: - def __init__(self, handle, n): + handle_kind = 'foo' + + def __init__(self, handle: ops.Handle, n: int): self.handle = handle self.my_n = n - def snapshot(self): + def snapshot(self) -> typing.Dict[str, int]: return {"My N!": self.my_n} - def restore(self, snapshot): + def restore(self, snapshot: typing.Dict[str, int]): self.my_n = snapshot["My N!"] + 1 handle = ops.Handle(None, "a_foo", "some_key") @@ -106,19 +109,21 @@ def restore(self, snapshot): framework1 = self.create_framework(tmpdir=self.tmpdir) framework1.register_type(Foo, None, handle.kind) - framework1.save_snapshot(event) + framework1.save_snapshot(event) # type: ignore framework1.commit() framework1.close() framework2 = self.create_framework(tmpdir=self.tmpdir) framework2.register_type(Foo, None, handle.kind) event2 = framework2.load_snapshot(handle) + event2 = typing.cast(Foo, event2) self.assertEqual(event2.my_n, 2) - framework2.save_snapshot(event2) + framework2.save_snapshot(event2) # type: ignore del event2 gc.collect() event3 = framework2.load_snapshot(handle) + event3 = typing.cast(Foo, event3) self.assertEqual(event3.my_n, 3) framework2.drop_snapshot(event.handle) @@ -142,16 +147,16 @@ class MyNotifier(ops.Object): baz = ops.EventSource(MyEvent) class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] - self.reprs = [] + self.seen: typing.List[str] = [] + self.reprs: typing.List[str] = [] - def on_any(self, event): + def on_any(self, event: ops.EventBase): self.seen.append(f"on_any:{event.handle.kind}") self.reprs.append(repr(event)) - def on_foo(self, event): + def on_foo(self, event: ops.EventBase): self.seen.append(f"on_foo:{event.handle.kind}") self.reprs.append(repr(event)) @@ -162,7 +167,7 @@ def on_foo(self, event): framework.observe(pub.bar, obs.on_any) with self.assertRaisesRegex(RuntimeError, "^Framework.observe requires a method"): - framework.observe(pub.baz, obs) + framework.observe(pub.baz, obs) # type: ignore pub.foo.emit() pub.bar.emit() @@ -188,13 +193,17 @@ class MyObserver(ops.Object): def _on_foo(self): assert False, 'should not be reached' - def _on_bar(self, event, extra): + def _on_bar(self, event: ops.EventBase, extra: typing.Any): assert False, 'should not be reached' - def _on_baz(self, event, extra=None, *, k): + def _on_baz(self, + event: ops.EventBase, + extra: typing.Optional[typing.Any] = None, + *, + k: typing.Any): assert False, 'should not be reached' - def _on_qux(self, event, extra=None): + def _on_qux(self, event: ops.EventBase, extra: typing.Optional[typing.Any] = None): assert False, 'should not be reached' framework = self.create_framework() @@ -202,11 +211,11 @@ def _on_qux(self, event, extra=None): obs = MyObserver(framework, "obs") with self.assertRaisesRegex(TypeError, "must accept event parameter"): - framework.observe(pub.foo, obs._on_foo) + framework.observe(pub.foo, obs._on_foo) # type: ignore with self.assertRaisesRegex(TypeError, "has extra required parameter"): - framework.observe(pub.bar, obs._on_bar) + framework.observe(pub.bar, obs._on_bar) # type: ignore with self.assertRaisesRegex(TypeError, "has extra required parameter"): - framework.observe(pub.baz, obs._on_baz) + framework.observe(pub.baz, obs._on_baz) # type: ignore framework.observe(pub.qux, obs._on_qux) def test_on_pre_commit_emitted(self): @@ -216,17 +225,17 @@ class PreCommitObserver(ops.Object): _stored = ops.StoredState() - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: typing.Optional[str]): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[typing.Any] = [] self._stored.myinitdata = 40 - def on_pre_commit(self, event): + def on_pre_commit(self, event: ops.PreCommitEvent): self._stored.myinitdata = 41 self._stored.mydata = 42 self.seen.append(type(event)) - def on_commit(self, event): + def on_commit(self, event: ops.CommitEvent): # Modifications made here will not be persisted. self._stored.myinitdata = 42 self._stored.mydata = 43 @@ -239,8 +248,8 @@ def on_commit(self, event): framework.commit() - self.assertEqual(obs._stored.myinitdata, 41) - self.assertEqual(obs._stored.mydata, 42) + self.assertEqual(obs._stored.myinitdata, 41) # type: ignore + self.assertEqual(obs._stored.mydata, 42) # type: ignore self.assertTrue(obs.seen, [ops.PreCommitEvent, ops.CommitEvent]) framework.close() @@ -248,11 +257,11 @@ def on_commit(self, event): new_obs = PreCommitObserver(other_framework, None) - self.assertEqual(obs._stored.myinitdata, 41) - self.assertEqual(new_obs._stored.mydata, 42) + self.assertEqual(obs._stored.myinitdata, 41) # type: ignore + self.assertEqual(new_obs._stored.mydata, 42) # type: ignore with self.assertRaises(AttributeError): - new_obs._stored.myotherdata + new_obs._stored.myotherdata # type: ignore def test_defer_and_reemit(self): framework = self.create_framework() @@ -268,12 +277,12 @@ class MyNotifier2(ops.Object): c = ops.EventSource(MyEvent) class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] - self.done = {} + self.seen: typing.List[str] = [] + self.done: typing.Dict[str, bool] = {} - def on_any(self, event): + def on_any(self, event: ops.EventBase): self.seen.append(event.handle.kind) if not self.done.get(event.handle.kind): event.defer() @@ -328,14 +337,14 @@ def test_custom_event_data(self): framework = self.create_framework() class MyEvent(ops.EventBase): - def __init__(self, handle, n): + def __init__(self, handle: ops.Handle, n: int): super().__init__(handle) self.my_n = n def snapshot(self): return {"My N!": self.my_n} - def restore(self, snapshot): + def restore(self, snapshot: typing.Dict[str, typing.Any]): super().restore(snapshot) self.my_n = snapshot["My N!"] + 1 @@ -343,11 +352,11 @@ class MyNotifier(ops.Object): foo = ops.EventSource(MyEvent) class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[str] = [] - def _on_foo(self, event): + def _on_foo(self, event: MyEvent): self.seen.append(f"on_foo:{event.handle.kind}={event.my_n}") event.defer() @@ -373,7 +382,7 @@ def _on_foo(self, event): def test_weak_observer(self): framework = self.create_framework() - observed_events = [] + observed_events: typing.List[str] = [] class MyEvent(ops.EventBase): pass @@ -382,10 +391,10 @@ class MyEvents(ops.ObjectEvents): foo = ops.EventSource(MyEvent) class MyNotifier(ops.Object): - on = MyEvents() + on = MyEvents() # type: ignore class MyObserver(ops.Object): - def _on_foo(self, event): + def _on_foo(self, event: ops.EventBase): observed_events.append("foo") pub = MyNotifier(framework, "1") @@ -405,7 +414,11 @@ def test_forget_and_multiple_objects(self): framework = self.create_framework() class MyObject(ops.Object): - pass + def snapshot(self) -> typing.Dict[str, typing.Any]: + raise NotImplementedError() + + def restore(self, snapshot: typing.Dict[str, typing.Any]) -> None: + raise NotImplementedError() o1 = MyObject(framework, "path") # Creating a second object at the same path should fail with RuntimeError @@ -430,30 +443,32 @@ def test_forget_and_multiple_objects_with_load_snapshot(self): framework = self.create_framework(tmpdir=self.tmpdir) class MyObject(ops.Object): - def __init__(self, parent, name): + def __init__(self, parent: ops.Object, name: str): super().__init__(parent, name) self.value = name def snapshot(self): - return self.value + return {"value": self.value} - def restore(self, value): - self.value = value + def restore(self, snapshot: typing.Dict[str, typing.Any]): + self.value = snapshot["value"] framework.register_type(MyObject, None, MyObject.handle_kind) o1 = MyObject(framework, "path") - framework.save_snapshot(o1) + framework.save_snapshot(o1) # type: ignore framework.commit() o_handle = o1.handle del o1 gc.collect() o2 = framework.load_snapshot(o_handle) + o2 = typing.cast(MyObject, o2) # Trying to load_snapshot a second object at the same path should fail with RuntimeError with self.assertRaises(RuntimeError): framework.load_snapshot(o_handle) # Unless we _forget the object first framework._forget(o2) o3 = framework.load_snapshot(o_handle) + o3 = typing.cast(MyObject, o3) self.assertEqual(o2.value, o3.value) # A loaded object also prevents direct creation of an object with self.assertRaises(RuntimeError): @@ -467,6 +482,7 @@ def restore(self, value): framework_copy2 = self.create_framework(tmpdir=self.tmpdir) framework_copy2.register_type(MyObject, None, MyObject.handle_kind) o_copy2 = framework_copy2.load_snapshot(o_handle) + o_copy2 = typing.cast(MyObject, o_copy2) self.assertEqual(o_copy2.value, "path") def test_events_base(self): @@ -480,18 +496,18 @@ class MyEvents(ops.ObjectEvents): bar = ops.EventSource(MyEvent) class MyNotifier(ops.Object): - on = MyEvents() + on = MyEvents() # type: ignore class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[str] = [] - def _on_foo(self, event): + def _on_foo(self, event: ops.EventBase): self.seen.append(f"on_foo:{event.handle.kind}") event.defer() - def _on_bar(self, event): + def _on_bar(self, event: ops.EventBase): self.seen.append(f"on_bar:{event.handle.kind}") pub = MyNotifier(framework, "1") @@ -519,15 +535,15 @@ class MyEvents(ops.ObjectEvents): foo = event with self.assertRaises(RuntimeError) as cm: - class OtherEvents(ops.ObjectEvents): + class OtherEvents(ops.ObjectEvents): # type: ignore foo = event self.assertEqual( str(cm.exception.__cause__), "EventSource(MyEvent) reused as MyEvents.foo and OtherEvents.foo") with self.assertRaises(RuntimeError) as cm: - class MyNotifier(ops.Object): - on = MyEvents() + class MyNotifier(ops.Object): # type: ignore + on = MyEvents() # type: ignore bar = event self.assertEqual( str(cm.exception.__cause__), @@ -540,17 +556,17 @@ def test_reemit_ignores_unknown_event_type(self): framework = self.create_framework() class MyEvent(ops.EventBase): - pass + handle_kind = "test" class MyNotifier(ops.Object): foo = ops.EventSource(MyEvent) class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[typing.Any] = [] - def _on_foo(self, event): + def _on_foo(self, event: ops.EventBase): self.seen.append(event.handle) event.defer() @@ -588,19 +604,19 @@ class MyEvents(ops.ObjectEvents): foo = ops.EventSource(MyFoo) class MyNotifier(ops.Object): - on = MyEvents() + on = MyEvents() # type: ignore bar = ops.EventSource(MyBar) class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[str] = [] - def _on_foo(self, event): + def _on_foo(self, event: ops.EventBase): self.seen.append(f"on_foo:{type(event).__name__}:{event.handle.kind}") event.defer() - def _on_bar(self, event): + def _on_bar(self, event: ops.EventBase): self.seen.append(f"on_bar:{type(event).__name__}:{event.handle.kind}") event.defer() @@ -632,15 +648,15 @@ class MyNotifier(ops.Object): on_b = MyEventsB() class MyObserver(ops.Object): - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[str] = [] - def _on_foo(self, event): + def _on_foo(self, event: ops.EventBase): self.seen.append(f"on_foo:{type(event).__name__}:{event.handle.kind}") event.defer() - def _on_bar(self, event): + def _on_bar(self, event: ops.EventBase): self.seen.append(f"on_bar:{type(event).__name__}:{event.handle.kind}") event.defer() @@ -688,14 +704,14 @@ class NoneEvent(ops.EventBase): def test_event_key_roundtrip(self): class MyEvent(ops.EventBase): - def __init__(self, handle, value): + def __init__(self, handle: ops.Handle, value: typing.Any): super().__init__(handle) self.value = value def snapshot(self): return self.value - def restore(self, value): + def restore(self, value: typing.Any): self.value = value class MyNotifier(ops.Object): @@ -704,11 +720,11 @@ class MyNotifier(ops.Object): class MyObserver(ops.Object): has_deferred = False - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) - self.seen = [] + self.seen: typing.List[typing.Any] = [] - def _on_foo(self, event): + def _on_foo(self, event: MyEvent): self.seen.append((event.handle.key, event.value)) # Only defer the first event and once. if not MyObserver.has_deferred: @@ -741,8 +757,8 @@ def _on_foo(self, event): def test_helper_properties(self): framework = self.create_framework() - framework.model = 'test-model' - framework.meta = 'test-meta' + framework.model = 'test-model' # type: ignore + framework.meta = 'test-meta' # type: ignore my_obj = ops.Object(framework, 'my_obj') self.assertEqual(my_obj.model, framework.model) @@ -759,6 +775,8 @@ def test_snapshot_saving_restricted_to_simple_types(self): to_be_saved = {"bar": TestFramework} class FooEvent(ops.EventBase): + handle_kind = "test" + def snapshot(self): return to_be_saved @@ -783,7 +801,7 @@ class Events(ops.ObjectEvents): foo = ops.EventSource(FooEvent) class Emitter(ops.Object): - on = Events() + on = Events() # type: ignore framework = self.create_framework() e = Emitter(framework, 'key') @@ -823,14 +841,14 @@ class Events(ops.ObjectEvents): class ObjectWithStorage(ops.Object): _stored = ops.StoredState() - on = Events() + on = Events() # type: ignore - def __init__(self, framework, key): + def __init__(self, framework: ops.Framework, key: str): super().__init__(framework, key) self._stored.set_default(foo=2) self.framework.observe(self.on.event, self._on_event) - def _on_event(self, event): + def _on_event(self, event: ops.EventBase): event.defer() # This is an event that 'happened in the past' that doesn't have an associated notice. @@ -866,24 +884,22 @@ def setUp(self): self.addCleanup(shutil.rmtree, str(self.tmpdir)) def test_stored_dict_repr(self): - self.assertEqual(repr(ops.StoredDict(None, {})), "ops.framework.StoredDict()") - self.assertEqual( - repr( - ops.StoredDict( - None, { - "a": 1})), "ops.framework.StoredDict({'a': 1})") + self.assertEqual(repr(ops.StoredDict(None, {})), # type: ignore + "ops.framework.StoredDict()") + self.assertEqual(repr(ops.StoredDict(None, {"a": 1})), # type: ignore + "ops.framework.StoredDict({'a': 1})") def test_stored_list_repr(self): - self.assertEqual(repr(ops.StoredList(None, [])), "ops.framework.StoredList()") - self.assertEqual( - repr( - ops.StoredList( - None, [ - 1, 2, 3])), 'ops.framework.StoredList([1, 2, 3])') + self.assertEqual(repr(ops.StoredList(None, [])), # type: ignore + "ops.framework.StoredList()") + self.assertEqual(repr(ops.StoredList(None, [1, 2, 3])), # type: ignore + 'ops.framework.StoredList([1, 2, 3])') def test_stored_set_repr(self): - self.assertEqual(repr(ops.StoredSet(None, set())), 'ops.framework.StoredSet()') - self.assertEqual(repr(ops.StoredSet(None, {1})), 'ops.framework.StoredSet({1})') + self.assertEqual(repr(ops.StoredSet(None, set())), # type: ignore + 'ops.framework.StoredSet()') + self.assertEqual(repr(ops.StoredSet(None, {1})), # type: ignore + 'ops.framework.StoredSet({1})') def test_basic_state_storage(self): class SomeObject(ops.Object): @@ -907,7 +923,7 @@ class SomeObject(ops.Object): class Sub(SomeObject): pass - class SubSub(SomeObject): + class SubSub(Sub): pass self._stored_state_tests(SubSub) @@ -940,19 +956,24 @@ class FinalChild(StatedObject, Sibling): self._stored_state_tests(FinalChild) - def _stored_state_tests(self, cls): + def _stored_state_tests(self, cls: typing.Type[ops.Object]): + @typing.runtime_checkable + class _StoredProtocol(typing.Protocol): + _stored: ops.StoredState + framework = self.create_framework(tmpdir=self.tmpdir) obj = cls(framework, "1") + assert isinstance(obj, _StoredProtocol) try: - obj._stored.foo + obj._stored.foo # type: ignore except AttributeError as e: self.assertEqual(str(e), "attribute 'foo' is not stored") else: self.fail("AttributeError not raised") try: - obj._stored.on = "nonono" + obj._stored.on = "nonono" # type: ignore except AttributeError as e: self.assertEqual(str(e), "attribute 'on' is reserved and cannot be set") else: @@ -976,6 +997,7 @@ def _stored_state_tests(self, cls): # Since this has the same absolute object handle, it will get its state back. framework_copy = self.create_framework(tmpdir=self.tmpdir) obj_copy = cls(framework_copy, "1") + assert isinstance(obj_copy, _StoredProtocol) self.assertEqual(obj_copy._stored.foo, 42) self.assertEqual(obj_copy._stored.bar, "s") self.assertEqual(obj_copy._stored.baz, 4.2) @@ -1081,15 +1103,22 @@ class CustomObject: framework.commit() def test_mutable_types(self): - # Test and validation functions in a list of 2-tuples. + # Test and validation functions in a list of tuples. # Assignment and keywords like del are not supported in lambdas # so functions are used instead. - test_operations = [( - lambda: {}, # Operand A. - None, # Operand B. - {}, # Expected result. - lambda a, b: None, # Operation to perform. - lambda res, expected_res: self.assertEqual(res, expected_res) # Validation to perform. + test_case = typing.Tuple[ + typing.Callable[[], typing.Any], # Called to get operand A. + typing.Any, # Operand B. + typing.Any, # Expected result. + typing.Callable[[typing.Any, typing.Any], None], # Operation to perform. + typing.Callable[[typing.Any, typing.Any], typing.Any], # Validation to perform. + ] + test_operations: typing.List[test_case] = [( + lambda: {}, + None, + {}, + lambda a, b: None, + lambda res, expected_res: self.assertEqual(res, expected_res) ), ( lambda: {}, {'a': {}}, @@ -1115,7 +1144,7 @@ def test_mutable_types(self): lambda a, b: a['a'].pop(b), lambda res, expected_res: self.assertEqual(res, expected_res) ), ( - lambda: {'s': set()}, + lambda: {'s': set()}, # type: ignore 'a', {'s': {'a'}}, lambda a, b: a['s'].add(b), @@ -1239,11 +1268,16 @@ class SomeObject(ops.Object): _stored = ops.StoredState() class WrappedFramework(ops.Framework): - def __init__(self, store, charm_dir, meta, model, event_name): + def __init__(self, + store: typing.Union[SQLiteStorage, JujuStorage], + charm_dir: typing.Union[str, Path], + meta: ops.CharmMeta, + model: ops.Model, + event_name: str): super().__init__(store, charm_dir, meta, model, event_name) - self.snapshots = [] + self.snapshots: typing.List[typing.Any] = [] - def save_snapshot(self, value): + def save_snapshot(self, value: typing.Union[ops.StoredStateData, ops.EventBase]): if value.handle.path == 'SomeObject[1]/StoredStateData[_stored]': self.snapshots.append((type(value), value.snapshot())) return super().save_snapshot(value) @@ -1251,7 +1285,7 @@ def save_snapshot(self, value): # Validate correctness of modification operations. for get_a, b, expected_res, op, validate_op in test_operations: storage = SQLiteStorage(self.tmpdir / "framework.data") - framework = WrappedFramework(storage, self.tmpdir, None, None, "foo") + framework = WrappedFramework(storage, self.tmpdir, None, None, "foo") # type: ignore obj = SomeObject(framework, '1') obj._stored.a = get_a() @@ -1277,7 +1311,8 @@ def save_snapshot(self, value): framework.close() storage_copy = SQLiteStorage(self.tmpdir / "framework.data") - framework_copy = WrappedFramework(storage_copy, self.tmpdir, None, None, "foo") + framework_copy = WrappedFramework( + storage_copy, self.tmpdir, None, None, "foo") # type: ignore obj_copy2 = SomeObject(framework_copy, '1') @@ -1291,12 +1326,19 @@ def save_snapshot(self, value): framework_copy.close() def test_comparison_operations(self): - test_operations = [( - {"1"}, # Operand A. - {"1", "2"}, # Operand B. - lambda a, b: a < b, # Operation to test. - True, # Result of op(A, B). - False, # Result of op(B, A). + test_case = typing.Tuple[ + typing.Any, # Operand A. + typing.Any, # Operand B. + typing.Callable[[typing.Any, typing.Any], bool], # Operation to test. + bool, # Result of op(A, B). + bool, # Result of op(B, A). + ] + test_operations: typing.List[test_case] = [( + {"1"}, + {"1", "2"}, + lambda a, b: a < b, + True, + False, ), ( {"1"}, {"1", "2"}, @@ -1384,11 +1426,18 @@ class SomeObject(ops.Object): self.assertEqual(op(b, obj._stored.a), op_ba) def test_set_operations(self): - test_operations = [( - {"1"}, # A set to test an operation against (other_set). - lambda a, b: a | b, # An operation to test. - {"1", "a", "b"}, # The expected result of operation(obj._stored.set, other_set). - {"1", "a", "b"} # The expected result of operation(other_set, obj._stored.set). + test_case = typing.Tuple[ + typing.Set[str], # A set to test an operation against (other_set). + # An operation to test. + typing.Callable[[typing.Set[str], typing.Set[str]], typing.Set[str]], + typing.Set[str], # The expected result of operation(obj._stored.set, other_set). + typing.Set[str], # The expected result of operation(other_set, obj._stored.set). + ] + test_operations: typing.List[test_case] = [( + {"1"}, + lambda a, b: a | b, + {"1", "a", "b"}, + {"1", "a", "b"} ), ( {"a", "c"}, lambda a, b: a - b, @@ -1423,6 +1472,7 @@ class SomeObject(ops.Object): for i, (variable_operand, operation, ab_res, ba_res) in enumerate(test_operations): obj = SomeObject(framework, str(i)) obj._stored.set = {"a", "b"} + assert isinstance(obj._stored.set, ops.StoredSet) for a, b, expected in [ (obj._stored.set, variable_operand, ab_res), @@ -1468,11 +1518,11 @@ class StatefulObject(ops.Object): class GenericObserver(ops.Object): """Generic observer for the tests.""" - def __init__(self, parent, key): + def __init__(self, parent: ops.Object, key: str): super().__init__(parent, key) self.called = False - def callback_method(self, event): + def callback_method(self, event: ops.EventBase): """Set the instance .called to True.""" self.called = True @@ -1480,7 +1530,7 @@ def callback_method(self, event): @patch('sys.stderr', new_callable=io.StringIO) class BreakpointTests(BaseTestCase): - def test_ignored(self, fake_stderr): + def test_ignored(self, fake_stderr: io.StringIO): # It doesn't do anything really unless proper environment is there. with patch.dict(os.environ): os.environ.pop('JUJU_DEBUG_AT', None) @@ -1499,7 +1549,7 @@ def test_ignored(self, fake_stderr): self.assertEqual(mock.call_count, 0) self.assertEqual(fake_stderr.getvalue(), "") - def test_pdb_properly_called(self, fake_stderr): + def test_pdb_properly_called(self, fake_stderr: io.StringIO): # The debugger needs to leave the user in the frame where the breakpoint is executed, # which for the test is the frame we're calling it here in the test :). with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}): @@ -1512,7 +1562,7 @@ def test_pdb_properly_called(self, fake_stderr): self.assertEqual(mock.call_count, 1) self.assertEqual(mock.call_args, ((this_frame,), {})) - def test_welcome_message(self, fake_stderr): + def test_welcome_message(self, fake_stderr: io.StringIO): # Check that an initial message is shown to the user when code is interrupted. with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}): framework = self.create_framework() @@ -1520,7 +1570,7 @@ def test_welcome_message(self, fake_stderr): framework.breakpoint() self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE) - def test_welcome_message_not_multiple(self, fake_stderr): + def test_welcome_message_not_multiple(self, fake_stderr: io.StringIO): # Check that an initial message is NOT shown twice if the breakpoint is exercised # twice in the same run. with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}): @@ -1531,7 +1581,7 @@ def test_welcome_message_not_multiple(self, fake_stderr): framework.breakpoint() self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE) - def test_breakpoint_builtin_sanity(self, fake_stderr): + def test_breakpoint_builtin_sanity(self, fake_stderr: io.StringIO): # this just checks that calling breakpoint() works as expected # nothing really framework-dependent with patch.dict(os.environ): @@ -1545,7 +1595,7 @@ def test_breakpoint_builtin_sanity(self, fake_stderr): self.assertEqual(mock.call_count, 1) self.assertEqual(mock.call_args, ((this_frame,), {})) - def test_builtin_breakpoint_hooked(self, fake_stderr): + def test_builtin_breakpoint_hooked(self, fake_stderr: io.StringIO): # Verify that the proper hook is set. with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}): framework = self.create_framework() @@ -1555,7 +1605,7 @@ def test_builtin_breakpoint_hooked(self, fake_stderr): breakpoint() self.assertEqual(mock.call_count, 1) - def test_breakpoint_builtin_unset(self, fake_stderr): + def test_breakpoint_builtin_unset(self, fake_stderr: io.StringIO): # if no JUJU_DEBUG_AT, no call to pdb is done with patch.dict(os.environ): os.environ.pop('JUJU_DEBUG_AT', None) @@ -1568,7 +1618,7 @@ def test_breakpoint_builtin_unset(self, fake_stderr): self.assertEqual(mock.call_count, 0) - def test_breakpoint_names(self, fake_stderr): + def test_breakpoint_names(self, fake_stderr: io.StringIO): framework = self.create_framework() # Name rules: @@ -1629,10 +1679,14 @@ def test_breakpoint_names(self, fake_stderr): for name in not_really_names: with self.subTest(name=name): with self.assertRaises(TypeError) as cm: - framework.breakpoint(name) + framework.breakpoint(name) # type: ignore self.assertEqual(str(cm.exception), 'breakpoint names must be strings') - def check_trace_set(self, envvar_value, breakpoint_name, call_count): + def check_trace_set( + self, + envvar_value: str, + breakpoint_name: typing.Optional[str], + call_count: int): """Helper to check the diverse combinations of situations.""" with patch.dict(os.environ, {'JUJU_DEBUG_AT': envvar_value}): framework = self.create_framework() @@ -1640,19 +1694,19 @@ def check_trace_set(self, envvar_value, breakpoint_name, call_count): framework.breakpoint(breakpoint_name) self.assertEqual(mock.call_count, call_count) - def test_unnamed_indicated_all(self, fake_stderr): + def test_unnamed_indicated_all(self, fake_stderr: io.StringIO): # If 'all' is indicated, unnamed breakpoints will always activate. self.check_trace_set('all', None, 1) - def test_unnamed_indicated_hook(self, fake_stderr): + def test_unnamed_indicated_hook(self, fake_stderr: io.StringIO): # Special value 'hook' was indicated, nothing to do with any call. self.check_trace_set('hook', None, 0) - def test_named_indicated_specifically(self, fake_stderr): + def test_named_indicated_specifically(self, fake_stderr: io.StringIO): # Some breakpoint was indicated, and the framework call used exactly that name. self.check_trace_set('mybreak', 'mybreak', 1) - def test_named_indicated_unnamed(self, fake_stderr): + def test_named_indicated_unnamed(self, fake_stderr: io.StringIO): # Some breakpoint was indicated, but the framework call was unnamed with self.assertLogs(level="WARNING") as cm: self.check_trace_set('some-breakpoint', None, 0) @@ -1661,7 +1715,7 @@ def test_named_indicated_unnamed(self, fake_stderr): "(not found in the requested breakpoints: {'some-breakpoint'})" ]) - def test_named_indicated_somethingelse(self, fake_stderr): + def test_named_indicated_somethingelse(self, fake_stderr: io.StringIO): # Some breakpoint was indicated, but the framework call was with a different name with self.assertLogs(level="WARNING") as cm: self.check_trace_set('some-breakpoint', 'other-name', 0) @@ -1669,15 +1723,15 @@ def test_named_indicated_somethingelse(self, fake_stderr): "WARNING:ops.framework:Breakpoint 'other-name' skipped " "(not found in the requested breakpoints: {'some-breakpoint'})"]) - def test_named_indicated_ingroup(self, fake_stderr): + def test_named_indicated_ingroup(self, fake_stderr: io.StringIO): # A multiple breakpoint was indicated, and the framework call used a name among those. self.check_trace_set('some,mybreak,foobar', 'mybreak', 1) - def test_named_indicated_all(self, fake_stderr): + def test_named_indicated_all(self, fake_stderr: io.StringIO): # The framework indicated 'all', which includes any named breakpoint set. self.check_trace_set('all', 'mybreak', 1) - def test_named_indicated_hook(self, fake_stderr): + def test_named_indicated_hook(self, fake_stderr: io.StringIO): # The framework indicated the special value 'hook', nothing to do with any named call. self.check_trace_set('hook', 'mybreak', 0) @@ -1714,6 +1768,7 @@ def test_basic_interruption_enabled(self): framework.observe(publisher.install, observer.callback_method) with patch('sys.stderr', new_callable=io.StringIO) as fake_stderr: + fake_stderr = typing.cast(io.StringIO, fake_stderr) with patch('pdb.runcall') as mock: publisher.install.emit() @@ -1855,6 +1910,7 @@ def test_welcome_message_not_multiple(self): framework.observe(publisher.install, observer.callback_method) with patch('sys.stderr', new_callable=io.StringIO) as fake_stderr: + fake_stderr = typing.cast(io.StringIO, fake_stderr) with patch('pdb.runcall') as mock: publisher.install.emit() self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)