Skip to content

Commit

Permalink
Use settrace
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Rossi committed Aug 23, 2021
1 parent 3525e04 commit 19e207d
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 74 deletions.
15 changes: 1 addition & 14 deletions google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,16 +674,6 @@ def new_value(old_value):
pass


def _syncpoint_update_key():
"""A no-op function meant to be patched for testing.
Should be replaced by `orchestrate.syncpoint` using `mock.patch` during testing to
orchestrate concurrent testing scenarios.
See: `tests.unit.test_concurrency`
"""


@tasklets.tasklet
def _update_key(key, new_value):
success = False
Expand All @@ -693,10 +683,7 @@ def _update_key(key, new_value):
utils.logging_debug(log, "old value: {}", old_value)

value = new_value(old_value)
utils.logging_debug(log, "new value: {}", value)

if __debug__:
_syncpoint_update_key()
utils.logging_debug(log, "new value: {}", value) # pragma: SYNCPOINT update key

if old_value is not None:
utils.logging_debug(log, "compare and swap")
Expand Down
160 changes: 125 additions & 35 deletions tests/unit/orchestrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@

import itertools
import math
import sys
import threading
import tokenize

try:
import queue
except ImportError: # pragma: NO PY3 COVER
import Queue as queue


def orchestrate(*tests):
def orchestrate(*tests, **kwargs):
"""
Orchestrate a deterministic concurrency test.
Expand All @@ -40,9 +42,8 @@ def orchestrate(*tests):
`orchestrate` runs each passed in test function in its own thread. Threads then
"take turns" running. Turns are defined by setting syncpoints in the code under
test. To do this, you'll write a no-op function and call it at the point where you'd
like your code to pause and give another thread a turn. In your test, then, use
`mock.patch` to replace your no-op function with :func:`syncpoint` in your test.
test, using comment containing "pragma: SYNCPOINT". `orchestrate` will scan the code
under test and add syncpoints where it finds these comments.
For example, let's say you have the following code in production::
Expand All @@ -52,19 +53,16 @@ def hither_and_yon(destination):
You've found there's a concurrency bug when two threads execute this code with the
same destination, and you think that by adding a syncpoint between the calls to
`hither` and `yon` you can reproduce the problem in a regression test. First you'd,
write a no-op function, include it in your production code, and call it in
`hither_and_yon`::
def _syncpoint_123():
pass
`hither` and `yon` you can reproduce the problem in a regression test. First add a
comment with "pragma: SYNCPOINT" to the code under test::
def hither_and_yon(destination):
hither(destination)
_syncpoint_123()
hither(destination) # pragma: SYNCPOINT
yon(destination)
Now you can write a test to exercise `hither_and_yon` running in parallel::
When testing with orchestrate, there will now be a syncpoint, or a pause, after the
call to `hither` and before the call to `yon`. Now you can write a test to exercise
`hither_and_yon` running in parallel::
from unittest import mock
from tests.unit import orchestrate
Expand All @@ -91,8 +89,8 @@ def test_hither_and_yon():
encountered when executing the test.
Once the counts have been taken, `orchestrate` will construct a test sequence that
represents the all the turns taken by the passed in tests, with each value in the
sequence representing the the index of the test whose turn it is in the sequence. In
represents all of the turns taken by the passed in tests, with each value in the
sequence representing the index of the test whose turn it is in the sequence. In
this example, then, it would produce::
[0, 0, 1, 1]
Expand Down Expand Up @@ -122,25 +120,49 @@ def test_hither_and_yon():
have over 17 thousand scenarios. In general, use the least number of steps/threads
you can get away with and still expose the behavior you want to correct.
For the same reason as above, if you have many concurrent tests, when writing a new
test, make sure you're not accidentally patching syncpoints intended for other
tests, as this will add steps to your tests. While it's not problematic from a
testing standpoint to have extra steps in your tests, it can use computing resources
unnecessarily. Using different no-op functions with different names for different
tests can help with this.
For the same reason as above, its recommended that if you have many concurrent
tests, that you name your syncpoints so that you're not accidentally using
syncpoints intended for other tests, as this will add steps to your tests. While
it's not problematic from a testing standpoint to have extra steps in your tests, it
can use computing resources unnecessarily. A name can be added to any syncpoint
after the `SYNCPOINT` keyword in the pragma definition::
def hither_and_yon(destination):
hither(destination) # pragma: SYNCPOINT hither and yon
yon(destination)
In your test, then, pass that name to `orchestrate` to cause it to use only
syncpoints with that name::
orchestrate.orchestrate(
test_hither_and_yon, test_hither_and_yon, name="hither and yon"
)
As soon as any error or failure is detected, no more scenarios are run
and that error is propagated to the main thread.
One limitation of `orchestrate` is that it cannot really be used with `coverage`,
since both tools use `sys.set_trace`. Any code that needs verifiable test coverage
should have additional tests that do not use `orchestrate`, since code that is run
under orchestrate will not show up in a coverage report generated by `coverage`.
Args:
tests (Tuple[Callable]): Test functions to be run. These functions will not be
called with any arguments, so they must not have any required arguments.
name (Optional[str]): Only use syncpoints with the given name. If omitted, only
unnamed syncpoints will be used.
Returns:
Tuple[int]: A tuple of the count of the number turns for test passed in. Can be
used a sanity check in tests to make sure you understand what's actually
happening during a test.
"""
name = kwargs.pop("name", None)
if kwargs:
raise TypeError(
"Unexpected keyword arguments: {}".format(", ".join(kwargs.keys()))
)

