Skip to content

Commit

Permalink
Compute kernel naming fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jack-melchert committed Jan 19, 2024
1 parent bd1caf9 commit 7da2f1a
Showing 1 changed file with 78 additions and 87 deletions.
165 changes: 78 additions & 87 deletions archipelago/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,45 +443,44 @@ def branch_delay_match_within_kernels(
if match is not None:
ports_with_unique_latenices[match] = []
for kernel_port, d1 in latency_dict.items():
for port_num, d2 in d1.items():
if d2["pe_port"] != []:
port_nodes = []
for compute_file_tile, compute_file_port in d2["pe_port"]:
found = False
for pe in graph.get_tiles():
if (
graph.id_to_name[str(pe)]
== f"{match}$inner_compute${compute_file_tile}"
):
found_port = False
for source in graph.sources[pe]:
if source.port in port_remap_r:
port = port_remap_r[source.port]
if d1["pe_port"] != []:
port_nodes = []
for compute_file_tile, compute_file_port in d1["pe_port"]:
found = False
for pe in graph.get_tiles():
if (
graph.id_to_name[str(pe)]
== f"{match}$inner_compute${compute_file_tile}"
):
found_port = False
for source in graph.sources[pe]:
if source.port in port_remap_r:
port = port_remap_r[source.port]
if port == compute_file_port:
found = True
found_port = True
port_nodes.append(source)

for source_node, dest_node in graph.removed_edges:
if dest_node == pe:
if source_node.port in port_remap_r:
port = port_remap_r[source_node.port]
if port == compute_file_port:
found = True
found_port = True
port_nodes.append(source)

for source_node, dest_node in graph.removed_edges:
if dest_node == pe:
if source_node.port in port_remap_r:
port = port_remap_r[source_node.port]
if port == compute_file_port:
found = True
found_port = True
port_nodes.append(source_node)
port_nodes.append(source_node)

if not found_port:
print("Couldn't find pe port")
print(latency_dict)
breakpoint()
if not found_port:
print("Couldn't find pe port")
print(latency_dict)
breakpoint()

if not found:
print("Couldn't find pe")
print(latency_dict)
breakpoint()
if not found:
print("Couldn't find pe")
print(latency_dict)
breakpoint()

ports_with_unique_latenices[match].append(port_nodes)
ports_with_unique_latenices[match].append(port_nodes)

# Then branch delay match the nodes without unique latencies
for kernel in node_cycles:
Expand Down Expand Up @@ -708,51 +707,48 @@ def calculate_latencies(
match = find_closest_match(kernel, list(node_latencies.keys()))
if match is not None:
for kernel_port, d1 in latency_dict.items():
for port_num, d2 in d1.items():
if d2["pe_port"] == [] and match in max_latencies:
kernel_latencies[kernel][kernel_port][port_num][
"latency"
] = max_latencies[match]
elif d2["pe_port"] != []:
found = False
for compute_file_tile, compute_file_port in d2["pe_port"]:
# Within this loop, all the ports should have the same latency
found_lat = None
for pe in graph.get_tiles():
if (
graph.id_to_name[str(pe)]
== f"{match}$inner_compute${compute_file_tile}"
):
found_port = False
for source in graph.sources[pe]:
if source.port in port_remap_r:
port = port_remap_r[source.port]
if port == compute_file_port:
reg = graph.get_connected_reg(source)
if reg is not None:
lat = node_latencies[match][reg]
else:
lat = node_latencies[match][source]

if found_lat is not None:
assert (
lat == found_lat
), f"Found multiple latencies for {kernel} {kernel_port} {port_num} {compute_file_tile} {compute_file_port} {lat} {found_lat}"
kernel_latencies[kernel][kernel_port][
port_num
]["latency"] = lat
found = True
found_port = True
found_lat = lat
break
if not found_port:
found = True
kernel_latencies[kernel][kernel_port][port_num][
"latency"
] = node_latencies[match][graph.sources[pe][0]]

if not found:
print("Couldn't find tile port in kernel latencies", kernel)
if d1["pe_port"] == [] and match in max_latencies:
kernel_latencies[kernel][kernel_port][
"latency"
] = max_latencies[match]
elif d1["pe_port"] != []:
found = False
for compute_file_tile, compute_file_port in d1["pe_port"]:
# Within this loop, all the ports should have the same latency
found_lat = None
for pe in graph.get_tiles():
if (
graph.id_to_name[str(pe)]
== f"{match}$inner_compute${compute_file_tile}"
):
found_port = False
for source in graph.sources[pe]:
if source.port in port_remap_r:
port = port_remap_r[source.port]
if port == compute_file_port:
reg = graph.get_connected_reg(source)
if reg is not None:
lat = node_latencies[match][reg]
else:
lat = node_latencies[match][source]

if found_lat is not None:
assert (
lat == found_lat
), f"Found multiple latencies for {kernel} {kernel_port} {compute_file_tile} {compute_file_port} {lat} {found_lat}"
kernel_latencies[kernel][kernel_port]["latency"] = lat
found = True
found_port = True
found_lat = lat
break
if not found_port:
found = True
kernel_latencies[kernel][kernel_port][
"latency"
] = node_latencies[match][graph.sources[pe][0]]

if not found:
print("Couldn't find tile port in kernel latencies", kernel)

return kernel_latencies, stencil_valid_adjust

Expand Down Expand Up @@ -812,20 +808,15 @@ def update_kernel_latencies(
if "hcompute_output_cgra_stencil" in kernel:
for kernel_port, d1 in latency_dict.items():
if "input_cgra_stencil" or "in2_output_cgra_stencil" in kernel_port:
for port_num, d2 in d1.items():
d2["latency"] = updated_kernel_latencies[kernel][
d1["latency"] = updated_kernel_latencies[kernel][
kernel_port
][port_num]["latency"]
]["latency"]
if "hcompute_input_cgra_stencil" in kernel:
for kernel_port, d1 in latency_dict.items():
for port_num, d2 in d1.items():
d2["latency"] = ub_latencies["input_cgra_stencil"][port_num][
"latency"
]
d1["latency"] = ub_latencies["input_cgra_stencil"]["latency"]
if "hcompute_kernel_cgra_stencil" in kernel:
for kernel_port, d1 in latency_dict.items():
for port_num, d2 in d1.items():
d2["latency"] = min(
d1["latency"] = min(
value["latency"]
for value in ub_latencies["kernel_cgra_stencil"].values()
)
Expand Down

0 comments on commit 7da2f1a

Please sign in to comment.