Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions tests/test_tritonparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,181 @@ def test_loc_alias_parsing(self):

print("✓ All loc alias parsing tests passed")

def test_load_ndjson_gzip_support(self):
"""Test that load_ndjson can load .ndjson.gz files."""
from pathlib import Path

from tritonparse.tools.prettify_ndjson import load_ndjson

# Use existing .ndjson.gz test file
gz_file = (
Path(__file__).parent
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
)

# Verify file exists
self.assertTrue(gz_file.exists(), f"Test file not found: {gz_file}")

# Load and verify
events = load_ndjson(gz_file)
self.assertIsInstance(events, list)
self.assertGreater(len(events), 0, "Should load at least one event")

# Verify we have expected event types
event_types = {e.get("event_type") for e in events if isinstance(e, dict)}
self.assertTrue(
"compilation" in event_types or "launch" in event_types,
f"Expected compilation or launch events, got: {event_types}",
)

print(f"✓ Successfully loaded {len(events)} events from .ndjson.gz file")

def test_list_kernels_empty(self):
"""Test listing kernels from empty events list."""
from tritonparse.info.kernel_query import list_kernels

events = []
result = list_kernels(events)
self.assertEqual(result, [])

def test_list_kernels_single(self):
"""Test listing kernels with single kernel and multiple launches."""
from pathlib import Path

from tritonparse.info.kernel_query import list_kernels
from tritonparse.tools.prettify_ndjson import load_ndjson

# Load real test data
gz_file = (
Path(__file__).parent
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
)
events = load_ndjson(gz_file)

# Filter to only fused_op_kernel launches (4 launches)
filtered_events = []
for event in events:
if event.get("event_type") == "launch":
kernel_name = event.get("compilation_metadata", {}).get("name")
if kernel_name == "fused_op_kernel":
filtered_events.append(event)
else:
# Keep non-launch events to test filtering
filtered_events.append(event)

result = list_kernels(filtered_events)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "fused_op_kernel")
self.assertEqual(result[0].total_launches, 4)

def test_list_kernels_multiple(self):
"""Test listing kernels with multiple different kernels."""
from pathlib import Path

from tritonparse.info.kernel_query import list_kernels
from tritonparse.tools.prettify_ndjson import load_ndjson

# Load real test data
gz_file = (
Path(__file__).parent
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
)
events = load_ndjson(gz_file)

result = list_kernels(events)
self.assertEqual(len(result), 2)

# Check that results are sorted by name
names = [k.name for k in result]
self.assertEqual(names, ["fused_op_kernel", "matmul_kernel"])

# Check launch counts
kernel_dict = {k.name: k for k in result}
self.assertEqual(kernel_dict["matmul_kernel"].total_launches, 1553)
self.assertEqual(kernel_dict["fused_op_kernel"].total_launches, 4)

def test_find_launch_index_valid(self):
"""Test finding valid kernel name and launch_id."""
from pathlib import Path

from tritonparse.info.kernel_query import find_launch_index_by_kernel
from tritonparse.tools.prettify_ndjson import load_ndjson

# Load real test data
gz_file = (
Path(__file__).parent
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
)
events = load_ndjson(gz_file)

# Test first launch of fused_op_kernel (launch_id=0)
index = find_launch_index_by_kernel(events, "fused_op_kernel", 0)
self.assertEqual(events[index].get("event_type"), "launch")
self.assertEqual(
events[index].get("compilation_metadata", {}).get("name"),
"fused_op_kernel",
)

# Test second launch of fused_op_kernel (launch_id=1)
index = find_launch_index_by_kernel(events, "fused_op_kernel", 1)
self.assertEqual(events[index].get("event_type"), "launch")
self.assertEqual(
events[index].get("compilation_metadata", {}).get("name"),
"fused_op_kernel",
)

# Test first launch of matmul_kernel (launch_id=0)
index = find_launch_index_by_kernel(events, "matmul_kernel", 0)
self.assertEqual(events[index].get("event_type"), "launch")
self.assertEqual(
events[index].get("compilation_metadata", {}).get("name"),
"matmul_kernel",
)

def test_find_launch_index_kernel_not_found(self):
"""Test that ValueError is raised when kernel not found."""
from pathlib import Path

from tritonparse.info.kernel_query import find_launch_index_by_kernel
from tritonparse.tools.prettify_ndjson import load_ndjson

# Load real test data
gz_file = (
Path(__file__).parent
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
)
events = load_ndjson(gz_file)

with self.assertRaises(ValueError) as cm:
find_launch_index_by_kernel(events, "nonexistent_kernel", 0)

error_msg = str(cm.exception)
self.assertIn("not found", error_msg)
self.assertIn("nonexistent_kernel", error_msg)

def test_find_launch_index_out_of_range(self):
"""Test that ValueError is raised when launch_id is out of range."""
from pathlib import Path

from tritonparse.info.kernel_query import find_launch_index_by_kernel
from tritonparse.tools.prettify_ndjson import load_ndjson

# Load real test data
gz_file = (
Path(__file__).parent
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
)
events = load_ndjson(gz_file)

