Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
language: python
cache: pip
python:
- "2.7"
- "3.4"
- "3.5"
- "3.6"
Expand Down
79 changes: 76 additions & 3 deletions jupyter_tensorboard/tensorboard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import threading
import time
import inspect
import itertools
from collections import namedtuple
import logging
Expand Down Expand Up @@ -108,8 +109,77 @@ def _ReloadForever():
return thread


def TensorBoardWSGIApp(logdir, plugins, multiplexer,
reload_interval, path_prefix="", reload_task="auto"):
def is_tensorboard_greater_than_or_equal_to20():
# tensorflow<1.4 will be
# (logdir, plugins, multiplexer, reload_interval)

# tensorflow>=1.4, <1.12 will be
# (logdir, plugins, multiplexer, reload_interval, path_prefix)

# tensorflow>=1.12, <1.14 will be
# (logdir, plugins, multiplexer, reload_interval,
# path_prefix='', reload_task='auto')

# tensorflow 2.0 will be
# (flags, plugins, data_provider=None, assets_zip_provider=None,
# deprecated_multiplexer=None)

s = inspect.signature(application.TensorBoardWSGIApp)
first_parameter_name = list(s.parameters.keys())[0]
return first_parameter_name == 'flags'


def TensorBoardWSGIApp_2x(
flags, plugins,
data_provider=None,
assets_zip_provider=None,
deprecated_multiplexer=None):

logdir = flags.logdir
multiplexer = deprecated_multiplexer
reload_interval = flags.reload_interval

path_to_run = application.parse_event_files_spec(logdir)
if reload_interval:
thread = start_reloading_multiplexer(
multiplexer, path_to_run, reload_interval)
else:
application.reload_multiplexer(multiplexer, path_to_run)
thread = None

db_uri = None
db_connection_provider = None

plugin_name_to_instance = {}

from tensorboard.plugins import base_plugin
context = base_plugin.TBContext(
data_provider=data_provider,
db_connection_provider=db_connection_provider,
db_uri=db_uri,
flags=flags,
logdir=flags.logdir,
multiplexer=deprecated_multiplexer,
assets_zip_provider=assets_zip_provider,
plugin_name_to_instance=plugin_name_to_instance,
window_title=flags.window_title)

tbplugins = []
for loader in plugins:
plugin = loader.load(context)
if plugin is None:
continue
tbplugins.append(plugin)
plugin_name_to_instance[plugin.plugin_name] = plugin

tb_app = application.TensorBoardWSGI(tbplugins)
manager.add_instance(logdir, tb_app, thread)
return tb_app


def TensorBoardWSGIApp_1x(
logdir, plugins, multiplexer,
reload_interval, path_prefix="", reload_task="auto"):
path_to_run = application.parse_event_files_spec(logdir)
if reload_interval:
thread = start_reloading_multiplexer(
Expand All @@ -122,7 +192,10 @@ def TensorBoardWSGIApp(logdir, plugins, multiplexer,
return tb_app


application.TensorBoardWSGIApp = TensorBoardWSGIApp
if is_tensorboard_greater_than_or_equal_to20():
application.TensorBoardWSGIApp = TensorBoardWSGIApp_2x
else:
application.TensorBoardWSGIApp = TensorBoardWSGIApp_1x


class TensorboardManger(dict):
Expand Down
7 changes: 6 additions & 1 deletion tests/test_tensorboard_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
def tf_logs(tmpdir_factory):

import numpy as np
import tensorflow as tf
try:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf

x = np.random.rand(5)
y = 3 * x + 1 + 0.05 * np.random.rand(5)

Expand Down
6 changes: 4 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[tox]
envlist = {py27,py34,py35,py36}-tensorflow{13,14,15,16,17,18,19,110,111,112,113}
envlist = {py34,py35,py36}-tensorflow{13,14,15,16,17,18,19,110,111,112,113,200}


[testenv]
deps =
Expand All @@ -15,6 +16,7 @@ deps =
tensorflow111: tensorflow>=1.11, <1.12
tensorflow112: tensorflow>=1.12, <1.13
tensorflow113: tensorflow<=1.13, <1.14
tensorflow200: tensorflow<=2.0, <2.1

commands =
pytest
Expand All @@ -23,4 +25,4 @@ alwayscopy = True

[testenv:py36]
commands =
flake8
flake8