Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ Lastly it will print a representation of the call tree to the terminal to allow
| ------------------ | --------------------- | -------- | ------- | ---------------------------------------------------- |
| positional src_root | `str` | ✅ | | Path to the root of the repository to scan |
| `--new-root` | `str` | ❌ | `''` | Optional new root path for output (default: empty, meaning same as src_root if optimisation is enabled). |
| `--call-tree-save-path` | `str` | ❌ | `'./call_tree.json'` | The location to save the generated call tree. Only used if `--optimise-src-code` isn't used. Defaults to `./call_tree.json`. |

| `--optimise-src-code` | Flag (no value) | ❌ | | Enable optimisation of the source code. |


Expand Down
72 changes: 51 additions & 21 deletions src/spaghettree/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
from pprint import pformat
import json
from pathlib import Path

from spaghettree import Ok, Result
from spaghettree import Result
from spaghettree.adapters.io_wrapper import IOProtocol, IOWrapper
from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.optimisation import (
cyan,
get_dwm,
get_top_suggested_merges,
yellow,
Expand All @@ -21,13 +21,30 @@
from spaghettree.logger import logger


def main(src_root: str, *, new_root: str = "", optimise_src_code: bool = False) -> Result:
def main(
src_root: str,
*,
new_root: str = "",
call_tree_save_path: str = "./call_tree.json",
optimise_src_code: bool = False,
) -> Result:
io = IOWrapper()
return run_process(io, src_root, new_root=new_root, optimise_src_code=optimise_src_code)
return run_process(
io,
src_root,
new_root=new_root,
optimise_src_code=optimise_src_code,
call_tree_save_path=call_tree_save_path,
)


def run_process(
io: IOProtocol, src_root: str, *, new_root: str = "", optimise_src_code: bool = False
io: IOProtocol,
src_root: str,
*,
new_root: str = "",
call_tree_save_path: str = "./call_tree.json",
optimise_src_code: bool = False,
) -> Result:
logger.info(f"*** RUNNING `spaghettree` {src_root = } {new_root = } ***")
src_code = io.read_files(src_root).unwrap()
Expand All @@ -39,27 +56,28 @@ def run_process(
call_tree = entities_res.and_then(create_call_tree).unwrap()

if optimise_src_code:
return optimise_entity_positions(
io=io,
res = optimise_entity_positions(
entities=entities,
location_map=location_map,
call_tree=call_tree,
src_root=src_root,
new_root=new_root,
).unwrap()
else:
adj_mat = AdjMat.from_call_tree_no_optimisation(call_tree).unwrap()
print( # noqa: T201
yellow(
f"Current Directed Weighted Modularity (DWM): {get_dwm(adj_mat.mat, adj_mat.communities): .5f}"
)
)
top_merges = get_top_suggested_merges(adj_mat).unwrap()

adj_mat = AdjMat.from_call_tree_no_optimisation(call_tree).unwrap()
print( # noqa: T201
yellow(
f"Current Directed Weighted Modularity (DWM): {get_dwm(adj_mat.mat, adj_mat.communities): .5f}"
)
)
top_merges = get_top_suggested_merges(adj_mat).unwrap()
for merge in top_merges:
merge.display()

for merge in top_merges:
merge.display()
res = {Path(call_tree_save_path).absolute(): json.dumps(call_tree, indent=4)}

return Ok(call_tree)
return io.write_files(res, ruff_root=new_root, format_code=optimise_src_code)


if __name__ == "__main__":
Expand All @@ -75,6 +93,13 @@ def run_process(
default="",
help="Optional new root path for output (default: empty, meaning same as src_root if optimisation is enabled).",
)
parser.add_argument(
"--call-tree-save-path",
dest="call_tree_save_path",
type=str,
default="./call_tree.json",
help="The location to save the generated call tree. Only used if `--optimise-src-code` isn't used. Defaults to `./call_tree.json`.",
)
parser.add_argument(
"--optimise-src-code",
dest="optimise_src_code",
Expand All @@ -83,6 +108,11 @@ def run_process(
)

args = parser.parse_args()
res = main(args.src_root, new_root=args.new_root, optimise_src_code=args.optimise_src_code)
call_tree = res.unwrap()
print(f"\n{cyan(pformat(call_tree))}") # noqa: T201
res = main(
args.src_root,
new_root=args.new_root,
call_tree_save_path=args.call_tree_save_path,
optimise_src_code=args.optimise_src_code,
)
if not res.is_ok():
raise res.error
16 changes: 11 additions & 5 deletions src/spaghettree/adapters/io_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def read_files(self, root: str | Path) -> Result: ...
@safe
def write(self, modified_code: str, filepath: str, *, format_code: bool = True) -> None: ...

def write_files(self, src_code: dict[str, str], ruff_root: str | None = None) -> Result: ...
def write_files(
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True
) -> Result: ...


@attrs.define
Expand Down Expand Up @@ -70,11 +72,13 @@ def write(self, modified_code: str, filepath: str, *, format_code: bool = True)
if format_code:
self._run_ruff(filepath)

def write_files(self, src_code: dict[str, str], ruff_root: str | None = None) -> Result:
def write_files(
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True
) -> Result:
results, fails = {}, {}

for filepath, modified_code in src_code.items():
if ruff_root is not None:
if not ruff_root or not format_code:
# format all at the end instead
res = self.write(modified_code, filepath, format_code=False)
else:
Expand Down Expand Up @@ -139,12 +143,14 @@ def read_files(self, root: str | Path) -> Result:
def write(self, modified_code: str, filepath: str, *, format_code: bool = True) -> None:
self.files[filepath] = format_code_str(modified_code) if format_code else modified_code

def write_files(self, src_code: dict[str, str], ruff_root: str | None = None) -> Result:
def write_files(
self, src_code: dict[str, str], ruff_root: str | None = None, *, format_code: bool = True
) -> Result:
results, fails = {}, {}

for filepath, modified_code in src_code.items():
if ruff_root is not None:
res = self.write(modified_code, filepath)
res = self.write(modified_code, filepath, format_code=format_code)

if res.is_ok():
results[filepath] = res.inner
Expand Down
10 changes: 7 additions & 3 deletions src/spaghettree/domain/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,13 @@ def add_referenced_imports(self, imports: set[ImportCST]) -> Self:
self.imports.update(imports)
return self

for imp in imports:
if imp.as_name in self.referenced or imp.module in sys.stdlib_module_names:
self.imports.add(imp)
self.imports.update(
{
imp
for imp in imports
if imp.as_name in self.referenced or imp.module in sys.stdlib_module_names
}
)
return self


Expand Down
5 changes: 1 addition & 4 deletions src/spaghettree/domain/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial

from spaghettree import Result, safe
from spaghettree.adapters.io_wrapper import IOProtocol
from spaghettree.domain.adj_mat import AdjMat
from spaghettree.domain.entities import EntityCST, ImportCST, ImportType
from spaghettree.domain.optimisation import (
Expand All @@ -19,8 +18,7 @@
from spaghettree.logger import logger


def optimise_entity_positions( # noqa: PLR0913
io: IOProtocol,
def optimise_entity_positions(
entities: dict[str, EntityCST],
location_map: dict[str, EntityLocation],
call_tree: dict[str, list[str]],
Expand All @@ -44,7 +42,6 @@ def optimise_entity_positions( # noqa: PLR0913
)
.and_then(partial(create_new_filepaths, new_root=new_root or src_root))
.and_then(add_empty_inits_if_needed)
.and_then(partial(io.write_files, ruff_root=new_root or src_root))
)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import shutil
from pathlib import Path
Expand Down Expand Up @@ -345,4 +346,4 @@ def test_run_process_return_call_tree(fixture_get_subset_files, expected_result)
io = FakeIOWrapper(files)
res = run_process(io, name, optimise_src_code=False)
assert res.is_ok()
assert res.inner == expected_result
assert json.loads(io.files[Path("./call_tree.json").absolute()]) == expected_result