Skip to content

Add kernel tracing functionality #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 124 additions & 9 deletions tritonparse/structured_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import atexit
import fnmatch
import importlib
import inspect
import json
Expand All @@ -12,7 +13,7 @@
from datetime import date, datetime
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

import triton

Expand All @@ -27,6 +28,10 @@
triton_trace_folder = os.environ.get("TRITON_TRACE", None)
# Enable debug logging for tritonparse itself
TRITONPARSE_DEBUG = os.getenv("TRITONPARSE_DEBUG", None) in ["1", "true", "True"]
# Kernel allowlist for filtering traced kernels. Use comma separated list of fnmatch patterns.
TRITONPARSE_KERNEL_ALLOWLIST = os.environ.get("TRITONPARSE_KERNEL_ALLOWLIST", None)
# Parsed kernel allowlist patterns (set during init)
_KERNEL_ALLOWLIST_PATTERNS: Optional[List[str]] = None
# The compilation information will be stored to /logs/DEFAULT_TRACE_FILE_PREFIX by default
# unless other flags disable or set another store. Add USER to avoid permission issues in shared servers.
DEFAULT_TRACE_FILE_PREFIX = (
Expand Down Expand Up @@ -171,14 +176,28 @@ def maybe_enable_debug_logging():
"""
This logging is for logging module itself, not for logging the triton compilation.
"""
if TRITONPARSE_DEBUG and not log.hasHandlers():
log_handler = logging.StreamHandler()
log_handler.setLevel(logging.DEBUG)
log_handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
if TRITONPARSE_DEBUG:
# Always set debug level if TRITONPARSE_DEBUG is set
log.setLevel(logging.DEBUG)
log.addHandler(log_handler)

# Prevent propagation to root logger to avoid duplicate messages
log.propagate = False

# Check if we already have a debug handler
has_debug_handler = any(
isinstance(handler, logging.StreamHandler)
and handler.level <= logging.DEBUG
for handler in log.handlers
)

if not has_debug_handler:
log_handler = logging.StreamHandler()
log_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s[%(levelname)s] %(message)s")
formatter.default_time_format = '%Y%m%d %H:%M:%S'
formatter.default_msec_format = None
log_handler.setFormatter(formatter)
log.addHandler(log_handler)


def get_stack_trace(skip=1):
Expand Down Expand Up @@ -210,6 +229,87 @@ def get_stack_trace(skip=1):
return frames


def parse_kernel_allowlist() -> Optional[List[str]]:
"""
Parse the kernel allowlist from environment variable.

Returns:
List[str] or None: List of kernel name patterns to trace, or None if all kernels should be traced
"""
if not TRITONPARSE_KERNEL_ALLOWLIST:
return None

# Split by comma and strip whitespace
patterns = [pattern.strip() for pattern in TRITONPARSE_KERNEL_ALLOWLIST.split(",")]
# Filter out empty patterns
patterns = [pattern for pattern in patterns if pattern]

if not patterns:
return None

log.debug(f"Kernel allowlist patterns: {patterns}")
return patterns


def extract_kernel_name(src) -> Optional[str]:
"""
Extract kernel name from the source object.

Args:
src (Union[ASTSource, IRSource]): Source object containing kernel information

Returns:
str or None: Kernel name if extractable, None otherwise
"""
from triton.compiler import IRSource

try:
if isinstance(src, IRSource):
# For IRSource, try to get name from attributes if available
if hasattr(src, 'name'):
return src.name
# Fallback to string representation parsing if needed
return None
else:
# For ASTSource, get the function name
if hasattr(src, 'fn') and hasattr(src.fn, 'fn') and hasattr(src.fn.fn, '__name__'):
return src.fn.fn.__name__
return None
except Exception as e:
log.debug(f"Error extracting kernel name: {e}")
return None


def should_trace_kernel(kernel_name: Optional[str], allowlist_patterns: Optional[List[str]]) -> bool:
"""
Check if a kernel should be traced based on the allowlist.

Args:
kernel_name (str or None): Name of the kernel
allowlist_patterns (List[str] or None): List of patterns to match against

Returns:
bool: True if the kernel should be traced, False otherwise
"""
# If no allowlist is set, trace all kernels
if allowlist_patterns is None:
return True

# If we can't extract kernel name, don't trace (conservative approach)
if kernel_name is None:
log.debug("Cannot extract kernel name, skipping trace")
return False

# Check if kernel name matches any pattern in the allowlist
for pattern in allowlist_patterns:
if fnmatch.fnmatch(kernel_name, pattern):
log.debug(f"Kernel '{kernel_name}' matches pattern '{pattern}', will trace")
return True

log.debug(f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace")
return False


def extract_python_source_info(trace_data: Dict[str, Any], source):
"""
Extract Python source code information from the source object and add it to trace_data.
Expand Down Expand Up @@ -569,6 +669,13 @@ def maybe_trace_triton(
Returns:
Dict[str, Any]: Dictionary containing all collected trace data, even if tracing is disabled
"""
# Check kernel allowlist early to avoid unnecessary work
if _KERNEL_ALLOWLIST_PATTERNS is not None:
kernel_name = extract_kernel_name(src)
if not should_trace_kernel(kernel_name, _KERNEL_ALLOWLIST_PATTERNS):
# Return empty dict to indicate no tracing was done
return {}

# Initialize a dictionary with defaultdict to avoid key errors
trace_data = defaultdict(dict)
# Add cache_hit to metadata
Expand Down Expand Up @@ -623,7 +730,7 @@ def init(trace_folder: Optional[str] = None):
Args:
trace_folder (Optional[str]): The folder to store the trace files.
"""
global triton_trace_folder
global triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
maybe_enable_debug_logging()
if triton_trace_folder is not None and trace_folder is not None:
log.info(
Expand All @@ -633,5 +740,13 @@ def init(trace_folder: Optional[str] = None):
)
if trace_folder is not None:
triton_trace_folder = trace_folder

# Parse and store kernel allowlist configuration
_KERNEL_ALLOWLIST_PATTERNS = parse_kernel_allowlist()
if _KERNEL_ALLOWLIST_PATTERNS:
log.info(f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}")
else:
log.debug("Kernel allowlist not set, tracing all kernels")

init_logs()
triton.knobs.compilation.listener = maybe_trace_triton