50
50
from tensorflow_metadata .proto .v0 import statistics_pb2
51
51
52
52
53
+ # The combiner accumulates tables from the upstream and merge them when certain
54
+ # conditions are met. A merged table would allow better vectorized processing,
55
+ # but we have to pay for copying and the RAM to contain the merged table.
56
+ # If the total byte size of accumulated tables exceeds this threshold a merge
57
+ # will be forced to avoid consuming too much memory.
58
+ _MERGE_TABLE_BYTE_SIZE_THRESHOLD = 20 << 20 # 20MiB
59
+
60
+
53
61
@beam .typehints .with_input_types (pa .Table )
54
62
@beam .typehints .with_output_types (statistics_pb2 .DatasetFeatureStatisticsList )
55
63
class GenerateStatisticsImpl (beam .PTransform ):
@@ -505,7 +513,8 @@ def extract_output(self, accumulator: List[float]
505
513
class _CombinerStatsGeneratorsCombineFnAcc (object ):
506
514
"""accumulator for _CombinerStatsGeneratorsCombineFn."""
507
515
508
- __slots__ = ['partial_accumulators' , 'input_tables' , 'curr_batch_size' ]
516
+ __slots__ = ['partial_accumulators' , 'input_tables' , 'curr_batch_size' ,
517
+ 'curr_byte_size' ]
509
518
510
519
def __init__ (self , partial_accumulators : List [Any ]):
511
520
# Partial accumulator states of the underlying CombinerStatsGenerators.
@@ -514,6 +523,8 @@ def __init__(self, partial_accumulators: List[Any]):
514
523
self .input_tables = []
515
524
# Current batch size.
516
525
self .curr_batch_size = 0
526
+ # Current total byte size of all the pa.Tables accumulated.
527
+ self .curr_byte_size = 0
517
528
518
529
519
530
@beam .typehints .with_input_types (pa .Table )
@@ -544,7 +555,7 @@ class _CombinerStatsGeneratorsCombineFn(beam.CombineFn):
544
555
"""
545
556
546
557
__slots__ = ['_generators' , '_desired_batch_size' , '_combine_batch_size' ,
547
- '_num_compacts' , '_num_instances' ]
558
+ '_combine_byte_size' , ' _num_compacts' , '_num_instances' ]
548
559
549
560
# This needs to be large enough to allow for efficient merging of
550
561
# accumulators in the individual stats generators, but shouldn't be too large
@@ -569,6 +580,8 @@ def __init__(
569
580
# Metrics
570
581
self ._combine_batch_size = beam .metrics .Metrics .distribution (
571
582
constants .METRICS_NAMESPACE , 'combine_batch_size' )
583
+ self ._combine_byte_size = beam .metrics .Metrics .distribution (
584
+ constants .METRICS_NAMESPACE , 'combine_byte_size' )
572
585
self ._num_compacts = beam .metrics .Metrics .counter (
573
586
constants .METRICS_NAMESPACE , 'num_compacts' )
574
587
self ._num_instances = beam .metrics .Metrics .counter (
@@ -596,6 +609,20 @@ def create_accumulator(self
596
609
return _CombinerStatsGeneratorsCombineFnAcc (
597
610
[g .create_accumulator () for g in self ._generators ])
598
611
612
+ def _should_do_batch (self , accumulator : _CombinerStatsGeneratorsCombineFnAcc ,
613
+ force : bool ) -> bool :
614
+ curr_batch_size = accumulator .curr_batch_size
615
+ if force and curr_batch_size > 0 :
616
+ return True
617
+
618
+ if curr_batch_size >= self ._desired_batch_size :
619
+ return True
620
+
621
+ if accumulator .curr_byte_size >= _MERGE_TABLE_BYTE_SIZE_THRESHOLD :
622
+ return True
623
+
624
+ return False
625
+
599
626
def _maybe_do_batch (
600
627
self ,
601
628
accumulator : _CombinerStatsGeneratorsCombineFnAcc ,
@@ -610,9 +637,9 @@ def _maybe_do_batch(
610
637
force: Force computation of stats even if accumulator has less examples
611
638
than the batch size.
612
639
"""
613
- batch_size = accumulator . curr_batch_size
614
- if ( force and batch_size > 0 ) or batch_size >= self ._desired_batch_size :
615
- self ._combine_batch_size .update (batch_size )
640
+ if self . _should_do_batch ( accumulator , force ):
641
+ self ._combine_batch_size . update ( accumulator . curr_batch_size )
642
+ self ._combine_byte_size .update (accumulator . curr_byte_size )
616
643
if len (accumulator .input_tables ) == 1 :
617
644
arrow_table = accumulator .input_tables [0 ]
618
645
else :
@@ -622,6 +649,7 @@ def _maybe_do_batch(
622
649
accumulator .partial_accumulators )
623
650
del accumulator .input_tables [:]
624
651
accumulator .curr_batch_size = 0
652
+ accumulator .curr_byte_size = 0
625
653
626
654
def add_input (
627
655
self , accumulator : _CombinerStatsGeneratorsCombineFnAcc ,
@@ -630,6 +658,7 @@ def add_input(
630
658
accumulator .input_tables .append (input_table )
631
659
num_rows = input_table .num_rows
632
660
accumulator .curr_batch_size += num_rows
661
+ accumulator .curr_byte_size += table_util .TotalByteSize (input_table )
633
662
self ._maybe_do_batch (accumulator )
634
663
self ._num_instances .inc (num_rows )
635
664
return accumulator
@@ -657,6 +686,7 @@ def merge_accumulators(
657
686
for acc in batched_accumulators :
658
687
result .input_tables .extend (acc .input_tables )
659
688
result .curr_batch_size += acc .curr_batch_size
689
+ result .curr_byte_size += acc .curr_byte_size
660
690
self ._maybe_do_batch (result )
661
691
batched_partial_accumulators .append (acc .partial_accumulators )
662
692
0 commit comments