Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions tbview/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
INFO = '[INFO]'
DEBUG = '[DEBUG]'

# Maximum number of points to render before applying sampling
MAX_PLOT_POINTS = 1000

class TensorboardViewer:
def __init__(self, event_path, event_tag) -> None:
# Support single or multiple runs
Expand Down Expand Up @@ -283,6 +286,13 @@ def plot(self, tbox):
xlabel = 'time since start (h)'
fmt = '{:.1f}'
x_vals = [r / divisor for r in rel]

# Keep original values for range tracking before sampling
values_for_range = values

# Sample data points if there are too many to render efficiently
x_vals, values = self._sample_data(x_vals, values)

# Compute per-run ETA and speed (steps/s) using train/epoch, always show if available
eta_str = None
speed_str = None
Expand Down Expand Up @@ -326,10 +336,10 @@ def plot(self, tbox):
global_xmin_step = s_first
if global_xmax_step is None or s_last > global_xmax_step:
global_xmax_step = s_last
# track global y range for ylim validation
if values:
vmin = min(values)
vmax = max(values)
# track global y range for ylim validation (use original values before sampling)
if values_for_range:
vmin = min(values_for_range)
vmax = max(values_for_range)
if global_ymin is None or vmin < global_ymin:
global_ymin = vmin
if global_ymax is None or vmax > global_ymax:
Expand Down Expand Up @@ -537,6 +547,48 @@ def _moving_average(self, values, window):
smoothed.append(total / count)
return smoothed

def _sample_data(self, x_vals, y_vals, max_points=MAX_PLOT_POINTS):
"""
Sample data points to reduce rendering overhead for large datasets.
Uses uniform sampling while preserving first and last points.

Args:
x_vals: List of x-axis values
y_vals: List of y-axis values
max_points: Maximum number of points to keep (default: MAX_PLOT_POINTS)

Returns:
Tuple of (sampled_x_vals, sampled_y_vals)
"""
if len(x_vals) <= max_points or len(y_vals) <= max_points:
return x_vals, y_vals

if len(x_vals) != len(y_vals):
# Safety check - should not happen in practice
return x_vals, y_vals

n = len(x_vals)
# Calculate step size for uniform sampling
# We want to keep approximately max_points
step = n / max_points

# Always keep first point
sampled_x = [x_vals[0]]
sampled_y = [y_vals[0]]

# Sample intermediate points uniformly
for i in range(1, max_points - 1):
idx = int(i * step)
if idx < n - 1: # Ensure we don't accidentally include the last point twice
sampled_x.append(x_vals[idx])
sampled_y.append(y_vals[idx])

# Always keep last point
sampled_x.append(x_vals[-1])
sampled_y.append(y_vals[-1])

return sampled_x, sampled_y

def _format_duration(self, seconds):
try:
secs = max(0, int(round(seconds)))
Expand Down
35 changes: 35 additions & 0 deletions tests/test_viewer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,38 @@ def test_compute_run_epoch_eta_and_speed_from_epoch_series():
assert speed is not None and abs(speed - 1.0) < 1e-6


def test_sample_data_returns_original_when_below_threshold():
dummy = Dummy()
x = list(range(500))
y = [i * 2 for i in range(500)]
x_sampled, y_sampled = TensorboardViewer._sample_data(dummy, x, y, max_points=1000)
assert x_sampled == x
assert y_sampled == y


def test_sample_data_reduces_points_when_above_threshold():
dummy = Dummy()
x = list(range(5000))
y = [i * 2 for i in range(5000)]
x_sampled, y_sampled = TensorboardViewer._sample_data(dummy, x, y, max_points=1000)
# Should reduce to approximately max_points
assert len(x_sampled) <= 1000
assert len(y_sampled) <= 1000
assert len(x_sampled) == len(y_sampled)
# Should preserve first and last points
assert x_sampled[0] == x[0]
assert x_sampled[-1] == x[-1]
assert y_sampled[0] == y[0]
assert y_sampled[-1] == y[-1]


def test_sample_data_preserves_correspondence():
dummy = Dummy()
x = list(range(3000))
y = [i * 3 + 5 for i in range(3000)]
x_sampled, y_sampled = TensorboardViewer._sample_data(dummy, x, y, max_points=500)
# Check that x and y values still correspond (y = x * 3 + 5)
for xs, ys in zip(x_sampled, y_sampled):
assert ys == xs * 3 + 5