Skip to content

Commit 46f2559

Browse files
sryapfacebook-github-bot
authored andcommitted
Add optimized TBE training forward (#1804)
Summary: Pull Request resolved: #1804 This diff adds the frontend changes and tests for TBE v2 (D43634651) The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for enabling/disabling the new implementation at runtime. If `FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation. If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation. If the TBE usecases are not supported in the new implementation, TBE will fall back to the original implementation. By default, `FBGEMM_EXPERIMENTAL_TBE` is not set. This can also be enabled by passing `use_experimental_tbe=True` when instantiating the TBE operator. ``` emb_op = SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=..., ..., use_experimental_tbe=True, ) ``` Reviewed By: jianyuh Differential Revision: D44479772 fbshipit-source-id: b961811488a25904a3f34660c553067b1ab93c95
1 parent 27ef9a0 commit 46f2559

5 files changed

+66
-6
lines changed

fbgemm_gpu/codegen/lookup_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class CommonArgs(NamedTuple):
4040
lxu_cache_locations: torch.Tensor
4141
output_dtype: int
4242
vbe_metadata: VBEMetadata
43+
is_experimental: bool
4344

4445

4546
class OptimizerArgs(NamedTuple):

fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,5 +277,6 @@ def invoke(
277277
max_counter=max_counter,
278278
{% endif %}
279279
output_dtype=common_args.output_dtype,
280+
is_experimental=common_args.is_experimental,
280281
)
281282
{% endif %}

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import enum
1111
import logging
12+
import os
1213
from dataclasses import dataclass, field
1314
from itertools import accumulate
1415
from math import log2
@@ -209,6 +210,7 @@ def __init__( # noqa C901
209210
device: Optional[Union[str, int, torch.device]] = None,
210211
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
211212
uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM
213+
use_experimental_tbe: bool = False, # set to True to use TBE v2 (only support NVIDIA GPUs)
212214
) -> None:
213215
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
214216

@@ -595,6 +597,22 @@ def __init__( # noqa C901
595597

596598
self.step = 0
597599

600+
# Check whether to use TBE v2
601+
is_experimental = False
602+
fbgemm_exp_tbe = os.environ.get("FBGEMM_EXPERIMENTAL_TBE")
603+
if use_experimental_tbe:
604+
is_experimental = True
605+
logging.info(
606+
"use_experimental_tbe is set to True; Use experimental TBE: True"
607+
)
608+
elif fbgemm_exp_tbe is not None:
609+
is_experimental = int(fbgemm_exp_tbe) == 1
610+
logging.info(
611+
f"FBGEMM_EXPERIMENTAL_TBE is set to {fbgemm_exp_tbe}; "
612+
f"Use experimental TBE: {is_experimental}"
613+
)
614+
self.is_experimental: bool = is_experimental
615+
598616
def _register_nonpersistent_buffers(self, prefix: str) -> None:
599617
# NOTE: make TorchScript work!
600618
self.register_buffer(
@@ -811,6 +829,7 @@ def forward( # noqa: C901
811829
lxu_cache_locations=lxu_cache_locations,
812830
output_dtype=self.output_dtype,
813831
vbe_metadata=vbe_metadata,
832+
is_experimental=self.is_experimental,
814833
)
815834

816835
if self.optimizer == OptimType.EXACT_SGD:

fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def forward(
451451
max_B_feature_rank=-1,
452452
output_size=-1,
453453
),
454+
is_experimental=False,
454455
)
455456

456457
momentum1 = invokers.lookup_args.Momentum(

fbgemm_gpu/test/split_table_batched_embeddings_test.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def execute_forward_( # noqa C901
156156
pooling_mode: PoolingMode,
157157
use_cpu: bool,
158158
output_dtype: SparseType,
159+
use_experimental_tbe: bool,
159160
) -> None:
160161
# NOTE: cache is not applicable to CPU version.
161162
assume(not use_cpu or not use_cache)
@@ -324,6 +325,7 @@ def execute_forward_( # noqa C901
324325
cache_algorithm=cache_algorithm,
325326
pooling_mode=pooling_mode,
326327
output_dtype=output_dtype,
328+
use_experimental_tbe=use_experimental_tbe,
327329
)
328330
# NOTE: test TorchScript-compatible!
329331
cc = torch.jit.script(cc)
@@ -412,6 +414,7 @@ def test_forward_cpu_int8(
412414
pooling_mode,
413415
use_cpu,
414416
SparseType.FP32,
417+
False, # use_experimental_tbe
415418
)
416419

417420
def test_forward_cpu_fp32(
@@ -456,6 +459,7 @@ def test_forward_cpu_fp32(
456459
pooling_mode,
457460
use_cpu,
458461
SparseType.FP32,
462+
False, # use_experimental_tbe
459463
)
460464

461465
@unittest.skipIf(*gpu_unavailable)
@@ -505,11 +509,22 @@ def test_forward_gpu_no_cache_int8(
505509
pooling_mode,
506510
use_cpu,
507511
SparseType.FP32,
512+
False, # use_experimental_tbe
508513
)
509514

510515
@unittest.skipIf(*gpu_unavailable)
516+
@given(
517+
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
518+
)
519+
@settings(
520+
verbosity=Verbosity.verbose,
521+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
522+
deadline=None,
523+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
524+
)
511525
def test_forward_gpu_no_cache_fp16(
512526
self,
527+
use_experimental_tbe: bool,
513528
) -> None:
514529
weights_precision = SparseType.FP16
515530
use_cpu = False
@@ -527,15 +542,17 @@ def test_forward_gpu_no_cache_fp16(
527542
[
528543
PoolingMode.SUM,
529544
PoolingMode.MEAN,
530-
PoolingMode.NONE,
531545
]
546+
+ ([PoolingMode.NONE] if not use_experimental_tbe else [])
532547
)
533548
if pooling_mode == PoolingMode.NONE:
534549
mixed = False
535550
mixed_B = False
536551
else:
537552
mixed = random.choice([True, False])
538-
mixed_B = random.choice([True, False])
553+
mixed_B = (
554+
random.choice([True, False]) if not use_experimental_tbe else False
555+
)
539556
if pooling_mode == PoolingMode.SUM:
540557
weighted = random.choice([True, False])
541558
else:
@@ -555,11 +572,22 @@ def test_forward_gpu_no_cache_fp16(
555572
pooling_mode,
556573
use_cpu,
557574
SparseType.FP32,
575+
use_experimental_tbe,
558576
)
559577

560578
@unittest.skipIf(*gpu_unavailable)
579+
@given(
580+
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
581+
)
582+
@settings(
583+
verbosity=Verbosity.verbose,
584+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
585+
deadline=None,
586+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
587+
)
561588
def test_forward_gpu_no_cache_fp32(
562589
self,
590+
use_experimental_tbe: bool,
563591
) -> None:
564592
weights_precision = SparseType.FP32
565593
use_cpu = False
@@ -577,15 +605,17 @@ def test_forward_gpu_no_cache_fp32(
577605
[
578606
PoolingMode.SUM,
579607
PoolingMode.MEAN,
580-
PoolingMode.NONE,
581608
]
609+
+ ([PoolingMode.NONE] if not use_experimental_tbe else [])
582610
)
583611
if pooling_mode == PoolingMode.NONE:
584612
mixed = False
585613
mixed_B = False
586614
else:
587615
mixed = random.choice([True, False])
588-
mixed_B = random.choice([True, False])
616+
mixed_B = (
617+
random.choice([True, False]) if not use_experimental_tbe else False
618+
)
589619
if pooling_mode == PoolingMode.SUM:
590620
weighted = random.choice([True, False])
591621
else:
@@ -605,6 +635,7 @@ def test_forward_gpu_no_cache_fp32(
605635
pooling_mode,
606636
use_cpu,
607637
SparseType.FP32,
638+
use_experimental_tbe,
608639
)
609640

610641
@unittest.skipIf(*gpu_unavailable)
@@ -668,11 +699,13 @@ def test_forward_gpu_uvm_cache_int8(
668699
pooling_mode,
669700
use_cpu,
670701
output_dtype,
702+
False, # use_experimental_tbe
671703
)
672704

673705
@unittest.skipIf(*gpu_unavailable)
674706
@given(
675707
cache_algorithm=st.sampled_from(CacheAlgorithm),
708+
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
676709
)
677710
@settings(
678711
verbosity=Verbosity.verbose,
@@ -683,6 +716,7 @@ def test_forward_gpu_uvm_cache_int8(
683716
def test_forward_gpu_uvm_cache_fp16(
684717
self,
685718
cache_algorithm: CacheAlgorithm,
719+
use_experimental_tbe: bool,
686720
) -> None:
687721
weights_precision = SparseType.FP16
688722
use_cpu = False
@@ -698,8 +732,8 @@ def test_forward_gpu_uvm_cache_fp16(
698732
[
699733
PoolingMode.SUM,
700734
PoolingMode.MEAN,
701-
PoolingMode.NONE,
702735
]
736+
+ ([PoolingMode.NONE] if not use_experimental_tbe else [])
703737
)
704738
output_dtype = random.choice(
705739
[
@@ -731,11 +765,13 @@ def test_forward_gpu_uvm_cache_fp16(
731765
pooling_mode,
732766
use_cpu,
733767
output_dtype,
768+
use_experimental_tbe,
734769
)
735770

736771
@unittest.skipIf(*gpu_unavailable)
737772
@given(
738773
cache_algorithm=st.sampled_from(CacheAlgorithm),
774+
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
739775
)
740776
@settings(
741777
verbosity=Verbosity.verbose,
@@ -746,6 +782,7 @@ def test_forward_gpu_uvm_cache_fp16(
746782
def test_forward_gpu_uvm_cache_fp32(
747783
self,
748784
cache_algorithm: CacheAlgorithm,
785+
use_experimental_tbe: bool,
749786
) -> None:
750787
weights_precision = SparseType.FP32
751788
use_cpu = False
@@ -761,8 +798,8 @@ def test_forward_gpu_uvm_cache_fp32(
761798
[
762799
PoolingMode.SUM,
763800
PoolingMode.MEAN,
764-
PoolingMode.NONE,
765801
]
802+
+ ([PoolingMode.NONE] if not use_experimental_tbe else [])
766803
)
767804
output_dtype = random.choice(
768805
[
@@ -794,6 +831,7 @@ def test_forward_gpu_uvm_cache_fp32(
794831
pooling_mode,
795832
use_cpu,
796833
output_dtype,
834+
use_experimental_tbe,
797835
)
798836

799837
@unittest.skipIf(*gpu_unavailable)

0 commit comments

Comments
 (0)