Skip to content

Commit

Permalink
schedule: update types
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored and inducer committed Nov 13, 2024
1 parent 920cb49 commit 57f6654
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 60 deletions.
85 changes: 38 additions & 47 deletions loopy/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,11 @@

import logging
import sys
from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
from dataclasses import dataclass, replace
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Dict,
FrozenSet,
Hashable,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
)

Expand Down Expand Up @@ -155,7 +146,7 @@ class Barrier(ScheduleItem):

def gather_schedule_block(
schedule: Sequence[ScheduleItem], start_idx: int
) -> Tuple[Sequence[ScheduleItem], int]:
) -> tuple[Sequence[ScheduleItem], int]:
assert isinstance(schedule[start_idx], BeginBlockItem)
level = 0

Expand All @@ -176,7 +167,7 @@ def gather_schedule_block(

def generate_sub_sched_items(
schedule: Sequence[ScheduleItem], start_idx: int
) -> Iterator[Tuple[int, ScheduleItem]]:
) -> Iterator[tuple[int, ScheduleItem]]:
if not isinstance(schedule[start_idx], BeginBlockItem):
yield start_idx, schedule[start_idx]

Expand All @@ -203,7 +194,7 @@ def generate_sub_sched_items(

def get_insn_ids_for_block_at(
schedule: Sequence[ScheduleItem], start_idx: int
) -> FrozenSet[str]:
) -> frozenset[str]:
return frozenset(
sub_sched_item.insn_id
for i, sub_sched_item in generate_sub_sched_items(
Expand All @@ -212,7 +203,7 @@ def get_insn_ids_for_block_at(


def find_used_inames_within(
kernel: LoopKernel, sched_index: int) -> AbstractSet[str]:
kernel: LoopKernel, sched_index: int) -> set[str]:
assert kernel.linearization is not None
sched_item = kernel.linearization[sched_index]

Expand All @@ -234,7 +225,7 @@ def find_used_inames_within(
return result


def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]:
def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other inames that are
always nested with them.
"""
Expand All @@ -257,11 +248,11 @@ def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]
return result


def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]:
def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other inames that are
always nested around them.
"""
result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}

all_inames = kernel.all_inames()

Expand Down Expand Up @@ -299,14 +290,14 @@ def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[st

def find_loop_insn_dep_map(
kernel: LoopKernel,
loop_nest_with_map: Mapping[str, AbstractSet[str]],
loop_nest_around_map: Mapping[str, AbstractSet[str]]
) -> Mapping[str, AbstractSet[str]]:
loop_nest_with_map: Mapping[str, Set[str]],
loop_nest_around_map: Mapping[str, Set[str]]
) -> Mapping[str, set[str]]:
"""Returns a dictionary mapping inames to other instruction ids that need to
be scheduled before the iname should be eligible for scheduling.
"""

result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}

from loopy.kernel.data import ConcurrentTag, IlpBaseTag
for insn in kernel.instructions:
Expand Down Expand Up @@ -372,7 +363,7 @@ def find_loop_insn_dep_map(


def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]:
result: Dict[str, int] = {}
result: dict[str, int] = {}

for insn in kernel.instructions:
for grp in insn.groups:
Expand All @@ -382,7 +373,7 @@ def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]:


