Skip to content

Add TensorBlobManager for efficient tensor storage #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 324 additions & 1 deletion tritonparse/structured_logging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import atexit
import gzip
import hashlib
import importlib
import inspect
import json
import logging
import math
import os
import tempfile
from collections import defaultdict
from dataclasses import asdict, is_dataclass
from datetime import date, datetime
Expand All @@ -32,7 +35,17 @@
DEFAULT_TRACE_FILE_PREFIX = (
f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_"
)
# Enable launch trace. WARNNING: it will overwrite launch_metadata for each triton kernel.
TRITON_TRACE_LAUNCH = os.getenv("TRITON_TRACE_LAUNCH", None) in ["1", "true", "True"]
# Enable tensor blob storage
TRITONPARSE_SAVE_TENSOR_BLOBS = os.getenv("TRITONPARSE_SAVE_TENSOR_BLOBS", "0") in ["1", "true", "True"]
# Tensor size limit in bytes (default 10GB)
TRITONPARSE_TENSOR_SIZE_LIMIT = int(os.getenv("TRITONPARSE_TENSOR_SIZE_LIMIT", str(10 * 1024 * 1024 * 1024)))

TRITON_TRACE_HANDLER = None
# Global tensor blob manager instance
TENSOR_BLOB_MANAGER = None

if importlib.util.find_spec("torch") is not None:
TORCH_INSTALLED = True
import torch
Expand All @@ -41,6 +54,131 @@
TORCH_INSTALLED = False


class TensorBlobManager:
"""
Manager for storing tensor data as content-addressed blobs.

Uses BLAKE2b hashing for content addressing and stores blobs in a two-level
directory structure to avoid filesystem limitations with large numbers of files.
"""

def __init__(self, root_dir: Optional[str] = None):
self.root_dir = None
self.hash_to_path_cache = {} # In-memory cache for hash -> path mapping
if root_dir:
self.set_root_dir(root_dir)

def set_root_dir(self, root_dir: str):
"""Set the root directory for blob storage."""
self.root_dir = Path(root_dir) / "saved_tensors"
self.root_dir.mkdir(parents=True, exist_ok=True)
log.debug(f"TensorBlobManager: using root directory {self.root_dir}")

def _compute_hash(self, data: bytes) -> str:
"""Compute BLAKE2b hash of the data."""
return hashlib.blake2b(data).hexdigest()

def _get_blob_path(self, hash_hex: str) -> Path:
"""Get the file path for a given hash using two-level directory structure."""
if not self.root_dir:
raise ValueError("Root directory not set")

# Two-level directory: first 2 chars / full_hash.bin
subdir = hash_hex[:2]
filename = f"{hash_hex}.bin"
return (self.root_dir / subdir / filename).resolve()

def _get_tensor_size_bytes(self, tensor) -> int:
"""Get tensor size in bytes before serialization."""
if hasattr(tensor, 'numel') and hasattr(tensor, 'element_size'):
return tensor.numel() * tensor.element_size()
return 0

def save_tensor_blob(self, tensor) -> Dict[str, Any]:
"""
Save tensor as a blob and return metadata.

Args:
tensor: PyTorch tensor to save

Returns:
Dictionary with blob metadata or error information:
- Success: {'tensor_hash': str, 'blob_path': str, 'blob_size': int, 'serialization_method': str}
- Error: {'error': str, 'tensor_hash': None}
"""
if not self.root_dir:
return {'error': 'Blob storage not initialized', 'tensor_hash': None}

try:
# Check tensor size before serialization
tensor_size = self._get_tensor_size_bytes(tensor)
if tensor_size > TRITONPARSE_TENSOR_SIZE_LIMIT:
log.warning(
f"Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes, skipping blob storage"
)
return {
'error': f'Tensor size {tensor_size} bytes exceeds limit {TRITONPARSE_TENSOR_SIZE_LIMIT} bytes',
'tensor_hash': None
}

# Serialize tensor using torch.save
# TODO: Consider async serialization for very large tensors to avoid blocking
import io
buffer = io.BytesIO()
if TORCH_INSTALLED:
torch.save(tensor.cpu(), buffer)
else:
return {'error': 'PyTorch not available for tensor serialization', 'tensor_hash': None}

blob_data = buffer.getvalue()
hash_hex = self._compute_hash(blob_data)