# Produce an initial test sequence. The fundamental question we're always trying to
# answer is "whose turn is it?" First we'll find out how many "turns" each test
# needs to complete when run serially and use that to construct a sequence of
Expand All @@ -150,7 +172,7 @@ def test_hither_and_yon():
test_sequence = []
counts = []
for index, test in enumerate(tests):
thread = _TestThread(test)
thread = _TestThread(test, name)
for count in itertools.count(1): # pragma: NO BRANCH
# Pragma is required because loop never finishes naturally.
thread.go()
Expand All @@ -169,7 +191,7 @@ def test_hither_and_yon():

# Test each sequence
for test_sequence in sequences:
threads = [_TestThread(test) for test in tests]
threads = [_TestThread(test, name) for test in tests]
try:
for index in test_sequence:
threads[index].go()
Expand All @@ -191,17 +213,6 @@ def test_hither_and_yon():
return tuple(counts)


def syncpoint():
"""End a thread's "turn" at this point.
This will generally be inserted by `mock.patch` to replace a no-op function in
production code. See documentation for :func:`orchestrate`.
"""
conductor = _local.conductor
conductor.notify()
conductor.standby()


_local = threading.local()


Expand Down Expand Up @@ -235,18 +246,53 @@ def go(self):
self._go.put(None)


_SYNCPOINTS = {}
"""Dict[str, Dict[str, Set[int]]]: Dict mapping source fileneme to a dict mapping
syncpoint name to set of line numbers where syncpoints with that name occur in the
source file.
"""


def _get_syncpoints(filename):
"""Find syncpoints in a source file.
Does a simple tokenization of the source file, looking for comments with "pragma:
SYNCPOINT", and populates _SYNCPOINTS using the syncpoint name and line number in
the source file.
"""
_SYNCPOINTS[filename] = syncpoints = {}

# Use tokenize to find pragma comments
with open(filename, "r") as pyfile:
tokens = tokenize.generate_tokens(pyfile.readline)
for type, value, start, end, line in tokens:
if type == tokenize.COMMENT and "pragma: SYNCPOINT" in value:
name = value.split("SYNCPOINT", 1)[1].strip()
if not name:
name = None

if name not in syncpoints:
syncpoints[name] = set()

lineno, column = start
syncpoints[name].add(lineno)


class _TestThread:
"""A thread for a test function."""

thread = None
finished = False
error = None
at_syncpoint = False

def __init__(self, test):
def __init__(self, test, name):
self.test = test
self.name = name
self.conductor = _Conductor()

def _run(self):
sys.settrace(self._trace)
_local.conductor = self.conductor
try:
self.test()
Expand All @@ -256,6 +302,50 @@ def _run(self):
self.finished = True
self.conductor.notify()

def _sync(self):
# Tell main thread we're finished, for now
self.conductor.notify()

# Wait for the main thread to tell us to go again
self.conductor.standby()

def _trace(self, frame, event, arg):
"""Argument to `sys.settrace`.
Handles frames during test run, syncing at syncpoints, when found.
Returns:
`None` if no more tracing is required for the function call, `self._trace`
if tracing should continue.
"""
if self.at_syncpoint:
# We hit a syncpoint on the previous call, so now we sync.
self._sync()
self.at_syncpoint = False

filename = frame.f_globals.get("__file__")
if not filename:
# Can't trace code without a source file
return

if filename.endswith(".pyc"):
filename = filename[:-1]

if filename not in _SYNCPOINTS:
_get_syncpoints(filename)

syncpoints = _SYNCPOINTS[filename].get(self.name)
if not syncpoints:
# This file doesn't contain syncpoints, don't continue to trace
return

# We've hit a syncpoint. Execute whatever line the syncpoint is on and then
# sync next time this gets called.
if frame.f_lineno in syncpoints:
self.at_syncpoint = True

return self._trace

def go(self):
if self.finished:
return
Expand Down
14 changes: 4 additions & 10 deletions tests/unit/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@

import pytest

try:
from unittest import mock
except ImportError: # pragma: NO PY3 COVER
import mock

from google.cloud.ndb import _cache
from google.cloud.ndb import global_cache as global_cache_module
from google.cloud.ndb import tasklets
Expand All @@ -31,7 +26,7 @@
log = logging.getLogger(__name__)


def cache_factories():
def cache_factories(): # pragma: NO COVER
yield global_cache_module._InProcessGlobalCache

def redis_cache():
Expand All @@ -48,7 +43,6 @@ def memcache_cache():


@pytest.mark.parametrize("cache_factory", cache_factories())
@mock.patch("google.cloud.ndb._cache._syncpoint_update_key", orchestrate.syncpoint)
def test_global_cache_concurrent_write_692(cache_factory, context_factory):
"""Regression test for #692
Expand All @@ -57,7 +51,7 @@ def test_global_cache_concurrent_write_692(cache_factory, context_factory):
key = b"somekey"

@tasklets.synctasklet
def lock_unlock_key():
def lock_unlock_key(): # pragma: NO COVER
lock = yield _cache.global_lock_for_write(key)
cache_value = yield _cache.global_get(key)
assert lock in cache_value
Expand All @@ -66,9 +60,9 @@ def lock_unlock_key():
cache_value = yield _cache.global_get(key)
assert lock not in cache_value

def run_test():
def run_test(): # pragma: NO COVER
global_cache = cache_factory()
with context_factory(global_cache=global_cache).use():
lock_unlock_key()

orchestrate.orchestrate(run_test, run_test)
orchestrate.orchestrate(run_test, run_test, name="update key")
Loading

0 comments on commit 19e207d

Please sign in to comment.