Skip to content

Commit

Permalink
refactor(stronger mypy checks): remove redundant casts
Browse files Browse the repository at this point in the history
  • Loading branch information
goerlibe committed Sep 11, 2023
1 parent 89971d0 commit 8f1acce
Show file tree
Hide file tree
Showing 22 changed files with 83 additions and 155 deletions.
16 changes: 7 additions & 9 deletions discopop_explorer/PETGraphX.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,7 @@ def get_nesting_level(self, pet: PETGraphX, return_invert_result: bool = True) -
parent_nesting_levels.append(
min(
2,
cast(LoopNode, parent_node).get_nesting_level(
pet, return_invert_result=False
),
parent_node.get_nesting_level(pet, return_invert_result=False),
)
)

Expand All @@ -322,7 +320,7 @@ def get_entry_node(self, pet: PETGraphX) -> Optional[Node]:
if pet.node_at(s) not in pet.direct_children(self)
]
if len(predecessors_outside_loop_body) > 0:
return cast(Node, node)
return node
return None


Expand Down Expand Up @@ -361,7 +359,7 @@ def get_entry_cu_id(self, pet: PETGraphX) -> NodeID:
def get_exit_cu_ids(self, pet: PETGraphX) -> Set[NodeID]:
exit_cu_ids: Set[NodeID] = set()
if self.children_cu_ids is not None:
for child_cu_id in cast(List[NodeID], self.children_cu_ids):
for child_cu_id in self.children_cu_ids:
if (
len(pet.out_edges(child_cu_id, EdgeType.SUCCESSOR)) == 0
and len(pet.in_edges(child_cu_id, EdgeType.SUCCESSOR)) != 0
Expand Down Expand Up @@ -1154,7 +1152,7 @@ def unused_check_alias(self, s: NodeID, t: NodeID, d: Dependency, root_loop: Nod
parent_func_source = self.get_parent_function(self.node_at(t))

res = False
d_var_name_str = cast(str, str(d.var_name))
d_var_name_str = str(d.var_name)

if self.unused_is_global(d_var_name_str, sub) and not (
self.is_passed_by_reference(d, parent_func_sink)
Expand Down Expand Up @@ -1278,7 +1276,7 @@ def get_variables(self, nodes: Sequence[Node]) -> Dict[Variable, Set[MemoryRegio
):
if dep.var_name == var_name.name:
if dep.memory_region is not None:
res[var_name].add(cast(MemoryRegion, dep.memory_region))
res[var_name].add(dep.memory_region)
return res

def get_undefined_variables_inside_loop(
Expand Down Expand Up @@ -1559,7 +1557,7 @@ def check_reachability(self, target: Node, source: Node, edge_types: List[EdgeTy
return True
else:
if e[0] not in visited:
queue.append(cast(Node, self.node_at(e[0])))
queue.append(self.node_at(e[0]))
return False

def is_predecessor(self, source_id: NodeID, target_id: NodeID) -> bool:
Expand Down Expand Up @@ -1696,7 +1694,7 @@ def get_memory_regions(self, nodes: List[CUNode], var_name: str) -> Set[MemoryRe
for s, t, d in out_deps:
if d.var_name == var_name:
if d.memory_region is not None:
mem_regs.add(cast(MemoryRegion, d.memory_region))
mem_regs.add(d.memory_region)
return mem_regs

def get_path_nodes_between(
Expand Down
2 changes: 1 addition & 1 deletion discopop_explorer/generate_Data_CUInst.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __output_dependencies_of_type(
+ dep_identifier
+ cast(str, dep[2].source_line)
+ "|"
+ cast(str, dep[2].var_name)
+ dep[2].var_name
+ ","
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ def get_cu_and_varname_to_memory_regions(
if dep.var_name is None or dep.memory_region is None or len(dep.memory_region) == 0:
continue
if dep.var_name not in result_dict[cu_id]:
result_dict[cu_id][VarName(cast(str, dep.var_name))] = set()
result_dict[cu_id][VarName(cast(str, dep.var_name))].add(
MemoryRegion(cast(str, dep.memory_region))
)
result_dict[cu_id][VarName(dep.var_name)] = set()
result_dict[cu_id][VarName(dep.var_name)].add(dep.memory_region)

return result_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,16 @@ def calculate_host_liveness(
for _, target, dep in out_data_edges:
if target in comb_gpu_reg.device_cu_ids:
if dep.var_name is not None:
shared_variables.add(VarName(cast(str, dep.var_name)))
shared_variables.add(VarName(dep.var_name))
if dep.memory_region is not None:
shared_memory_regions.add(MemoryRegion(cast(str, dep.memory_region)))
shared_memory_regions.add(MemoryRegion(dep.memory_region))

for source, _, dep in in_data_edges:
if source in comb_gpu_reg.device_cu_ids:
if dep.var_name is not None:
shared_variables.add(VarName(cast(str, dep.var_name)))
shared_variables.add(VarName(dep.var_name))
if dep.memory_region is not None:
shared_memory_regions.add(MemoryRegion(cast(str, dep.memory_region)))
shared_memory_regions.add(dep.memory_region)
for var_name in shared_variables:
if var_name not in host_liveness_lists:
host_liveness_lists[var_name] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,9 @@ def propagate_writes(
pet_node = pet.node_at(cu_id)
if type(pet_node) != CUNode:
continue
cu_node = cast(CUNode, pet_node)
if cu_node.return_instructions_count > 0:
if pet_node.return_instructions_count > 0:
# propagate write to calling cus
parent_function = pet.get_parent_function(cu_node)
parent_function = pet.get_parent_function(pet_node)
callees = [
s for s, t, d in pet.in_edges(parent_function.id, EdgeType.CALLSNODE)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def identify_updates(
# get parent functions
parent_functions: Set[NodeID] = set()
for region in comb_gpu_reg.contained_regions:
parent_functions.add(cast(NodeID, pet.get_parent_function(pet.node_at(region.node_id)).id))
parent_functions.add(pet.get_parent_function(pet.node_at(region.node_id)).id)

for parent_function_id in parent_functions:
print("IDENTIFY UPDATES FOR: ", pet.node_at(parent_function_id).name, file=sys.stderr)
Expand All @@ -410,7 +410,7 @@ def identify_updates(
]:
in_successor_edges = pet.in_edges(function_child_id, EdgeType.SUCCESSOR)
if len(in_successor_edges) == 0 and pet.node_at(function_child_id).type == NodeType.CU:
entry_points.append(cast(NodeID, function_child_id))
entry_points.append(function_child_id)

for entry_point in entry_points:
print(
Expand Down Expand Up @@ -780,7 +780,7 @@ def identify_updates_in_unrolled_function_graphs(
# get parent functions
parent_functions: Set[NodeID] = set()
for region in comb_gpu_reg.contained_regions:
parent_functions.add(cast(NodeID, pet.get_parent_function(pet.node_at(region.node_id)).id))
parent_functions.add(pet.get_parent_function(pet.node_at(region.node_id)).id)

for parent_function_id in parent_functions:
print("IDENTIFY UPDATES FOR: ", pet.node_at(parent_function_id).name, file=sys.stderr)
Expand All @@ -791,7 +791,7 @@ def identify_updates_in_unrolled_function_graphs(
]:
in_successor_edges = pet.in_edges(function_child_id, EdgeType.SUCCESSOR)
if len(in_successor_edges) == 0 and pet.node_at(function_child_id).type == NodeType.CU:
entry_points.append(cast(NodeID, function_child_id))
entry_points.append(function_child_id)

for entry_point in entry_points:
print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def setCollapseClause(self, pet: PETGraphX, node_id: NodeID, res):
# information from the LLVM debug information
# check if distance between first CU of node_id and cn_id is 2 steps on the successor graph
potentials: Set[Node] = set()
for succ1 in pet.direct_successors(cast(Node, loop_entry_node)):
for succ1 in pet.direct_successors(loop_entry_node):
for succ2 in pet.direct_successors(succ1):
potentials.add(succ2)
if cast(LoopNode, cn_id).get_entry_node(pet) in potentials:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def determineDataMapping(self) -> None:
if self.pet.node_at(source_cu_id) not in region_cus:
if dep.dtype == DepType.RAW:
if dep.var_name not in consumed_vars and dep.var_name is not None:
consumed_vars.append(cast(str, dep.var_name))
consumed_vars.append(dep.var_name)

# determine variables which are read afterwards and written in the region
produced_vars: List[str] = []
Expand All @@ -297,7 +297,7 @@ def determineDataMapping(self) -> None:
if self.pet.node_at(sink_cu_id) not in region_cus:
if dep.dtype in [DepType.RAW, DepType.WAW]:
if dep.var_name not in produced_vars and dep.var_name is not None:
produced_vars.append(cast(str, dep.var_name))
produced_vars.append(dep.var_name)

# gather consumed, produced, allocated and deleted variables from mapping information
map_to_vars: List[str] = []
Expand Down Expand Up @@ -351,10 +351,9 @@ def old_mapData(self) -> None:

for i in range(0, self.numRegions):
for j in self.cascadingLoopsInRegions[i]:
tmp_result = self.findGPULoop(j)
if tmp_result is None:
gpuLoop = self.findGPULoop(j)
if gpuLoop is None:
continue
gpuLoop: GPULoopPattern = cast(GPULoopPattern, tmp_result)
gpuLoop.printGPULoop()
print(f"==============================================")
for i in range(0, self.numRegions):
Expand All @@ -370,15 +369,14 @@ def old_mapData(self) -> None:
)
visitedVars: Set[Variable] = set()
while t >= 0:
tmp_result = self.findGPULoop(
loopIter = self.findGPULoop(
self.cascadingLoopsInRegions[i][t]
) # tmp_result contains GPU loops inside the parent region

if tmp_result is None:
if loopIter is None:
t -= 1
continue

loopIter: GPULoopPattern = cast(GPULoopPattern, tmp_result)
varis: Set[Variable] = set([])
varis.update(loopIter.map_type_alloc)
varis.update(loopIter.map_type_to)
Expand Down
22 changes: 7 additions & 15 deletions discopop_explorer/pattern_detectors/task_parallelism/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __filter_data_sharing_clauses_suppress_shared_loop_index(
# only consider task suggestions
if type(suggestion) != TaskParallelismInfo:
continue
suggestion = cast(TaskParallelismInfo, suggestion)
if suggestion.type is not TPIType.TASK:
continue
# get parent loops of suggestion
Expand Down Expand Up @@ -92,7 +91,6 @@ def __filter_data_sharing_clauses_by_function(
# only consider task suggestions
if type(suggestion) != TaskParallelismInfo:
continue
suggestion = cast(TaskParallelismInfo, suggestion)
if suggestion.type not in [TPIType.TASK, TPIType.TASKLOOP]:
continue
# get function containing the task cu
Expand Down Expand Up @@ -302,7 +300,6 @@ def __filter_data_sharing_clauses_by_scope(
# only consider task suggestions
if type(suggestion) != TaskParallelismInfo:
continue
suggestion = cast(TaskParallelismInfo, suggestion)
if suggestion.type is not TPIType.TASK:
continue
# get function containing the task cu
Expand Down Expand Up @@ -348,15 +345,14 @@ def __filter_sharing_clause(
# accept global vars
continue
# get CU which contains var_def_line
optional_var_def_cu: Optional[Node] = None
var_def_cu: Optional[Node] = None
for child_cu in get_cus_inside_function(pet, parent_function_cu):
if line_contained_in_region(
var_def_line, child_cu.start_position(), child_cu.end_position()
):
optional_var_def_cu = child_cu
if optional_var_def_cu is None:
var_def_cu = child_cu
if var_def_cu is None:
continue
var_def_cu = cast(Node, optional_var_def_cu)
# 1. check control flow (reverse BFS from suggestion._node to parent_function
if __reverse_reachable_w_o_breaker(
pet, pet.node_at(suggestion.node_id), parent_function_cu, var_def_cu, []
Expand Down Expand Up @@ -434,16 +430,14 @@ def remove_duplicate_data_sharing_clauses(suggestions: List[PatternInfo]) -> Lis
:param suggestions: List[PatternInfo]
:return: Modified List of PatternInfos
"""
result = []
result: List[PatternInfo] = []
for s in suggestions:
if not type(s) == TaskParallelismInfo:
result.append(s)
else:
s_tpi = cast(TaskParallelismInfo, s)
s_tpi.in_dep = list(set(s_tpi.in_dep))
s_tpi.out_dep = list(set(s_tpi.out_dep))
s_tpi.in_out_dep = list(set(s_tpi.in_out_dep))
s = cast(PatternInfo, s_tpi)
s.in_dep = list(set(s.in_dep))
s.out_dep = list(set(s.out_dep))
s.in_out_dep = list(set(s.in_out_dep))
result.append(s)
return result

Expand Down Expand Up @@ -716,7 +710,6 @@ def filter_data_depend_clauses(
# only consider task suggestions
if type(suggestion) != TaskParallelismInfo:
continue
suggestion = cast(TaskParallelismInfo, suggestion)
if suggestion.type not in [TPIType.TASK, TPIType.TASKLOOP]:
continue
for var in suggestion.in_dep:
Expand All @@ -741,7 +734,6 @@ def filter_data_depend_clauses(
# only consider task suggestions
if type(suggestion) != TaskParallelismInfo:
continue
suggestion = cast(TaskParallelismInfo, suggestion)
if suggestion.type not in [TPIType.TASK, TPIType.TASKLOOP]:
continue
# get function containing the task cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@ def group_task_suggestions(pet: PETGraphX, suggestions: List[PatternInfo]) -> Li
:return: Updated suggestions"""
task_suggestions = [
s
for s in [
cast(TaskParallelismInfo, e) for e in suggestions if type(e) == TaskParallelismInfo
]
for s in [e for e in suggestions if type(e) == TaskParallelismInfo]
if s.type is TPIType.TASK
]
taskwait_suggestions = [
s
for s in [
cast(TaskParallelismInfo, e) for e in suggestions if type(e) == TaskParallelismInfo
]
for s in [e for e in suggestions if type(e) == TaskParallelismInfo]
if s.type is TPIType.TASKWAIT
]
# mark preceeding suggestions for each taskwait suggestion
Expand Down Expand Up @@ -116,7 +112,7 @@ def group_task_suggestions(pet: PETGraphX, suggestions: List[PatternInfo]) -> Li
)
# valid, overwrite sug.task_group if value is not None
if value is not None:
sug.task_group = [cast(int, value)]
sug.task_group = [value]
return suggestions


Expand All @@ -130,30 +126,13 @@ def sort_output(suggestions: List[PatternInfo]) -> List[PatternInfo]:
sorted_suggestions = []
tmp_dict: Dict[str, List[Tuple[str, PatternInfo]]] = dict()
for sug in suggestions:
if type(sug) == ParallelRegionInfo:
sug_par = cast(ParallelRegionInfo, sug)
# Note: Code duplicated for type correctness
# get start_line and file_id for sug
if ":" not in sug_par.region_start_line:
start_line = sug_par.region_start_line
file_id = sug_par.start_line[0 : sug_par.start_line.index(":")]
else:
start_line = sug_par.region_start_line
file_id = start_line[0 : start_line.index(":")]
start_line = start_line[start_line.index(":") + 1 :]
# split suggestions by file-id
if file_id not in tmp_dict:
tmp_dict[file_id] = []
tmp_dict[file_id].append((start_line, sug))
elif type(sug) == TaskParallelismInfo:
sug_task = cast(TaskParallelismInfo, sug)
# Note: Code duplicated for type correctness
if isinstance(sug, (ParallelRegionInfo, TaskParallelismInfo)):
# get start_line and file_id for sug
if ":" not in sug_task.region_start_line:
start_line = sug_task.region_start_line
file_id = sug_task.start_line[0 : sug_task.start_line.index(":")]
if ":" not in sug.region_start_line:
start_line = sug.region_start_line
file_id = sug.start_line[0 : sug.start_line.index(":")]
else:
start_line = sug_task.region_start_line
start_line = sug.region_start_line
file_id = start_line[0 : start_line.index(":")]
start_line = start_line[start_line.index(":") + 1 :]
# split suggestions by file-id
Expand Down
Loading

0 comments on commit 8f1acce

Please sign in to comment.