Skip to content

fix: Add special cases for clone and to_copy where input of graph is output #2265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 20, 2023

Conversation

gs-olive
Copy link
Collaborator

Description

  • TRT does not allow inputs of graphs to be outputs as well, however many of the scenarios encountered in real models can have this situation come up, especially in cases where the input is cloned or copied and then returned
  • The current converters will register these operators as a no-op, causing TRT engine building to fail on such inputs
  • Instead of requiring creation of an identity layer for every case of a clone or copy node, we instead check if that node is the only operator on a placeholder (input) and then insert the identity layer or not, accordingly
  • Coalesce implementations of clone and to_copy, which are effectively the same operator
  • Add test cases to validate new behavior
  • Add new boilerplate converter validator utility to support this case

Addresses bug in #1565

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Aug 25, 2023
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Aug 25, 2023
@github-actions github-actions bot requested a review from apbose August 25, 2023 17:45
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive gs-olive force-pushed the get_clone_fixes branch 2 times, most recently from 4f4556f to bcfa6c1 Compare August 25, 2023 20:56
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@narendasan
Copy link
Collaborator

@gs-olive is this only for graphs like:

graph(x: Tensor):
  return x

Is it worth even converting these into engines?

@gs-olive
Copy link
Collaborator Author

@narendasan - This change is primarily for cases encountered in detectron-style models where inputs are also outputs of the engine. It might be something like:

graph(x: Tensor, y: Tensor):
  # A lot of logic/operations applied to y
  ...
  return x, y_new

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@narendasan
Copy link
Collaborator

I see

@@ -103,9 +104,13 @@ def convert_binary_elementwise(
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this getting converted to np array in place of the torch.Tensor. Is that because the torch.Tensor was leading to clones getting produced in the graph?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating torch.Tensor objects within the torch.compile scope causes them to be "Fake-ified" which removes the data contained within them.

and any(user.op == "output" for user in list(node.users.keys()))
)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there any used here? I believe that the above is for to_copy and clone cases having one node arg only? Or are there different cases for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any is used here because we are checking if any of the users of the node are outputs, meaning that the node is the only function between a placeholder and an output. We are effectively search for subgraphs where an input is followed by a function is followed by an output.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

- TRT does not allow inputs of graphs to be outputs as well, however
many of the scenarios encountered in real models can have this
situation come up, especially in cases where the input is cloned or
copied and then returned
- The current converters will register these operators as a no-op,
causing TRT engine building to fail on such inputs
- Instead of requiring creation of an identity layer for every case of a
clone or copy node, we instead check if that node is the only operator
on a placeholder (input) and then insert the identity layer or not,
accordingly
- Coalesce implementations of clone and to_copy, which are effectively
the same operator
- Add test cases to validate new behavior
- Add new boilerplate converter validator utility to support this case
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive gs-olive merged commit ac007ce into main Sep 20, 2023
@gs-olive gs-olive deleted the get_clone_fixes branch September 20, 2023 18:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants