Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing types to tests.util. #14597

Merged
merged 16 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14597.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
13 changes: 3 additions & 10 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py
|tests/util/test_expiring_cache.py
|tests/util/test_file_consumer.py
|tests/util/test_linearizer.py
|tests/util/test_logcontext.py
|tests/util/test_lrucache.py
|tests/util/test_rwlock.py
|tests/util/test_wheel_timer.py
)$

[mypy-synapse.federation.transport.client]
Expand Down Expand Up @@ -137,6 +127,9 @@ disallow_untyped_defs = True
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False

[mypy-tests.util.*]
disallow_untyped_defs = True

[mypy-tests.utils]
disallow_untyped_defs = True

Expand Down
118 changes: 63 additions & 55 deletions tests/util/test_async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Generator, List, NoReturn, Optional

from parameterized import parameterized_class

Expand Down Expand Up @@ -41,8 +42,8 @@


class ObservableDeferredTest(TestCase):
def test_succeed(self):
origin_d = Deferred()
def test_succeed(self) -> None:
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d)

observer1 = observable.observe()
Expand All @@ -52,16 +53,18 @@ def test_succeed(self):
self.assertFalse(observer2.called)

# check the first observer is called first
def check_called_first(res):
def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res

observer1.addBoth(check_called_first)

# store the results
results = [None, None]
results: List[Optional[ObservableDeferred[int]]] = [None, None]

def check_val(res, idx):
def check_val(
res: ObservableDeferred[int], idx: int
) -> ObservableDeferred[int]:
results[idx] = res
return res

Expand All @@ -72,8 +75,8 @@ def check_val(res, idx):
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")

def test_failure(self):
origin_d = Deferred()
def test_failure(self) -> None:
origin_d: Deferred = Deferred()
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
observable = ObservableDeferred(origin_d, consumeErrors=True)

observer1 = observable.observe()
Expand All @@ -83,16 +86,16 @@ def test_failure(self):
self.assertFalse(observer2.called)

# check the first observer is called first
def check_called_first(res):
def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res

observer1.addBoth(check_called_first)

# store the results
results = [None, None]
results: List[Optional[ObservableDeferred[str]]] = [None, None]

def check_val(res, idx):
def check_val(res: ObservableDeferred[str], idx: int) -> None:
results[idx] = res
return None

Expand All @@ -103,10 +106,12 @@ def check_val(res, idx):
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")

def test_cancellation(self):
def test_cancellation(self) -> None:
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
Expand Down Expand Up @@ -136,37 +141,38 @@ def test_cancellation(self):


class TimeoutDeferredTest(TestCase):
def setUp(self):
def setUp(self) -> None:
self.clock = Clock()

def test_times_out(self):
def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
cancelled = [False]
cancelled = False

def canceller(_d):
cancelled[0] = True
def canceller(_d: Deferred) -> None:
nonlocal cancelled
cancelled = True
clokep marked this conversation as resolved.
Show resolved Hide resolved

non_completing_d = Deferred(canceller)
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)

self.assertNoResult(timing_out_d)
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
self.assertFalse(cancelled, "deferred was cancelled prematurely")

self.clock.pump((1.0,))

self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)

def test_times_out_when_canceller_throws(self):
def test_times_out_when_canceller_throws(self) -> None:
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""

def canceller(_d):
def canceller(_d: Deferred) -> None:
raise Exception("can't cancel this deferred")

non_completing_d = Deferred(canceller)
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)

self.assertNoResult(timing_out_d)
Expand All @@ -175,22 +181,24 @@ def canceller(_d):

self.failureResultOf(timing_out_d, defer.TimeoutError)

def test_logcontext_is_preserved_on_cancellation(self):
blocking_was_cancelled = [False]
def test_logcontext_is_preserved_on_cancellation(self) -> None:
blocking_was_cancelled = False

@defer.inlineCallbacks
def blocking():
non_completing_d = Deferred()
def blocking() -> Generator["Deferred[object]", object, None]:
nonlocal blocking_was_cancelled

