-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ZarrTrace #7540
base: main
Are you sure you want to change the base?
Add ZarrTrace #7540
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7540 +/- ##
==========================================
+ Coverage 92.83% 92.88% +0.05%
==========================================
Files 106 107 +1
Lines 17669 18048 +379
==========================================
+ Hits 16403 16764 +361
- Misses 1266 1284 +18
|
This is an important issue to keep track of when we'll eventually want to read the zarr store and create an |
pymc/backends/zarr.py
Outdated
_dtype = np.dtype(dtype) | ||
if np.issubdtype(_dtype, np.floating): | ||
return (np.nan, _dtype, None) | ||
elif np.issubdtype(_dtype, np.integer): | ||
return (-1_000_000, _dtype, None) | ||
elif np.issubdtype(_dtype, "bool"): | ||
return (False, _dtype, None) | ||
elif np.issubdtype(_dtype, "str"): | ||
return ("", _dtype, None) | ||
elif np.issubdtype(_dtype, "datetime64"): | ||
return (np.datetime64(0, "Y"), _dtype, None) | ||
elif np.issubdtype(_dtype, "timedelta64"): | ||
return (np.timedelta64(0, "Y"), _dtype, None) | ||
else: | ||
return (None, _dtype, numcodecs.Pickle()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?
If so, this seems especially perilous for bool
😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value
as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes a lot more sense now, thanks for the explanation!
In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec
, or make some little comment that the fill value is used for initialization?
Yes, therefore I would recommend not to use them for any new implementation.
Just to clarify:
It should be quite simple to implement a
Yes and No. I would say ArviZ is a first-class citizen, because
I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. First we must find answers to:
Protocol buffers I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it. And they are only for the constant metadata From the Python perspective this could also be done with Is that tight integration? The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:
|
Thanks @michaelosthege for the feedback!
I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.
The way I see this is that McBackend offers a signature to convert from a kind of storage (like
The key thing is that I added these groups to the zarr hierarchy, having them as
I decided to only focus on MCMC for now, and I'm trying to make
Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:
I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend. |
@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck. |
3206597
to
69bb2ac
Compare
No, I haven't. But I've made the chunksize customizable now via the Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished). |
By the way, I've added a |
1f6d646
to
413d724
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love this direction. I left a comment on ArviZ integration.
I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160
Also, anything on ArviZ side that can help with this let me know
Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw
dimensions in all its variables. the mass matrix could also go in there and a divergence_id
even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging
variable with chain, draw
dimension.
samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I absolutely love this! :D
pymc/backends/zarr.py
Outdated
def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] | ||
self.chain = chain | ||
|
||
def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not check the source, but I think zarr will write the whole chunk each time we set a draw here, even if that chunk is not full yet. If that is indeed the case, we should be able to speed this up a lot if draws_per_chunk is >1 if we buffer draws_per_chunk
draws, and the set values in one go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it would be a shame that zarr itself doesn't buffer the write until the chunk is filled.
Great! Yes, let's try to address those in other PRs.
I don't know if
I tried to stay close to how pymc is returning things right now. I agree that this could be greatly improved, but maybe we can do so in follow up iterations.
I'll try to add this. It doesn't seem to be difficult. |
0c5e73b
to
ee0a36d
Compare
I just did a little experiment with strace, to see what zarr does when we write a value to a chunk. I wrote a little script that creates the array, writes two chunks and then closes the file. With import zarr
import numpy as np
import sys
store = zarr.DirectoryStore("zarr-test.zarr")
#store = zarr.LRUStoreCache(store, max_size=2**28)
data = zarr.open(store)
foo = data.array("foo", np.zeros((0, 0, 0)), chunks=(10, 10, 1))
foo.resize((1000, 10, 1))
# Mark the position in the code to make it easier to find the correct part
print("start--", flush=True, file=sys.stderr)
foo[0, 0, 0] = 1.0
print("mid--", flush=True, file=sys.stderr)
foo[1, 0, 0] = 2.0
print("done--", flush=True, file=sys.stderr) The first write triggers this:
The second write triggers this:
So there's a lot going on for each write to the array. If I read this correctly, for the second store it actually reads the chunk from the disc, then modifies the chunk with the indexing update, writes the new chunk to a temporary file and then replaces that with the original chunk file. For the first write it skips reading in the chunk data, because there still is nothing there to read. So I think if we want to get good performance from this, we should try to combine writes, by buffering |
26049a8
to
48504d5
Compare
If the tests pass, the only steps left for this PR to be ready to merge are:
@ricardoV94, to ensure that the previous sampling state is consistent with the model and step methods that are fed to |
Not at all :) Why do you need to know whether the model changed for this specific backend? |
It's not for the backend. The backend works with what I've written so far. It stores all of the inference data information, and it also stores the step method's sampling state. Maybe I can leave this PR with that functionality only. What I also wanted to include was the ability for
I only need the hash utilities for starting sampling from a previously stored state. |
The use case sounds great, but why do we need to check for model consistency? That should be the responsibility of the user. I'm sure functional PPLs like blackjax don't try to guess if the density function changed during evals? |
The step methods have a reference to the tensor variable they are supposed to step. I wanted to make sure that if we're trying to set the state of a NUTS stepper, it should at least have been defined for the same variable. The problem is that the stored value won't be the exact same object, but a copy of it. That's why I want to check if it's hash representation matches. |
I don't think we should check, the step samplers should be a step function, given a logp/dlogp function and a previous state they propose the next value/state. They shouldn't need to know anything about a pymc model or pytensor variables? In theory they could work with densities provided by the user that doesn't even involve pymc, like nutpie does. |
Matching should be done by variable names like posterior predictive does. |
d605710
to
db33706
Compare
Can we keep this PR about adding Zarr trace only? And not sampling state/ resuming |
Re flaky test, @bwengals was surprised by the magnitude of the error and there is an open issue for it. If you're confident about there being no issue please close the associated issue. Perhaps as a separate PR to unblock other contributions, depending on how long this takes to get merged |
Put a fix in for the flaky GP test here #7567 |
Yes, the only thing that I still need to figure out is what needs to be changed in the current |
Your fix looks great. I was just trying out stuff to see if the CI would build. I'll drop the commit with the test patch. |
Description
This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).
To be honest, the current situation of the
MultiTrace
andNDArray
backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.McBackend
was an attempt to make things sane again. As far as I understand,McBackend
does support ways to dump samples to disk instead of holding them in memory using theClickHouse
database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:
xarray
can read zarr stores directly making it possible to writeInferenceData
objects to disk directly almost without even having to call a converter.object
dtyped arrays using thenumcodec
package. This makes it possible use the same store to hold sample stats warning objects and step methodssampling_state
in the same place as the actual samples from the posterior.Having stated all of these considerations I intend to:
ZarrTrace
integrate well withpymc.sample
Replace theMultiTrace
andNDArray
backend defaults with their Zarr counterpartsZarrTrace
ZarrChain.record
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/