Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(typing): update to latest version of Pyright and fix errors #1105

Merged
merged 4 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ops/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ class RelationCreatedEvent(RelationEvent):
can occur before units for those applications have started. All existing
relations should be established before start.
"""
unit: None
unit: None # pyright: ignore[reportIncompatibleVariableOverride]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cleanest fix here is to remove the type from RelationEvent and add it to RelationChangedEvent. The downside would be that I'm not sure if we can still then have the attribute have documentation in the base class, although it would be in all the subclasses.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I had a play with this, and I don't like the fact that then the docs for RelationEvent wouldn't have "unit" defined at all, which seems very weird. Or were you suggesting something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that was what I was suggesting, and I indeed wondered if that meant it was not possible to have it in the RelationEvent doc (unless there's some Sphinx trick for this I don't know). I agree that's not great.

In balance, I'm fine keeping it how you have it with the ignores. The goal was to have the types correct in the various subclasses so people didn't need a bunch of assert unit is not None and that all still works, so there doesn't seem to be much value in trying too hard to get rids of the ignores.

"""Always ``None``."""


Expand All @@ -481,7 +481,7 @@ class RelationJoinedEvent(RelationEvent):
remote ``private-address`` setting, which is always available when
the relation is created and is by convention not deleted.
"""
unit: model.Unit
unit: model.Unit # pyright: ignore[reportIncompatibleVariableOverride]
"""The remote unit that has triggered this event."""


Expand Down Expand Up @@ -523,7 +523,7 @@ class RelationDepartedEvent(RelationEvent):
Once all callback methods bound to this event have been run for such a
relation, the unit agent will fire the :class:`RelationBrokenEvent`.
"""
unit: model.Unit
unit: model.Unit # pyright: ignore[reportIncompatibleVariableOverride]
"""The remote unit that has triggered this event."""

def __init__(self, handle: 'Handle', relation: 'model.Relation',
Expand Down Expand Up @@ -580,7 +580,7 @@ class RelationBrokenEvent(RelationEvent):
bound to this event is being executed, it is guaranteed that no remote units
are currently known locally.
"""
unit: None
unit: None # pyright: ignore[reportIncompatibleVariableOverride]
"""Always ``None``."""


Expand Down
15 changes: 8 additions & 7 deletions ops/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,18 @@ def get_secret(self, *, id: Optional[str] = None, label: Optional[str] = None) -
return Secret(self._backend, id=id, label=label, content=content)


if typing.TYPE_CHECKING:
# (entity type, name): instance.
_WeakCacheType = weakref.WeakValueDictionary[
Tuple['UnitOrApplicationType', str],
Optional[Union['Unit', 'Application']]]


class _ModelCache:
def __init__(self, meta: 'ops.charm.CharmMeta', backend: '_ModelBackend'):
if typing.TYPE_CHECKING:
# (entity type, name): instance.
_weakcachetype = weakref.WeakValueDictionary[
Tuple['UnitOrApplicationType', str],
Optional[Union['Unit', 'Application']]]

self._meta = meta
self._backend = backend
self._weakrefs: _weakcachetype = weakref.WeakValueDictionary()
self._weakrefs: _WeakCacheType = weakref.WeakValueDictionary()

