Skip to content

create fx argument for sparsefeature types #708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation
from torch import nn
from torch import fx, nn
from torchrec.distributed.types import (
ModuleSharder,
ParameterStorage,
Expand Down Expand Up @@ -88,6 +88,17 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
if self.id_score_list_features is not None:
self.id_score_list_features.record_stream(stream)

def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument:
return tracer.create_node(
"call_function",
SparseFeatures,
args=(
tracer.create_arg(self.id_list_features),
tracer.create_arg(self.id_score_list_features),
),
kwargs={},
)


class SparseFeaturesList(Multistreamable):
def __init__(self, features: List[SparseFeatures]) -> None:
Expand All @@ -109,6 +120,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature in self.features:
feature.record_stream(stream)

def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument:
return tracer.create_node(
"call_function",
SparseFeaturesList,
args=(tracer.create_arg(self.features),),
kwargs={},
)


class ListOfSparseFeaturesList(Multistreamable):
def __init__(self, features: List[SparseFeaturesList]) -> None:
Expand All @@ -130,6 +149,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature in self.features_list:
feature.record_stream(stream)

def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> fx.node.Argument:
return tracer.create_node(
"call_function",
ListOfSparseFeaturesList,
args=(tracer.create_arg(self.features_list),),
kwargs={},
)


@dataclass
class ShardedConfig:
Expand Down