# fused_op_kernel has only 4 launches (0-3), test with launch_id=10
with self.assertRaises(ValueError) as cm:
find_launch_index_by_kernel(events, "fused_op_kernel", 10)

error_msg = str(cm.exception)
self.assertIn("has only 4 launches", error_msg)
self.assertIn("--launch-id 10", error_msg)
self.assertIn("Valid range: 0 to 3", error_msg)


class TestTritonparseCUDA(unittest.TestCase):
"""CUDA tests (require GPU)"""
Expand Down
24 changes: 24 additions & 0 deletions tritonparse/info/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Info module for querying kernel information from NDJSON trace files.

This module provides core query functions for kernel information:
- Listing all kernels with their launch counts
- Finding launch events by kernel name and launch ID
- Querying launch information for specific kernels
"""

from tritonparse.info.kernel_query import (
find_launch_index_by_kernel,
KernelSummary,
LaunchInfo,
list_kernels,
)

__all__ = [
"KernelSummary",
"LaunchInfo",
"list_kernels",
"find_launch_index_by_kernel",
]
107 changes: 107 additions & 0 deletions tritonparse/info/kernel_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
Core query functions for kernel information from NDJSON trace files.

This module provides functions to query kernel launch information from parsed
event lists. It supports both raw log files and parsed ndjson files (with launch_diff events).
"""

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List


@dataclass
class KernelSummary:
"""Summary information about a kernel."""

name: str
hash: str
total_launches: int


@dataclass
class LaunchInfo:
"""Information about a specific kernel launch."""

launch_id: int # 0-based
line_index: int # 0-based (index in events list)
grid: List[int]


def list_kernels(events: List[Dict[str, Any]]) -> List[KernelSummary]:
"""
List all kernels with their launch counts.

Args:
events: List of parsed event dictionaries from NDJSON file

Returns:
List of KernelSummary objects, sorted by kernel name
"""
# Count launches per kernel
kernel_counts: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {"hash": "", "count": 0}
)

for event in events:
if event.get("event_type") != "launch":
continue

comp_meta = event.get("compilation_metadata", {})
kernel_name = comp_meta.get("name")
kernel_hash = comp_meta.get("hash", "")

if kernel_name:
kernel_counts[kernel_name]["hash"] = kernel_hash
kernel_counts[kernel_name]["count"] += 1

# Convert to KernelSummary list
summaries = [
KernelSummary(name=name, hash=info["hash"], total_launches=info["count"])
for name, info in kernel_counts.items()
]

# Sort by kernel name for consistent output
summaries.sort(key=lambda x: x.name)

return summaries


def find_launch_index_by_kernel(
events: List[Dict[str, Any]], kernel_name: str, launch_id: int
) -> int:
"""
Find the 0-based line index for a kernel's N-th launch.

Args:
events: List of parsed event dictionaries
kernel_name: Exact kernel name to match (case-sensitive)
launch_id: 0-based launch index for the kernel

Returns:
0-based line index (index in events list)

Raises:
ValueError: If kernel not found or launch_id out of range
"""
count = 0
for i, event in enumerate(events):
if event.get("event_type") != "launch":
continue

comp_meta = event.get("compilation_metadata", {})
name = comp_meta.get("name")
if name == kernel_name:
if count == launch_id:
return i
count += 1

if count == 0:
raise ValueError(f"Kernel '{kernel_name}' not found")
else:
raise ValueError(
f"Kernel '{kernel_name}' has only {count} launches, "
f"but --launch-id {launch_id} was requested. Valid range: 0 to {count - 1}"
)
17 changes: 16 additions & 1 deletion tritonparse/tools/prettify_ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@
"""

import argparse
import gzip
import json
import sys
from pathlib import Path
from typing import Any, List, Union


def _is_gzip_file(file_path: Path) -> bool:
"""Check if file is gzip compressed (.gz or .bin.ndjson)."""
path_str = str(file_path)
return path_str.endswith(".gz") or path_str.endswith(".bin.ndjson")


def parse_line_ranges(lines_arg: str) -> set[int]:
"""
Parse line ranges from string like "1,2,3,5-10" into a set of line numbers.
Expand Down Expand Up @@ -106,6 +113,9 @@ def load_ndjson(
"""
Load NDJSON file and return list of JSON objects.

Supports uncompressed (.ndjson), gzip compressed (.ndjson.gz),
and gzip member concatenation (.bin.ndjson) formats.

Args:
file_path: Path to the NDJSON file
not_save_irs: Whether to NOT save file_content and python_source for compilation events
Expand All @@ -122,8 +132,13 @@ def load_ndjson(
filtered_compilation_events = 0
total_lines_processed = 0

# Determine if file is gzip compressed
is_compressed = _is_gzip_file(file_path)
opener = gzip.open if is_compressed else open
mode = "rt" if is_compressed else "r"

try:
with open(file_path, "r", encoding="utf-8") as f:
with opener(file_path, mode, encoding="utf-8") as f:
# enumerate(f, 1) starts line numbering from 1 (1-based indexing)
for line_num, line in enumerate(f, 1):
line = line.strip()
Expand Down