Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions lightning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __init__(self, host="http://localhost:3000", ipython=False, dbcloud=False, a

if ipython:
self.enable_ipython()
else:
self.ipython_enabled = False

if dbcloud:
self.enable_dbcloud()
Expand All @@ -26,6 +28,11 @@ def __repr__(self):
else:
return 'Lightning server at host: %s' % self.host


def get_ipython_markup_link(self):
return '%s/js/ipython-comm.js' % (self.host)


def enable_ipython(self, **kwargs):
"""
Enable plotting in the iPython notebook.
Expand All @@ -39,10 +46,19 @@ def enable_ipython(self, **kwargs):
# https://github.com/jakevdp/mpld3/blob/master/mpld3/_display.py#L357

from IPython.core.getipython import get_ipython
from IPython.display import display, HTML

self.ipython_enabled = True
ip = get_ipython()
formatter = ip.display_formatter.formatters['text/html']
formatter.for_type(Visualization, lambda viz, kwds=kwargs: viz.get_html())

r = requests.get(self.get_ipython_markup_link(), auth=self.auth)
ipython_comm_markup = '<script>' + r.text + '</script>'

display(HTML(ipython_comm_markup))


def disable_ipython(self):
"""
Disable plotting in the iPython notebook.
Expand All @@ -51,6 +67,8 @@ def disable_ipython(self):
but will not appear in the notebook.
"""
from IPython.core.getipython import get_ipython

self.ipython_enabled = False
ip = get_ipython()
formatter = ip.display_formatter.formatters['text/html']
formatter.type_printers.pop(Visualization, None)
Expand Down Expand Up @@ -82,7 +100,7 @@ def create_session(self, name=None):
Can create a session with the provided name, otherwise session name
will be "Session No." with the number automatically generated.
"""
self.session = Session.create(self.host, name=name, auth=self.auth)
self.session = Session.create(self, name=name)
return self.session

def use_session(self, session_id):
Expand All @@ -92,7 +110,7 @@ def use_session(self, session_id):
Specify a lightning session by id number. Check the number of an existing
session in the attribute lightning.session.id.
"""
self.session = Session(host=self.host, id=session_id, auth=self.auth)
self.session = Session(lgn=self, id=session_id)
return self.session

def set_basic_auth(self, username, password):
Expand Down
15 changes: 8 additions & 7 deletions lightning/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ class Session(object):
name = None
visualizations = []

def __init__(self, host=None, id=None, json=None, auth=None):
self.host = host
def __init__(self, lgn=None, id=None, json=None):
self.lgn = lgn
self.host = lgn.host
self.auth = lgn.auth
self.id = id
self.auth = auth

if json:
self.id = json.get('id')
Expand All @@ -36,14 +37,14 @@ def open(self):
webbrowser.open(self.host + '/sessions/' + str(self.id) + '/feed/')

@classmethod
def create(cls, host, name=None, auth=None):
url = host + '/sessions/'
def create(cls, lgn, name=None):
url = lgn.host + '/sessions/'

payload = {}
if name:
payload = {'name': name}

headers = {'Content-type': 'application/json', 'Accept': 'text/plain'}

r = requests.post(url, data=json.dumps(payload), headers=headers, auth=auth)
return cls(host=host, json=r.json(), auth=auth)
r = requests.post(url, data=json.dumps(payload), headers=headers, auth=lgn.auth)
return cls(lgn=lgn, json=r.json())
25 changes: 25 additions & 0 deletions lightning/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ def __init__(self, session=None, json=None, auth=None):
self.id = json.get('id')
self.auth = auth

if self.session.lgn.ipython_enabled:
from IPython.kernel.comm import Comm
self.comm = Comm('lightning', {'id': self.id})
self.comm_handlers = {}
self.comm.on_msg(self._handle_comm_message)


def _format_url(self, url):
if not url.endswith('/'):
url += '/'
Expand Down Expand Up @@ -68,6 +75,24 @@ def delete(self):
url = self.get_permalink()
return requests.delete(url)


def on(self, event_name, handler):

if self.session.lgn.ipython_enabled:
self.comm_handlers[event_name] = handler

else:
raise Exception('The current implementation of this method is only compatible with IPython.')


def _handle_comm_message(self, message):
# Parsing logic taken from similar code in matplotlib
message = json.loads(message['content']['data'])

if message['type'] in self.comm_handlers:
self.comm_handlers[message['type']](message['data'])


@classmethod
def create(cls, session=None, data=None, images=None, type=None, options=None):

Expand Down