Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Fixed:
- Tree Exporter: `tree_to_mermaid` fix for root node to have node attribute.

## [0.19.4] - 2024-08-15
### Changed:
- Docs: Clean CSS for playground.
- Misc: Refactor tests for `tree_to_mermaid`.
- Misc: Allow untyped calls in mypy type checking due to ImageFont.truetype call.
### Fixed:
- Tree Exporter: `tree_to_mermaid` fix where the node colour is added wrongly to the wrong node.
- Tree Exporter: `tree_to_mermaid` fix where the node attribute is added wrongly to the wrong node.
- Misc: Fix and update code examples in docstring.
- Misc: Fix test cases for pydot due to code upgrade.

Expand Down
50 changes: 29 additions & 21 deletions bigtree/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,9 +1572,7 @@ def tree_to_mermaid(
assert_key_in_dict("edge_arrow", edge_arrow, edge_arrows)

mermaid_template = """```mermaid\n{title}{line_style}\nflowchart {rankdir}\n{flows}\n{styles}\n```"""
flowchart_template = (
"{from_ref}{from_name} {arrow}{arrow_label} {to_ref}{to_name}{to_style}"
)
flowchart_template = "{from_ref}{from_name}{from_style} {arrow}{arrow_label} {to_ref}{to_name}{to_style}"
style_template = "classDef {style_name} {style}"

# Content
Expand Down Expand Up @@ -1658,43 +1656,53 @@ def _get_attr(
if not node.is_root:
# Get custom style (node_shape_attr)
_parent_node_name = ""
_from_style = ""
if node.parent.is_root:
_parent_node_shape_choice = _get_attr(
node.parent, node_shape_attr, node_shape
)
_parent_node_shape = node_shapes[_parent_node_shape_choice]
_parent_node_name = _parent_node_shape.format(label=node.parent.name)
_node_shape_choice = _get_attr(node, node_shape_attr, node_shape)
_node_shape = node_shapes[_node_shape_choice]
_node_name = _node_shape.format(label=node.name)
# Get custom style for root (node_shape_attr, node_attr)
_parent_node_name = node_shapes[
_get_attr(node.parent, node_shape_attr, node_shape)
].format(label=node.parent.name)

if _get_attr(node.parent, node_attr, "") and len(styles) < 2:
_from_style = _get_attr(node.parent, node_attr, "")
_from_style_class = (
f"""class{node.parent.get_attr("mermaid_name")}"""
)
styles.append(
style_template.format(
style_name=_from_style_class, style=_from_style
)
)
_from_style = f":::{_from_style_class}"
_node_name = node_shapes[
_get_attr(node, node_shape_attr, node_shape)
].format(label=node.name)

# Get custom style (edge_arrow_attr, edge_label)
_arrow_choice = _get_attr(node, edge_arrow_attr, edge_arrow)
_arrow = edge_arrows[_arrow_choice]
_arrow = edge_arrows[_get_attr(node, edge_arrow_attr, edge_arrow)]
_arrow_label = (
f"|{node.get_attr(edge_label)}|" if node.get_attr(edge_label) else ""
)

# Get custom style (node_attr)
_flow_style = _get_attr(node, node_attr, "")
if _flow_style:
_flow_style_class = f"""class{node.get_attr("mermaid_name")}"""
_to_style = _get_attr(node, node_attr, "")
if _to_style:
_to_style_class = f"""class{node.get_attr("mermaid_name")}"""
styles.append(
style_template.format(
style_name=_flow_style_class, style=_flow_style
)
style_template.format(style_name=_to_style_class, style=_to_style)
)
_flow_style = f":::{_flow_style_class}"
_to_style = f":::{_to_style_class}"

flows.append(
flowchart_template.format(
from_ref=node.parent.get_attr("mermaid_name"),
from_name=_parent_node_name,
from_style=_from_style,
arrow=_arrow,
arrow_label=_arrow_label,
to_ref=node.get_attr("mermaid_name"),
to_name=_node_name,
to_style=_flow_style,
to_style=_to_style,
)
)

Expand Down
27 changes: 27 additions & 0 deletions tests/tree/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2431,6 +2431,33 @@ def get_node_attr(node):
mermaid_md = tree_to_mermaid(tree_node_no_attr, node_attr=get_node_attr)
assert mermaid_md == self.MERMAID_STR_NODE_ATTR

def test_tree_to_mermaid_node_attr_root(self, tree_node_no_attr):
def get_node_attr(node):
if node.node_name == "a":
return "fill:green,stroke:black"
elif node.node_name in ["g", "h"]:
return "fill:red,stroke:black,stroke-width:2"
return ""

mermaid_md = tree_to_mermaid(tree_node_no_attr, node_attr=get_node_attr)
expected_str = (
"""```mermaid\n"""
"""%%{ init: { \'flowchart\': { \'curve\': \'basis\' } } }%%\n"""
"""flowchart TB\n"""
"""0("a"):::class0 --> 0-0("b")\n"""
"""0-0 --> 0-0-0("d")\n"""
"""0-0 --> 0-0-1("e")\n"""
"""0-0-1 --> 0-0-1-0("g"):::class0-0-1-0\n"""
"""0-0-1 --> 0-0-1-1("h"):::class0-0-1-1\n"""
"""0("a") --> 0-1("c")\n"""
"""0-1 --> 0-1-0("f")\n"""
"""classDef default stroke-width:1\n"""
"""classDef class0 fill:green,stroke:black\n"""
"""classDef class0-0-1-0 fill:red,stroke:black,stroke-width:2\n"""
"""classDef class0-0-1-1 fill:red,stroke:black,stroke-width:2\n```"""
)
assert mermaid_md == expected_str


class TestTreeToNewick:
@staticmethod
Expand Down