Skip to content

ENH: speed up DataFrame.plot using LineCollection #61532

Open
@Abdelgha-4

Description

@Abdelgha-4

Description:

When plotting line charts with many columns or rows, DataFrame.plot() currently adds one Line2D object per column. This incurs significant overhead in large datasets.

Replacing this with a single LineCollection (from matplotlib.collections) can yield substantial speedups. In my benchmarks, plotting via LineCollection was ~2.5× faster on large DataFrames with many columns.

Minimal example:

# Imports and data generation
import itertools

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import LineCollection

num_rows = 500
num_cols = 2000

test_df = pd.DataFrame(np.random.randn(num_rows, num_cols).cumsum(axis=0))

# Simply using DataFrame.plot, (5.6 secs)
test_df.plot(legend=False, figsize=(12, 8))
plt.show()

# Optimized version using LineCollection (2.2 secs)
x = np.arange(len(test_df.index))
lines = [np.column_stack([x, test_df[col].values]) for col in test_df.columns]
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
color_cycle = list(itertools.islice(itertools.cycle(default_colors), len(lines)))

line_collection = LineCollection(lines, colors=color_cycle)
fig, ax = plt.subplots(figsize=(12, 8))
ax.add_collection(line_collection)
ax.margins(0.05)
plt.show()

Note: the ~2.5x speed improvement is specific to dataframes with integer index. For dataframes with DatetimeIndex the actual speed improvement is ~27x when combined with the workaround here: #61398

Thank you for considering this suggestion!

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions