Skip to content

Commit 0ce0ffd

Browse files
refactor to move functions inside async
Signed-off-by: Jaya Venkatesh <jjayabaskar@nvidia.com>
1 parent 4e88de0 commit 0ce0ffd

File tree

4 files changed

+421
-20
lines changed

4 files changed

+421
-20
lines changed

distributed/client.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5457,20 +5457,22 @@ def unregister_worker_plugin(self, name, nanny=None):
54575457
return self.sync(self._unregister_worker_plugin, name=name, nanny=nanny)
54585458

54595459
def has_plugin(
5460-
self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | list
5460+
self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence
54615461
) -> bool | dict[str, bool]:
54625462
"""Check if plugin(s) are registered
54635463
54645464
Parameters
54655465
----------
5466-
plugin : str | plugin object | list
5467-
Plugin to check. You can use the plugin object directly or the plugin name. For plugin objects, they must have a 'name' attribute. You can also pass a list of plugin objects or names.
5466+
plugin : str | plugin object | Sequence
5467+
Plugin to check. You can use the plugin object directly or the plugin name.
5468+
For plugin objects, they must have a 'name' attribute. You can also pass
5469+
a sequence of plugin objects or names.
54685470
54695471
Returns
54705472
-------
54715473
bool or dict[str, bool]
54725474
If name is str: True if plugin is registered, False otherwise
5473-
If name is list: dict mapping names to registration status
5475+
If name is Sequence: dict mapping names to registration status
54745476
54755477
Examples
54765478
--------
@@ -5485,23 +5487,27 @@ def has_plugin(
54855487
>>> client.has_plugin([logging_plugin, 'other-plugin'])
54865488
{'logging-config': True, 'other-plugin': False}
54875489
"""
5488-
if isinstance(plugin, str):
5489-
result = self.sync(self._get_plugin_registration_status, names=[plugin])
5490-
return result[plugin]
5490+
return self.sync(self._has_plugin_async, plugin=plugin)
54915491

5492+
async def _has_plugin_async(
5493+
self, plugin: str | WorkerPlugin | SchedulerPlugin | NannyPlugin | Sequence
5494+
) -> bool | dict[str, bool]:
5495+
"""Async implementation for checking plugin registration"""
5496+
5497+
# Convert plugin to list of names
5498+
if isinstance(plugin, str):
5499+
names_to_check = [plugin]
5500+
return_single = True
54925501
elif isinstance(plugin, (WorkerPlugin, SchedulerPlugin, NannyPlugin)):
54935502
plugin_name = getattr(plugin, "name", None)
54945503
if plugin_name is None:
54955504
raise ValueError(
54965505
f"Plugin {funcname(type(plugin))} has no 'name' attribute. "
54975506
"Please add a 'name' attribute to your plugin class."
54985507
)
5499-
result = self.sync(
5500-
self._get_plugin_registration_status, names=[plugin_name]
5501-
)
5502-
return result[plugin_name]
5503-
5504-
elif isinstance(plugin, list):
5508+
names_to_check = [plugin_name]
5509+
return_single = True
5510+
elif isinstance(plugin, Sequence):
55055511
names_to_check = []
55065512
for p in plugin:
55075513
if isinstance(p, str):
@@ -5513,13 +5519,20 @@ def has_plugin(
55135519
f"Plugin {funcname(type(p))} has no 'name' attribute"
55145520
)
55155521
names_to_check.append(plugin_name)
5516-
return self.sync(self._get_plugin_registration_status, names=names_to_check)
5517-
5518-
async def _get_plugin_registration_status(
5519-
self, names: list[str]
5520-
) -> dict[str, bool]:
5521-
"""Async implementation for checking plugin registration"""
5522-
return await self.scheduler.get_plugin_registration_status(names=names)
5522+
return_single = False
5523+
else:
5524+
raise TypeError(
5525+
f"plugin must be a plugin object, name string, or Sequence. Got {type(plugin)}"
5526+
)
5527+
5528+
# Get status from scheduler
5529+
result = await self.scheduler.get_plugin_registration_status(names=names_to_check)
5530+
5531+
# Return single bool or dict based on input
5532+
if return_single:
5533+
return result[names_to_check[0]]
5534+
else:
5535+
return result
55235536

55245537
@property
55255538
def amm(self):

