Skip to content

Commit c6cf374

Browse files
Asynchronous metadata fetching using celery beat - PR 8518
1 parent 39b56cd commit c6cf374

File tree

8 files changed

+204
-138
lines changed

8 files changed

+204
-138
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ logs
9191
chromedriver.log
9292
ghostdriver.log
9393

94+
### Celery artifacts ###
95+
celerybeat-schedule
96+
9497
### Unknown artifacts
9598
database.sqlite
9699
courseware/static/js/mathjax/*

common/djangoapps/third_party_auth/admin.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin
99
from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData
10+
from .tasks import fetch_saml_metadata
1011

1112
admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin)
1213

@@ -29,6 +30,17 @@ def has_data(self, inst):
2930
has_data.short_description = u'Metadata Ready'
3031
has_data.boolean = True
3132

33+
def save_model(self, request, obj, form, change):
34+
"""
35+
Post save: Queue an asynchronous metadata fetch to update SAMLProviderData.
36+
We only want to do this for manual edits done using the admin interface.
37+
38+
Note: This only works if the celery worker and the app worker are using the
39+
same 'configuration' cache.
40+
"""
41+
super(SAMLProviderConfigAdmin, self).save_model(request, obj, form, change)
42+
fetch_saml_metadata.apply_async((), countdown=2)
43+
3244
admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin)
3345

3446

@@ -54,7 +66,7 @@ def key_summary(self, inst):
5466

5567

5668
class SAMLProviderDataAdmin(admin.ModelAdmin):
57-
""" Django Admin class for SAMLProviderData """
69+
""" Django Admin class for SAMLProviderData (Read Only) """
5870
list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url')
5971
readonly_fields = ('is_valid', )
6072

common/djangoapps/third_party_auth/management/commands/saml.py

Lines changed: 15 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,10 @@
22
"""
33
Management commands for third_party_auth
44
"""
5-
import datetime
6-
import dateutil.parser
75
from django.core.management.base import BaseCommand, CommandError
8-
from lxml import etree
9-
import requests
10-
from onelogin.saml2.utils import OneLogin_Saml2_Utils
11-
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
12-
13-
#pylint: disable=superfluous-parens,no-member
14-
15-
16-
class MetadataParseError(Exception):
17-
""" An error occurred while parsing the SAML metadata from an IdP """
18-
pass
6+
import logging
7+
from third_party_auth.models import SAMLConfiguration
8+
from third_party_auth.tasks import fetch_saml_metadata
199

2010

2111
class Command(BaseCommand):
@@ -27,120 +17,21 @@ def handle(self, *args, **options):
2717
raise CommandError("saml requires one argument: pull")
2818

2919
if not SAMLConfiguration.is_enabled():
30-
self.stdout.write("Warning: SAML support is disabled via SAMLConfiguration.\n")
20+
raise CommandError("SAML support is disabled via SAMLConfiguration.")
3121

3222
subcommand = args[0]
3323

3424
if subcommand == "pull":
35-
self.cmd_pull()
25+
log_handler = logging.StreamHandler(self.stdout)
26+
log_handler.setLevel(logging.DEBUG)
27+
log = logging.getLogger('third_party_auth.tasks')
28+
log.propagate = False
29+
log.addHandler(log_handler)
30+
num_changed, num_failed, num_total = fetch_saml_metadata()
31+
self.stdout.write(
32+
"\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format(
33+
num_changed=num_changed, num_failed=num_failed, num_total=num_total
34+
)
35+
)
3636
else:
3737
raise CommandError("Unknown argment: {}".format(subcommand))
38-
39-
@staticmethod
40-
def tag_name(tag_name):
41-
""" Get the namespaced-qualified name for an XML tag """
42-
return '{urn:oasis:names:tc:SAML:2.0:metadata}' + tag_name
43-
44-
def cmd_pull(self):
45-
""" Fetch the metadata for each provider and update the DB """
46-
# First make a list of all the metadata XML URLs:
47-
url_map = {}
48-
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
49-
config = SAMLProviderConfig.current(idp_slug)
50-
if not config.enabled:
51-
continue
52-
url = config.metadata_source
53-
if url not in url_map:
54-
url_map[url] = []
55-
if config.entity_id not in url_map[url]:
56-
url_map[url].append(config.entity_id)
57-
# Now fetch the metadata:
58-
for url, entity_ids in url_map.items():
59-
try:
60-
self.stdout.write("\n→ Fetching {}\n".format(url))
61-
if not url.lower().startswith('https'):
62-
self.stdout.write("→ WARNING: This URL is not secure! It should use HTTPS.\n")
63-
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
64-
response.raise_for_status() # May raise an HTTPError
65-
66-
try:
67-
parser = etree.XMLParser(remove_comments=True)
68-
xml = etree.fromstring(response.text, parser)
69-
except etree.XMLSyntaxError:
70-
raise
71-
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
72-
73-
for entity_id in entity_ids:
74-
self.stdout.write("→ Processing IdP with entityID {}\n".format(entity_id))
75-
public_key, sso_url, expires_at = self._parse_metadata_xml(xml, entity_id)
76-
self._update_data(entity_id, public_key, sso_url, expires_at)
77-
except Exception as err: # pylint: disable=broad-except
78-
self.stderr.write(u"→ ERROR: {}\n\n".format(err.message))
79-
80-
@classmethod
81-
def _parse_metadata_xml(cls, xml, entity_id):
82-
"""
83-
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
84-
(public_key, sso_url, expires_at) for the specified entityID.
85-
86-
Raises MetadataParseError if anything is wrong.
87-
"""
88-
if xml.tag == cls.tag_name('EntityDescriptor'):
89-
entity_desc = xml
90-
else:
91-
if xml.tag != cls.tag_name('EntitiesDescriptor'):
92-
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
93-
entity_desc = xml.find(".//{}[@entityID='{}']".format(cls.tag_name('EntityDescriptor'), entity_id))
94-
if not entity_desc:
95-
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
96-
97-
expires_at = None
98-
if "validUntil" in xml.attrib:
99-
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
100-
if "cacheDuration" in xml.attrib:
101-
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
102-
if expires_at is None or cache_expires < expires_at:
103-
expires_at = cache_expires
104-
105-
sso_desc = entity_desc.find(cls.tag_name("IDPSSODescriptor"))
106-
if not sso_desc:
107-
raise MetadataParseError("IDPSSODescriptor missing")
108-
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
109-
raise MetadataParseError("This IdP does not support SAML 2.0")
110-
111-
# Now we just need to get the public_key and sso_url
112-
public_key = sso_desc.findtext("./{}//{}".format(
113-
cls.tag_name("KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
114-
))
115-
if not public_key:
116-
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
117-
public_key = public_key.replace(" ", "")
118-
binding_elements = sso_desc.iterfind("./{}".format(cls.tag_name("SingleSignOnService")))
119-
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
120-
try:
121-
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
122-
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
123-
except KeyError:
124-
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
125-
return public_key, sso_url, expires_at
126-
127-
def _update_data(self, entity_id, public_key, sso_url, expires_at):
128-
"""
129-
Update/Create the SAMLProviderData for the given entity ID.
130-
"""
131-
data_obj = SAMLProviderData.current(entity_id)
132-
fetched_at = datetime.datetime.now()
133-
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
134-
data_obj.expires_at = expires_at
135-
data_obj.fetched_at = fetched_at
136-
data_obj.save()
137-
self.stdout.write("→ Updated existing SAMLProviderData. Nothing has changed.\n")
138-
else:
139-
SAMLProviderData.objects.create(
140-
entity_id=entity_id,
141-
fetched_at=fetched_at,
142-
expires_at=expires_at,
143-
sso_url=sso_url,
144-
public_key=public_key,
145-
)
146-
self.stdout.write("→ Created new record for SAMLProviderData\n")
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Code to manage fetching and storing the metadata of IdPs.
4+
"""
5+
#pylint: disable=no-member
6+
from celery.task import task # pylint: disable=import-error,no-name-in-module
7+
import datetime
8+
import dateutil.parser
9+
import logging
10+
from lxml import etree
11+
import requests
12+
from onelogin.saml2.utils import OneLogin_Saml2_Utils
13+
from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData
14+
15+
log = logging.getLogger(__name__)
16+
17+
SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' # The SAML Metadata XML namespace
18+
19+
20+
class MetadataParseError(Exception):
21+
""" An error occurred while parsing the SAML metadata from an IdP """
22+
pass
23+
24+
25+
@task(name='third_party_auth.fetch_saml_metadata')
26+
def fetch_saml_metadata():
27+
"""
28+
Fetch and store/update the metadata of all IdPs
29+
30+
This task should be run on a daily basis.
31+
It's OK to run this whether or not SAML is enabled.
32+
33+
Return value:
34+
tuple(num_changed, num_failed, num_total)
35+
num_changed: Number of providers that are either new or whose metadata has changed
36+
num_failed: Number of providers that could not be updated
37+
num_total: Total number of providers whose metadata was fetched
38+
"""
39+
if not SAMLConfiguration.is_enabled():
40+
return (0, 0, 0) # Nothing to do until SAML is enabled.
41+
42+
num_changed, num_failed = 0, 0
43+
44+
# First make a list of all the metadata XML URLs:
45+
url_map = {}
46+
for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True):
47+
config = SAMLProviderConfig.current(idp_slug)
48+
if not config.enabled:
49+
continue
50+
url = config.metadata_source
51+
if url not in url_map:
52+
url_map[url] = []
53+
if config.entity_id not in url_map[url]:
54+
url_map[url].append(config.entity_id)
55+
# Now fetch the metadata:
56+
for url, entity_ids in url_map.items():
57+
try:
58+
log.info("Fetching %s", url)
59+
if not url.lower().startswith('https'):
60+
log.warning("This SAML metadata URL is not secure! It should use HTTPS. (%s)", url)
61+
response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError
62+
response.raise_for_status() # May raise an HTTPError
63+
64+
try:
65+
parser = etree.XMLParser(remove_comments=True)
66+
xml = etree.fromstring(response.text, parser)
67+
except etree.XMLSyntaxError:
68+
raise
69+
# TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that
70+
71+
for entity_id in entity_ids:
72+
log.info(u"Processing IdP with entityID %s", entity_id)
73+
public_key, sso_url, expires_at = _parse_metadata_xml(xml, entity_id)
74+
changed = _update_data(entity_id, public_key, sso_url, expires_at)
75+
if changed:
76+
log.info(u"→ Created new record for SAMLProviderData")
77+
num_changed += 1
78+
else:
79+
log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.")
80+
except Exception as err: # pylint: disable=broad-except
81+
log.exception(err.message)
82+
num_failed += 1
83+
return (num_changed, num_failed, len(url_map))
84+
85+
86+
def _parse_metadata_xml(xml, entity_id):
87+
"""
88+
Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of
89+
(public_key, sso_url, expires_at) for the specified entityID.
90+
91+
Raises MetadataParseError if anything is wrong.
92+
"""
93+
if xml.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor'):
94+
entity_desc = xml
95+
else:
96+
if xml.tag != etree.QName(SAML_XML_NS, 'EntitiesDescriptor'):
97+
raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag))
98+
entity_desc = xml.find(
99+
".//{}[@entityID='{}']".format(etree.QName(SAML_XML_NS, 'EntityDescriptor'), entity_id)
100+
)
101+
if not entity_desc:
102+
raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id))
103+
104+
expires_at = None
105+
if "validUntil" in xml.attrib:
106+
expires_at = dateutil.parser.parse(xml.attrib["validUntil"])
107+
if "cacheDuration" in xml.attrib:
108+
cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"])
109+
if expires_at is None or cache_expires < expires_at:
110+
expires_at = cache_expires
111+
112+
sso_desc = entity_desc.find(etree.QName(SAML_XML_NS, "IDPSSODescriptor"))
113+
if not sso_desc:
114+
raise MetadataParseError("IDPSSODescriptor missing")
115+
if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"):
116+
raise MetadataParseError("This IdP does not support SAML 2.0")
117+
118+
# Now we just need to get the public_key and sso_url
119+
public_key = sso_desc.findtext("./{}//{}".format(
120+
etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate"
121+
))
122+
if not public_key:
123+
raise MetadataParseError("Public Key missing. Expected an <X509Certificate>")
124+
public_key = public_key.replace(" ", "")
125+
binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService")))
126+
sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements}
127+
try:
128+
# The only binding supported by python-saml and python-social-auth is HTTP-Redirect:
129+
sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect']
130+
except KeyError:
131+
raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.")
132+
return public_key, sso_url, expires_at
133+
134+
135+
def _update_data(entity_id, public_key, sso_url, expires_at):
136+
"""
137+
Update/Create the SAMLProviderData for the given entity ID.
138+
Return value:
139+
False if nothing has changed and existing data's "fetched at" timestamp is just updated.
140+
True if a new record was created. (Either this is a new provider or something changed.)
141+
"""
142+
data_obj = SAMLProviderData.current(entity_id)
143+
fetched_at = datetime.datetime.now()
144+
if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url):
145+
data_obj.expires_at = expires_at
146+
data_obj.fetched_at = fetched_at
147+
data_obj.save()
148+
return False
149+
else:
150+
SAMLProviderData.objects.create(
151+
entity_id=entity_id,
152+
fetched_at=fetched_at,
153+
expires_at=expires_at,
154+
sso_url=sso_url,
155+
public_key=public_key,
156+
)
157+
return True

common/djangoapps/third_party_auth/tests/specs/test_testshib.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""
22
Third_party_auth integration tests using a mock version of the TestShib provider
33
"""
4-
from django.core.management import call_command
54
from django.core.urlresolvers import reverse
65
import httpretty
76
from mock import patch
8-
import StringIO
97
from student.tests.factories import UserFactory
8+
from third_party_auth.tasks import fetch_saml_metadata
109
from third_party_auth.tests import testutil
1110
import unittest
1211

@@ -209,15 +208,11 @@ def _configure_testshib_provider(self, **kwargs):
209208
self.configure_saml_provider(**kwargs)
210209

211210
if fetch_metadata:
212-
stdout = StringIO.StringIO()
213-
stderr = StringIO.StringIO()
214211
self.assertTrue(httpretty.is_enabled())
215-
call_command('saml', 'pull', stdout=stdout, stderr=stderr)
216-
stdout = stdout.getvalue().decode('utf-8')
217-
stderr = stderr.getvalue().decode('utf-8')
218-
self.assertEqual(stderr, '')
219-
self.assertIn(u'Fetching {}'.format(TESTSHIB_METADATA_URL), stdout)
220-
self.assertIn(u'Created new record for SAMLProviderData', stdout)
212+
num_changed, num_failed, num_total = fetch_saml_metadata()
213+
self.assertEqual(num_failed, 0)
214+
self.assertEqual(num_changed, 1)
215+
self.assertEqual(num_total, 1)
221216

222217
def _fake_testshib_login_and_return(self):
223218
""" Mocked: the user logs in to TestShib and then gets redirected back """

0 commit comments

Comments
 (0)