# Check if we already have this blob
if hash_hex in self.hash_to_path_cache:
blob_path = self.hash_to_path_cache[hash_hex]
if blob_path.exists():
return {
'tensor_hash': hash_hex,
'blob_path': str(blob_path),
'blob_size': len(blob_data),
'serialization_method': 'torch_save'
}

# Create blob file
blob_path = self._get_blob_path(hash_hex)
blob_path.parent.mkdir(parents=True, exist_ok=True)

# Atomic write using temporary file + rename
with tempfile.NamedTemporaryFile(
mode='wb',
dir=blob_path.parent,
prefix=f".tmp_{hash_hex}_",
delete=False
) as tmp_file:
tmp_file.write(blob_data)
tmp_path = Path(tmp_file.name)

# Atomic rename
tmp_path.rename(blob_path)

# Cache the path
self.hash_to_path_cache[hash_hex] = blob_path

log.debug(f"Saved tensor blob: {hash_hex} -> {blob_path}")

return {
'tensor_hash': hash_hex,
'blob_path': str(blob_path),
'blob_size': len(blob_data),
'serialization_method': 'torch_save'
}

except Exception as e:
error_msg = f"Failed to save tensor blob: {str(e)}"
log.error(error_msg)
return {'error': error_msg, 'tensor_hash': None}


class TritonLogRecord(logging.LogRecord):
"""
Custom LogRecord class for structured logging of Triton operations.
Expand Down Expand Up @@ -461,7 +599,7 @@ def init_logs():
DEBUG:tritonparse_trace:
lines by blocking propagation to the root logger.
"""
global TRITON_TRACE_HANDLER, triton_trace_folder
global TRITON_TRACE_HANDLER, triton_trace_folder, TENSOR_BLOB_MANAGER

# Basic logger settings (safe to run on every call)
triton_trace_log.setLevel(logging.DEBUG)
Expand All @@ -486,6 +624,15 @@ def init_logs():
if TRITON_TRACE_HANDLER not in triton_trace_log.handlers:
TRITON_TRACE_HANDLER.setFormatter(TritonJsonFormatter())
triton_trace_log.addHandler(TRITON_TRACE_HANDLER)

# Initialize tensor blob manager if enabled
if TRITONPARSE_SAVE_TENSOR_BLOBS:
if TENSOR_BLOB_MANAGER is None:
TENSOR_BLOB_MANAGER = TensorBlobManager()

# Set or update root directory for blob storage
if root_dir and TENSOR_BLOB_MANAGER.root_dir is None:
TENSOR_BLOB_MANAGER.set_root_dir(root_dir)


