|
12 | 12 | import math |
13 | 13 | import typing |
14 | 14 | from functools import partial |
15 | | -from typing import Iterable, List, Optional, Tuple |
| 15 | +from typing import Iterable, List, Optional, Set, Tuple |
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | from executorch.backends.cadence.aot.memory_constraints import ( |
@@ -73,11 +73,11 @@ def collect_specs_from_graph_module( |
73 | 73 | # the fastest memory available |
74 | 74 | # flake8: noqa 'position_based_greedy_with_hierarchy' is too complex (13) |
75 | 75 | def position_based_greedy_with_hierarchy( |
76 | | - graph_module: torch.fx.GraphModule, |
77 | 76 | alignment: int, |
| 77 | + specs: Set[TensorSpec], |
| 78 | + graph_module: torch.fx.GraphModule, |
78 | 79 | graph_signature: ExportGraphSignature, |
79 | | - alloc_graph_input: bool, |
80 | | - alloc_graph_output: bool, |
| 80 | + extra_padding: int = 0, |
81 | 81 | *, |
82 | 82 | memory_config: MemoryConfig, |
83 | 83 | mem_constraints: MemConstraints, |
@@ -119,9 +119,7 @@ def memory_available(spec: TensorSpec) -> bool: |
119 | 119 |
|
120 | 120 | # Iterate over all the specs in sorted order |
121 | 121 | for spec in sorted( |
122 | | - collect_specs_from_graph_module( |
123 | | - graph_module, graph_signature, alloc_graph_input, alloc_graph_output |
124 | | - ), |
| 122 | + specs, |
125 | 123 | key=lambda spec: spec.allocated_memory, |
126 | 124 | reverse=True, |
127 | 125 | ): |
@@ -167,11 +165,11 @@ def memory_available(spec: TensorSpec) -> bool: |
167 | 165 |
|
168 | 166 | # Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf |
169 | 167 | def greedy_by_size_for_offset_calculation_with_hierarchy( |
170 | | - graph_module: torch.fx.GraphModule, |
171 | 168 | alignment: int, |
| 169 | + specs: Set[TensorSpec], |
| 170 | + graph_module: torch.fx.GraphModule, |
172 | 171 | graph_signature: ExportGraphSignature, |
173 | | - alloc_graph_input: bool, |
174 | | - alloc_graph_output: bool, |
| 172 | + extra_padding: int = 0, |
175 | 173 | *, |
176 | 174 | memory_config: MemoryConfig, |
177 | 175 | mem_constraints: MemConstraints, |
@@ -199,9 +197,7 @@ def greedy_by_size_for_offset_calculation_with_hierarchy( |
199 | 197 |
|
200 | 198 | # Iterate over all the specs in sorted order |
201 | 199 | for spec in sorted( |
202 | | - collect_specs_from_graph_module( |
203 | | - graph_module, graph_signature, alloc_graph_input, alloc_graph_output |
204 | | - ), |
| 200 | + specs, |
205 | 201 | key=lambda spec: spec.allocated_memory, |
206 | 202 | reverse=True, |
207 | 203 | ): |
|
0 commit comments