non_completing_d: Deferred = Deferred()
with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
blocking_was_cancelled[0] = True
blocking_was_cancelled = True
raise

with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
def errback(res, deferred_name):
def errback(res: Failure, deferred_name: str) -> Failure:
self.assertIs(
current_context(),
context_one,
Expand All @@ -209,7 +217,7 @@ def errback(res, deferred_name):
self.clock.pump((1.0,))

self.assertTrue(
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
blocking_was_cancelled, "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
Expand All @@ -220,13 +228,13 @@ class _TestException(Exception):


class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self):
def test_limits_runners(self) -> None:
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []

async def callback(v):
async def callback(v: int) -> None:
# when we first enter, bump the start count
nonlocal started
started += 1
Expand All @@ -235,7 +243,7 @@ async def callback(v):
processed.append(v)

# wait for the goahead before returning
d2 = Deferred()
d2: "Deferred[int]" = Deferred()
waiters.append(d2)
await d2

Expand Down Expand Up @@ -265,16 +273,16 @@ async def callback(v):
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)

def test_preserves_stacktraces(self):
def test_preserves_stacktraces(self) -> None:
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1 = Deferred()
d1: "Deferred[int]" = Deferred()

async def callback(v):
async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")

async def caller():
async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
Expand All @@ -290,17 +298,17 @@ async def caller():
d1.callback(0)
self.successResultOf(d2)

def test_preserves_stacktraces_on_preformed_failure(self):
def test_preserves_stacktraces_on_preformed_failure(self) -> None:
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1 = Deferred()
d1: "Deferred[int]" = Deferred()
f = Failure(_TestException("bah"))

async def callback(v):
async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)

async def caller():
async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
Expand Down Expand Up @@ -336,7 +344,7 @@ def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")

def test_succeed(self):
def test_succeed(self) -> None:
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
Expand All @@ -346,7 +354,7 @@ def test_succeed(self):
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))

def test_failure(self):
def test_failure(self) -> None:
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
Expand All @@ -361,7 +369,7 @@ def test_failure(self):
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""

def test_cancellation(self):
def test_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
Expand All @@ -384,7 +392,7 @@ def test_cancellation(self):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""

def test_deferred_cancellation(self):
def test_deferred_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
Expand All @@ -405,12 +413,12 @@ def test_deferred_cancellation(self):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_coroutine_cancellation(self):
def test_coroutine_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()

async def task():
async def task() -> NoReturn:
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
Expand All @@ -434,7 +442,7 @@ async def task():
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_suppresses_second_cancellation(self):
def test_suppresses_second_cancellation(self) -> None:
"""Test that a second cancellation is suppressed.

Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
Expand All @@ -459,7 +467,7 @@ def test_suppresses_second_cancellation(self):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_propagates_cancelled_error(self):
def test_propagates_cancelled_error(self) -> None:
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
Expand All @@ -472,14 +480,14 @@ def test_propagates_cancelled_error(self):
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)

def test_preserves_logcontext(self):
def test_preserves_logcontext(self) -> None:
"""Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred()

async def inner():
async def inner() -> None:
await make_deferred_yieldable(blocking_d)

async def outer():
async def outer() -> None:
with LoggingContext("c") as c:
try:
await delay_cancellation(inner())
Expand All @@ -503,7 +511,7 @@ async def outer():
class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"

def test_sleep(self):
def test_sleep(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)

Expand All @@ -518,7 +526,7 @@ def test_sleep(self):
reactor.advance(0.6)
self.assertTrue(d.called)

def test_explicit_wake(self):
def test_explicit_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)

Expand All @@ -535,7 +543,7 @@ def test_explicit_wake(self):

reactor.advance(0.6)

def test_multiple_sleepers_timeout(self):
def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)

Expand All @@ -555,7 +563,7 @@ def test_multiple_sleepers_timeout(self):
reactor.advance(0.6)
self.assertTrue(d2.called)

def test_multiple_sleepers_wake(self):
def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)

Expand Down
Loading