Skip to content

Commit 206ccaf

Browse files
committed
Add function extractor module for reproducer utility functions
This commit introduces a new module, `function_extractor.py`, which extracts utility functions from `utils.py` and `load_tensor.py` using the `inspect` module. The extracted functions are intended for use in the reproducer template, allowing for the generation of standalone code. Additionally, the `placeholder_replacer.py` has been updated to include a new placeholder for utility functions, which will be replaced with the extracted code during the reproduction process. Changes include: - New `extract_utility_functions` function to gather necessary utility functions. - Integration of utility function extraction into the placeholder replacement logic. - Minor updates to `utils.py` for function renaming and additional utility functions. This enhancement aims to streamline the reproducer generation process by ensuring all required utility functions are readily available in the generated code.
1 parent c5f04de commit 206ccaf

File tree

4 files changed

+344
-483
lines changed

4 files changed

+344
-483
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Function extractor for reproducer utility functions.
3+
4+
This module extracts utility functions from utils.py and load_tensor.py
5+
using the inspect module, and generates standalone code for reproducers.
6+
"""
7+
8+
import importlib.util
9+
import inspect
10+
from pathlib import Path
11+
12+
13+
def extract_utility_functions() -> str:
14+
"""
15+
Extract all utility functions needed for the reproducer template.
16+
17+
Returns:
18+
str: Complete Python code including imports and all utility functions.
19+
"""
20+
# Import the modules
21+
utils_module = _import_module_from_path(
22+
"tritonparse.reproducer.utils", Path(__file__).parent / "utils.py"
23+
)
24+
load_tensor_module = _import_module_from_path(
25+
"tritonparse.tools.load_tensor",
26+
Path(__file__).parent.parent / "tools" / "load_tensor.py",
27+
)
28+
29+
# Functions to extract from utils.py
30+
utils_functions = [
31+
"_get_triton_tensor_types",
32+
"create_args_from_json_file",
33+
"create_args_from_json",
34+
"_apply_stride_and_offset",
35+
"_create_base_tensor",
36+
"_create_tensor",
37+
"_create_arg_from_info",
38+
]
39+
40+
# Functions to extract from load_tensor.py
41+
load_tensor_functions = [
42+
"load_tensor",
43+
]
44+
45+
# Extract all function source code
46+
extracted_code = []
47+
48+
# Add required imports
49+
imports = _generate_imports()
50+
extracted_code.append(imports)
51+
52+
# Add TRITON_KERNELS_CUSTOM_TYPES constant
53+
constant_code = inspect.getsource(utils_module).split("\n")
54+
for i, line in enumerate(constant_code):
55+
if line.startswith("TRITON_KERNELS_CUSTOM_TYPES"):
56+
# Find the end of this statement
57+
j = i
58+
while j < len(constant_code) and not constant_code[j].rstrip().endswith(
59+
")"
60+
):
61+
j += 1
62+
constant_lines = constant_code[i : j + 1]
63+
extracted_code.append("\n".join(constant_lines))
64+
break
65+
66+
extracted_code.append("")
67+
68+
# Extract load_tensor function
69+
for func_name in load_tensor_functions:
70+
func = getattr(load_tensor_module, func_name)
71+
source = inspect.getsource(func)
72+
extracted_code.append(source)
73+
74+
# Extract utils functions
75+
for func_name in utils_functions:
76+
func = getattr(utils_module, func_name)
77+
source = inspect.getsource(func)
78+
extracted_code.append(source)
79+
80+
return "\n\n".join(extracted_code)
81+
82+
83+
def _import_module_from_path(module_name: str, file_path: Path):
84+
"""
85+
Import a module from a file path.
86+
87+
Args:
88+
module_name: Name for the module
89+
file_path: Path to the Python file
90+
91+
Returns:
92+
The imported module
93+
"""
94+
spec = importlib.util.spec_from_file_location(module_name, file_path)
95+
module = importlib.util.module_from_spec(spec)
96+
spec.loader.exec_module(module)
97+
return module
98+
99+
100+
def _generate_imports() -> str:
101+
"""
102+
Generate the import statements needed for the extracted functions.
103+
104+
Returns:
105+
str: Import statements as a single string
106+
"""
107+
imports = [
108+
"import gzip",
109+
"import hashlib",
110+
"import importlib",
111+
"import importlib.util",
112+
"import io",
113+
"import json",
114+
"import logging",
115+
"import sys",
116+
"from functools import lru_cache",
117+
"from pathlib import Path",
118+
"from typing import Union",
119+
"",
120+
"import torch",
121+
]
122+
return "\n".join(imports)
123+

tritonparse/reproducer/placeholder_replacer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any, Dict, Protocol
44

5+
from tritonparse.reproducer.function_extractor import extract_utility_functions
56
from tritonparse.reproducer.ingestion.ndjson import ContextBundle
67
from tritonparse.reproducer.types import KernelImportMode
78
from tritonparse.reproducer.utils import (
@@ -82,6 +83,9 @@ def __init__(self):
8283
)
8384
self.register("# {{KERNEL_SYSPATH_PLACEHOLDER}}", self._replace_kernel_syspath)
8485
self.register("# {{KERNEL_IMPORT_PLACEHOLDER}}", self._replace_kernel_import)
86+
self.register(
87+
"# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", self._replace_utility_functions
88+
)
8589
self.register(
8690
"# {{KERNEL_INVOCATION_PLACEHOLDER}}", self._replace_kernel_invocation
8791
)
@@ -217,6 +221,13 @@ def _replace_kernel_import(
217221
else:
218222
raise ValueError(f"Unknown kernel_import mode: {kernel_import}")
219223

224+
def _replace_utility_functions(
225+
self, code: str, context_bundle: ContextBundle, **kwargs
226+
) -> str:
227+
"""Replace the utility functions placeholder with extracted functions."""
228+
utility_code = extract_utility_functions()
229+
return code.replace("# {{UTILITY_FUNCTIONS_PLACEHOLDER}}", utility_code)
230+
220231
def _replace_kernel_invocation(
221232
self, code: str, context_bundle: ContextBundle, **kwargs
222233
) -> str:

0 commit comments

Comments
 (0)