-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add function that caches sampling results #277
base: main
Are you sure you want to change the base?
Conversation
ricardoV94
commented
Dec 4, 2023
•
edited
Loading
edited
5711f96
to
9f0d5c7
Compare
When would that be useful? |
When rerunning notebooks or any workflow with saving/loading of traces where you might still be tinkering with the model. You don't need to bother defining the names of the traces, or overriding old traces, since caching is automatically derived from the model and its data |
df085cc
to
5b20bdf
Compare
return name, props | ||
|
||
|
||
def hash_from_fg(fg: FunctionGraph) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe to pytensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too experimental for now
I usually rely on things like MLFlow for storing artifacts like this. |
I'm not familiar with MLflow, the idea here is that it pairs the saved traces to the exact model/sampling function (and arguments) that were used. Basically the model and the function kwargs are the cache key. Does this have any parallel to your workflow with MLflow? |
5b20bdf
to
334a211
Compare
name = str(obj) | ||
props = str(getattr(obj, "_props", lambda: {})()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name = str(obj) | |
props = str(getattr(obj, "_props", lambda: {})()) | |
name = str(obj) | |
if hasattr(obj, "_props"): | |
prop_dict = obj._prop_dict() | |
props = str( | |
{k: get_name_and_props(v) for k, v in prop_dict.items()} | |
) | |
else: | |
props = str({}) | |
name = str(obj) | |
props = str(getattr(obj, "_props", lambda: {})()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to make sure that potential recursion of _props is handled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_props are not recursive
if os.path.exists(file_path): | ||
os.remove(file_path) | ||
if not os.path.exists(dir): | ||
os.mkdir(dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.mkdir(dir) | |
os.makedirs(dir, exist_ok=True) |
I think that it's better to use os.makedirs
because it creates the intermediate directories if they are required.
az.to_netcdf(idata_out, file_path) | ||
|
||
# We save inferencedata separately and extend if needed | ||
if extend_inferencedata: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks weird to me. The sampling_fn
will go make up the hash. So if someone first calls sample
, and then sample_posterior_predictive
using extend, the second cache will include both the posterior
and posterior_predictive
groups; but the first cache will only include the posterior
group. I think that it's cleaner to never allow for extending the idata inplace, and force users to combine the different InferenceData
objects themselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait no, I don't see what the problem is. We only ever save new idatas coming out of the sampling_fn and these never extend the previous one.