From f3fd0b1a84859c998168cea900ac335aae5374bb Mon Sep 17 00:00:00 2001 From: Fabio Bonelli Date: Wed, 4 Dec 2019 12:09:14 +0100 Subject: [PATCH] Refactor the SP metadata loading logic. The SP metadata get loaded once at the start from all sources to index the entities ids and where the come from. Upon every request from an SP, the metadata of that SP gets reloaded from the last known source: if it's not there anymore the server tries to look in all known sources again. Protip! Using: ``` metadata: local: - conf/*.xml ``` Allows to drop config files to the directory at runtime and have their configuration picked up without restarting the server. Other goodies: * Invalid SP metadata are logged and discarded, but no longer prevent the rest of metadata from loading. (Fix #210) * Duplicate entityIds are discarded with precedence given to local sources (local > db > remote) (Fix #205). * Display the source of the SP metadata in the UI --- templates/home.html | 9 +- testenv/server.py | 15 +- testenv/spmetadata.py | 305 ++++++++++++++++++++++------- testenv/tests/test_spid_testenv.py | 2 +- testenv/tests/test_validators.py | 2 +- testenv/validators.py | 2 +- 6 files changed, 251 insertions(+), 84 deletions(-) diff --git a/templates/home.html b/templates/home.html index 7d3e5238..c4d134fb 100644 --- a/templates/home.html +++ b/templates/home.html @@ -10,11 +10,14 @@