def gen_dependencies_except(
kernel: LoopKernel, insn_id: str, except_insn_ids: AbstractSet[str]
kernel: LoopKernel, insn_id: str, except_insn_ids: Set[str]
) -> Iterator[str]:
insn = kernel.id_to_insn[insn_id]
for dep_id in insn.depends_on:
Expand All @@ -396,9 +387,9 @@ def gen_dependencies_except(


def get_priority_tiers(
wanted: AbstractSet[int],
priorities: AbstractSet[Sequence[int]]
) -> Iterator[AbstractSet[int]]:
wanted: Set[int],
priorities: Set[Sequence[int]]
) -> Iterator[set[int]]:
# Get highest priority tier candidates: These are the first inames
# of all the given priority constraints
candidates = set()
Expand Down Expand Up @@ -677,32 +668,32 @@ class SchedulerState:
order with instruction priorities as tie breaker.
"""
kernel: LoopKernel
loop_nest_around_map: Mapping[str, AbstractSet[str]]
loop_insn_dep_map: Mapping[str, AbstractSet[str]]
loop_nest_around_map: Mapping[str, set[str]]
loop_insn_dep_map: Mapping[str, set[str]]

breakable_inames: AbstractSet[str]
ilp_inames: AbstractSet[str]
vec_inames: AbstractSet[str]
concurrent_inames: AbstractSet[str]
breakable_inames: set[str]
ilp_inames: set[str]
vec_inames: set[str]
concurrent_inames: set[str]

insn_ids_to_try: Optional[AbstractSet[str]]
insn_ids_to_try: set[str] | None
active_inames: Sequence[str]
entered_inames: FrozenSet[str]
enclosing_subkernel_inames: Tuple[str, ...]
entered_inames: frozenset[str]
enclosing_subkernel_inames: tuple[str, ...]
schedule: Sequence[ScheduleItem]
scheduled_insn_ids: AbstractSet[str]
unscheduled_insn_ids: AbstractSet[str]
scheduled_insn_ids: frozenset[str]
unscheduled_insn_ids: set[str]
preschedule: Sequence[ScheduleItem]
prescheduled_insn_ids: AbstractSet[str]
prescheduled_inames: AbstractSet[str]
prescheduled_insn_ids: set[str]
prescheduled_inames: set[str]
may_schedule_global_barriers: bool
within_subkernel: bool
group_insn_counts: Mapping[str, int]
active_group_counts: Mapping[str, int]
insns_in_topologically_sorted_order: Sequence[InstructionBase]

@property
def last_entered_loop(self) -> Optional[str]:
def last_entered_loop(self) -> str | None:
if self.active_inames:
return self.active_inames[-1]
else:
Expand All @@ -718,7 +709,7 @@ def get_insns_in_topologically_sorted_order(
kernel: LoopKernel) -> Sequence[InstructionBase]:
from pytools.graph import compute_topological_order

rev_dep_map: Dict[str, Set[str]] = {
rev_dep_map: dict[str, set[str]] = {
not_none(insn.id): set() for insn in kernel.instructions}
for insn in kernel.instructions:
for dep in insn.depends_on:
Expand All @@ -733,7 +724,7 @@ def get_insns_in_topologically_sorted_order(
# Instead of returning these features as a key, we assign an id to
# each set of features to avoid comparing them which can be expensive.
insn_id_to_feature_id = {}
insn_features: Dict[Hashable, int] = {}
insn_features: dict[Hashable, int] = {}
for insn in kernel.instructions:
feature = (insn.within_inames, insn.groups, insn.conflicts_with_groups)
if feature not in insn_features:
Expand Down Expand Up @@ -890,7 +881,7 @@ def _get_outermost_diverging_inames(
tree: LoopTree,
within1: InameStrSet,
within2: InameStrSet
) -> Tuple[InameStr, InameStr]:
) -> tuple[InameStr, InameStr]:
"""
For loop nestings *within1* and *within2*, returns the first inames at which
the loops nests diverge in the loop nesting tree *tree*.
Expand Down Expand Up @@ -2180,7 +2171,7 @@ def __init__(self, kernel):
def generate_loop_schedules(
kernel: LoopKernel,
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]:
debug_args: Mapping[str, Any] | None = None) -> Iterator[LoopKernel]:
"""
.. warning::
Expand Down Expand Up @@ -2236,7 +2227,7 @@ def _postprocess_schedule(kernel, callables_table, gen_sched):
def _generate_loop_schedules_inner(
kernel: LoopKernel,
callables_table: CallablesTable,
debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]:
debug_args: Mapping[str, Any] | None) -> Iterator[LoopKernel]:
if debug_args is None:
debug_args = {}

Expand Down Expand Up @@ -2337,7 +2328,7 @@ def _generate_loop_schedules_inner(
get_insns_in_topologically_sorted_order(kernel)),
)

schedule_gen_kwargs: Dict[str, Any] = {}
schedule_gen_kwargs: dict[str, Any] = {}

def print_longest_dead_end():
if debug.interactive:
Expand Down Expand Up @@ -2402,7 +2393,7 @@ def print_longest_dead_end():


schedule_cache: WriteOncePersistentDict[
Tuple[LoopKernel, CallablesTable],
tuple[LoopKernel, CallablesTable],
LoopKernel
] = WriteOncePersistentDict(
"loopy-schedule-cache-v4-"+DATA_MODEL_VERSION,
Expand Down
23 changes: 12 additions & 11 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
THE SOFTWARE.
"""

