Skip to content

Commit

Permalink
Turn on script and sharded output comparision (pytorch#1687)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1687

As titled

Reviewed By: sarckk

Differential Revision: D53497548

fbshipit-source-id: 5f88616e24cadb63ab589caaff40869cec19bbd5
  • Loading branch information
gnahzg authored and facebook-github-bot committed Feb 8, 2024
1 parent 3f1f29f commit b617cf8
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,8 @@ def test_rw_uneven_sharding_mutiple_table(

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
_ = gm_script(*inputs[0])
# TODO (drqiangzhang): Add comparison between scripted and nonscripted model outputs
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down Expand Up @@ -1169,8 +1169,7 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None:
assert count_registered_fp == world_size

sharded_output = sharded_model(*inputs[0])
# TODO(ivankobzarev): check the correctness of non_sharded vs sharded
# assert_close(non_sharded_output, sharded_output)
assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(
sharded_model,
Expand All @@ -1197,5 +1196,3 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None:
print(f"gm_script:\n{gm_script}")
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)
_ = gm_script(*inputs[0])
# TODO (drqiangzhang): Add comparison between scripted and nonscripted model outputs

0 comments on commit b617cf8

Please sign in to comment.