Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add tests for Float8Tensor at graph boundaries #196

Closed
wants to merge 3 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jan 24, 2024

Summary

This is a copy with some tweaks of: #166

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 24, 2024
@drisspg drisspg mentioned this pull request Jan 24, 2024
12 tasks
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg requested a review from bdhirsh January 24, 2024 20:08
x = torch.randn(16, 16, device="cuda")
y_compiled = compiled_mod(x)

assert not isinstance(
Copy link
Contributor Author

@drisspg drisspg Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets call flatten so that this test doens't become stale

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rip this assert haha. I promised Voz I would add this check more natively in dynamo here: pytorch/pytorch#118211, so eventually you (hopefully) won't need to feel so paranoid

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -76,5 +78,79 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp
_test_compile_base("inductor", fullgraph, emulate, linear_type, dtype)


class TestGraphBreaks:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) test_compile is a weird file when we have test/dynamo/*
(2) Can we use regular dynamo testing infra? compile counter, etc?
(3) 1 test class per test file please
(4) This class name is wrong? It's testing fp8 + dynamo, not testing graph breaks (I would imagine TestGraphBreaks would test things like deep dynamo internal working)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.) We don't have test/dynamo/*
2.) Sure is there a pointer for this?
3.) I hadn't heard of this rule before is this in some style guide convention somewhere?
4.) I thought thats implied by being in Float8Experimetnal, no?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) I know, but you can put a test there that depends on this, and have CI install it, right? Alternatively, if you depend on dynamo, just pull the deps out

(2) Yes! Check any test under test/dynamo/*, you dont have to inherit from the base class, but it does cover things like .reset() for you nicely (which I think we missed here!) - grep for CompileCounter :)

(3) No, but I think it helps with organization. Its just a tiny nipick

(4) True, that is my mistake.


mod = self.MockLinear(graph_break=False).cuda()
x = torch.randn(2, 2, device="cuda")
compiled_to_float = torch.compile(to_float)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_float could hit a skip, and look exactly identical, and still return as compiled_to_float, and this test would pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right thats a bug 120 should be removed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, thats not what I mean! I mean that if you dont use our test infra that counts frames, you could hit a skip, get a frame_count of 0, and this test would pass!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline - this is the only grounds for my reject.

Copy link

@voznesenskym voznesenskym left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use existing dynamo test infra.

@drisspg drisspg requested a review from voznesenskym January 24, 2024 21:04
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg
Copy link
Contributor Author

drisspg commented Jan 24, 2024

@voznesenskym Updated to use CompileCounterWithBackend

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 289c122.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants