Skip to content

Commit 14f23e0

Browse files
committed
Merge branch 'update-pre-commit' into py312
2 parents e9efa5d + ea1617f commit 14f23e0

33 files changed

+104
-70
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ exclude: |
77
)$
88
repos:
99
- repo: https://github.com/pre-commit/pre-commit-hooks
10-
rev: v4.4.0
10+
rev: v4.5.0
1111
hooks:
1212
- id: debug-statements
1313
exclude: |
@@ -20,23 +20,23 @@ repos:
2020
)$
2121
- id: check-merge-conflict
2222
- repo: https://github.com/asottile/pyupgrade
23-
rev: v3.3.1
23+
rev: v3.15.0
2424
hooks:
2525
- id: pyupgrade
2626
args: [--py39-plus]
2727
- repo: https://github.com/psf/black
28-
rev: 23.1.0
28+
rev: 23.12.1
2929
hooks:
3030
- id: black
3131
language_version: python3
3232
- repo: https://github.com/pycqa/flake8
33-
rev: 6.0.0
33+
rev: 7.0.0
3434
hooks:
3535
- id: flake8
3636
additional_dependencies:
3737
- flake8-comprehensions
3838
- repo: https://github.com/pycqa/isort
39-
rev: 5.12.0
39+
rev: 5.13.2
4040
hooks:
4141
- id: isort
4242
- repo: https://github.com/humitos/mirrors-autoflake.git
@@ -54,7 +54,7 @@ repos:
5454
)$
5555
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
5656
- repo: https://github.com/pre-commit/mirrors-mypy
57-
rev: v1.0.0
57+
rev: v1.8.0
5858
hooks:
5959
- id: mypy
6060
language: python

pytensor/compile/debugmode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def _lessbroken_deepcopy(a):
689689
else:
690690
rval = copy.deepcopy(a)
691691

692-
assert type(rval) == type(a), (type(rval), type(a))
692+
assert type(rval) is type(a), (type(rval), type(a))
693693

694694
if isinstance(rval, np.ndarray):
695695
assert rval.dtype == a.dtype
@@ -1156,7 +1156,7 @@ def __str__(self):
11561156
return str(self.__dict__)
11571157

11581158
def __eq__(self, other):
1159-
rval = type(self) == type(other)
1159+
rval = type(self) is type(other)
11601160
if rval:
11611161
# nodes are not compared because this comparison is
11621162
# supposed to be true for corresponding events that happen

pytensor/compile/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape):
246246
self.infer_shape = self._infer_shape
247247

248248
def __eq__(self, other):
249-
return type(self) == type(other) and self.__fn == other.__fn
249+
return type(self) is type(other) and self.__fn == other.__fn
250250

251251
def __hash__(self):
252252
return hash(type(self)) ^ hash(self.__fn)

pytensor/compile/profiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,8 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
10841084
viewof_change = []
10851085
# Use to track view_of changes
10861086

1087-
viewedby_add = defaultdict(lambda: [])
1088-
viewedby_remove = defaultdict(lambda: [])
1087+
viewedby_add = defaultdict(list)
1088+
viewedby_remove = defaultdict(list)
10891089
# Use to track viewed_by changes
10901090

10911091
for var in node.outputs:

pytensor/graph/basic.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TypeVar,
2424
Union,
2525
cast,
26+
overload,
2627
)
2728

2829
import numpy as np
@@ -718,7 +719,7 @@ def __eq__(self, other):
718719
return True
719720

