-
Notifications
You must be signed in to change notification settings - Fork 141
Add Type Change Pass #2149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Type Change Pass #2149
Conversation
This change adds a new DaCE Pass that can replace one simple DaCe dtype with another. This change was successfully tested on a complex ICON sdfg. Here is a sample usage: ``` import dace import numpy as np from dace.transformation import pass_pipeline as ppl from dace.transformation.passes.type_change import TypeChange N = dace.symbol('N') @dace.program def simple(a: dace.float64[N], b: dace.float64[1]): for i in range(N): a[i] = a[i] * 2.0 + b[0] sdfg = simple.to_sdfg() tc = TypeChange(dace.float64, dace.float32) type_change_pipeline = ppl.Pipeline([tc]) print("Pipeline created") results = type_change_pipeline.apply_pass(sdfg, {}) print(results) # {'TypeChange': 6} A = np.ones(10, dtype=np.float32) B = np.ones(1, dtype=np.float64) sdfg(A, B, N=10) print(A) ```
I’m not sure this is a good candidate for a pass. Maybe if it checked for reductions and, e.g., transcendental ops, for higher precision it would be more appropriate. Otherwise, as it is right now it should be a utility function. Additionally, please avoid type(x) == y checks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments for code improvements. Additionally, I agree with the point raised by @tbennun
Thank you both for your comments
I'm happy to move this into a utility function then. Where is the correct place to put it / do you have a similar function I can mimic
Yes, good point. Will use |
This change adds a new DaCE Pass that can replace one simple DaCe dtype with another. This change was successfully tested on a complex ICON sdfg. Here is a sample usage: ``` import dace import numpy as np from dace.transformation.helpers import replace_sdfg_dtypes N = dace.symbol('N') @dace.program def simple(a: dace.float64[N], b: dace.float64[1]): for i in range(N): a[i] = a[i] * 2.0 + b[0] sdfg = simple.to_sdfg() results = replace_sdfg_dtypes(sdfg, dace.float64, dace.float32) print(results) # {'TypeChange': 6} A = np.ones(10, dtype=np.float32) B = np.ones(1, dtype=np.float64) sdfg(A, B, N=10) print(A) ```
65d50a1
to
ffa9c8d
Compare
Changes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks good to me, however the PR is missing unit tests. This is a new feature and as such is certainly not covered by existing tests. Please ensure there are at least 2-3 unit tests to check for correct type changing behavior w.r.t. the data types and places the utility function modifies.
Perfect, will do |
This change adds a new DaCE Pass that can replace one simple DaCe dtype with another. This change was successfully tested on a complex ICON sdfg. Here is a sample usage: