Skip to content

Update to work with new networkx dispatching #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Check with twine
run: python -m twine check --strict dist/*
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@v1.8.6
uses: pypa/gh-action-pypi-publish@v1.8.10
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
5 changes: 3 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
activate-environment: testing
- name: Install dependencies
run: |
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly
conda install -c conda-forge python-graphblas scipy pandas pytest-cov pytest-randomly pytest-mpl
# matplotlib lxml pygraphviz pydot sympy # Extra networkx deps we don't need yet
pip install git+https://github.com/networkx/networkx.git@main --no-deps
pip install -e . --no-deps
Expand All @@ -39,7 +39,8 @@ jobs:
python -c 'import sys, graphblas_algorithms; assert "networkx" not in sys.modules'
coverage run --branch -m pytest --color=yes -v --check-structure
coverage report
NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
# NETWORKX_GRAPH_CONVERT=graphblas pytest --color=yes --pyargs networkx --cov --cov-append
./run_nx_tests.sh --color=yes --cov --cov-append
coverage report
coverage xml
- name: Coverage
Expand Down
30 changes: 18 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ci:
# See: https://pre-commit.ci/#configuration
autofix_prs: false
autoupdate_schedule: monthly
autoupdate_schedule: quarterly
skip: [no-commit-to-branch]
fail_fast: true
default_language_version:
Expand All @@ -17,21 +17,27 @@ repos:
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-ast
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
exclude_types: [svg]
- id: mixed-line-ending
- id: trailing-whitespace
- id: name-tests-test
args: ["--pytest-test-first"]
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.13
rev: v0.14
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
# I don't yet trust ruff to do what autoflake does
- repo: https://github.com/PyCQA/autoflake
rev: v2.1.1
rev: v2.2.0
hooks:
- id: autoflake
args: [--in-place]
Expand All @@ -40,7 +46,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -50,38 +56,38 @@ repos:
- id: auto-walrus
args: [--line-length, "100"]
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
# - id: black-jupyter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.270
rev: v0.0.285
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==6.0.0
- flake8-bugbear==23.5.9
- flake8==6.1.0
- flake8-bugbear==23.7.10
- flake8-simplify==0.20.0
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
rev: v1.5.0
hooks:
- id: yesqa
additional_dependencies: *flake8_dependencies
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
rev: v2.2.5
hooks:
- id: codespell
types_or: [python, rst, markdown]
additional_dependencies: [tomli]
files: ^(graphblas_algorithms|docs)/
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.270
rev: v0.0.285
hooks:
- id: ruff
# `pyroma` may help keep our package standards up to date if best practices change.
Expand Down
69 changes: 60 additions & 9 deletions graphblas_algorithms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,71 @@ class Dispatcher:
# End auto-generated code: dispatch

@staticmethod
def convert_from_nx(graph, weight=None, *, name=None):
def convert_from_nx(
graph,
edge_attrs=None,
node_attrs=None,
preserve_edge_attrs=False,
preserve_node_attrs=False,
preserve_graph_attrs=False,
name=None,
graph_name=None,
*,
weight=None, # For nx.__version__ <= 3.1
):
import networkx as nx

from .classes import DiGraph, Graph, MultiDiGraph, MultiGraph

if preserve_edge_attrs:
if graph.is_multigraph():
attrs = set().union(
*(
datadict
for nbrs in graph._adj.values()
for keydict in nbrs.values()
for datadict in keydict.values()
)
)
else:
attrs = set().union(
*(datadict for nbrs in graph._adj.values() for datadict in nbrs.values())
)
if len(attrs) == 1:
[attr] = attrs
edge_attrs = {attr: None}
elif attrs:
raise NotImplementedError("`preserve_edge_attrs=True` is not fully implemented")
if node_attrs:
raise NotImplementedError("non-None `node_attrs` is not yet implemented")
if preserve_node_attrs:
attrs = set().union(*(datadict for node, datadict in graph.nodes(data=True)))
if attrs:
raise NotImplementedError("`preserve_node_attrs=True` is not implemented")
if edge_attrs:
if len(edge_attrs) > 1:
raise NotImplementedError(
"Multiple edge attributes is not implemented (bad value for edge_attrs)"
)
if weight is not None:
raise TypeError("edge_attrs and weight both given")
[[weight, default]] = edge_attrs.items()
if default is not None and default != 1:
raise NotImplementedError(f"edge default != 1 is not implemented; got {default}")

if isinstance(graph, nx.MultiDiGraph):
return MultiDiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.MultiGraph):
return MultiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.DiGraph):
return DiGraph.from_networkx(graph, weight=weight)
if isinstance(graph, nx.Graph):
return Graph.from_networkx(graph, weight=weight)
raise TypeError(f"Unsupported type of graph: {type(graph)}")
G = MultiDiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.MultiGraph):
G = MultiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.DiGraph):
G = DiGraph.from_networkx(graph, weight=weight)
elif isinstance(graph, nx.Graph):
G = Graph.from_networkx(graph, weight=weight)
else:
raise TypeError(f"Unsupported type of graph: {type(graph)}")
if preserve_graph_attrs:
G.graph.update(graph.graph)
return G

@staticmethod
def convert_to_nx(obj, *, name=None):
Expand Down
24 changes: 21 additions & 3 deletions graphblas_algorithms/tests/test_match_nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,29 @@
"Matching networkx namespace requires networkx to be installed", allow_module_level=True
)
else:
from networkx.classes import backends # noqa: F401
try:
from networkx.utils import backends

IS_NX_30_OR_31 = False
except ImportError: # pragma: no cover (import)
# This is the location in nx 3.1
from networkx.classes import backends # noqa: F401

IS_NX_30_OR_31 = True


def isdispatched(func):
"""Can this NetworkX function dispatch to other backends?"""
if IS_NX_30_OR_31:
return (
callable(func)
and hasattr(func, "dispatchname")
and func.__module__.startswith("networkx")
)
return (
callable(func) and hasattr(func, "dispatchname") and func.__module__.startswith("networkx")
callable(func)
and hasattr(func, "preserve_edge_attrs")
and func.__module__.startswith("networkx")
)


Expand All @@ -37,7 +53,9 @@ def dispatchname(func):
# Haha, there should be a better way to get this
if not isdispatched(func):
raise ValueError(f"Function is not dispatched in NetworkX: {func.__name__}")
return func.dispatchname
if IS_NX_30_OR_31:
return func.dispatchname
return func.name


def fullname(func):
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@ ignore = [
"RET502", # Do not implicitly `return None` in function able to return non-`None` value
"RET503", # Missing explicit `return` at the end of function able to return non-`None` value
"RET504", # Unnecessary variable assignment before `return` statement
"RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` (Note: no annotations yet)
"S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log)
"S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log)
"SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary)
"SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster)
"SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer)
"TRY003", # Avoid specifying long messages outside the exception class (Note: why?)
"FIX001", "FIX002", "FIX003", "FIX004", # flake8-fixme (like flake8-todos)

# Ignored categories
"C90", # mccabe (Too strict, but maybe we should make things less complex)
Expand Down
7 changes: 5 additions & 2 deletions run_nx_tests.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/bin/bash
NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx "$@"
# NETWORKX_GRAPH_CONVERT=graphblas pytest --pyargs networkx --cov --cov-report term-missing "$@"
NETWORKX_GRAPH_CONVERT=graphblas \
NETWORKX_TEST_BACKEND=graphblas \
NETWORKX_FALLBACK_TO_NX=True \
pytest --pyargs networkx "$@"
# pytest --pyargs networkx --cov --cov-report term-missing "$@"
4 changes: 2 additions & 2 deletions scripts/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

datapaths = [
Path(__file__).parent / ".." / "data",
Path("."),
Path(),
]


Expand All @@ -37,7 +37,7 @@ def find_data(dataname):
if dataname not in download_data.data_urls:
raise FileNotFoundError(f"Unable to find data file for {dataname}")
curpath = Path(download_data.main([dataname])[0])
return curpath.resolve().relative_to(Path(".").resolve())
return curpath.resolve().relative_to(Path().resolve())


def get_symmetry(file_or_mminfo):
Expand Down
2 changes: 1 addition & 1 deletion scripts/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(datanames, overwrite=False):
for name in datanames:
target = datapath / f"{name}.mtx"
filenames.append(target)
relpath = target.resolve().relative_to(Path(".").resolve())
relpath = target.resolve().relative_to(Path().resolve())
if not overwrite and target.exists():
print(f"{relpath} already exists; skipping", file=sys.stderr)
continue
Expand Down