720721
return (
721-
type(self) == type(other)
722+
type(self) is type(other)
722723
and self.id == other.id
723724
and self.type == other.type
724725
)
@@ -1301,9 +1302,31 @@ def clone_get_equiv(
13011302
return memo
13021303

13031304

1305+
@overload
1306+
def general_toposort(
1307+
outputs: Iterable[T],
1308+
deps: None,
1309+
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
1310+
deps_cache: Optional[dict[T, list[T]]],
1311+
clients: Optional[dict[T, list[T]]],
1312+
) -> list[T]:
1313+
...
1314+
1315+
1316+
@overload
13041317
def general_toposort(
13051318
outputs: Iterable[T],
13061319
deps: Callable[[T], Union[OrderedSet, list[T]]],
1320+
compute_deps_cache: None,
1321+
deps_cache: None,
1322+
clients: Optional[dict[T, list[T]]],
1323+
) -> list[T]:
1324+
...
1325+
1326+
1327+
def general_toposort(
1328+
outputs: Iterable[T],
1329+
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
13071330
compute_deps_cache: Optional[
13081331
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
13091332
] = None,
@@ -1345,7 +1368,7 @@ def general_toposort(
13451368
if deps_cache is None:
13461369
deps_cache = {}
13471370

1348-
def _compute_deps_cache(io):
1371+
def _compute_deps_cache_(io):
13491372
if io not in deps_cache:
13501373
d = deps(io)
13511374

@@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
13631386
else:
13641387
return deps_cache[io]
13651388

1389+
_compute_deps_cache = _compute_deps_cache_
1390+
13661391
else:
13671392
_compute_deps_cache = compute_deps_cache
13681393

@@ -1451,15 +1476,14 @@ def io_toposort(
14511476
)
14521477
return order
14531478

1454-
compute_deps = None
1455-
compute_deps_cache = None
14561479
iset = set(inputs)
1457-
deps_cache: dict = {}
14581480

14591481
if not orderings: # ordering can be None or empty dict
14601482
# Specialized function that is faster when no ordering.
14611483
# Also include the cache in the function itself for speed up.
14621484

1485+
deps_cache: dict = {}
1486+
14631487
def compute_deps_cache(obj):
14641488
if obj in deps_cache:
14651489
return deps_cache[obj]
@@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
14781502
deps_cache[obj] = rval
14791503
return rval
14801504

1505+
topo = general_toposort(
1506+
outputs,
1507+
deps=None,
1508+
compute_deps_cache=compute_deps_cache,
1509+
deps_cache=deps_cache,
1510+
clients=clients,
1511+
)
1512+
14811513
else:
14821514
# the inputs are used only here in the function that decides what
14831515
# 'predecessors' to explore
@@ -1494,13 +1526,13 @@ def compute_deps(obj):
14941526
assert not orderings.get(obj, None)
14951527
return rval
14961528

1497-
topo = general_toposort(
1498-
outputs,
1499-
deps=compute_deps,
1500-
compute_deps_cache=compute_deps_cache,
1501-
deps_cache=deps_cache,
1502-
clients=clients,
1503-
)
1529+
topo = general_toposort(
1530+
outputs,
1531+
deps=compute_deps,
1532+
compute_deps_cache=None,
1533+
deps_cache=None,
1534+
clients=clients,
1535+
)
15041536
return [o for o in topo if isinstance(o, Apply)]
15051537

15061538

pytensor/graph/null_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True):
3333
raise ValueError("NullType has no values to compare")
3434

3535
def __eq__(self, other):
36-
return type(self) == type(other)
36+
return type(self) is type(other)
3737

3838
def __hash__(self):
3939
return hash(type(self))

pytensor/graph/rewriting/basic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -951,8 +951,8 @@ class MetaNodeRewriter(NodeRewriter):
951951

952952
def __init__(self):
953953
self.verbose = config.metaopt__verbose
954-
self.track_dict = defaultdict(lambda: [])
955-
self.tag_dict = defaultdict(lambda: [])
954+
self.track_dict = defaultdict(list)
955+
self.tag_dict = defaultdict(list)
956956
self._tracks = []
957957
self.rewriters = []
958958

@@ -2406,13 +2406,15 @@ def importer(node):
24062406
if node is not current_node:
24072407
q.append(node)
24082408

2409-
chin = None
2409+
chin: Optional[Callable] = None
24102410
if self.tracks_on_change_inputs:
24112411

2412-
def chin(node, i, r, new_r, reason):
2412+
def chin_(node, i, r, new_r, reason):
24132413
if node is not current_node and not isinstance(node, str):
24142414
q.append(node)
24152415

2416+
chin = chin_
2417+
24162418
u = self.attach_updater(
24172419
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
24182420
)

pytensor/graph/rewriting/unify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def __new__(cls, constraint, token=None, prefix=""):
5858
return obj
5959

6060
def __eq__(self, other):
61-
if type(self) == type(other):
62-
return self.token == other.token and self.constraint == other.constraint
61+
if type(self) is type(other):
62+
return self.token is other.token and self.constraint == other.constraint
6363
return NotImplemented
6464

6565
def __hash__(self):

pytensor/graph/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __hash__(self):
229229
if "__eq__" not in dct:
230230

231231
def __eq__(self, other):
232-
return type(self) == type(other) and tuple(
232+
return type(self) is type(other) and tuple(
233233
getattr(self, a) for a in props
234234
) == tuple(getattr(other, a) for a in props)
235235

pytensor/ifelse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None):
7878
self.name = name
7979

8080
def __eq__(self, other):
81-
if type(self) != type(other):
81+
if type(self) is not type(other):
8282
return False
8383
if self.as_view != other.as_view:
8484
return False

0 commit comments

Comments
 (0)