Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for some XSD features #17

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes
  • Loading branch information
cre-os committed Jan 21, 2025
commit 0579dff7a45e6cfcd3cb6fe5763bc4b24e9b1d22
6 changes: 4 additions & 2 deletions docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ original one.

### Recursive XSD

Recursive XML schemas are not supported, because most of the time they will result in cycles in foreign key constraints
dependencies, which we cannot handle easily.
Recursive XML schemas are not fully supported, because they result in cycles in tables dependencies, which would make
the process much more complex. Whenever a field which would introduce a dependency cycle is detected in the XSD, it is
discarded with a warning, which means that the corresponding data in XML files will not be imported. The rest of the
data should be processed correctly.

### Mixed content elements

Expand Down
8 changes: 2 additions & 6 deletions src/xml2db/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ def _extract_node(
for field_type, key, field in model_table.fields:
if field_type == "col":
content_key = (
(
f"{key[:-5]}__attr"
if key.endswith("_attr")
else f"{key}__attr"
)
(f"{key[:-5]}__attr" if field.has_suffix else f"{key}__attr")
if field.is_attr
else key
)
Expand Down Expand Up @@ -334,7 +330,7 @@ def _build_node(node_type: str, node_pk: int) -> tuple:
content_key = (
(
f"{rel_name[:-5]}__attr"
if rel_name.endswith("_attr")
if rel.has_suffix
else f"{rel_name}__attr"
)
if rel.is_attr
Expand Down
8 changes: 7 additions & 1 deletion src/xml2db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,20 @@ def get_occurs(particle):
max_length,
allow_empty,
) = recurse_parse_simple_type([attrib.type])
suffix = attrib_name in children_names
parent_table.add_column(
f"{attrib_name}{'_attr' if attrib_name in children_names else ''}",
f"{attrib_name}{'_attr' if suffix else ''}",
data_type,
[0, 1],
min_length,
max_length,
True,
suffix,
False,
allow_empty,
None,
)

nested_containers = []
# go through the children to add either arguments either relations to the current element
for child in parent_node:
Expand All @@ -470,6 +473,7 @@ def get_occurs(particle):
if child.parent
and child.parent.max_occurs != 1
and child.parent.model != "choice"
and child.max_occurs == 1
else None
),
)
Expand Down Expand Up @@ -498,6 +502,7 @@ def get_occurs(particle):
max_length,
False,
False,
False,
allow_empty,
nested_containers[-1][1],
)
Expand Down Expand Up @@ -556,6 +561,7 @@ def get_occurs(particle):
min_length,
max_length,
False,
False,
True,
allow_empty,
None,
Expand Down
2 changes: 2 additions & 0 deletions src/xml2db/table/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
min_length: int,
max_length: Union[int, None],
is_attr: bool,
has_suffix: bool,
is_content: bool,
allow_empty: bool,
ngroup: Union[int, None],
Expand All @@ -181,6 +182,7 @@ def __init__(
self.min_length = min_length
self.max_length = max_length
self.is_attr = is_attr
self.has_suffix = has_suffix
self.is_content = is_content
self.allow_empty = allow_empty
self.ngroup = ngroup
Expand Down
1 change: 1 addition & 0 deletions src/xml2db/table/reused_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_col(temp=False):
False,
False,
False,
False,
None,
self.config,
self.data_model,
Expand Down
3 changes: 3 additions & 0 deletions src/xml2db/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def add_column(
min_length: int,
max_length: Union[int, None],
is_attr: bool,
has_suffix: bool,
is_content: bool,
allow_empty: bool,
ngroup: Union[str, None],
Expand All @@ -143,6 +144,7 @@ def add_column(
min_length: minimum length
max_length: maximum length
is_attr: is XML attribute or element?
has_suffix: for an attribute, do we need the '_attr' suffix?
is_content: is content of a mixed type element?
allow_empty: is nullable?
ngroup: a string id signaling that the column belongs to a nested sequence
Expand All @@ -155,6 +157,7 @@ def add_column(
min_length,
max_length,
is_attr,
has_suffix,
is_content,
allow_empty,
ngroup,
Expand Down
8 changes: 7 additions & 1 deletion src/xml2db/table/transformed_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _transform_to_choice(self) -> None:
False,
False,
False,
False,
None,
self.config,
self.data_model,
Expand All @@ -89,6 +90,7 @@ def _transform_to_choice(self) -> None:
max(max_lengths) if all(e is not None for e in max_lengths) else None,
False,
False,
False,
any(allow_empty),
None,
self.config,
Expand Down Expand Up @@ -193,6 +195,7 @@ def _elevate_relation_1(
child_field.min_length,
child_field.max_length,
child_field.is_attr,
child_field.has_suffix,
child_field.is_content,
child_field.allow_empty,
child_field.ngroup,
Expand Down Expand Up @@ -276,9 +279,12 @@ def simplify_table(self) -> Tuple[dict, dict]:

# if the table can be transformed, stop here
if self._is_table_choice_transform_applicable():
fields_transform = {}
for col in self.columns.values():
fields_transform[(self.type_name, col.name)] = (None, "join")
self._transform_to_choice()
self.is_simplified = True
return {self.type_name: "choice"}, {}
return {self.type_name: "choice"}, fields_transform

# loop through field to transform them if need be
out_fields = []
Expand Down
104 changes: 68 additions & 36 deletions src/xml2db/xml_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _parse_xml_node(
key
!= "{http://www.w3.org/2001/XMLSchema-instance}noNamespaceSchemaLocation"
):
content[f"{key}__attr"] = [val]
content[f"{key}__attr"] = [val.strip() if val.strip() else val]

if node.text and node.text.strip():
content["value"] = [node.text.strip()]
Expand All @@ -138,22 +138,26 @@ def _parse_xml_node(
key = element.tag.split("}")[1] if "}" in element.tag else element.tag
node_type_key = (node_type, key)
value = None
if element.text and element.text.strip():
value = element.text
transform = self.model.fields_transforms.get(
node_type_key, (None, "join")
)[1]
if element.text:
value = (
element.text.strip() if element.text.strip() else element.text
)
if node_type_key not in self.model.fields_transforms:
# skip the node if it is not in the data model
continue
transform = self.model.fields_transforms[node_type_key][1]
if transform != "join":
value = self._parse_xml_node(
self.model.fields_transforms[node_type_key][0],
element,
transform not in ["elevate", "elevate_wo_prefix"],
hash_maps,
)
if key in content:
content[key].append(value)
else:
content[key] = [value]
if value is not None:
if key in content:
content[key].append(value)
else:
content[key] = [value]

node = self._transform_node(node_type, content)

Expand Down Expand Up @@ -190,19 +194,25 @@ def _parse_iterative(
hash_maps = {}

joined_values = False
skipped_nodes = 0
for event, element in etree.iterparse(
xml_file,
recover=recover,
events=["start", "end"],
remove_blank_text=True,
):
key = element.tag.split("}")[1] if "}" in element.tag else element.tag
if event == "start":

if event == "start" and skipped_nodes > 0:
skipped_nodes += 1

elif event == "start":
if nodes_stack[-1][0]:
node_type_key = (nodes_stack[-1][0], key)
node_type, transform = self.model.fields_transforms.get(
node_type_key, (None, "join")
)
if node_type_key not in self.model.fields_transforms:
skipped_nodes += 1
continue
node_type, transform = self.model.fields_transforms[node_type_key]
else:
node_type, transform = self.model.root_table, None
joined_values = transform == "join"
Expand All @@ -213,28 +223,41 @@ def _parse_iterative(
attrib_key
!= "{http://www.w3.org/2001/XMLSchema-instance}noNamespaceSchemaLocation"
):
content[f"{attrib_key}__attr"] = [attrib_val]
content[f"{attrib_key}__attr"] = [
attrib_val.strip() if attrib_val.strip() else attrib_val
]
nodes_stack.append((node_type, content))

elif event == "end" and skipped_nodes > 0:
skipped_nodes -= 1

elif event == "end":
# joined_values was set with the previous "start" event just before
# joined_values was set with the previous "start" event just before and corresponds to lists of simple
# type elements
if joined_values:
value = None
if element.text:
if key in nodes_stack[-1][1]:
nodes_stack[-1][1][key].append(element.text)
if element.text.strip():
value = element.text.strip()
else:
nodes_stack[-1][1][key] = [element.text]
value = element.text
if key in nodes_stack[-1][1]:
nodes_stack[-1][1][key].append(value)
else:
nodes_stack[-1][1][key] = [value]

# else, we have completed a complex type node
else:
node = nodes_stack.pop()
if nodes_stack[-1][0]:
node_type_key = (nodes_stack[-1][0], key)
node_type, transform = self.model.fields_transforms.get(
node_type_key, (None, "join")
)
node_type, transform = self.model.fields_transforms[
node_type_key
]
else:
node_type, transform = self.model.root_table, None
if element.text:
node[1]["value"] = [element.text]
if element.text and element.text.strip():
node[1]["value"] = [element.text.strip()]
node = self._transform_node(*node)
if transform not in ["elevate", "elevate_wo_prefix"]:
node = self._compute_hash_deduplicate(node, hash_maps)
Expand Down Expand Up @@ -293,18 +316,26 @@ def _compute_hash_deduplicate(self, node: tuple, hash_maps: dict) -> tuple:
A tuple of (node_type, content, hash) representing a node after deduplication
"""
node_type, content = node
if node_type not in self.model.tables:
return "", None, b""
table = self.model.tables[node_type]

h = self.model.model_config["record_hash_constructor"]()
for field_type, name, field in table.fields:
if field_type == "col":
if field.is_attr:
if f"{name}__attr" in content:
h.update(str(content[f"{name}__attr"]).encode("utf-8"))
elif f"{name[:-5]}__attr" in content:
h.update(str(content[f"{name[:-5]}__attr"]).encode("utf-8"))
else:
h.update(str(None).encode("utf-8"))
h.update(
str(
content.get(
(
f"{name[:-5]}__attr"
if field.has_suffix
else f"{name}__attr"
),
None,
)
).encode("utf-8")
)
else:
h.update(str(content.get(name, None)).encode("utf-8"))
elif field_type == "rel1":
Expand Down Expand Up @@ -429,14 +460,14 @@ def check_transformed_node(node_type, element):
text_content = None
if field_type == "col":
if rel.is_attr:
if f"{rel_name}__attr" in content:
attributes[rel.name_chain[-1][0]] = content[
f"{rel_name}__attr"
][0]
elif f"{rel_name[:-5]}__attr" in content:
if rel.has_suffix and f"{rel_name[:-5]}__attr" in content:
attributes[rel.name_chain[-1][0][:-5]] = content[
f"{rel_name[:-5]}__attr"
][0]
elif not rel.has_suffix and f"{rel_name}__attr" in content:
attributes[rel.name_chain[-1][0]] = content[
f"{rel_name}__attr"
][0]
elif rel_name in content:
if rel.is_content:
text_content = content[rel_name][0]
Expand All @@ -462,7 +493,8 @@ def check_transformed_node(node_type, element):
if prev_ngroup and rel.ngroup != prev_ngroup:
for ngroup_children in zip_longest(*ngroup_stack):
for child in ngroup_children:
nodes_stack[-1][1].append(child)
if child is not None:
nodes_stack[-1][1].append(child)
ngroup_stack = []
prev_ngroup = rel.ngroup
if len(children) > 0:
Expand Down
13 changes: 12 additions & 1 deletion tests/fixtures.py → tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

from xml2db import DataModel

models_path = "tests/sample_models"


def list_xml_path(test_config, key):
path = os.path.join(models_path, test_config["id"], key)
if os.path.isdir(path):
return [os.path.join(path, f) for f in os.listdir(path)]
return []


@pytest.fixture
def conn_string():
Expand All @@ -14,7 +23,9 @@ def conn_string():
def setup_db_model(conn_string, model_config):
db_schema = f"test_xml2db"
model = DataModel(
xsd_file=model_config.get("xsd_path"),
xsd_file=str(
os.path.join(models_path, model_config["id"], model_config["xsd"])
),
short_name=model_config.get("id"),
connection_string=conn_string,
db_schema=db_schema,
Expand Down
Loading
Loading