Skip to content

Commit 8482bfc

Browse files
perf: Use flyweight for node fields (#1654)
1 parent 9128c4a commit 8482bfc

File tree

5 files changed

+214
-38
lines changed

5 files changed

+214
-38
lines changed

bigframes/core/bigframe_node.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
import functools
2121
import itertools
2222
import typing
23-
from typing import Callable, Dict, Generator, Iterable, Mapping, Set, Tuple
23+
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Set, Tuple
2424

2525
from bigframes.core import identifiers
26-
import bigframes.core.guid
2726
import bigframes.core.schema as schemata
2827
import bigframes.dtypes
2928

@@ -163,7 +162,7 @@ def roots(self) -> typing.Set[BigFrameNode]:
163162
# TODO: Store some local data lazily for select, aggregate nodes.
164163
@property
165164
@abc.abstractmethod
166-
def fields(self) -> Iterable[Field]:
165+
def fields(self) -> Sequence[Field]:
167166
...
168167

169168
@property

bigframes/core/nodes.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import google.cloud.bigquery as bq
3535

36-
from bigframes.core import identifiers, local_data
36+
from bigframes.core import identifiers, local_data, sequences
3737
from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET, Field
3838
import bigframes.core.expression as ex
3939
from bigframes.core.ordering import OrderingExpression, RowOrdering
@@ -87,7 +87,7 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
8787
return (self.child,)
8888

8989
@property
90-
def fields(self) -> Iterable[Field]:
90+
def fields(self) -> Sequence[Field]:
9191
return self.child.fields
9292

9393
@property
@@ -226,8 +226,8 @@ def added_fields(self) -> Tuple[Field, ...]:
226226
return (Field(self.indicator_col, bigframes.dtypes.BOOL_DTYPE, nullable=False),)
227227

228228
@property
229-
def fields(self) -> Iterable[Field]:
230-
return itertools.chain(
229+
def fields(self) -> Sequence[Field]:
230+
return sequences.ChainedSequence(
231231
self.left_child.fields,
232232
self.added_fields,
233233
)
@@ -321,15 +321,15 @@ def order_ambiguous(self) -> bool:
321321
def explicitly_ordered(self) -> bool:
322322
return self.propogate_order
323323

324-
@property
325-
def fields(self) -> Iterable[Field]:
326-
left_fields = self.left_child.fields
324+
@functools.cached_property
325+
def fields(self) -> Sequence[Field]:
326+
left_fields: Iterable[Field] = self.left_child.fields
327327
if self.type in ("right", "outer"):
328328
left_fields = map(lambda x: x.with_nullable(), left_fields)
329-
right_fields = self.right_child.fields
329+
right_fields: Iterable[Field] = self.right_child.fields
330330
if self.type in ("left", "outer"):
331331
right_fields = map(lambda x: x.with_nullable(), right_fields)
332-
return itertools.chain(left_fields, right_fields)
332+
return (*left_fields, *right_fields)
333333

334334
@property
335335
def joins_nulls(self) -> bool:
@@ -430,10 +430,10 @@ def explicitly_ordered(self) -> bool:
430430
return True
431431

432432
@property
433-
def fields(self) -> Iterable[Field]:
433+
def fields(self) -> Sequence[Field]:
434434
# TODO: Output names should probably be aligned beforehand or be part of concat definition
435435
# TODO: Handle nullability
436-
return (
436+
return tuple(
437437
Field(id, field.dtype)
438438
for id, field in zip(self.output_ids, self.children[0].fields)
439439
)
@@ -505,7 +505,7 @@ def explicitly_ordered(self) -> bool:
505505
return True
506506

507507
@functools.cached_property
508-
def fields(self) -> Iterable[Field]:
508+
def fields(self) -> Sequence[Field]:
509509
return (
510510
Field(self.output_id, next(iter(self.start.fields)).dtype, nullable=False),
511511
)
@@ -626,12 +626,20 @@ class ReadLocalNode(LeafNode):
626626
session: typing.Optional[bigframes.session.Session] = None
627627

628628
@property
629-
def fields(self) -> Iterable[Field]:
630-
fields = (Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
629+
def fields(self) -> Sequence[Field]:
630+
fields = tuple(
631+
Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items
632+
)
631633
if self.offsets_col is not None:
632-
return itertools.chain(
633-
fields,
634-
(Field(self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False),),
634+
return tuple(
635+
itertools.chain(
636+
fields,
637+
(
638+
Field(
639+
self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False
640+
),
641+
),
642+
)
635643
)
636644
return fields
637645

@@ -767,8 +775,8 @@ def session(self):
767775
return self.table_session
768776

769777
@property
770-
def fields(self) -> Iterable[Field]:
771-
return (
778+
def fields(self) -> Sequence[Field]:
779+
return tuple(
772780
Field(col_id, dtype, self.source.table.schema_by_id[source_id].is_nullable)
773781
for col_id, dtype, source_id in self.scan_list.items
774782
)
@@ -881,8 +889,8 @@ def non_local(self) -> bool:
881889
return True
882890

883891
@property
884-
def fields(self) -> Iterable[Field]:
885-
return itertools.chain(self.child.fields, self.added_fields)
892+
def fields(self) -> Sequence[Field]:
893+
return sequences.ChainedSequence(self.child.fields, self.added_fields)
886894

887895
@property
888896
def relation_ops_created(self) -> int:
@@ -1097,7 +1105,7 @@ def _validate(self):
10971105
raise ValueError(f"Reference to column not in child: {ref.id}")
10981106

10991107
@functools.cached_property
1100-
def fields(self) -> Iterable[Field]:
1108+
def fields(self) -> Sequence[Field]:
11011109
input_fields_by_id = {field.id: field for field in self.child.fields}
11021110
return tuple(
11031111
Field(
@@ -1192,8 +1200,8 @@ def added_fields(self) -> Tuple[Field, ...]:
11921200
return tuple(fields)
11931201

11941202
@property
1195-
def fields(self) -> Iterable[Field]:
1196-
return itertools.chain(self.child.fields, self.added_fields)
1203+
def fields(self) -> Sequence[Field]:
1204+
return sequences.ChainedSequence(self.child.fields, self.added_fields)
11971205

11981206
@property
11991207
def variables_introduced(self) -> int:
@@ -1263,7 +1271,7 @@ def non_local(self) -> bool:
12631271
return True
12641272

12651273
@property
1266-
def fields(self) -> Iterable[Field]:
1274+
def fields(self) -> Sequence[Field]:
12671275
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE, nullable=False),)
12681276

12691277
@property
@@ -1313,7 +1321,7 @@ def non_local(self) -> bool:
13131321
return True
13141322

13151323
@functools.cached_property
1316-
def fields(self) -> Iterable[Field]:
1324+
def fields(self) -> Sequence[Field]:
13171325
# TODO: Use child nullability to infer grouping key nullability
13181326
by_fields = (self.child.field_by_id[ref.id] for ref in self.by_column_ids)
13191327
if self.dropna:
@@ -1411,8 +1419,8 @@ def non_local(self) -> bool:
14111419
return True
14121420

14131421
@property
1414-
def fields(self) -> Iterable[Field]:
1415-
return itertools.chain(self.child.fields, [self.added_field])
1422+
def fields(self) -> Sequence[Field]:
1423+
return sequences.ChainedSequence(self.child.fields, (self.added_field,))
14161424

14171425
@property
14181426
def variables_introduced(self) -> int:
@@ -1547,7 +1555,7 @@ def row_preserving(self) -> bool:
15471555
return False
15481556

15491557
@property
1550-
def fields(self) -> Iterable[Field]:
1558+
def fields(self) -> Sequence[Field]:
15511559
fields = (
15521560
Field(
15531561
field.id,
@@ -1561,11 +1569,17 @@ def fields(self) -> Iterable[Field]:
15611569
for field in self.child.fields
15621570
)
15631571
if self.offsets_col is not None:
1564-
return itertools.chain(
1565-
fields,
1566-
(Field(self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False),),
1572+
return tuple(
1573+
itertools.chain(
1574+
fields,
1575+
(
1576+
Field(
1577+
self.offsets_col, bigframes.dtypes.INT_DTYPE, nullable=False
1578+
),
1579+
),
1580+
)
15671581
)
1568-
return fields
1582+
return tuple(fields)
15691583

15701584
@property
15711585
def relation_ops_created(self) -> int:

bigframes/core/schema.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from dataclasses import dataclass
1818
import functools
1919
import typing
20+
from typing import Sequence
2021

2122
import google.cloud.bigquery
2223
import pyarrow
2324

24-
import bigframes.core.guid
2525
import bigframes.dtypes
2626

2727
ColumnIdentifierType = str
@@ -35,7 +35,10 @@ class SchemaItem:
3535

3636
@dataclass(frozen=True)
3737
class ArraySchema:
38-
items: typing.Tuple[SchemaItem, ...]
38+
items: Sequence[SchemaItem]
39+
40+
def __iter__(self):
41+
yield from self.items
3942

4043
@classmethod
4144
def from_bq_table(

bigframes/core/sequences.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import collections.abc
18+
import functools
19+
import itertools
20+
from typing import Iterable, Iterator, Sequence, TypeVar
21+
22+
ColumnIdentifierType = str
23+
24+
25+
T = TypeVar("T")
26+
27+
# Further optimizations possible:
28+
# * Support mapping operators
29+
# * Support insertions and deletions
30+
31+
32+
class ChainedSequence(collections.abc.Sequence[T]):
33+
"""
34+
Memory-optimized sequence from composing chain of existing sequences.
35+
36+
Will use the provided parts as underlying storage - so do not mutate provided parts.
37+
May merge small underlying parts for better access performance.
38+
"""
39+
40+
def __init__(self, *parts: Sequence[T]):
41+
# Could build an index that makes random access faster?
42+
self._parts: tuple[Sequence[T], ...] = tuple(
43+
_defrag_parts(_flatten_parts(parts))
44+
)
45+
46+
def __getitem__(self, index):
47+
if isinstance(index, slice):
48+
return tuple(self)[index]
49+
if index < 0:
50+
index = len(self) + index
51+
if index < 0:
52+
raise IndexError("Index out of bounds")
53+
54+
offset = 0
55+
for part in self._parts:
56+
if (index - offset) < len(part):
57+
return part[index - offset]
58+
offset += len(part)
59+
raise IndexError("Index out of bounds")
60+
61+
@functools.cache
62+
def __len__(self):
63+
return sum(map(len, self._parts))
64+
65+
def __iter__(self):
66+
for part in self._parts:
67+
yield from part
68+
69+
70+
def _flatten_parts(parts: Iterable[Sequence[T]]) -> Iterator[Sequence[T]]:
71+
for part in parts:
72+
if isinstance(part, ChainedSequence):
73+
yield from part._parts
74+
else:
75+
yield part
76+
77+
78+
# Should be a cache-friendly chunk size?
79+
_TARGET_SIZE = 128
80+
_MAX_MERGABLE = 32
81+
82+
83+
def _defrag_parts(parts: Iterable[Sequence[T]]) -> Iterator[Sequence[T]]:
84+
"""
85+
Merge small chunks into larger chunks for better performance.
86+
"""
87+
parts_queue: list[Sequence[T]] = []
88+
queued_items = 0
89+
for part in parts:
90+
# too big, just yield from the buffer
91+
if len(part) > _MAX_MERGABLE:
92+
yield from parts_queue
93+
parts_queue = []
94+
queued_items = 0
95+
yield part
96+
else: # can be merged, so lets add to the queue
97+
parts_queue.append(part)
98+
queued_items += len(part)
99+
# if queue has reached target size, merge, dump and reset queue
100+
if queued_items >= _TARGET_SIZE:
101+
yield tuple(itertools.chain(*parts_queue))
102+
parts_queue = []
103+
queued_items = 0
104+
105+
yield from parts_queue

0 commit comments

Comments
 (0)