distributed/diagnostics/tests/test_nanny_plugin.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,150 @@ async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s):
217217
logs = caplog.getvalue()
218218
assert "TestPlugin1 failed to teardown" in logs
219219
assert "test error" in logs
220+
221+
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
222+
async def test_has_nanny_plugin_by_name(c, s, a):
223+
"""Test checking if nanny plugin is registered using string name"""
224+
225+
class DuckPlugin(NannyPlugin):
226+
name = "duck-plugin"
227+
228+
def setup(self, nanny):
229+
nanny.foo = 123
230+
231+
def teardown(self, nanny):
232+
pass
233+
234+
# Check non-existent plugin
235+
assert not await c.has_plugin("duck-plugin")
236+
237+
# Register plugin
238+
await c.register_plugin(DuckPlugin())
239+
assert a.foo == 123
240+
241+
# Check using string name
242+
assert await c.has_plugin("duck-plugin")
243+
244+
# Unregister and check again
245+
await c.unregister_worker_plugin("duck-plugin", nanny=True)
246+
assert not await c.has_plugin("duck-plugin")
247+
248+
249+
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
250+
async def test_has_nanny_plugin_by_object(c, s, a):
251+
"""Test checking if nanny plugin is registered using plugin object"""
252+
253+
class DuckPlugin(NannyPlugin):
254+
name = "duck-plugin"
255+
256+
def setup(self, nanny):
257+
nanny.bar = 456
258+
259+
def teardown(self, nanny):
260+
pass
261+
262+
plugin = DuckPlugin()
263+
264+
# Check before registration
265+
assert not await c.has_plugin(plugin)
266+
267+
# Register and check
268+
await c.register_plugin(plugin)
269+
assert a.bar == 456
270+
assert await c.has_plugin(plugin)
271+
272+
# Unregister and check
273+
await c.unregister_worker_plugin("duck-plugin", nanny=True)
274+
assert not await c.has_plugin(plugin)
275+
276+
277+
@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], Worker=Nanny)
278+
async def test_has_nanny_plugin_multiple_nannies(c, s, a, b):
279+
"""Test checking nanny plugin with multiple nannies"""
280+
281+
class DuckPlugin(NannyPlugin):
282+
name = "duck-plugin"
283+
284+
def setup(self, nanny):
285+
nanny.multi = "setup"
286+
287+
def teardown(self, nanny):
288+
pass
289+
290+
# Check before registration
291+
assert not await c.has_plugin("duck-plugin")
292+
293+
# Register plugin (should propagate to all nannies)
294+
await c.register_plugin(DuckPlugin())
295+
296+
# Verify both nannies have the plugin
297+
assert a.multi == "setup"
298+
assert b.multi == "setup"
299+
300+
# Check plugin is registered
301+
assert await c.has_plugin("duck-plugin")
302+
303+
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
304+
async def test_has_nanny_plugin_custom_name_override(c, s, a):
305+
"""Test nanny plugin registered with custom name different from class name"""
306+
307+
class DuckPlugin(NannyPlugin):
308+
name = "duck-plugin"
309+
310+
def setup(self, nanny):
311+
nanny.custom = "test"
312+
313+
def teardown(self, nanny):
314+
pass
315+
316+
plugin = DuckPlugin()
317+
318+
# Register with custom name (overriding the class name attribute)
319+
await c.register_plugin(plugin, name="custom-override")
320+
321+
# Check with custom name works
322+
assert await c.has_plugin("custom-override")
323+
324+
# Original name won't work since we overrode it
325+
assert not await c.has_plugin("duck-plugin")
326+
327+
328+
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
329+
async def test_has_nanny_plugin_list_check(c, s, a):
330+
"""Test checking multiple nanny plugins at once"""
331+
332+
class IdempotentPlugin(NannyPlugin):
333+
name = "idempotentplugin"
334+
335+
def setup(self, nanny):
336+
pass
337+
338+
def teardown(self, nanny):
339+
pass
340+
341+
class NonIdempotentPlugin(NannyPlugin):
342+
name = "nonidempotentplugin"
343+
344+
def setup(self, nanny):
345+
pass
346+
347+
def teardown(self, nanny):
348+
pass
349+
350+
# Check multiple before registration
351+
result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin", "nonexistent"])
352+
assert result == {
353+
"idempotentplugin": False,
354+
"nonidempotentplugin": False,
355+
"nonexistent": False,
356+
}
357+
358+
# Register first plugin
359+
await c.register_plugin(IdempotentPlugin())
360+
result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"])
361+
assert result == {"idempotentplugin": True, "nonidempotentplugin": False}
362+
363+
# Register second plugin
364+
await c.register_plugin(NonIdempotentPlugin())
365+
result = await c.has_plugin(["idempotentplugin", "nonidempotentplugin"])
366+
assert result == {"idempotentplugin": True, "nonidempotentplugin": True}