from collections.abc import Hashable, Iterator, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, Hashable, Iterator, List, Optional, Sequence, Tuple, TypeVar
from typing import Generic, TypeVar

from immutables import Map

Expand Down Expand Up @@ -70,11 +71,11 @@ class Tree(Generic[NodeT]):
this allocates a new stack frame for each iteration of the operation.
"""

_parent_to_children: Map[NodeT, Tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, Optional[NodeT]]
_parent_to_children: Map[NodeT, tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, NodeT | None]

@staticmethod
def from_root(root: NodeT) -> "Tree[NodeT]":
def from_root(root: NodeT) -> Tree[NodeT]:
return Tree(Map({root: ()}),
Map({root: None}))

Expand All @@ -89,7 +90,7 @@ def root(self) -> NodeT:
return guess

@memoize_method
def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]:
def ancestors(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns a :class:`tuple` of nodes that are ancestors of *node*.
"""
Expand All @@ -104,15 +105,15 @@ def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]:

return (parent,) + self.ancestors(parent)

def parent(self, node: NodeT) -> Optional[NodeT]:
def parent(self, node: NodeT) -> NodeT | None:
"""
Returns the parent of *node*.
"""
assert node in self

return self._child_to_parent[node]

def children(self, node: NodeT) -> Tuple[NodeT, ...]:
def children(self, node: NodeT) -> tuple[NodeT, ...]:
"""
Returns the children of *node*.
"""
Expand Down Expand Up @@ -150,7 +151,7 @@ def __contains__(self, node: NodeT) -> bool:
"""Return *True* if *node* is a node in the tree."""
return node in self._child_to_parent

def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]:
"""
Returns a :class:`Tree` with added node *node* having a parent
*parent*.
Expand All @@ -165,7 +166,7 @@ def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
.set(node, ())),
self._child_to_parent.set(node, parent))

def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]":
def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:
"""
Returns a copy of *self* with *node* replaced with *new_node*.
"""
Expand Down Expand Up @@ -207,7 +208,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]":
return Tree(parent_to_children_mut.finish(),
child_to_parent_mut.finish())

def move_node(self, node: NodeT, new_parent: Optional[NodeT]) -> "Tree[NodeT]":
def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
"""
Returns a copy of *self* with node *node* as a child of *new_parent*.
"""
Expand Down Expand Up @@ -262,7 +263,7 @@ def __str__(self) -> str:
├── D
└── E
"""
def rec(node: NodeT) -> List[str]:
def rec(node: NodeT) -> list[str]:
children_result = [rec(c) for c in self.children(node)]

def post_process_non_last_child(children: Sequence[str]) -> list[str]:
Expand Down
5 changes: 3 additions & 2 deletions loopy/transform/precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def storage_axis_exprs(storage_axis_sources, args) -> Sequence[ExpressionT]:
# {{{ gather rule invocations

class RuleInvocationGatherer(RuleAwareIdentityMapper):
def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within) \
-> None:
super().__init__(rule_mapping_context)

from loopy.symbolic import SubstitutionRuleExpander
Expand All @@ -167,7 +168,7 @@ def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within):
self.subst_tag = subst_tag
self.within = within

self.access_descriptors: List[RuleAccessDescriptor] = []
self.access_descriptors: list[RuleAccessDescriptor] = []

def map_substitution(self, name, tag, arguments, expn_state):
process_me = name == self.subst_name
Expand Down

0 comments on commit 57f6654

Please sign in to comment.