Skip to content

Commit f145659

Browse files
authored
Reincarnate TensorBoardWSGIApp as new DataProvider-friendly entry point (#2576)
1 parent 0de506e commit f145659

File tree

4 files changed

+134
-42
lines changed

4 files changed

+134
-42
lines changed

tensorboard/backend/application.py

Lines changed: 106 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# ==============================================================================
1515
"""TensorBoard WSGI Application Logic.
1616
17-
TensorBoardApplication constructs TensorBoard as a WSGI application.
18-
It handles serving static assets, and implements TensorBoard data APIs.
17+
Provides TensorBoardWSGIApp for building a TensorBoard WSGI app.
1918
"""
2019

2120
from __future__ import absolute_import
@@ -107,64 +106,98 @@ def standard_tensorboard_wsgi(flags, plugin_loaders, assets_zip_provider):
107106
:type plugin_loaders: list[base_plugin.TBLoader]
108107
:rtype: TensorBoardWSGI
109108
"""
110-
event_file_active_filter = _get_event_file_active_filter(flags)
111-
multiplexer = event_multiplexer.EventMultiplexer(
112-
size_guidance=DEFAULT_SIZE_GUIDANCE,
113-
tensor_size_guidance=tensor_size_guidance_from_flags(flags),
114-
purge_orphaned_data=flags.purge_orphaned_data,
115-
max_reload_threads=flags.max_reload_threads,
116-
event_file_active_filter=event_file_active_filter)
117-
if flags.generic_data == 'false':
118-
data_provider = None
119-
else:
120-
data_provider = event_data_provider.MultiplexerDataProvider(multiplexer)
121-
loading_multiplexer = multiplexer
109+
data_provider = None
110+
multiplexer = None
122111
reload_interval = flags.reload_interval
123-
db_uri = flags.db
124-
db_connection_provider = None
125-
# For DB import mode, create a DB file if we weren't given one.
126-
if flags.db_import and not flags.db:
127-
tmpdir = tempfile.mkdtemp(prefix='tbimport')
128-
atexit.register(shutil.rmtree, tmpdir)
129-
db_uri = 'sqlite:%s/tmp.sqlite' % tmpdir
130112
if flags.db_import:
131113
# DB import mode.
132-
logger.info('Importing logdir into DB at %s', db_uri)
114+
db_uri = flags.db
115+
# Create a temporary DB file if we weren't given one.
116+
if not db_uri:
117+
tmpdir = tempfile.mkdtemp(prefix='tbimport')
118+
atexit.register(shutil.rmtree, tmpdir)
119+
db_uri = 'sqlite:%s/tmp.sqlite' % tmpdir
133120
db_connection_provider = create_sqlite_connection_provider(db_uri)
134-
loading_multiplexer = db_import_multiplexer.DbImportMultiplexer(
121+
logger.info('Importing logdir into DB at %s', db_uri)
122+
multiplexer = db_import_multiplexer.DbImportMultiplexer(
123+
db_uri=db_uri,
135124
db_connection_provider=db_connection_provider,
136125
purge_orphaned_data=flags.purge_orphaned_data,
137126
max_reload_threads=flags.max_reload_threads)
138127
elif flags.db:
139128
# DB read-only mode, never load event logs.
140129
reload_interval = -1
141-
db_connection_provider = create_sqlite_connection_provider(db_uri)
130+
db_connection_provider = create_sqlite_connection_provider(flags.db)
131+
multiplexer = _DbModeMultiplexer(flags.db, db_connection_provider)
132+
else:
133+
# Regular logdir loading mode.
134+
multiplexer = event_multiplexer.EventMultiplexer(
135+
size_guidance=DEFAULT_SIZE_GUIDANCE,
136+
tensor_size_guidance=tensor_size_guidance_from_flags(flags),
137+
purge_orphaned_data=flags.purge_orphaned_data,
138+
max_reload_threads=flags.max_reload_threads,
139+
event_file_active_filter=_get_event_file_active_filter(flags))
140+
if flags.generic_data != 'false':
141+
data_provider = event_data_provider.MultiplexerDataProvider(multiplexer)
142+
143+
if reload_interval >= 0:
144+
# We either reload the multiplexer once when TensorBoard starts up, or we
145+
# continuously reload the multiplexer.
146+
path_to_run = parse_event_files_spec(flags.logdir)
147+
start_reloading_multiplexer(
148+
multiplexer, path_to_run, reload_interval, flags.reload_task)
149+
return TensorBoardWSGIApp(
150+
flags, plugin_loaders, data_provider, assets_zip_provider, multiplexer)
151+
152+
153+
def TensorBoardWSGIApp(
154+
flags,
155+
plugins,
156+
data_provider=None,
157+
assets_zip_provider=None,
158+
deprecated_multiplexer=None):
159+
"""Constructs a TensorBoard WSGI app from plugins and data providers.
160+
161+
Args:
162+
flags: An argparse.Namespace containing TensorBoard CLI flags.
163+
plugins: A list of TBLoader subclasses for the plugins to load.
164+
assets_zip_provider: See TBContext documentation for more information.
165+
data_provider: Instance of `tensorboard.data.provider.DataProvider`. May
166+
be `None` if `flags.generic_data` is set to `"false"` in which case
167+
`deprecated_multiplexer` must be passed instead.
168+
deprecated_multiplexer: Optional `plugin_event_multiplexer.EventMultiplexer`
169+
to use for any plugins not yet enabled for the DataProvider API.
170+
Required if the data_provider argument is not passed.
171+
172+
Returns:
173+
A WSGI application that implements the TensorBoard backend.
174+
"""
175+
db_uri = None
176+
db_connection_provider = None
177+
if isinstance(
178+
deprecated_multiplexer,
179+
(db_import_multiplexer.DbImportMultiplexer, _DbModeMultiplexer)):
180+
db_uri = deprecated_multiplexer.db_uri
181+
db_connection_provider = deprecated_multiplexer.db_connection_provider
142182
plugin_name_to_instance = {}
143183
context = base_plugin.TBContext(
144184
data_provider=data_provider,
145185
db_connection_provider=db_connection_provider,
146186
db_uri=db_uri,
147187
flags=flags,
148188
logdir=flags.logdir,
149-
multiplexer=multiplexer,
189+
multiplexer=deprecated_multiplexer,
150190
assets_zip_provider=assets_zip_provider,
151191
plugin_name_to_instance=plugin_name_to_instance,
152192
window_title=flags.window_title)
153-
plugins = []
154-
for loader in plugin_loaders:
193+
tbplugins = []
194+
for loader in plugins:
155195
plugin = loader.load(context)
156196
if plugin is None:
157197
continue
158-
plugins.append(plugin)
198+
tbplugins.append(plugin)
159199
plugin_name_to_instance[plugin.plugin_name] = plugin
160-
161-
if reload_interval >= 0:
162-
# We either reload the multiplexer once when TensorBoard starts up, or we
163-
# continuously reload the multiplexer.
164-
path_to_run = parse_event_files_spec(flags.logdir)
165-
start_reloading_multiplexer(
166-
loading_multiplexer, path_to_run, reload_interval, flags.reload_task)
167-
return TensorBoardWSGI(plugins, flags.path_prefix)
200+
return TensorBoardWSGI(tbplugins, flags.path_prefix)
168201

169202

170203
class TensorBoardWSGI(object):
@@ -531,3 +564,40 @@ def _get_event_file_active_filter(flags):
531564
if inactive_secs < 0:
532565
return lambda timestamp: True
533566
return lambda timestamp: timestamp + inactive_secs >= time.time()
567+
568+
569+
class _DbModeMultiplexer(event_multiplexer.EventMultiplexer):
570+
"""Shim EventMultiplexer to use when in read-only DB mode.
571+
572+
In read-only DB mode, the EventMultiplexer is nonfunctional - there is no
573+
logdir to reload, and the data is all exposed via SQL. This class represents
574+
the do-nothing EventMultiplexer for that purpose, which serves only as a
575+
conduit for DB-related parameters.
576+
577+
The load APIs raise exceptions if called, and the read APIs always
578+
return empty results.
579+
"""
580+
def __init__(self, db_uri, db_connection_provider):
581+
"""Constructor for `_DbModeMultiplexer`.
582+
583+
Args:
584+
db_uri: A URI to the database file in use.
585+
db_connection_provider: Provider function for creating a DB connection.
586+
"""
587+
logger.info('_DbModeMultiplexer initializing for %s', db_uri)
588+
super(_DbModeMultiplexer, self).__init__()
589+
self.db_uri = db_uri
590+
self.db_connection_provider = db_connection_provider
591+
logger.info('_DbModeMultiplexer done initializing')
592+
593+
def AddRun(self, path, name=None):
594+
"""Unsupported."""
595+
raise NotImplementedError()
596+
597+
def AddRunsFromDirectory(self, path, name=None):
598+
"""Unsupported."""
599+
raise NotImplementedError()
600+
601+
def Reload(self):
602+
"""Unsupported."""
603+
raise NotImplementedError()

tensorboard/backend/event_processing/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ py_library(
262262
deps = [
263263
":directory_watcher",
264264
":event_file_loader",
265+
":event_multiplexer",
265266
":io_wrapper",
266267
":sqlite_writer",
267268
"//tensorboard:data_compat",

tensorboard/backend/event_processing/db_import_multiplexer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorboard.backend.event_processing import directory_watcher
3232
from tensorboard.backend.event_processing import event_file_loader
3333
from tensorboard.backend.event_processing import io_wrapper
34+
from tensorboard.backend.event_processing import plugin_event_multiplexer
3435
from tensorboard.backend.event_processing import sqlite_writer
3536
from tensorboard.compat import tf
3637
from tensorboard.compat.proto import event_pb2
@@ -39,28 +40,35 @@
3940

4041
logger = tb_logging.get_logger()
4142

42-
class DbImportMultiplexer(object):
43+
44+
class DbImportMultiplexer(plugin_event_multiplexer.EventMultiplexer):
4345
"""A loading-only `EventMultiplexer` that populates a SQLite DB.
4446
45-
This EventMultiplexer only loads data; it provides no read APIs.
47+
This EventMultiplexer only loads data; the read APIs always return empty
48+
results, since all data is accessed instead via SQL against the
49+
db_connection_provider wrapped by this multiplexer.
4650
"""
4751

4852
def __init__(self,
53+
db_uri,
4954
db_connection_provider,
5055
purge_orphaned_data,
5156
max_reload_threads):
5257
"""Constructor for `DbImportMultiplexer`.
5358
5459
Args:
60+
db_uri: A URI to the database file in use.
5561
db_connection_provider: Provider function for creating a DB connection.
5662
purge_orphaned_data: Whether to discard any events that were "orphaned" by
5763
a TensorFlow restart.
5864
max_reload_threads: The max number of threads that TensorBoard can use
5965
to reload runs. Each thread reloads one run at a time. If not provided,
6066
reloads runs serially (one after another).
6167
"""
62-
logger.info('DbImportMultiplexer initializing')
63-
self._db_connection_provider = db_connection_provider
68+
logger.info('DbImportMultiplexer initializing for %s', db_uri)
69+
super(DbImportMultiplexer, self).__init__()
70+
self.db_uri = db_uri
71+
self.db_connection_provider = db_connection_provider
6472
self._purge_orphaned_data = purge_orphaned_data
6573
self._max_reload_threads = max_reload_threads
6674
self._event_sink = None
@@ -70,13 +78,17 @@ def __init__(self,
7078
logger.warn(
7179
'--db_import does not yet support purging orphaned data')
7280

73-
conn = self._db_connection_provider()
81+
conn = self.db_connection_provider()
7482
# Set the DB in WAL mode so reads don't block writes.
7583
conn.execute('PRAGMA journal_mode=wal')
7684
conn.execute('PRAGMA synchronous=normal') # Recommended for WAL mode
7785
sqlite_writer.initialize_schema(conn)
7886
logger.info('DbImportMultiplexer done initializing')
7987

88+
def AddRun(self, path, name=None):
89+
"""Unsupported; instead use AddRunsFromDirectory."""
90+
raise NotImplementedError("Unsupported; use AddRunsFromDirectory()")
91+
8092
def AddRunsFromDirectory(self, path, name=None):
8193
"""Load runs from a directory; recursively walks subdirectories.
8294
@@ -111,7 +123,7 @@ def Reload(self):
111123
# Defer event sink creation until needed; this ensures it will only exist in
112124
# the thread that calls Reload(), since DB connections must be thread-local.
113125
if not self._event_sink:
114-
self._event_sink = _SqliteWriterEventSink(self._db_connection_provider)
126+
self._event_sink = _SqliteWriterEventSink(self.db_connection_provider)
115127
# Use collections.deque() for speed when we don't need blocking since it
116128
# also has thread-safe appends/pops.
117129
loader_queue = collections.deque(six.itervalues(self._run_loaders))

tensorboard/backend/event_processing/db_import_multiplexer_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def setUp(self):
4545
db_file_name = os.path.join(self.get_temp_dir(), 'db')
4646
self.db_connection_provider = lambda: sqlite3.connect(db_file_name)
4747
self.multiplexer = db_import_multiplexer.DbImportMultiplexer(
48+
db_uri='sqlite:' + db_file_name,
4849
db_connection_provider=self.db_connection_provider,
4950
purge_orphaned_data=False,
5051
max_reload_threads=1)
@@ -150,6 +151,14 @@ def test_manual_name(self):
150151
self.assertEqual(self._get_runs(), [os.path.join('some', 'nested', 'name'),
151152
os.path.join('some', 'nested', 'name')])
152153

154+
def test_empty_read_apis(self):
155+
path = self.get_temp_dir()
156+
add_event(path)
157+
self.assertEmpty(self.multiplexer.Runs())
158+
self.multiplexer.AddRunsFromDirectory(path)
159+
self.multiplexer.Reload()
160+
self.assertEmpty(self.multiplexer.Runs())
161+
153162

154163
if __name__ == '__main__':
155164
tf.test.main()

0 commit comments

Comments
 (0)