Skip to content

Commit

Permalink
Merge pull request #278 from onelogin/xmlparser
Browse files Browse the repository at this point in the history
See #221 and #267. Custom lxml parser based on the one defined at xmldefused
  • Loading branch information
pitbulk authored Jan 12, 2021
2 parents 910a288 + 64fbe77 commit 4a3efac
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 23 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
},
test_suite='tests',
install_requires=[
'lxml>=3.3.5',
'dm.xmlsec.binding==1.3.7',
'isodate>=0.5.0',
'defusedxml>=0.6.0',
Expand Down
12 changes: 6 additions & 6 deletions src/onelogin/saml2/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@

from base64 import b64encode
from urllib import quote_plus
from defusedxml.lxml import tostring

from onelogin.saml2.settings import OneLogin_Saml2_Settings
from onelogin.saml2.response import OneLogin_Saml2_Response
from onelogin.saml2.authn_request import OneLogin_Saml2_Authn_Request
from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.errors import OneLogin_Saml2_Error
from onelogin.saml2.logout_response import OneLogin_Saml2_Logout_Response
from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.utils import OneLogin_Saml2_Utils, xmlsec
from onelogin.saml2.logout_request import OneLogin_Saml2_Logout_Request
from onelogin.saml2.authn_request import OneLogin_Saml2_Authn_Request
from onelogin.saml2.response import OneLogin_Saml2_Response
from onelogin.saml2.settings import OneLogin_Saml2_Settings
from onelogin.saml2.utils import OneLogin_Saml2_Utils, xmlsec
from onelogin.saml2.xmlparser import tostring


class OneLogin_Saml2_Auth(object):
Expand Down
2 changes: 1 addition & 1 deletion src/onelogin/saml2/idp_metadata_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import ssl

from copy import deepcopy
from defusedxml.lxml import fromstring

from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from onelogin.saml2.xmlparser import fromstring


class OneLogin_Saml2_IdPMetadataParser(object):
Expand Down
5 changes: 2 additions & 3 deletions src/onelogin/saml2/logout_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from zlib import decompress
from base64 import b64encode, b64decode
from lxml import etree
from defusedxml.lxml import fromstring
from xml.dom.minidom import Document

from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError

from onelogin.saml2.utils import OneLogin_Saml2_Utils
from onelogin.saml2.xmlparser import fromstring

class OneLogin_Saml2_Logout_Request(object):
"""
Expand Down
5 changes: 2 additions & 3 deletions src/onelogin/saml2/logout_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
"""

from base64 import b64encode, b64decode
from defusedxml.lxml import fromstring

from xml.dom.minidom import Document
from defusedxml.minidom import parseString

from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from onelogin.saml2.xmlparser import fromstring


class OneLogin_Saml2_Logout_Response(object):
Expand Down
2 changes: 1 addition & 1 deletion src/onelogin/saml2/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def add_x509_key_descriptors(metadata, cert=None, add_encryption=True):
if cert is None or cert == '':
return metadata
try:
xml = parseString(metadata.encode('utf-8'), forbid_dtd=True)
xml = parseString(metadata.encode('utf-8'), forbid_dtd=True, forbid_entities=True, forbid_external=True)
except Exception as e:
raise Exception('Error parsing metadata. ' + e.message)

Expand Down
4 changes: 2 additions & 2 deletions src/onelogin/saml2/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

from base64 import b64decode
from copy import deepcopy
from defusedxml.lxml import tostring, fromstring
from xml.dom.minidom import Document

from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.utils import OneLogin_Saml2_Utils, return_false_on_exception
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.utils import OneLogin_Saml2_Utils, return_false_on_exception
from onelogin.saml2.xmlparser import tostring, fromstring


class OneLogin_Saml2_Response(object):
Expand Down
9 changes: 5 additions & 4 deletions src/onelogin/saml2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from hashlib import sha1, sha256, sha384, sha512
from isodate import parse_duration as duration_parser
from lxml import etree
from defusedxml.lxml import tostring, fromstring
from os.path import basename, dirname, join
import re
from sys import stderr
Expand All @@ -36,6 +35,7 @@

from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.xmlparser import tostring, fromstring


if not globals().get('xmlsec_setup', False):
Expand Down Expand Up @@ -165,11 +165,12 @@ def validate_xml(xml, schema, debug=False):

return 'invalid_xml'

return parseString(tostring(dom, encoding='unicode').encode('utf-8'), forbid_dtd=True)
return parseString(tostring(dom, encoding='unicode').encode('utf-8'), forbid_dtd=True, forbid_entities=True, forbid_external=True)

@staticmethod
def element_text(node):
etree.strip_tags(node, etree.Comment)
# Double check, the LXML Parser already removes comments
#etree.strip_tags(node, etree.Comment)
return node.text

@staticmethod
Expand Down Expand Up @@ -717,7 +718,7 @@ def generate_name_id(value, sp_nq, sp_format=None, cert=None, debug=False, nq=No

edata = enc_ctx.encryptXml(enc_data, elem[0])

newdoc = parseString(tostring(edata, encoding='unicode').encode('utf-8'), forbid_dtd=True)
newdoc = parseString(tostring(edata, encoding='unicode').encode('utf-8'), forbid_dtd=True, forbid_entities=True, forbid_external=True)

if newdoc.hasChildNodes():
child = newdoc.firstChild
Expand Down
145 changes: 145 additions & 0 deletions src/onelogin/saml2/xmlparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Based on the lxml example from defusedxml
#
# Copyright (c) 2013 by Christian Heimes <christian@python.org>
# Licensed to PSF under a Contributor Agreement.
# See https://www.python.org/psf/license for licensing details.
"""lxml.etree protection"""

from __future__ import print_function, absolute_import

import threading

from lxml import etree as _etree

from defusedxml.lxml import DTDForbidden, EntitiesForbidden, NotSupportedError

LXML3 = _etree.LXML_VERSION[0] >= 3

__origin__ = "lxml.etree"

tostring = _etree.tostring


class RestrictedElement(_etree.ElementBase):
"""A restricted Element class that filters out instances of some classes
"""

__slots__ = ()
blacklist = (_etree._Entity, _etree._ProcessingInstruction, _etree._Comment)

def _filter(self, iterator):
blacklist = self.blacklist
for child in iterator:
if isinstance(child, blacklist):
continue
yield child

def __iter__(self):
iterator = super(RestrictedElement, self).__iter__()
return self._filter(iterator)

def iterchildren(self, tag=None, reversed=False):
iterator = super(RestrictedElement, self).iterchildren(tag=tag, reversed=reversed)
return self._filter(iterator)

def iter(self, tag=None, *tags):
iterator = super(RestrictedElement, self).iter(tag=tag, *tags)
return self._filter(iterator)

def iterdescendants(self, tag=None, *tags):
iterator = super(RestrictedElement, self).iterdescendants(tag=tag, *tags)
return self._filter(iterator)

def itersiblings(self, tag=None, preceding=False):
iterator = super(RestrictedElement, self).itersiblings(tag=tag, preceding=preceding)
return self._filter(iterator)

def getchildren(self):
iterator = super(RestrictedElement, self).__iter__()
return list(self._filter(iterator))

def getiterator(self, tag=None):
iterator = super(RestrictedElement, self).getiterator(tag)
return self._filter(iterator)


class GlobalParserTLS(threading.local):
"""Thread local context for custom parser instances
"""

parser_config = {
"resolve_entities": False,
'remove_comments': True,
'no_network': True,
'remove_pis': True,
'huge_tree': False
}

element_class = RestrictedElement

def createDefaultParser(self):
parser = _etree.XMLParser(**self.parser_config)
element_class = self.element_class
if self.element_class is not None:
lookup = _etree.ElementDefaultClassLookup(element=element_class)
parser.set_element_class_lookup(lookup)
return parser

def setDefaultParser(self, parser):
self._default_parser = parser

def getDefaultParser(self):
parser = getattr(self, "_default_parser", None)
if parser is None:
parser = self.createDefaultParser()
self.setDefaultParser(parser)
return parser


_parser_tls = GlobalParserTLS()
getDefaultParser = _parser_tls.getDefaultParser


def check_docinfo(elementtree, forbid_dtd=False, forbid_entities=True):
"""Check docinfo of an element tree for DTD and entity declarations
The check for entity declarations needs lxml 3 or newer. lxml 2.x does
not support dtd.iterentities().
"""
docinfo = elementtree.docinfo
if docinfo.doctype:
if forbid_dtd:
raise DTDForbidden(docinfo.doctype, docinfo.system_url, docinfo.public_id)
if forbid_entities and not LXML3:
# lxml < 3 has no iterentities()
raise NotSupportedError("Unable to check for entity declarations " "in lxml 2.x")

if forbid_entities:
for dtd in docinfo.internalDTD, docinfo.externalDTD:
if dtd is None:
continue
for entity in dtd.iterentities():
raise EntitiesForbidden(entity.name, entity.content, None, None, None, None)


def parse(source, parser=None, base_url=None, forbid_dtd=True, forbid_entities=True):
if parser is None:
parser = getDefaultParser()
elementtree = _etree.parse(source, parser, base_url=base_url)
check_docinfo(elementtree, forbid_dtd, forbid_entities)
return elementtree


def fromstring(text, parser=None, base_url=None, forbid_dtd=True, forbid_entities=True):
if parser is None:
parser = getDefaultParser()
rootelement = _etree.fromstring(text, parser, base_url=base_url)
elementtree = rootelement.getroottree()
check_docinfo(elementtree, forbid_dtd, forbid_entities)
return rootelement


XML = fromstring


def iterparse(*args, **kwargs):
raise NotSupportedError("iterparse not available")
5 changes: 2 additions & 3 deletions tests/src/OneLogin/saml2_tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

from base64 import b64decode
import json
from defusedxml.lxml import fromstring
from lxml import etree
from os.path import dirname, join, exists
import unittest
from xml.dom.minidom import Document, parseString

from onelogin.saml2.constants import OneLogin_Saml2_Constants
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.settings import OneLogin_Saml2_Settings
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from onelogin.saml2.errors import OneLogin_Saml2_Error, OneLogin_Saml2_ValidationError
from onelogin.saml2.xmlparser import fromstring


class OneLogin_Saml2_Utils_Test(unittest.TestCase):
Expand Down Expand Up @@ -1035,7 +1035,6 @@ def testValidateSign(self):
with self.assertRaisesRegexp(OneLogin_Saml2_ValidationError, "Expected exactly one signature node; got 0."):
OneLogin_Saml2_Utils.validate_sign(wrapping_attack1, cert, raise_exceptions=True)


if __name__ == '__main__':
runner = unittest.TextTestRunner()
unittest.main(testRunner=runner)

0 comments on commit 4a3efac

Please sign in to comment.