Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation
from torch import nn
from torch import fx, nn
from torchrec.distributed.types import (
ModuleSharder,
ParameterStorage,
Expand Down Expand Up @@ -88,6 +88,17 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
if self.id_score_list_features is not None:
self.id_score_list_features.record_stream(stream)

def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument:
return tracer.create_node(
"call_function",
SparseFeatures,
args=(
tracer.create_arg(self.id_list_features),
tracer.create_arg(self.id_score_list_features),
),
kwargs={},
)


class SparseFeaturesList(Multistreamable):
def __init__(self, features: List[SparseFeatures]) -> None:
Expand All @@ -109,6 +120,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature in self.features:
feature.record_stream(stream)

def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument:
return tracer.create_node(
"call_function",
SparseFeaturesList,
args=(tracer.create_arg(self.features),),
kwargs={},
)


class ListOfSparseFeaturesList(Multistreamable):
def __init__(self, features: List[SparseFeaturesList]) -> None:
Expand All @@ -130,6 +149,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature in self.features_list:
feature.record_stream(stream)

def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument:
return tracer.create_node(
"call_function",
ListOfSparseFeaturesList,
args=(tracer.create_arg(self.features_list),),
kwargs={},
)


@dataclass
class ShardedConfig:
Expand Down