Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Unreleased

- [added] The `db.Reference` type now provides a `listen()` API for
receiving realtime update events from the Firebase Database.
- [added] The `db.reference()` method now optionally takes a `url`
parameter. This can be used to access multiple Firebase Databases
in the same project more easily.
Expand Down
141 changes: 58 additions & 83 deletions firebase_admin/_sseclient.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -10,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""SSEClient module to handle streaming of realtime changes on the database
to the firebase-admin-sdk
"""
"""SSEClient module to stream realtime updates in the Firebase Database."""

import re
import time
import warnings
import six

from google.auth import transport
import requests


Expand All @@ -26,80 +27,63 @@
end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n')


class KeepAuthSession(requests.Session):
"""A session that does not drop Authentication on redirects between domains"""
class KeepAuthSession(transport.requests.AuthorizedSession):
"""A session that does not drop authentication on redirects between domains."""

def __init__(self, credential):
super(KeepAuthSession, self).__init__(credential)

def rebuild_auth(self, prepared_request, response):
pass


class SSEClient(object):
"""SSE Client Class"""
"""SSE client implementation."""

def __init__(self, url, session, retry=3000, **kwargs):
"""Initializes the SSEClient.

def __init__(self, url, session, last_id=None, retry=3000, **kwargs):
"""Initialize the SSEClient
Args:
url: the url to connect to
session: the requests.session()
last_id: optional id
retry: the interval in ms
**kwargs: extra kwargs will be sent to requests.get
url: The remote url to connect to.
session: The requests session.
retry: The retry interval in milliseconds (optional).
**kwargs: Extra kwargs that will be sent to ``requests.get()`` (optional).
"""
self.should_connect = True
self.url = url
self.last_id = last_id
self.retry = retry
self.session = session
self.retry = retry
self.requests_kwargs = kwargs
self.should_connect = True
self.last_id = None
self.buf = u'' # Keep data here as it streams in

headers = self.requests_kwargs.get('headers', {})
# The SSE spec requires making requests with Cache-Control: nocache
headers['Cache-Control'] = 'no-cache'
# The 'Accept' header is not required, but explicit > implicit
headers['Accept'] = 'text/event-stream'

self.requests_kwargs['headers'] = headers

# Keep data here as it streams in
self.buf = u''

self._connect()

def close(self):
"""Close the SSE Client instance"""
# TODO: check if AttributeError is needed to catch here
"""Closes the SSEClient instance."""
self.should_connect = False
self.retry = 0
self.resp.close()
# self.resp.raw._fp.fp.raw._sock.shutdown(socket.SHUT_RDWR)
# self.resp.raw._fp.fp.raw._sock.close()


def _connect(self):
"""connects to the server using requests"""
"""Connects to the server using requests."""
if self.should_connect:
success = False
while not success:
if self.last_id:
self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id
# Use session if set. Otherwise fall back to requests module.
self.requester = self.session or requests
self.resp = self.requester.get(self.url, stream=True, **self.requests_kwargs)

self.resp_iterator = self.resp.iter_content(decode_unicode=True)

# TODO: Ensure we're handling redirects. Might also stick the 'origin'
# attribute on Events like the Javascript spec requires.
self.resp.raise_for_status()
success = True
if self.last_id:
self.requests_kwargs['headers']['Last-Event-ID'] = self.last_id
self.resp = self.session.get(self.url, stream=True, **self.requests_kwargs)
self.resp_iterator = self.resp.iter_content(decode_unicode=True)
self.resp.raise_for_status()
else:
raise StopIteration()

def _event_complete(self):
"""Checks if the event is completed by matching regular expression

Returns:
boolean: True if the regex matched meaning end of event, else False
"""
"""Checks if the event is completed by matching regular expression."""
return re.search(end_of_field, self.buf) is not None

def __iter__(self):
Expand All @@ -113,8 +97,6 @@ def __next__(self):
except (StopIteration, requests.RequestException):
time.sleep(self.retry / 1000.0)
self._connect()


# The SSE spec only supports resuming from a whole message, so
# if we have half a message we should throw it out.
head, sep, tail = self.buf.rpartition('\n')
Expand All @@ -123,56 +105,54 @@ def __next__(self):

split = re.split(end_of_field, self.buf)
head = split[0]
tail = "".join(split[1:])
tail = ''.join(split[1:])

self.buf = tail
msg = Event.parse(head)
event = Event.parse(head)

if msg.data == "credential is no longer valid":
if event.data == 'credential is no longer valid':
self._connect()
return None

if msg.data == 'null':
elif event.data == 'null':
return None

# If the server requests a specific retry delay, we need to honor it.
if msg.retry:
self.retry = msg.retry
if event.retry:
self.retry = event.retry

# last_id should only be set if included in the message. It's not
# forgotten if a message omits it.
if msg.event_id:
self.last_id = msg.event_id

return msg
if event.event_id:
self.last_id = event.event_id
return event

if six.PY2:
next = __next__
def next(self):
return self.__next__()


class Event(object):
"""Event class to handle the events fired by SSE"""
"""Event represents the events fired by SSE."""

sse_line_pattern = re.compile('(?P<name>[^:]*):?( ?(?P<value>.*))?')

def __init__(self, data='', event='message', event_id=None, retry=None):
def __init__(self, data='', event_type='message', event_id=None, retry=None):
self.data = data
self.event = event
self.event_type = event_type
self.event_id = event_id
self.retry = retry

@classmethod
def parse(cls, raw):
"""Given a possibly-multiline string representing an SSE message, parse it
and return a Event object.
"""Given a possibly-multiline string representing an SSE message, parses it
and returns an Event object.

Args:
raw: the raw data to parse
raw: the raw data to parse.

Returns:
Event: newly intialized Event() object with the parameters initialized
Event: newly intialized ``Event`` object with the parameters initialized.
"""
msg = cls()
event = cls()
for line in raw.split('\n'):
match = cls.sse_line_pattern.match(line)
if match is None:
Expand All @@ -185,22 +165,17 @@ def parse(cls, raw):
if name == '':
# line began with a ":", so is a comment. Ignore
continue

if name == 'data':
elif name == 'data':
# If we already have some data, then join to it with a newline.
# Else this is it.
if msg.data:
msg.data = '%s\n%s' % (msg.data, value)
if event.data:
event.data = '%s\n%s' % (event.data, value)
else:
msg.data = value
event.data = value
elif name == 'event':
msg.event = value
event.event_type = value
elif name == 'id':
msg.event_id = value
event.event_id = value
elif name == 'retry':
msg.retry = int(value)

return msg

def __str__(self):
return self.data
event.retry = int(value)
return event
Loading