Skip to content

Commit 8e2f334

Browse files
committed
make TensorBoardWSGIApp allow passing TBPlugin or TBLoader classes directly
1 parent cc8c7d3 commit 8e2f334

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

tensorboard/backend/application.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def TensorBoardWSGIApp(
160160
161161
Args:
162162
flags: An argparse.Namespace containing TensorBoard CLI flags.
163-
plugins: A list of TBLoader subclasses for the plugins to load.
163+
plugins: A list of plugins, which can be provided as TBPlugin subclasses
164+
or TBLoader instances or subclasses.
164165
assets_zip_provider: See TBContext documentation for more information.
165166
data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
166167
be `None` if `flags.generic_data` is set to `"false"` in which case
@@ -187,7 +188,8 @@ def TensorBoardWSGIApp(
187188
plugin_name_to_instance=plugin_name_to_instance,
188189
window_title=flags.window_title)
189190
tbplugins = []
190-
for loader in plugins:
191+
for plugin_spec in plugins:
192+
loader = _make_plugin_loader(plugin_spec)
191193
plugin = loader.load(context)
192194
if plugin is None:
193195
continue
@@ -597,3 +599,21 @@ def AddRunsFromDirectory(self, path, name=None):
597599
def Reload(self):
598600
"""Unsupported."""
599601
raise NotImplementedError()
602+
603+
604+
def _make_plugin_loader(plugin):
605+
"""Returns a plugin loader for the given plugin.
606+
607+
Args:
608+
plugin: A TBPlugin subclass, or a TBLoader instance or subclass.
609+
610+
Returns:
611+
A TBLoader for the given plugin.
612+
"""
613+
if isinstance(plugin, base_plugin.TBLoader):
614+
return plugin
615+
if issubclass(plugin, base_plugin.TBLoader):
616+
return plugin()
617+
if issubclass(plugin, base_plugin.TBPlugin):
618+
return base_plugin.BasicLoader(plugin)
619+
raise ValueError("Not a TBLoader or TBPlugin subclass: %s" % plugin)

tensorboard/backend/application_test.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ def frontend_metadata(self):
135135
)
136136

137137

138+
class FakePluginLoader(base_plugin.TBLoader):
139+
"""Pass-through loader for FakePlugin with arbitrary arguments."""
140+
141+
def __init__(self, **kwargs):
142+
self._kwargs = kwargs
143+
144+
def load(self, context):
145+
return FakePlugin(context, **self._kwargs)
146+
147+
138148
class ApplicationTest(tb_test.TestCase):
139149
def setUp(self):
140150
plugins = [
@@ -519,23 +529,21 @@ def setUp(self):
519529
self.app = application.standard_tensorboard_wsgi(
520530
FakeFlags(logdir=self.get_temp_dir()),
521531
[
522-
base_plugin.BasicLoader(functools.partial(
523-
FakePlugin,
524-
plugin_name='foo',
525-
is_active_value=True,
526-
routes_mapping={'/foo_route': self._foo_handler},
527-
construction_callback=self._construction_callback)),
528-
base_plugin.BasicLoader(functools.partial(
529-
FakePlugin,
530-
plugin_name='bar',
531-
is_active_value=True,
532-
routes_mapping={
533-
'/bar_route': self._bar_handler,
534-
'/wildcard/*': self._wildcard_handler,
535-
'/wildcard/special/*': self._wildcard_special_handler,
536-
'/wildcard/special/exact': self._foo_handler,
537-
},
538-
construction_callback=self._construction_callback)),
532+
FakePluginLoader(
533+
plugin_name='foo',
534+
is_active_value=True,
535+
routes_mapping={'/foo_route': self._foo_handler},
536+
construction_callback=self._construction_callback),
537+
FakePluginLoader(
538+
plugin_name='bar',
539+
is_active_value=True,
540+
routes_mapping={
541+
'/bar_route': self._bar_handler,
542+
'/wildcard/*': self._wildcard_handler,
543+
'/wildcard/special/*': self._wildcard_special_handler,
544+
'/wildcard/special/exact': self._foo_handler,
545+
},
546+
construction_callback=self._construction_callback),
539547
],
540548
dummy_assets_zip_provider)
541549

0 commit comments

Comments
 (0)