Skip to content

Commit 6a37bfc

Browse files
brillstfx-copybara
authored andcommitted
Flush the accumulated Tables if the total byte size of them exceeds a threshold.
This is to address the issue where a fixed desired_batch_size may produce an excessive large merged Table. PiperOrigin-RevId: 297463177
1 parent aae8510 commit 6a37bfc

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

tensorflow_data_validation/statistics/stats_impl.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@
5050
from tensorflow_metadata.proto.v0 import statistics_pb2
5151

5252

53+
# The combiner accumulates tables from the upstream and merge them when certain
54+
# conditions are met. A merged table would allow better vectorized processing,
55+
# but we have to pay for copying and the RAM to contain the merged table.
56+
# If the total byte size of accumulated tables exceeds this threshold a merge
57+
# will be forced to avoid consuming too much memory.
58+
_MERGE_TABLE_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB
59+
60+
5361
@beam.typehints.with_input_types(pa.Table)
5462
@beam.typehints.with_output_types(statistics_pb2.DatasetFeatureStatisticsList)
5563
class GenerateStatisticsImpl(beam.PTransform):
@@ -505,7 +513,8 @@ def extract_output(self, accumulator: List[float]
505513
class _CombinerStatsGeneratorsCombineFnAcc(object):
506514
"""accumulator for _CombinerStatsGeneratorsCombineFn."""
507515

508-
__slots__ = ['partial_accumulators', 'input_tables', 'curr_batch_size']
516+
__slots__ = ['partial_accumulators', 'input_tables', 'curr_batch_size',
517+
'curr_byte_size']
509518

510519
def __init__(self, partial_accumulators: List[Any]):
511520
# Partial accumulator states of the underlying CombinerStatsGenerators.
@@ -514,6 +523,8 @@ def __init__(self, partial_accumulators: List[Any]):
514523
self.input_tables = []
515524
# Current batch size.
516525
self.curr_batch_size = 0
526+
# Current total byte size of all the pa.Tables accumulated.
527+
self.curr_byte_size = 0
517528

518529

519530
@beam.typehints.with_input_types(pa.Table)
@@ -544,7 +555,7 @@ class _CombinerStatsGeneratorsCombineFn(beam.CombineFn):
544555
"""
545556

546557
__slots__ = ['_generators', '_desired_batch_size', '_combine_batch_size',
547-
'_num_compacts', '_num_instances']
558+
'_combine_byte_size', '_num_compacts', '_num_instances']
548559

549560
# This needs to be large enough to allow for efficient merging of
550561
# accumulators in the individual stats generators, but shouldn't be too large
@@ -569,6 +580,8 @@ def __init__(
569580
# Metrics
570581
self._combine_batch_size = beam.metrics.Metrics.distribution(
571582
constants.METRICS_NAMESPACE, 'combine_batch_size')
583+
self._combine_byte_size = beam.metrics.Metrics.distribution(
584+
constants.METRICS_NAMESPACE, 'combine_byte_size')
572585
self._num_compacts = beam.metrics.Metrics.counter(
573586
constants.METRICS_NAMESPACE, 'num_compacts')
574587
self._num_instances = beam.metrics.Metrics.counter(
@@ -596,6 +609,20 @@ def create_accumulator(self
596609
return _CombinerStatsGeneratorsCombineFnAcc(
597610
[g.create_accumulator() for g in self._generators])
598611

612+
def _should_do_batch(self, accumulator: _CombinerStatsGeneratorsCombineFnAcc,
613+
force: bool) -> bool:
614+
curr_batch_size = accumulator.curr_batch_size
615+
if force and curr_batch_size > 0:
616+
return True
617+
618+
if curr_batch_size >= self._desired_batch_size:
619+
return True
620+
621+
if accumulator.curr_byte_size >= _MERGE_TABLE_BYTE_SIZE_THRESHOLD:
622+
return True
623+
624+
return False
625+
599626
def _maybe_do_batch(
600627
self,
601628
accumulator: _CombinerStatsGeneratorsCombineFnAcc,
@@ -610,9 +637,9 @@ def _maybe_do_batch(
610637
force: Force computation of stats even if accumulator has less examples
611638
than the batch size.
612639
"""
613-
batch_size = accumulator.curr_batch_size
614-
if (force and batch_size > 0) or batch_size >= self._desired_batch_size:
615-
self._combine_batch_size.update(batch_size)
640+
if self._should_do_batch(accumulator, force):
641+
self._combine_batch_size.update(accumulator.curr_batch_size)
642+
self._combine_byte_size.update(accumulator.curr_byte_size)
616643
if len(accumulator.input_tables) == 1:
617644
arrow_table = accumulator.input_tables[0]
618645
else:
@@ -622,6 +649,7 @@ def _maybe_do_batch(
622649
accumulator.partial_accumulators)
623650
del accumulator.input_tables[:]
624651
accumulator.curr_batch_size = 0
652+
accumulator.curr_byte_size = 0
625653

626654
def add_input(
627655
self, accumulator: _CombinerStatsGeneratorsCombineFnAcc,
@@ -630,6 +658,7 @@ def add_input(
630658
accumulator.input_tables.append(input_table)
631659
num_rows = input_table.num_rows
632660
accumulator.curr_batch_size += num_rows
661+
accumulator.curr_byte_size += table_util.TotalByteSize(input_table)
633662
self._maybe_do_batch(accumulator)
634663
self._num_instances.inc(num_rows)
635664
return accumulator
@@ -657,6 +686,7 @@ def merge_accumulators(
657686
for acc in batched_accumulators:
658687
result.input_tables.extend(acc.input_tables)
659688
result.curr_batch_size += acc.curr_batch_size
689+
result.curr_byte_size += acc.curr_byte_size
660690
self._maybe_do_batch(result)
661691
batched_partial_accumulators.append(acc.partial_accumulators)
662692

0 commit comments

Comments
 (0)