-
Notifications
You must be signed in to change notification settings - Fork 365
feat: support aten.trunc dynamo converter #2543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
return impl.elementwise.trunc_div( | ||
ctx, target, source_ir, f"{name}_trunc", input_val, dividend | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the documentation of torch.trunc
, it seems that the behavior differs based on the data type. Specifically, for float
inputs, the output is also a float, and for int
inputs, the output is also an int. Does this line up with the behavior of trunc_div
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, could a line be added such as
if dtype not in (float16, float32):
return input_val
This can help to avoid unnecessary layer insertion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out! I added the code below since TRT engine inputs cannot be TRT engine outputs:
if input_val.dtype not in (trt.float16, trt.float32):
return impl.cast.to_copy(
ctx,
target,
source_ir,
f"{name}_copy",
input_val,
input_val.dtype,
force_layer=True,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Description
Support
aten.trunc
dynamo converterFixes #2536
Type of change
Checklist: