Skip to content

Feature list item #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/renderers.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ If you are considering using `XML` for your API, you may want to consider implem

**.charset**: `utf-8`

**item_tag_name**: `list-item`
**.item_tag_name**: `list-item`

**.root_tag_name**: `root`

**.override_item_tag_name**: `False`
20 changes: 19 additions & 1 deletion rest_framework_xml/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from django.utils import six
from django.utils.xmlutils import SimplerXMLGenerator
from django.utils.six.moves import StringIO
from django.utils.six import StringIO
from django.utils.encoding import force_text
from rest_framework.renderers import BaseRenderer
from xml.etree import ElementTree as ET


class XMLRenderer(BaseRenderer):
Expand All @@ -20,6 +21,7 @@ class XMLRenderer(BaseRenderer):
charset = 'utf-8'
item_tag_name = 'list-item'
root_tag_name = 'root'
override_item_tag_name = False

def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Expand All @@ -38,8 +40,24 @@ def render(self, data, accepted_media_type=None, renderer_context=None):

xml.endElement(self.root_tag_name)
xml.endDocument()

if self.override_item_tag_name:
self._do_override_item_tag_name(stream)

return stream.getvalue()

def _do_override_item_tag_name(self, stream):
root = ET.fromstring(stream.getvalue())
for parent in root.findall('.//*list-item/..'):
child_name = parent.tag[0:-1]
for child in list(parent):
child.tag = child_name

stream.truncate(0)
stream.seek(0)
stream.write('<?xml version="1.0" encoding="utf-8"?>\n')
stream.write(ET.tostring(root).decode('utf-8'))

def _to_xml(self, xml, data):
if isinstance(data, (list, tuple)):
for item in data:
Expand Down
51 changes: 47 additions & 4 deletions tests/test_renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from django.test import TestCase
from django.test.utils import skipUnless
from django.utils.six.moves import StringIO
from django.utils.six import StringIO
from django.utils.translation import gettext_lazy
from rest_framework_xml.renderers import XMLRenderer
from rest_framework_xml.parsers import XMLParser
Expand All @@ -33,6 +33,41 @@ class XMLRendererTestCase(TestCase):
]
}

_complex_order_data = {
"creation_date": datetime.datetime(2017, 7, 1, 14, 30, 00),
"orderId": 1,
"positions": [
{
"posNo": 1,
"amount": 3,
"messages": [
{
"type": "O",
"code": "xyz"
},
{
"type": "L",
"code": "zyx"
}
]
},
{
"posNo": 2,
"amount": 1,
"messages": [
{
"type": "O",
"code": "xyz"
},
{
"type": "L",
"code": "zyx"
}
]
}
]
}

def test_render_string(self):
"""
Test XML rendering.
Expand Down Expand Up @@ -104,6 +139,14 @@ def test_render_lazy(self):
content = renderer.render({'field': lazy}, 'application/xml')
self.assertXMLContains(content, '<field>hello</field>')

def test_render_override_list_item(self):
renderer = XMLRenderer()
renderer.root_tag_name = 'order'
renderer.override_item_tag_name = True
content = renderer.render(self._complex_order_data, 'application/xml')
self.assertXMLContains(content, '<position>', renderer.root_tag_name)
self.assertXMLContains(content, '<message>', renderer.root_tag_name)

@skipUnless(etree, 'defusedxml not installed')
def test_render_and_parse_complex_data(self):
"""
Expand All @@ -117,7 +160,7 @@ def test_render_and_parse_complex_data(self):
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg)

def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
def assertXMLContains(self, xml, string, root_tag='root'):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<{0}>'.format(root_tag)))
self.assertTrue(xml.endswith('</{0}>'.format(root_tag)))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))