@typing.overload
def get(self, entity_type: Type['Unit'], name: str) -> 'Unit': ... # noqa
Expand Down
8 changes: 4 additions & 4 deletions ops/pebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,7 +1614,7 @@ def _websocket_to_writer(ws: '_WebSocket', writer: '_WebsocketWriter',
break

if encoding is not None:
chunk = chunk.decode(encoding)
chunk = typing.cast(bytes, chunk).decode(encoding)
writer.write(chunk)


Expand Down Expand Up @@ -2019,7 +2019,7 @@ def _wait_change_using_wait(self, change_id: ChangeID, timeout: Optional[float])

def _wait_change(self, change_id: ChangeID, timeout: Optional[float] = None) -> Change:
"""Call the wait-change API endpoint directly."""
query = {}
query: Dict[str, Any] = {}
if timeout is not None:
query['timeout'] = _format_timeout(timeout)

Expand Down Expand Up @@ -2255,7 +2255,7 @@ def _encode_multipart(self, metadata: Dict[str, Any], path: str,
elif isinstance(source, bytes):
source_io: _AnyStrFileLikeIO = io.BytesIO(source)
else:
source_io: _AnyStrFileLikeIO = source
source_io: _AnyStrFileLikeIO = source # type: ignore
boundary = binascii.hexlify(os.urandom(16))
path_escaped = path.replace('"', '\\"').encode('utf-8') # NOQA: test_quote_backslashes
content_type = f"multipart/form-data; boundary=\"{boundary.decode('utf-8')}\"" # NOQA: test_quote_backslashes
Expand Down Expand Up @@ -2736,7 +2736,7 @@ def get_checks(
Returns:
List of :class:`CheckInfo` objects.
"""
query = {}
query: Dict[str, Any] = {}
if level is not None:
query['level'] = level.value
if names:
Expand Down
4 changes: 2 additions & 2 deletions ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import subprocess
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Generator, List, Optional, Tuple, Union, cast

import yaml # pyright: ignore[reportMissingModuleSource]

Expand Down Expand Up @@ -205,7 +205,7 @@ def notices(self, event_path: Optional[str] = None) -> '_NoticeGenerator':
if not rows:
break
for row in rows:
yield tuple(row)
yield cast(_Notice, tuple(row))


class JujuStorage:
Expand Down
5 changes: 3 additions & 2 deletions ops/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import shutil
import signal
import tempfile
import typing
import uuid
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -3017,7 +3018,7 @@ def push(
file_path.write_bytes(source)
else:
# If source is binary, open file in binary mode and ignore encoding param
is_binary = isinstance(source.read(0), bytes)
is_binary = isinstance(source.read(0), bytes) # type: ignore
open_mode = 'wb' if is_binary else 'w'
open_encoding = None if is_binary else encoding
with file_path.open(open_mode, encoding=open_encoding) as f:
Expand Down Expand Up @@ -3141,7 +3142,7 @@ def _transform_exec_handler_output(self,
f"exec handler must return bytes if encoding is None,"
f"not {data.__class__.__name__}")
else:
return io.StringIO(data)
return io.StringIO(typing.cast(str, data))

def exec(
self,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ reportMissingModuleSource = false
reportPrivateUsage = false
reportUnnecessaryIsInstance = false
reportUnnecessaryComparison = false
disableBytesTypePromotions = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs (I had to look up what this did!) say the default is false. Why do we need to add it in here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's odd. When I comment out that line and leave it as the default, Pyright prints a bunch of this message:

$ tox -e static
static: commands[0]> pyright
/home/ben/w/operator/ops/pebble.py
  /home/ben/w/operator/ops/pebble.py:1617:21 - error: Unnecessary "cast" call; type is already "bytes" (reportUnnecessaryCast)
  /home/ben/w/operator/ops/pebble.py:2993:52 - error: Argument of type "bytearray" cannot be assigned to parameter "buf" of type "bytes" in function "_next_part_boundary"
    "bytearray" is incompatible with "bytes"
...

Oh wait, it looks like the default is actually true in "strict" mode: https://github.com/microsoft/pyright/blob/main/docs/configuration.md#diagnostic-rule-defaults

Makes me wonder if we're causing ourselves a bunch of extra pain by being in strict mode...

stubPath = ""
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ flake8-builtins~=2.1
pyproject-flake8~=6.1
pep8-naming~=0.13
pytest~=7.2
pyright==1.1.317
pyright==1.1.345
pytest-operator~=0.23
coverage[toml]~=7.0
typing_extensions~=4.2
Expand Down
3 changes: 1 addition & 2 deletions test/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def test_observe_decorated_method(self):
# way we know of to cleanly decorate charm event observers.
events: typing.List[ops.EventBase] = []

def dec(fn: typing.Callable[['MyCharm', ops.EventBase], None] # noqa: F821
) -> typing.Callable[..., None]:
def dec(fn: typing.Any) -> typing.Callable[..., None]:
# simple decorator that appends to the nonlocal
# `events` list all events it receives
@functools.wraps(fn)
Expand Down
52 changes: 28 additions & 24 deletions test/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,31 @@ def _on_event(self, event: ops.EventBase):
'ObjectWithStorage[obj]/on/event[1]']))


MutableTypesTestCase = typing.Tuple[
typing.Callable[[], typing.Any], # Called to get operand A.
typing.Any, # Operand B.
typing.Any, # Expected result.
typing.Callable[[typing.Any, typing.Any], None], # Operation to perform.
typing.Callable[[typing.Any, typing.Any], typing.Any], # Validation to perform.
]

ComparisonOperationsTestCase = typing.Tuple[
typing.Any, # Operand A.
typing.Any, # Operand B.
typing.Callable[[typing.Any, typing.Any], bool], # Operation to test.
bool, # Result of op(A, B).
bool, # Result of op(B, A).
]

SetOperationsTestCase = typing.Tuple[
typing.Set[str], # A set to test an operation against (other_set).
# An operation to test.
typing.Callable[[typing.Set[str], typing.Set[str]], typing.Set[str]],
typing.Set[str], # The expected result of operation(obj._stored.set, other_set).
typing.Set[str], # The expected result of operation(other_set, obj._stored.set).
]


class TestStoredState(BaseTestCase):

def setUp(self):
Expand Down Expand Up @@ -1116,14 +1141,7 @@ def test_mutable_types(self):
# Test and validation functions in a list of tuples.
# Assignment and keywords like del are not supported in lambdas
# so functions are used instead.
test_case = typing.Tuple[
typing.Callable[[], typing.Any], # Called to get operand A.
typing.Any, # Operand B.
typing.Any, # Expected result.
typing.Callable[[typing.Any, typing.Any], None], # Operation to perform.
typing.Callable[[typing.Any, typing.Any], typing.Any], # Validation to perform.
]
test_operations: typing.List[test_case] = [(
test_operations: typing.List[MutableTypesTestCase] = [(
lambda: {},
None,
{},
Expand Down Expand Up @@ -1336,14 +1354,7 @@ def save_snapshot(self, value: typing.Union[ops.StoredStateData, ops.EventBase])
framework_copy.close()

def test_comparison_operations(self):
test_case = typing.Tuple[
typing.Any, # Operand A.
typing.Any, # Operand B.
typing.Callable[[typing.Any, typing.Any], bool], # Operation to test.
bool, # Result of op(A, B).
bool, # Result of op(B, A).
]
test_operations: typing.List[test_case] = [(
test_operations: typing.List[ComparisonOperationsTestCase] = [(
{"1"},
{"1", "2"},
lambda a, b: a < b,
Expand Down Expand Up @@ -1436,14 +1447,7 @@ class SomeObject(ops.Object):
self.assertEqual(op(b, obj._stored.a), op_ba)

def test_set_operations(self):
test_case = typing.Tuple[
typing.Set[str], # A set to test an operation against (other_set).
# An operation to test.
typing.Callable[[typing.Set[str], typing.Set[str]], typing.Set[str]],
typing.Set[str], # The expected result of operation(obj._stored.set, other_set).
typing.Set[str], # The expected result of operation(other_set, obj._stored.set).
]
test_operations: typing.List[test_case] = [(
test_operations: typing.List[SetOperationsTestCase] = [(
{"1"},
lambda a, b: a | b,
{"1", "a", "b"},
Expand Down
23 changes: 13 additions & 10 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,7 +2378,14 @@ def test_unresolved_ingress_addresses(self):
self.assertEqual(binding.network.ingress_addresses, ['foo.bar.baz.com'])


_metric_and_label_pair = typing.Tuple[typing.Dict[str, float], typing.Dict[str, str]]
_MetricAndLabelPair = typing.Tuple[typing.Dict[str, float], typing.Dict[str, str]]


_ValidMetricsTestCase = typing.Tuple[
typing.Mapping[str, typing.Union[int, float]],
typing.Mapping[str, str],
typing.List[typing.List[str]],
]


class TestModelBackend(unittest.TestCase):
Expand Down Expand Up @@ -2851,12 +2858,8 @@ def test_juju_log(self):
[['juju-log', '--log-level', 'BAR', '--', 'foo']])

def test_valid_metrics(self):
_caselist = typing.List[typing.Tuple[
typing.Mapping[str, typing.Union[int, float]],
typing.Mapping[str, str],
typing.List[typing.List[str]]]]
fake_script(self, 'add-metric', 'exit 0')
test_cases: _caselist = [(
test_cases: typing.List[_ValidMetricsTestCase] = [(
OrderedDict([('foo', 42), ('b-ar', 4.5), ('ba_-z', 4.5), ('a', 1)]),
OrderedDict([('de', 'ad'), ('be', 'ef_ -')]),
[['add-metric', '--labels', 'de=ad,be=ef_ -',
Expand All @@ -2871,7 +2874,7 @@ def test_valid_metrics(self):
self.assertEqual(fake_script_calls(self, clear=True), expected_calls)

def test_invalid_metric_names(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'': 4.2}, {}),
({'1': 4.2}, {}),
({'1': -4.2}, {}),
Expand All @@ -2890,7 +2893,7 @@ def test_invalid_metric_names(self):
self.backend.add_metrics(metrics, labels)

def test_invalid_metric_values(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'a': float('+inf')}, {}),
({'a': float('-inf')}, {}),
({'a': float('nan')}, {}),
Expand All @@ -2902,7 +2905,7 @@ def test_invalid_metric_values(self):
self.backend.add_metrics(metrics, labels)

def test_invalid_metric_labels(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'foo': 4.2}, {'': 'baz'}),
({'foo': 4.2}, {',bar': 'baz'}),
({'foo': 4.2}, {'b=a=r': 'baz'}),
Expand All @@ -2913,7 +2916,7 @@ def test_invalid_metric_labels(self):
self.backend.add_metrics(metrics, labels)

def test_invalid_metric_label_values(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'foo': 4.2}, {'bar': ''}),
({'foo': 4.2}, {'bar': 'b,az'}),
({'foo': 4.2}, {'bar': 'b=az'}),
Expand Down
11 changes: 6 additions & 5 deletions test/test_pebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@ def _parse_write_multipart(self,
for part in message.walk():
name = part.get_param('name', header='Content-Disposition')
if name == 'request':
req = json.loads(part.get_payload())
req = json.loads(typing.cast(str, part.get_payload()))
benhoyt marked this conversation as resolved.
Show resolved Hide resolved
elif name == 'files':
# decode=True, ironically, avoids decoding bytes to str
content = part.get_payload(decode=True)
Expand Down Expand Up @@ -3092,10 +3092,11 @@ def test_wait_exit_nonzero(self):
process = self.client.exec(['false'])
with self.assertRaises(pebble.ExecError) as cm:
process.wait()
self.assertEqual(cm.exception.command, ['false'])
self.assertEqual(cm.exception.exit_code, 1)
self.assertEqual(cm.exception.stdout, None)
self.assertEqual(cm.exception.stderr, None)
exc = typing.cast(pebble.ExecError[str], cm.exception)
self.assertEqual(exc.command, ['false'])
self.assertEqual(exc.exit_code, 1)
self.assertIsNone(exc.stdout)
self.assertIsNone(exc.stderr)

self.assertEqual(self.client.requests, [
('POST', '/v1/exec', None, self.build_exec_data(['false'])),
Expand Down
7 changes: 4 additions & 3 deletions test/test_real_pebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@ def test_exec_wait_output(self):
with self.assertRaises(pebble.ExecError) as cm:
process = self.client.exec(['/bin/sh', '-c', 'echo OUT; echo ERR >&2; exit 42'])
process.wait_output()
self.assertEqual(cm.exception.exit_code, 42)
self.assertEqual(cm.exception.stdout, 'OUT\n')
self.assertEqual(cm.exception.stderr, 'ERR\n')
exc = typing.cast(pebble.ExecError[str], cm.exception)
self.assertEqual(exc.exit_code, 42)
self.assertEqual(exc.stdout, 'OUT\n')
self.assertEqual(exc.stderr, 'ERR\n')

def test_exec_send_stdin(self):
process = self.client.exec(['awk', '{ print toupper($0) }'], stdin='foo\nBar\n')
Expand Down
Loading