def trace_structured_triton(
Expand Down Expand Up @@ -613,6 +760,182 @@ def maybe_trace_triton(
return trace_data


from triton.knobs import LaunchHook, JITHook


def extract_arg_info(arg_dict):
"""
Extract detailed information from kernel arguments, especially for PyTorch tensors.

Args:
arg_dict: Dictionary of kernel arguments

Returns:
Dictionary with extracted argument information including tensor properties
"""
global TENSOR_BLOB_MANAGER

extracted_args = {}

for arg_name, arg_value in arg_dict.items():
arg_info = {}

# Check if it's a PyTorch tensor
if hasattr(arg_value, 'shape') and hasattr(arg_value, 'dtype'):
arg_info['type'] = 'tensor'
arg_info['shape'] = list(arg_value.shape)
arg_info['dtype'] = str(arg_value.dtype)
arg_info['device'] = str(arg_value.device)
arg_info['stride'] = list(arg_value.stride())
arg_info['numel'] = arg_value.numel()
arg_info['is_contiguous'] = arg_value.is_contiguous()
arg_info['element_size'] = arg_value.element_size()
arg_info['storage_offset'] = arg_value.storage_offset()
# Memory usage in bytes
arg_info['memory_usage'] = arg_value.numel() * arg_value.element_size()
# Add data_ptr for memory tracking (optional)
if hasattr(arg_value, 'data_ptr'):
arg_info['data_ptr'] = hex(arg_value.data_ptr())

# Add tensor blob storage if enabled
if TRITONPARSE_SAVE_TENSOR_BLOBS and TENSOR_BLOB_MANAGER is not None:
blob_info = TENSOR_BLOB_MANAGER.save_tensor_blob(arg_value)
arg_info.update(blob_info)

# Handle scalar values
elif isinstance(arg_value, (int, float, bool)):
arg_info['type'] = type(arg_value).__name__
arg_info['value'] = arg_value
# Handle strings
elif isinstance(arg_value, str):
arg_info['type'] = 'str'
arg_info['value'] = arg_value
arg_info['length'] = len(arg_value)
# Handle other types
else:
arg_info['type'] = type(arg_value).__name__
# Try to convert to string for logging, but be safe about it
try:
arg_info['repr'] = str(arg_value)
if len(arg_info['repr']) > 200: # Truncate very long representations
arg_info['repr'] = arg_info['repr'][:200] + "..."
except:
arg_info['repr'] = f"<{type(arg_value).__name__} object>"

extracted_args[arg_name] = arg_info

return extracted_args


def add_launch_metadata(grid, metadata, arg_dict):
# Extract detailed argument information
extracted_args = extract_arg_info(arg_dict)
return {"launch_metadata_tritonparse": (grid, metadata, extracted_args)}


class JITHookImpl(JITHook):
"""
JIT Hook implementation that overrides or sets the launch_metadata function for Triton kernels.

This hook is essential for capturing detailed kernel launch information beyond the basic
metadata (like kernel name) that Triton provides by default. Without setting a custom
launch_metadata function, only minimal launch information is available as shown in:
https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/python/triton/compiler/compiler.py#L504

By intercepting the JIT compilation process and setting a custom launch_metadata function,
we can capture comprehensive runtime information including grid parameters, kernel metadata,
and argument dictionaries for detailed analysis and logging.
"""

def __call__(
self,
*,
key: str,
repr: str,
fn,
compile,
is_manual_warmup: bool,
already_compiled: bool,
) -> Optional[bool]:
"""
Override or set the launch_metadata function for the JIT-compiled kernel.

This method is called during the JIT compilation process and allows us to
inject our custom launch_metadata function that will be used to collect
detailed kernel launch information.

Args:
key: Unique identifier for the kernel
repr: String representation of the kernel
fn: The JIT function object
compile: Compilation function
is_manual_warmup: Whether this is a manual warmup call
already_compiled: Whether the kernel is already compiled

Returns:
True to continue with compilation, None/False to skip
"""
launch_metadata_fn = fn.jit_function.launch_metadata
if launch_metadata_fn is not None:
log.warning(
f"fn {fn} launch_metadata_fn is not None: {launch_metadata_fn}. It will be overridden by tritonparse."
)
fn.jit_function.launch_metadata = add_launch_metadata
return True


class LaunchHookImpl(LaunchHook):
"""
Launch Hook implementation for capturing and logging kernel launch metadata.

This hook is responsible for intercepting kernel launches and extracting the detailed
metadata that was set up by the JITHookImpl. It provides entry point for
kernel execution, allowing comprehensive logging and analysis of kernel launches
including timing, parameters, and execution context.

The metadata captured includes:
- Kernel name and function details
- Grid dimensions and launch parameters
- Kernel arguments and their values
- Stream information
- Custom metadata added by the launch_metadata function
"""

def enter(self, metadata):
"""
Handle kernel launch entry point.

This method is called when a kernel is about to be launched, providing
access to all the launch metadata for logging, profiling, or analysis.
metadata format:

Args:
metadata: LazyDict containing comprehensive launch information including
kernel name, function, stream, grid parameters, and custom data
format: {'name': 'add_kernel', 'function': None, 'stream': 0,
'launch_metadata_tritonparse': (grid, self.metadata, extracted_args)}
where extracted_args contains detailed info for each argument:
- For tensors: shape, dtype, device, stride, memory_usage, etc.
- For scalars: type and value
- For other types: type and string representation
defined here:
https://github.com/triton-lang/triton/blob/7ce287dc24b43476cdeb30529089ac361564505d/
python/triton/compiler/compiler.py#L512.
"""
trace_data = defaultdict(dict)
metadata_dict = metadata.get()
trace_data["name"] = metadata_dict["name"]
trace_data["function"] = metadata_dict["function"]
trace_data["stream"] = metadata_dict["stream"]
launch_metadata_tritonparse = metadata_dict.get("launch_metadata_tritonparse", None)
if launch_metadata_tritonparse is not None:
trace_data["grid"] = launch_metadata_tritonparse[0]
trace_data["metadata"] = launch_metadata_tritonparse[1]
trace_data["extracted_args"] = launch_metadata_tritonparse[2] # Now contains detailed arg info
trace_structured_triton("launch", metadata_fn=lambda: convert(trace_data))



def init(trace_folder: Optional[str] = None):
"""
Initialize the structured logging system for Triton compilation.
Expand Down