Skip to content

Commit 0b2d793

Browse files
authored
Allow passing TBPlugin/TBLoader subclasses to TensorBoardWSGIApp (#2582)
1 parent 87d2927 commit 0b2d793

File tree

7 files changed

+104
-42
lines changed

7 files changed

+104
-42
lines changed

tensorboard/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ py_test(
149149
":program",
150150
":test",
151151
"//tensorboard/backend:application",
152+
"//tensorboard/plugins:base_plugin",
152153
"//tensorboard/plugins/core:core_plugin",
153154
"@org_pocoo_werkzeug",
154155
],

tensorboard/backend/application.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def TensorBoardWSGIApp(
160160
161161
Args:
162162
flags: An argparse.Namespace containing TensorBoard CLI flags.
163-
plugins: A list of TBLoader subclasses for the plugins to load.
163+
plugins: A list of plugins, which can be provided as TBPlugin subclasses
164+
or TBLoader instances or subclasses.
164165
assets_zip_provider: See TBContext documentation for more information.
165166
data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
166167
be `None` if `flags.generic_data` is set to `"false"` in which case
@@ -191,7 +192,8 @@ def TensorBoardWSGIApp(
191192
plugin_name_to_instance=plugin_name_to_instance,
192193
window_title=flags.window_title)
193194
tbplugins = []
194-
for loader in plugins:
195+
for plugin_spec in plugins:
196+
loader = make_plugin_loader(plugin_spec)
195197
plugin = loader.load(context)
196198
if plugin is None:
197199
continue
@@ -601,3 +603,22 @@ def AddRunsFromDirectory(self, path, name=None):
601603
def Reload(self):
602604
"""Unsupported."""
603605
raise NotImplementedError()
606+
607+
608+
def make_plugin_loader(plugin_spec):
609+
"""Returns a plugin loader for the given plugin.
610+
611+
Args:
612+
plugin_spec: A TBPlugin subclass, or a TBLoader instance or subclass.
613+
614+
Returns:
615+
A TBLoader for the given plugin.
616+
"""
617+
if isinstance(plugin_spec, base_plugin.TBLoader):
618+
return plugin_spec
619+
if isinstance(plugin_spec, type):
620+
if issubclass(plugin_spec, base_plugin.TBLoader):
621+
return plugin_spec()
622+
if issubclass(plugin_spec, base_plugin.TBPlugin):
623+
return base_plugin.BasicLoader(plugin_spec)
624+
raise TypeError("Not a TBLoader or TBPlugin subclass: %r" % (plugin_spec,))

tensorboard/backend/application_test.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ def frontend_metadata(self):
135135
)
136136

137137

138+
class FakePluginLoader(base_plugin.TBLoader):
139+
"""Pass-through loader for FakePlugin with arbitrary arguments."""
140+
141+
def __init__(self, **kwargs):
142+
self._kwargs = kwargs
143+
144+
def load(self, context):
145+
return FakePlugin(context, **self._kwargs)
146+
147+
138148
class ApplicationTest(tb_test.TestCase):
139149
def setUp(self):
140150
plugins = [
@@ -357,6 +367,26 @@ def testSlashlessRoute(self):
357367
application.TensorBoardWSGI([self._make_plugin('runaway')])
358368

359369

370+
class MakePluginLoaderTest(tb_test.TestCase):
371+
372+
def testMakePluginLoader_pluginClass(self):
373+
loader = application.make_plugin_loader(FakePlugin)
374+
self.assertIsInstance(loader, base_plugin.BasicLoader)
375+
self.assertIs(loader.plugin_class, FakePlugin)
376+
377+
def testMakePluginLoader_pluginLoaderClass(self):
378+
loader = application.make_plugin_loader(FakePluginLoader)
379+
self.assertIsInstance(loader, FakePluginLoader)
380+
381+
def testMakePluginLoader_pluginLoader(self):
382+
loader = FakePluginLoader()
383+
self.assertIs(loader, application.make_plugin_loader(loader))
384+
385+
def testMakePluginLoader_invalidType(self):
386+
with six.assertRaisesRegex(self, TypeError, 'FakePlugin'):
387+
application.make_plugin_loader(FakePlugin())
388+
389+
360390
class GetEventFileActiveFilterTest(tb_test.TestCase):
361391

362392
def testDisabled(self):
@@ -519,23 +549,21 @@ def setUp(self):
519549
self.app = application.standard_tensorboard_wsgi(
520550
FakeFlags(logdir=self.get_temp_dir()),
521551
[
522-
base_plugin.BasicLoader(functools.partial(
523-
FakePlugin,
524-
plugin_name='foo',
525-
is_active_value=True,
526-
routes_mapping={'/foo_route': self._foo_handler},
527-
construction_callback=self._construction_callback)),
528-
base_plugin.BasicLoader(functools.partial(
529-
FakePlugin,
530-
plugin_name='bar',
531-
is_active_value=True,
532-
routes_mapping={
533-
'/bar_route': self._bar_handler,
534-
'/wildcard/*': self._wildcard_handler,
535-
'/wildcard/special/*': self._wildcard_special_handler,
536-
'/wildcard/special/exact': self._foo_handler,
537-
},
538-
construction_callback=self._construction_callback)),
552+
FakePluginLoader(
553+
plugin_name='foo',
554+
is_active_value=True,
555+
routes_mapping={'/foo_route': self._foo_handler},
556+
construction_callback=self._construction_callback),
557+
FakePluginLoader(
558+
plugin_name='bar',
559+
is_active_value=True,
560+
routes_mapping={
561+
'/bar_route': self._bar_handler,
562+
'/wildcard/*': self._wildcard_handler,
563+
'/wildcard/special/*': self._wildcard_special_handler,
564+
'/wildcard/special/exact': self._foo_handler,
565+
},
566+
construction_callback=self._construction_callback),
539567
],
540568
dummy_assets_zip_provider)
541569

tensorboard/default.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,20 @@
6060
# Ordering matters. The order in which these lines appear determines the
6161
# ordering of tabs in TensorBoard's GUI.
6262
_PLUGINS = [
63-
core_plugin.CorePluginLoader(),
63+
core_plugin.CorePluginLoader,
6464
scalars_plugin.ScalarsPlugin,
6565
custom_scalars_plugin.CustomScalarsPlugin,
6666
images_plugin.ImagesPlugin,
6767
audio_plugin.AudioPlugin,
68-
debugger_plugin_loader.DebuggerPluginLoader(),
68+
debugger_plugin_loader.DebuggerPluginLoader,
6969
graphs_plugin.GraphsPlugin,
7070
distributions_plugin.DistributionsPlugin,
7171
histograms_plugin.HistogramsPlugin,
7272
text_plugin.TextPlugin,
7373
pr_curves_plugin.PrCurvesPlugin,
74-
profile_plugin_loader.ProfilePluginLoader(),
75-
beholder_plugin_loader.BeholderPluginLoader(),
76-
interactive_inference_plugin_loader.InteractiveInferencePluginLoader(),
74+
profile_plugin_loader.ProfilePluginLoader,
75+
beholder_plugin_loader.BeholderPluginLoader,
76+
interactive_inference_plugin_loader.InteractiveInferencePluginLoader,
7777
hparams_plugin.HParamsPlugin,
7878
mesh_plugin.MeshPlugin,
7979
]

tensorboard/plugins/base_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,10 @@ def __init__(self, plugin_class):
257257
258258
:param plugin_class: :class:`TBPlugin`
259259
"""
260-
self._plugin_class = plugin_class
260+
self.plugin_class = plugin_class
261261

262262
def load(self, context):
263-
return self._plugin_class(context)
263+
return self.plugin_class(context)
264264

265265

266266
class FlagsError(ValueError):

tensorboard/program.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,13 @@ def __init__(self,
118118
"""Creates new instance.
119119
120120
Args:
121-
plugins: A list of TensorBoard plugins to load, as TBLoader instances or
122-
TBPlugin classes. If not specified, defaults to first-party plugins.
121+
plugin: A list of TensorBoard plugins to load, as TBPlugin classes or
122+
TBLoader instances or classes. If not specified, defaults to first-party
123+
plugins.
123124
assets_zip_provider: Delegates to TBContext or uses default if None.
124125
server_class: An optional factory for a `TensorBoardServer` to use
125126
for serving the TensorBoard WSGI app. If provided, its callable
126127
signature should match that of `TensorBoardServer.__init__`.
127-
128-
:type plugins: list[Union[base_plugin.TBLoader, Type[base_plugin.TBPlugin]]]
129-
:type assets_zip_provider: () -> file
130-
:type server_class: class
131128
"""
132129
if plugins is None:
133130
from tensorboard import default
@@ -136,13 +133,7 @@ def __init__(self,
136133
assets_zip_provider = get_default_assets_zip_provider()
137134
if server_class is None:
138135
server_class = create_port_scanning_werkzeug_server
139-
def make_loader(plugin):
140-
if isinstance(plugin, base_plugin.TBLoader):
141-
return plugin
142-
if issubclass(plugin, base_plugin.TBPlugin):
143-
return base_plugin.BasicLoader(plugin)
144-
raise ValueError("Not a TBLoader or TBPlugin subclass: %s" % plugin)
145-
self.plugin_loaders = [make_loader(p) for p in plugins]
136+
self.plugin_loaders = [application.make_plugin_loader(p) for p in plugins]
146137
self.assets_zip_provider = assets_zip_provider
147138
self.server_class = server_class
148139
self.flags = None

tensorboard/program_test.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,39 @@
2424

2525
from tensorboard import program
2626
from tensorboard import test as tb_test
27+
from tensorboard.plugins import base_plugin
2728
from tensorboard.plugins.core import core_plugin
2829

2930

3031
class TensorBoardTest(tb_test.TestCase):
3132
"""Tests the TensorBoard program."""
3233

34+
def testPlugins_pluginClass(self):
35+
tb = program.TensorBoard(plugins=[core_plugin.CorePlugin])
36+
self.assertIsInstance(tb.plugin_loaders[0], base_plugin.BasicLoader)
37+
self.assertIs(tb.plugin_loaders[0].plugin_class, core_plugin.CorePlugin)
38+
39+
def testPlugins_pluginLoaderClass(self):
40+
tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader])
41+
self.assertIsInstance(tb.plugin_loaders[0], core_plugin.CorePluginLoader)
42+
43+
def testPlugins_pluginLoader(self):
44+
loader = core_plugin.CorePluginLoader()
45+
tb = program.TensorBoard(plugins=[loader])
46+
self.assertIs(tb.plugin_loaders[0], loader)
47+
48+
def testPlugins_invalidType(self):
49+
plugin_instance = core_plugin.CorePlugin(base_plugin.TBContext())
50+
with six.assertRaisesRegex(self, TypeError, 'CorePlugin'):
51+
tb = program.TensorBoard(plugins=[plugin_instance])
52+
3353
def testConfigure(self):
34-
# Many useful flags are defined under the core plugin.
35-
tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader()])
54+
tb = program.TensorBoard(plugins=[core_plugin.CorePluginLoader])
3655
tb.configure(logdir='foo')
37-
self.assertStartsWith(tb.flags.logdir, 'foo')
56+
self.assertEqual(tb.flags.logdir, 'foo')
3857

58+
def testConfigure_unknownFlag(self):
59+
tb = program.TensorBoard(plugins=[core_plugin.CorePlugin])
3960
with six.assertRaisesRegex(self, ValueError, 'Unknown TensorBoard flag'):
4061
tb.configure(foo='bar')
4162

0 commit comments

Comments
 (0)