|
6 | 6 |
|
7 | 7 | from bigtree.node.node import Node |
8 | 8 | from bigtree.tree.search import find_path |
| 9 | +from bigtree.utils.assertions import assert_key_in_dict, assert_str_in_list |
| 10 | +from bigtree.utils.constants import ExportConstants, MermaidConstants |
9 | 11 | from bigtree.utils.exceptions import ( |
10 | 12 | optional_dependencies_image, |
11 | 13 | optional_dependencies_pandas, |
|
41 | 43 |
|
42 | 44 | T = TypeVar("T", bound=Node) |
43 | 45 |
|
44 | | -available_styles = { |
45 | | - "ansi": ("| ", "|-- ", "`-- "), |
46 | | - "ascii": ("| ", "|-- ", "+-- "), |
47 | | - "const": ("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 "), |
48 | | - "const_bold": ("\u2503 ", "\u2523\u2501\u2501 ", "\u2517\u2501\u2501 "), |
49 | | - "rounded": ("\u2502 ", "\u251c\u2500\u2500 ", "\u2570\u2500\u2500 "), |
50 | | - "double": ("\u2551 ", "\u2560\u2550\u2550 ", "\u255a\u2550\u2550 "), |
51 | | - "custom": ("", "", ""), |
52 | | -} |
53 | | - |
54 | 46 |
|
55 | 47 | def print_tree( |
56 | 48 | tree: T, |
@@ -324,6 +316,7 @@ def yield_tree( |
324 | 316 | style (str): style of print, defaults to abstract style |
325 | 317 | custom_style (Iterable[str]): style of stem, branch and final stem, used when `style` is set to 'custom' |
326 | 318 | """ |
| 319 | + available_styles = ExportConstants.AVAILABLE_STYLES |
327 | 320 | if style not in available_styles.keys(): |
328 | 321 | raise ValueError( |
329 | 322 | f"Choose one of {available_styles.keys()} style, use `custom` to define own style" |
@@ -1001,66 +994,16 @@ def tree_to_mermaid( |
1001 | 994 | """ |
1002 | 995 | from bigtree.tree.helper import clone_tree |
1003 | 996 |
|
1004 | | - rankdirs = ["TB", "BT", "LR", "RL"] |
1005 | | - line_shapes = [ |
1006 | | - "basis", |
1007 | | - "bumpX", |
1008 | | - "bumpY", |
1009 | | - "cardinal", |
1010 | | - "catmullRom", |
1011 | | - "linear", |
1012 | | - "monotoneX", |
1013 | | - "monotoneY", |
1014 | | - "natural", |
1015 | | - "step", |
1016 | | - "stepAfter", |
1017 | | - "stepBefore", |
1018 | | - ] |
1019 | | - node_shapes = { |
1020 | | - "rounded_edge": """("{label}")""", |
1021 | | - "stadium": """(["{label}"])""", |
1022 | | - "subroutine": """[["{label}"]]""", |
1023 | | - "cylindrical": """[("{label}")]""", |
1024 | | - "circle": """(("{label}"))""", |
1025 | | - "asymmetric": """>"{label}"]""", |
1026 | | - "rhombus": """{{"{label}"}}""", |
1027 | | - "hexagon": """{{{{"{label}"}}}}""", |
1028 | | - "parallelogram": """[/"{label}"/]""", |
1029 | | - "parallelogram_alt": """[\\"{label}"\\]""", |
1030 | | - "trapezoid": """[/"{label}"\\]""", |
1031 | | - "trapezoid_alt": """[\\"{label}"/]""", |
1032 | | - "double_circle": """((("{label}")))""", |
1033 | | - } |
1034 | | - edge_arrows = { |
1035 | | - "normal": "-->", |
1036 | | - "bold": "==>", |
1037 | | - "dotted": "-.->", |
1038 | | - "open": "---", |
1039 | | - "bold_open": "===", |
1040 | | - "dotted_open": "-.-", |
1041 | | - "invisible": "~~~", |
1042 | | - "circle": "--o", |
1043 | | - "cross": "--x", |
1044 | | - "double_normal": "<-->", |
1045 | | - "double_circle": "o--o", |
1046 | | - "double_cross": "x--x", |
1047 | | - } |
| 997 | + rankdirs = MermaidConstants.RANK_DIR |
| 998 | + line_shapes = MermaidConstants.LINE_SHAPES |
| 999 | + node_shapes = MermaidConstants.NODE_SHAPES |
| 1000 | + edge_arrows = MermaidConstants.EDGE_ARROWS |
1048 | 1001 |
|
1049 | 1002 | # Assertions |
1050 | | - if rankdir not in rankdirs: |
1051 | | - raise ValueError(f"Invalid input, check `rankdir` should be one of {rankdirs}") |
1052 | | - if node_shape not in node_shapes: |
1053 | | - raise ValueError( |
1054 | | - f"Invalid input, check `node_shape` should be one of {node_shapes.keys()}" |
1055 | | - ) |
1056 | | - if line_shape not in line_shapes: |
1057 | | - raise ValueError( |
1058 | | - f"Invalid input, check `line_shape` should be one of {line_shapes}" |
1059 | | - ) |
1060 | | - if edge_arrow not in edge_arrows: |
1061 | | - raise ValueError( |
1062 | | - f"Invalid input, check `edge_arrow` should be one of {edge_arrows.keys()}" |
1063 | | - ) |
| 1003 | + assert_str_in_list("rankdir", rankdir, rankdirs) |
| 1004 | + assert_key_in_dict("node_shape", node_shape, node_shapes) |
| 1005 | + assert_str_in_list("line_shape", line_shape, line_shapes) |
| 1006 | + assert_key_in_dict("edge_arrow", edge_arrow, edge_arrows) |
1064 | 1007 |
|
1065 | 1008 | mermaid_template = """```mermaid\n{title}{line_style}\nflowchart {rankdir}\n{flows}\n{styles}\n```""" |
1066 | 1009 | flowchart_template = "{from_node_ref}{from_node_name}{flow_style} {arrow}{arrow_label} {to_node_ref}{to_node_name}" |
@@ -1104,7 +1047,7 @@ def mermaid_name(self) -> str: |
1104 | 1047 | """ |
1105 | 1048 | if self.is_root: |
1106 | 1049 | return "0" |
1107 | | - return f"{self.parent.mermaid_name}{self.parent.children.index(self)}" |
| 1050 | + return f"{self.parent.mermaid_name}-{self.parent.children.index(self)}" |
1108 | 1051 |
|
1109 | 1052 | tree_mermaid = clone_tree(tree, MermaidNode) |
1110 | 1053 | default_edge_arrow = edge_arrows[edge_arrow] |
|
0 commit comments