Service Provider configurati {% for item in sp_list %} - - {{ item['entityID'] }} + + {{ item['entityID'] }} + + loaded from {{ item['location'] }} + {% endfor %} -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/testenv/server.py b/testenv/server.py index bd93e00a..9813ba22 100644 --- a/testenv/server.py +++ b/testenv/server.py @@ -284,7 +284,7 @@ def _handle_http_post(self, action): def _get_certificates_by_issuer(self, issuer): try: - return self._registry.get(issuer).certs() + return self._registry.load(issuer).certs() except KeyError: self._raise_error( 'entity ID {} non registrato, impossibile ricavare' @@ -357,7 +357,7 @@ def users(self): 'primary_attributes': spid_main_fields, 'secondary_attributes': spid_secondary_fields, 'users': self.user_manager.all(), - 'sp_list': self._registry.all(), + 'sp_list': self._registry.load_all().keys(), 'can_add_user': can_add_user } ) @@ -393,8 +393,9 @@ def index(self): **{ 'sp_list': [ { - "entityID": sp - } for sp in self._registry.all() + "entityID": entity_id, + "location": sp_metadata.location, + } for (entity_id, sp_metadata) in self._registry.load_all().items() ], } ) @@ -405,7 +406,7 @@ def get_destination(self, req, sp_id): acs_index = getattr(req, 'assertion_consumer_service_index', None) protocol_binding = getattr(req, 'protocol_binding', None) if acs_index is not None: - acss = self._registry.get( + acss = self._registry.load( sp_id).assertion_consumer_service(index=acs_index) if acss: destination = acss[0].get('Location') @@ -590,7 +591,7 @@ def login(self): atcs_idx ) ) - sp_metadata = self._registry.get(sp_id) + sp_metadata = self._registry.load(sp_id) required = [] optional = [] if atcs_idx and sp_metadata: @@ -791,7 +792,7 @@ def continue_response(self): def _sp_single_logout_service(self, issuer_name): _slo = None try: - _slo = self._registry.get(issuer_name).single_logout_services[0] + _slo = self._registry.load(issuer_name).single_logout_services[0] except Exception: pass return _slo diff --git a/testenv/spmetadata.py b/testenv/spmetadata.py index fb33f79d..925316c9 100644 --- a/testenv/spmetadata.py +++ b/testenv/spmetadata.py @@ -2,6 +2,7 @@ from itertools import chain import requests +from lxml.etree import LxmlError from testenv import config, log from testenv.exceptions import DeserializationError, MetadataLoadError, MetadataNotFoundError, ValidationError @@ -31,37 +32,117 @@ class ServiceProviderMetadataRegistry: def __init__(self): - self._loaders = [] - for source_type, source_params in list(config.params.metadata.items()): - self._loaders.append({ - 'local': ServiceProviderMetadataFileLoader, - 'remote': ServiceProviderMetadataHTTPLoader, - 'db': ServiceProviderMetadataDbLoader, - }[source_type](source_params)) self._validators = ValidatorGroup([ XMLMetadataFormatValidator(), ServiceProviderMetadataXMLSchemaValidator(), SpidMetadataValidator(), ]) + self._index_metadata() + + def load(self, entity_id): + """ + Loads the metadata of a Service Provider. + + Args: + entity_id (str): Entity id of the SP (usually a URL or a URN). - def get(self, entity_id): + Returns: + A ServiceProviderMetadata instance. + + Raises + MetadataNotFoundError: If there is no metadata associated to + the entity id. + DeserializationError: If the metadata associated to the entity id + is not valid. + """ entity_id = entity_id.strip() - for loader in self._loaders: + + fresh_metadata = None + + metadata = self._metadata.get(entity_id, None) + if not metadata: + # Try to reload all sources to see if the unknown entity id was added there + # somewhere. + logger.debug( + "Unknown entityId '{}`, reloading all the sources.".format(entity_id) + ) + self._index_metadata() + else: + # We got an known entity id, try to load its metadata the previously known + # location. try: - metadata = loader.get(entity_id) - try: - self._validators.validate(metadata.xml) - return metadata - except ValidationError as e: - raise DeserializationError(metadata.xml, e.details) - except MetadataNotFoundError: + fresh_metadata = metadata.loader.load(metadata.location) + if fresh_metadata.entity_id != entity_id: + raise MetadataLoadError + except MetadataLoadError as e: + logger.debug( + ("{}\n" + "Cannot find entityId '{}` at its previous location '{}`" + "reloading all the sources").format(e, entity_id, metadata.location) + ) + self._index_metadata() + + if not fresh_metadata: + try: + metadata = self._metadata[entity_id] + fresh_metadata = metadata.loader.load(metadata.location) + except (KeyError, MetadataLoadError): + raise MetadataNotFoundError(entity_id) + + if metadata.entity_id != entity_id: + raise MetadataNotFoundError(entity_id) + try: + self._validators.validate(fresh_metadata.xml) + except ValidationError as e: + raise DeserializationError(fresh_metadata.xml, e.details) + + return fresh_metadata + + def load_all(self): + """ + Returns a dict containing all ServerProviderMetadata loaded, + indexed by entityId. + """ + self._index_metadata() + + return self._metadata + + def _index_metadata(self): + """ + Populate self._metadata with the up to date information from all the + configured SP metadata. + """ + + # dict of { entity_id: ServiceProviderMetadata } + self._metadata = {} + + # Possible sources of metadata, ordered by preference + # (ie. the first source will be preferred in case of duplicate + # entity ids). + SOURCE_TYPES = ['local', 'db', 'remote'] + + for source_type in reversed(SOURCE_TYPES): + if source_type not in config.params.metadata: continue - raise MetadataNotFoundError(entity_id) + source_params = config.params.metadata[source_type] + + loader = { + 'local': ServiceProviderMetadataFileLoader, + 'remote': ServiceProviderMetadataHTTPLoader, + 'db': ServiceProviderMetadataDbLoader, + }[source_type](source_params) + + metadata = loader.load_all() + for dup in set(metadata.keys()).intersection(set(self._metadata)): + logger.info( + "Discarding duplicate entity_id `{}' from '{}`.".format( + dup, + self._metadata[dup].location + ) + ) - def all(self): - """Returns the list of entityIDs of all the known Service Providers""" - return [i for loader in self._loaders for i in loader.all()] + self._metadata.update(metadata) registry = None @@ -72,62 +153,126 @@ def build_metadata_registry(): registry = ServiceProviderMetadataRegistry() -class ServiceProviderMetadataFileLoader: - """Loads metadata from the configured files +class LoadAllMixin(object): + def load_all(self): + """ + Loads all the available SP metadata, skipping duplicates. - This could be improved automatically reloading the metadata when - file timestamps change - """ + Returns: + A dict containing all local ServerProviderMetadata loaded, + indexed by entityId. + """ + metadata = None + ret = {} - def __init__(self, conf): - self._metadata = {} - - files = [glob(entry) for entry in conf] - for file in list(chain.from_iterable(files)): + for location in self._locations: try: - with open(file, 'rb') as fp: - metadata = ServiceProviderMetadata(fp.read()) - self._metadata[metadata.entity_id] = metadata - logger.debug("Loaded metadata for: " + metadata.entity_id) - except Exception as e: - raise MetadataLoadError( - "Impossibile leggere il file '{}': '{}'".format(file, e) + metadata = self.load(location) + except MetadataLoadError as e: + logger.info( + "Skipping '{}` because of a load error: {}".format(location, e) ) + continue + + if metadata.entity_id in ret: + logger.info( + "Discarding duplicate entity_id `{}' from '{}`.".format( + metadata.entity_id, + metadata.location + ) + ) + continue + + ret[metadata.entity_id] = metadata - def get(self, entity_id): + return ret + + +class ServiceProviderMetadataFileLoader(LoadAllMixin, object): + """ + Loads SP metadata from a list of files. + + Args: + locations (list of str): List of paths to load. Paths can also contain + globbing metacharacters. + """ + + def __init__(self, locations): + files = [glob(entry) for entry in locations] + + self._locations = list(chain.from_iterable(files)) + + def load(self, location): + """ + Loads the SP metadata from file. + + Args: + location (str): The path of file. + + Returns: + A ServiceProviderMetadata instance. + + Raises: + MetadataLoadError: If the load fails. + """ try: - return self._metadata[entity_id] - except KeyError: - raise MetadataNotFoundError(entity_id) + with open(location, 'rb') as fp: + metadata = ServiceProviderMetadata(fp.read(), self, location) + except (IOError, LxmlError) as e: + raise MetadataLoadError( + "Failed to load '{}': '{}'".format(location, e) + ) + logger.debug( + "Loaded metadata for '{}` from '{}`".format( + metadata.entity_id, + location + ) + ) + return metadata - def all(self): - return list(self._metadata.keys()) +class ServiceProviderMetadataHTTPLoader(LoadAllMixin, object): + """ + Loads SP metadata from a list of HTTP URLs. -class ServiceProviderMetadataHTTPLoader: - """Loads metadata from the configured URLs""" + Args: + urls (list of str): List of HTTP URLs to load. + """ - def __init__(self, conf): - self._metadata = {} - for url in conf: - try: - response = requests.get(url) - response.raise_for_status() - metadata = ServiceProviderMetadata(response.content) - self._metadata[metadata.entity_id] = metadata - except Exception as e: - raise MetadataLoadError( - "La richiesta all'endpoint HTTP '{}': '{}'".format(url, e) - ) + def __init__(self, locations): + self._locations = locations + + def load(self, location): + """ + Loads the SP metadata from HTTP. + + Args: + location (str): The URL of the metadata to load. + + Returns: + A ServiceProviderMetadata instance. + + Raises: + MetadataLoadError: If the load fails. + """ - def get(self, entity_id): try: - return self._metadata[entity_id] - except KeyError: - raise MetadataNotFoundError(entity_id) + response = requests.get(location) + response.raise_for_status() + metadata = ServiceProviderMetadata(response.content, self, location) + except Exception as e: + raise MetadataLoadError( + "Request to HTTP endpoint '{}': '{}'".format(location, e) + ) - def all(self): - return list(self._metadata.keys()) + logger.debug( + "Loaded metadata for '{}` from '{}`".format( + metadata.entity_id, + location + ) + ) + + return metadata class ServiceProviderMetadataDbLoader: @@ -136,20 +281,38 @@ class ServiceProviderMetadataDbLoader: def __init__(self, conf): self._provider = DatabaseSPProvider(conf) - def get(self, entity_id): + def load(self, entity_id): metadata = self._provider.get(entity_id) if metadata is None: raise MetadataNotFoundError(entity_id) - return ServiceProviderMetadata(metadata) - - def all(self): - return list(self._provider.all().keys()) - + return ServiceProviderMetadata(metadata, self, 'db') + + def load_all(self): + """ + Returns a dict containing all 'db' ServerProviderMetadata loaded, + indexed by entityId.""" + return { + entity_id: ServiceProviderMetadata(xml, self, 'db') + for (entity_id, xml) in self._provider.all().items() + } -class ServiceProviderMetadata: - def __init__(self, xml): +class ServiceProviderMetadata(object): + """ + Object representing the metadata of a Service Provider. + + Args: + xml (str): The metadata as XML. + loader (instance of ServiceProviderMetadata{File,HTTP,Db}Loader): The loader the + metadata was loaded with. + location (str): The source the metadata was loaded from. + It's a path for 'local' metadata, a URL for 'remote' and + the string 'db' for 'db'. + """ + def __init__(self, xml, loader, location): self.xml = xml + self.loader = loader + self.location = location self._metadata = saml_to_dict(xml) @property diff --git a/testenv/tests/test_spid_testenv.py b/testenv/tests/test_spid_testenv.py index d4e3e1c3..11a6d0fb 100644 --- a/testenv/tests/test_spid_testenv.py +++ b/testenv/tests/test_spid_testenv.py @@ -29,7 +29,7 @@ def _sp_single_logout_service(server, issuer_name, binding): - _slo = server._registry.get(issuer_name).single_logout_service( + _slo = server._registry.load(issuer_name).single_logout_service( binding=binding ) return _slo[0] diff --git a/testenv/tests/test_validators.py b/testenv/tests/test_validators.py index 543a5cbc..1ecf9ecf 100644 --- a/testenv/tests/test_validators.py +++ b/testenv/tests/test_validators.py @@ -68,7 +68,7 @@ class FakeRegistry: def __init__(self, metadata): self._metadata = metadata.copy() - def get(self, entity_id): + def load(self, entity_id): return self._metadata.get(entity_id) @property diff --git a/testenv/validators.py b/testenv/validators.py index c24a58f8..5cfd2362 100644 --- a/testenv/validators.py +++ b/testenv/validators.py @@ -499,7 +499,7 @@ def validate(self, request): 'Issuer non presente nella {}'.format(req_type) ) try: - sp_metadata = self._registry.get(issuer_name) + sp_metadata = self._registry.load(issuer_name) except MetadataNotFoundError: raise UnknownEntityIDError( 'L\'entity ID "{}" indicato nell\'elemento non corrisponde a nessun Service Provider registrato in questo Identity Provider di test.'.format(issuer_name)