distributed/diagnostics/tests/test_scheduler_plugin.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,3 +753,141 @@ def __init__(self, instance=None):
753753
await s.register_scheduler_plugin(plugin=dumps(third))
754754
assert "nonidempotentplugin" in s.plugins
755755
assert s.plugins["nonidempotentplugin"].instance == "third"
756+
757+
@gen_cluster(client=True)
758+
async def test_has_scheduler_plugin_by_name(c, s, a, b):
759+
"""Test checking if scheduler plugin is registered using string name"""
760+
761+
class Dummy1(SchedulerPlugin):
762+
name = "Dummy1"
763+
764+
def start(self, scheduler):
765+
scheduler.foo = "bar"
766+
767+
# Check non-existent plugin
768+
assert not await c.has_plugin("Dummy1")
769+
770+
# Register plugin
771+
await c.register_plugin(Dummy1())
772+
assert s.foo == "bar"
773+
774+
# Check using string name
775+
assert await c.has_plugin("Dummy1")
776+
777+
# Unregister and check again
778+
await c.unregister_scheduler_plugin("Dummy1")
779+
assert not await c.has_plugin("Dummy1")
780+
781+
782+
@gen_cluster(client=True)
783+
async def test_has_scheduler_plugin_by_object(c, s, a, b):
784+
"""Test checking if scheduler plugin is registered using plugin object"""
785+
786+
class Dummy2(SchedulerPlugin):
787+
name = "Dummy2"
788+
789+
def start(self, scheduler):
790+
scheduler.check_value = 42
791+
792+
plugin = Dummy2()
793+
794+
# Check before registration
795+
assert not await c.has_plugin(plugin)
796+
797+
# Register and check
798+
await c.register_plugin(plugin)
799+
assert s.check_value == 42
800+
assert await c.has_plugin(plugin)
801+
802+
# Unregister and check
803+
await c.unregister_scheduler_plugin("Dummy2")
804+
assert not await c.has_plugin(plugin)
805+
806+
807+
@gen_cluster(client=True)
808+
async def test_has_plugin_mixed_scheduler_and_worker_types(c, s, a, b):
809+
"""Test checking scheduler and worker plugins together"""
810+
from distributed import WorkerPlugin
811+
812+
class MyPlugin(SchedulerPlugin):
813+
name = "MyPlugin"
814+
815+
def start(self, scheduler):
816+
scheduler.my_value = "scheduler"
817+
818+
class MyWorkerPlugin(WorkerPlugin):
819+
name = "MyWorkerPlugin"
820+
821+
def setup(self, worker):
822+
worker.my_value = "worker"
823+
824+
sched_plugin = MyPlugin()
825+
work_plugin = MyWorkerPlugin()
826+
827+
# Register both types
828+
await c.register_plugin(sched_plugin)
829+
await c.register_plugin(work_plugin)
830+
831+
# Verify both registered
832+
assert s.my_value == "scheduler"
833+
assert a.my_value == "worker"
834+
assert b.my_value == "worker"
835+
836+
# Check both with list of names
837+
result = await c.has_plugin(["MyPlugin", "MyWorkerPlugin"])
838+
assert result == {"MyPlugin": True, "MyWorkerPlugin": True}
839+
840+
# Check both with objects
841+
assert await c.has_plugin(sched_plugin)
842+
assert await c.has_plugin(work_plugin)
843+
844+
# Check non-existent alongside real ones
845+
result = await c.has_plugin(["MyPlugin", "nonexistent", "MyWorkerPlugin"])
846+
assert result == {
847+
"MyPlugin": True,
848+
"nonexistent": False,
849+
"MyWorkerPlugin": True
850+
}
851+
852+
853+
@gen_cluster(client=True, nthreads=[])
854+
async def test_has_scheduler_plugin_no_workers(c, s):
855+
"""Test checking scheduler plugin when no workers exist"""
856+
857+
class Plugin(SchedulerPlugin):
858+
name = "plugin"
859+
860+
def start(self, scheduler):
861+
scheduler.no_worker_test = True
862+
863+
# Check before registration
864+
assert not await c.has_plugin("plugin")
865+
866+
# Register plugin when no workers exist
867+
await c.register_plugin(Plugin())
868+
assert s.no_worker_test is True
869+
870+
# Check after registration
871+
assert await c.has_plugin("plugin")
872+
873+
874+
@gen_cluster(client=True)
875+
async def test_has_scheduler_plugin_custom_name_override(c, s, a, b):
876+
"""Test scheduler plugin registered with custom name different from class name"""
877+
878+
class Dummy3(SchedulerPlugin):
879+
name = "Dummy3"
880+
881+
def start(self, scheduler):
882+
scheduler.name_test = "custom"
883+
884+
plugin = Dummy3()
885+
886+
# Register with custom name (overriding the class name attribute)
887+
await c.register_plugin(plugin, name="custom-override")
888+
889+
# Check with custom name works
890+
assert await c.has_plugin("custom-override")
891+
892+
# Original name won't work since we overrode it
893+
assert not await c.has_plugin("Dummy3")

0 commit comments

Comments
 (0)