Skip to content

Commit eead823

Browse files
committed
skip mask functions in tracing
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ca33deb commit eead823

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,9 @@ def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Cal
180180
return node
181181

182182
if isinstance(node, ast.stmt):
183-
logger.debug("---- Autowrapper ----")
184-
logger.debug(ast.unparse(node))
185-
logger.debug("---------------------")
186183
return self._wrap_stmt(node)
187184

188185
elif isinstance(node, ast.expr):
189-
logger.debug("---- Autowrapper ----")
190-
logger.debug(ast.unparse(node))
191-
logger.debug("---------------------")
192186
return self._wrap_expr(node)
193187

194188
else:
@@ -254,6 +248,11 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
254248
# update local names with newly returned values
255249
self._local_names |= returns
256250

251+
# log newly created function definition
252+
logger.debug("---- Autowrapper ----")
253+
logger.debug(ast.unparse(ast.fix_missing_locations(fn_def)))
254+
logger.debug("---------------------")
255+
257256
return assign_call
258257

259258
def _wrap_expr(self, node: ast.expr) -> ast.Call:

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.nn import Module
1919
from transformers import PreTrainedModel
2020
from transformers.configuration_utils import PretrainedConfig
21+
from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING
2122
from transformers.utils.fx import HFTracer
2223

2324
from llmcompressor.modifiers import Modifier
@@ -169,10 +170,13 @@ class SequentialTracer(HFTracer):
169170
"""
170171

171172
def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
172-
super().__init__()
173173
self.ancestors = ancestors
174174
self.offloaded = offloaded
175175

176+
# skip any mask creation functions not already caught by the autowrapper
177+
autowrap_functions = tuple(LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING.values())
178+
super().__init__(autowrap_functions=autowrap_functions)
179+
176180
# check unlikely case that ancestors have direct params which are offloaded
177181
offloaded_ancestors = offloaded & ancestors
178182
if offloaded_ancestors:

0 commit comments

Comments
 (0)