Skip to content

Commit

Permalink
Merge pull request #894 from REANNZ/fix-ed-extensions
Browse files Browse the repository at this point in the history
Fix: render extensions also for EntityDescriptor and IdPSSODescriptor
  • Loading branch information
c00kiemon5ter authored Jan 31, 2023
2 parents 01f5567 + 30243a8 commit aa0de7c
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/saml2/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,17 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None):
idpsso = md.IDPSSODescriptor()
idpsso.protocol_support_enumeration = samlp.NAMESPACE

exts = conf.getattr("extensions", "idp")
if exts:
if idpsso.extensions is None:
idpsso.extensions = md.Extensions()

for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
idpsso.extensions.add_extension_element(_e)

endps = conf.getattr("endpoints", "idp")
if endps:
for (endpoint, instlist) in do_endpoints(endps, ENDPOINTS["idp"]).items():
Expand Down Expand Up @@ -578,6 +589,17 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None):
aad = md.AttributeAuthorityDescriptor()
aad.protocol_support_enumeration = samlp.NAMESPACE

exts = conf.getattr("extensions", "aa")
if exts:
if aad.extensions is None:
aad.extensions = md.Extensions()

for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
aad.extensions.add_extension_element(_e)

endps = conf.getattr("endpoints", "aa")

if endps:
Expand Down Expand Up @@ -606,6 +628,17 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None):
aqs = md.AuthnAuthorityDescriptor()
aqs.protocol_support_enumeration = samlp.NAMESPACE

exts = conf.getattr("extensions", "aa")
if exts:
if aqs.extensions is None:
aqs.extensions = md.Extensions()

for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
aqs.extensions.add_extension_element(_e)

endps = conf.getattr("endpoints", "aq")

if endps:
Expand All @@ -626,6 +659,17 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None):

pdp.protocol_support_enumeration = samlp.NAMESPACE

exts = conf.getattr("extensions", "pdp")
if exts:
if pdp.extensions is None:
pdp.extensions = md.Extensions()

for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
pdp.extensions.add_extension_element(_e)

endps = conf.getattr("endpoints", "pdp")

if endps:
Expand Down Expand Up @@ -675,6 +719,17 @@ def entity_descriptor(confd):
if confd.contact_person is not None:
entd.contact_person = do_contact_persons_info(confd.contact_person)

exts = confd.extensions
if exts:
if not entd.extensions:
entd.extensions = md.Extensions()

for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
entd.extensions.add_extension_element(_e)

if confd.entity_attributes:
if not entd.extensions:
entd.extensions = md.Extensions()
Expand Down

0 comments on commit aa0de7c

Please sign in to comment.