Skip to content

Commit 24e002e

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add gzip support for load_ndjson() (#207)
Summary: This PR adds support for loading gzip-compressed NDJSON files in the `load_ndjson()` function, fixing an issue where the CLI claimed to support `.ndjson.gz` files but the function only used `open()`. ## Supported Formats | Format | Extension | Description | |--------|-----------|-------------| | Uncompressed | `.ndjson` | Standard NDJSON (existing) | | Gzip compressed | `.ndjson.gz` | Whole file compressed | | Gzip member concatenation | `.bin.ndjson` | Each line compressed separately | ## Changes - **`tools/prettify_ndjson.py`**: - Added `import gzip` - Added `_is_gzip_file()` helper function to detect compressed files - Modified `load_ndjson()` to use `gzip.open()` for compressed files - Updated docstring to document supported formats - **`tests/test_tritonparse.py`**: - Added `test_load_ndjson_gzip_support()` test using existing `.ndjson.gz` test file ## Testing Uses existing test file: `tests/example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz` Pull Request resolved: #207 Reviewed By: wychi Differential Revision: D88171069 Pulled By: FindHao fbshipit-source-id: 701a238a3d9d34d1d096834088a4ac87cb16ed09
1 parent 311e016 commit 24e002e

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

tests/test_tritonparse.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,35 @@ def test_loc_alias_parsing(self):
307307

308308
print("✓ All loc alias parsing tests passed")
309309

310+
def test_load_ndjson_gzip_support(self):
311+
"""Test that load_ndjson can load .ndjson.gz files."""
312+
from pathlib import Path
313+
314+
from tritonparse.tools.prettify_ndjson import load_ndjson
315+
316+
# Use existing .ndjson.gz test file
317+
gz_file = (
318+
Path(__file__).parent
319+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
320+
)
321+
322+
# Verify file exists
323+
self.assertTrue(gz_file.exists(), f"Test file not found: {gz_file}")
324+
325+
# Load and verify
326+
events = load_ndjson(gz_file)
327+
self.assertIsInstance(events, list)
328+
self.assertGreater(len(events), 0, "Should load at least one event")
329+
330+
# Verify we have expected event types
331+
event_types = {e.get("event_type") for e in events if isinstance(e, dict)}
332+
self.assertTrue(
333+
"compilation" in event_types or "launch" in event_types,
334+
f"Expected compilation or launch events, got: {event_types}",
335+
)
336+
337+
print(f"✓ Successfully loaded {len(events)} events from .ndjson.gz file")
338+
310339

311340
class TestTritonparseCUDA(unittest.TestCase):
312341
"""CUDA tests (require GPU)"""

tritonparse/tools/prettify_ndjson.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,19 @@
3939
"""
4040

4141
import argparse
42+
import gzip
4243
import json
4344
import sys
4445
from pathlib import Path
4546
from typing import Any, List, Union
4647

4748

49+
def _is_gzip_file(file_path: Path) -> bool:
50+
"""Check if file is gzip compressed (.gz or .bin.ndjson)."""
51+
path_str = str(file_path)
52+
return path_str.endswith(".gz") or path_str.endswith(".bin.ndjson")
53+
54+
4855
def parse_line_ranges(lines_arg: str) -> set[int]:
4956
"""
5057
Parse line ranges from string like "1,2,3,5-10" into a set of line numbers.
@@ -106,6 +113,9 @@ def load_ndjson(
106113
"""
107114
Load NDJSON file and return list of JSON objects.
108115
116+
Supports uncompressed (.ndjson), gzip compressed (.ndjson.gz),
117+
and gzip member concatenation (.bin.ndjson) formats.
118+
109119
Args:
110120
file_path: Path to the NDJSON file
111121
not_save_irs: Whether to NOT save file_content and python_source for compilation events
@@ -122,8 +132,13 @@ def load_ndjson(
122132
filtered_compilation_events = 0
123133
total_lines_processed = 0
124134

135+
# Determine if file is gzip compressed
136+
is_compressed = _is_gzip_file(file_path)
137+
opener = gzip.open if is_compressed else open
138+
mode = "rt" if is_compressed else "r"
139+
125140
try:
126-
with open(file_path, "r", encoding="utf-8") as f:
141+
with opener(file_path, mode, encoding="utf-8") as f:
127142
# enumerate(f, 1) starts line numbering from 1 (1-based indexing)
128143
for line_num, line in enumerate(f, 1):
129144
line = line.strip()

0 commit comments

Comments
 (0)