Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZarrTrace #7540

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Add ZarrTrace #7540

wants to merge 9 commits into from

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Oct 16, 2024

Description

This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).

To be honest, the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them. McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:

  1. xarray can read zarr stores directly making it possible to write InferenceData objects to disk directly almost without even having to call a converter.
  2. zarr works with hierarchically structured data. It's possible to store arrays for each variable inside of a group (e.g. posterior, observed_data) directly.
  3. zarr arrays handle numpy arrays nicely. Fixed sized binary data can be fit into zarr arrays seemlessly.
  4. zarr arrays also have the possibility of storing object dtyped arrays using the numcodec package. This makes it possible use the same store to hold sample stats warning objects and step methods sampling_state in the same place as the actual samples from the posterior.
  5. zarr hierarchies can use many different kinds of storage: they can be held in memory, saved as a directory structure, inside of a zip file, or even remotely on s3 buckets.
  6. It's also possible to write to the same zarr object concurrently from different processes or threads, as long as a synchronization object is provided.
  7. zarr also stores the data using a compressed binary representation. The actual compressor can be customized.
  8. zarr arrays are chunked. This means that they don't need to be loaded entirely onto memory, making it possible to leave a smaller memory footprint while sampling. Another benefit of chunking is that write operations on different chunks should be completely independent from each other.

Having stated all of these considerations I intend to:

  • Build a zarr trace backend
  • Have ZarrTrace integrate well with pymc.sample
  • WONT DO NOW Replace the MultiTrace and NDArray backend defaults with their Zarr counterparts
  • Handle sampling state information in the zarr backend
  • Document ZarrTrace
  • Buffer write operations in ZarrChain.record
  • Record sampling state information periodically during sampling
  • Make it possible to load the zarr trace backend and resume sampling from it.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/

@lucianopaz lucianopaz added enhancements trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Oct 16, 2024
Copy link

codecov bot commented Oct 16, 2024

Codecov Report

Attention: Patch coverage is 95.09044% with 19 lines in your changes missing coverage. Please review.

Project coverage is 92.88%. Comparing base (a507ea8) to head (196f668).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pymc/backends/zarr.py 96.81% 9 Missing ⚠️
pymc/sampling/parallel.py 84.37% 5 Missing ⚠️
pymc/sampling/mcmc.py 92.00% 4 Missing ⚠️
pymc/step_methods/state.py 91.66% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7540      +/-   ##
==========================================
+ Coverage   92.83%   92.88%   +0.05%     
==========================================
  Files         106      107       +1     
  Lines       17669    18048     +379     
==========================================
+ Hits        16403    16764     +361     
- Misses       1266     1284      +18     
Files with missing lines Coverage Δ
pymc/backends/__init__.py 92.50% <100.00%> (+0.83%) ⬆️
pymc/step_methods/compound.py 97.58% <100.00%> (+0.02%) ⬆️
pymc/step_methods/hmc/quadpotential.py 84.63% <100.00%> (ø)
pymc/util.py 81.98% <100.00%> (+0.13%) ⬆️
pymc/step_methods/state.py 96.77% <91.66%> (-1.34%) ⬇️
pymc/sampling/mcmc.py 87.80% <92.00%> (+0.57%) ⬆️
pymc/sampling/parallel.py 88.08% <84.37%> (-0.42%) ⬇️
pymc/backends/zarr.py 96.81% <96.81%> (ø)

@lucianopaz
Copy link
Contributor Author

This is an important issue to keep track of when we'll eventually want to read the zarr store and create an InferenceData object using xarray and arviz

Comment on lines 95 to 109
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0, "Y"), _dtype, None)
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta64(0, "Y"), _dtype, None)
else:
return (None, _dtype, numcodecs.Pickle())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?

If so, this seems especially perilous for bool 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes a lot more sense now, thanks for the explanation!

In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec, or make some little comment that the fill value is used for initialization?

@michaelosthege
Copy link
Member

the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.

Yes, therefore I would recommend not to use them for any new implementation.

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

  • NumPyBackend is the go-to for in memory situations
  • ClickHouseBackend is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend!
I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.
There are two things which McBackend does not integrate tightly:

  1. xarray because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)
  2. InferenceData groups other than .posterior

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default.
I see that you added "the other" groups as properties to the ZarrTrace. Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

  • How do prior/posterior/log_likelihood data points arrive?
  • Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?
  • If yes, should the storage backend even make a difference between MCMC and forward sampling?

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language?
The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Is that tight integration?

The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:

  1. Determine it before starting the costly MCMC
  2. Serialize it to its own "blob" of data. Think of {constant_data, observed_data, coords, names dtypes, ...} as the "header" section of a trace.

@lucianopaz
Copy link
Contributor Author

Thanks @michaelosthege for the feedback!

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

* `NumPyBackend` is the go-to for _in memory_ situations

* `ClickHouseBackend` is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend! I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.

The way I see this is that McBackend offers a signature to convert from a kind of storage (like MultiTrace) into another one (arviz.InferenceData). I understand that with this method, you guarantee that there should always be a method to go from an McBackend Run to arviz.InferenceData, but you have to handle a lot of transformation logic in this conversion (just like the extra conversion logic that's already in pymc.backends.arviz). In my opinion, this isn't tight integration. Having the data stored in native zarr makes it possible to generate xarray.Dataset objects with a simple xr.open_zarr(store, group) calls, and then these can be wrapped into an InferenceData object with a simple InferenceData(posterior=zarr_posterior, ...) (and potentially even into the future DataTree objects, since zarr hierarchies are already very much tree-like).

There are two things which McBackend does not integrate tightly:

1. `xarray` because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)

xarray does not support sparse arrays. At the moment, the posterior samples are initialized as "empty" zarr arrays (in practice, filled arrays with a fill_value). The nice thing about zarr arrays is that these filled, uninitialized places, don't take up almost any space because they aren't actually stored. If queried, their value gets set from the fill_value attribute. xarray still needs to figure out pydata/xarray#5475 though.

2. `InferenceData` groups other than `.posterior`

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. I see that you added "the other" groups as properties to the ZarrTrace.

The key thing is that I added these groups to the zarr hierarchy, having them as ZarrTrace properties is not necessary. By having them in a single shared zarr entity, they are stored almost like an InferenceData from zarr. I need to actually check if arviz has a from_zarr method, because that would be the direct conversion method from a ZarrTrace to an InferenceData object without having to add any extra conversion code.

Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

* How do prior/posterior/log_likelihood data points arrive?

* Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?

* If yes, should the storage backend even make a difference between MCMC and forward sampling?

I decided to only focus on MCMC for now, and I'm trying to make ZarrTrace handle concurrent writes to the zarr store from multiple processes during sampling. Having said that, it's almost effortless to add other groups to a zarr hierarchy, and the same store could house prior, prior_predictive, posterior_predictive and predictions as well without having to handle almost any extra logic.

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language? The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with object dtype. At the moment, this is limited to two things:

  1. SamplerWarning
  2. StepMethodState

The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since SamplerWarning is a dataclass, it could potentially be converted into a dictionary and then represented as a json object, which zarr can serialize without problems.

Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:

  • Offloading the maintenance cost of the storage backend code
  • Growing set of features that will become available to us as time goes by
  • Seamless compression of the arrays to save storage space
  • Integration with multiple on disk storage options that range from directory structure, zipfiles and multiple SQL and no-SQL databases
  • Integration with distributed or cloud storage like S3, Hadoop, Google Cloud Storage and Azure storage blob, and also fsspec.

I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend.

@maresb
Copy link
Contributor

maresb commented Oct 17, 2024

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

@lucianopaz
Copy link
Contributor Author

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

No, I haven't. But I've made the chunksize customizable now via the draws_per_chunk parameter. @aseyboldt said that we could try to use a different chunk size depending on the dimensionality of the RV.

Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished).

@lucianopaz
Copy link
Contributor Author

By the way, I've added a to_inferencedata method to the zarr trace. I had to do it because I wanted to ensure that the zarr store had consolidated metadata (if it didn't, xarray would complain) and because I needed to pass mask_and_scale=False to xarray.open_zarr (which arviz doesn't allow in from_zarr). Anyway, you can see for yourselves that the conversion code is extremely short because the stored data is already aligned with what arviz wants.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this direction. I left a comment on ArviZ integration.

I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160

Also, anything on ArviZ side that can help with this let me know


Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw dimensions in all its variables. the mass matrix could also go in there and a divergence_id even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging variable with chain, draw dimension.

samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8

pymc/backends/zarr.py Outdated Show resolved Hide resolved
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely love this! :D

def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override]
self.chain = chain

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not check the source, but I think zarr will write the whole chunk each time we set a draw here, even if that chunk is not full yet. If that is indeed the case, we should be able to speed this up a lot if draws_per_chunk is >1 if we buffer draws_per_chunk draws, and the set values in one go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it would be a shame that zarr itself doesn't buffer the write until the chunk is filled.

@lucianopaz
Copy link
Contributor Author

I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160

Great! Yes, let's try to address those in other PRs.

Also, anything on ArviZ side that can help with this let me know

I don't know if from_zarr might have to be updated a bit? How did you handle the fill_value from zarr? I ran into problems on my side and had to add mask_and_scale=False when I opened the group.

Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw dimensions in all its variables. the mass matrix could also go in there and a divergence_id even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging variable with chain, draw dimension.

I tried to stay close to how pymc is returning things right now. I agree that this could be greatly improved, but maybe we can do so in follow up iterations.

samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8

I'll try to add this. It doesn't seem to be difficult.

@aseyboldt
Copy link
Member

I just did a little experiment with strace, to see what zarr does when we write a value to a chunk.

I wrote a little script that creates the array, writes two chunks and then closes the file. With strace python the-script.py we can see all the syscalls it uses in between

import zarr
import numpy as np
import sys

store = zarr.DirectoryStore("zarr-test.zarr")
#store = zarr.LRUStoreCache(store, max_size=2**28)
data = zarr.open(store)
foo = data.array("foo", np.zeros((0, 0, 0)), chunks=(10, 10, 1))
foo.resize((1000, 10, 1))

# Mark the position in the code to make it easier to find the correct part
print("start--", flush=True, file=sys.stderr)

foo[0, 0, 0] = 1.0
print("mid--", flush=True, file=sys.stderr)
foo[1, 0, 0] = 2.0

print("done--", flush=True, file=sys.stderr)

The first write triggers this:

write(2, "start--\n", 8start--
)                = 8
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", 0x7ffd56fa6d50) = -1 ENOENT (No such file or directory)
getpid()                                = 414823
futex(0x7b79166b0a88, FUTEX_WAKE_PRIVATE, 2147483647) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", 0x7ffd56fa6cd0) = -1 ENOENT (No such file or directory)
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=14, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=14, ...}) = 0
fstat(3, {st_mode=S_IFCHR|0666, st_rdev=makedev(0x1, 0x9), ...}) = 0
read(3, "\231;\177Y\rp\267\323W?\16\212n\372\346\372", 16) = 16
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, 0666) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=0, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c10)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
write(4, "\2\0011\10 \3\0\0 \3\0\0/\0\0\0\24\0\0\0\27\0\0\0\37\0\1\0\377\377F\37"..., 47) = 47
close(4)                                = 0
rename("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0") = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", 0x7ffd56fa6d30) = -1 ENOENT (No such file or directory)

The second write triggers this:

write(2, "mid--\n", 6mid--
)                  = 6
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", O_RDONLY|O_CLOEXEC) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c30)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
lseek(4, 0, SEEK_CUR)                   = 0
fstat(4, {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
read(4, "\2\0011\10 \3\0\0 \3\0\0/\0\0\0\24\0\0\0\27\0\0\0\37\0\1\0\377\377F\37"..., 48) = 47
read(4, "", 1)                          = 0
close(4)                                = 0
getpid()                                = 414823
getpid()                                = 414823
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=24, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=24, ...}) = 0
fstat(3, {st_mode=S_IFCHR|0666, st_rdev=makedev(0x1, 0x9), ...}) = 0
read(3, "q\313\267\231t\364\276E\235\223J4\354}\372\240", 16) = 16
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, 0666) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=0, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c10)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
write(4, "\2\0011\10 \3\0\0 \3\0\0006\0\0\0\24\0\0\0\36\0\0\0\37\0\1\0\377\377F\37"..., 54) = 54
close(4)                                = 0
rename("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0") = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", 0x7ffd56fa6d30) = -1 ENOENT (No such file or directory)
write(2, "done--\n", 7done--
)                 = 7

So there's a lot going on for each write to the array. If I read this correctly, for the second store it actually reads the chunk from the disc, then modifies the chunk with the indexing update, writes the new chunk to a temporary file and then replaces that with the original chunk file.

For the first write it skips reading in the chunk data, because there still is nothing there to read.

So I think if we want to get good performance from this, we should try to combine writes, by buffering draws_per_chunk items, and then writing them in one go.

@lucianopaz
Copy link
Contributor Author

If the tests pass, the only steps left for this PR to be ready to merge are:

  • Documentation
  • Start sampling from a previous run's sampling state

@ricardoV94, to ensure that the previous sampling state is consistent with the model and step methods that are fed to pymc.sample, I might need to grab these hash utilities. Since I can't build the model nor the step methods from the sampling state, I need to add a check for consistency (e.g. if data changed, if the step methods are not applied to the same variables, if the variable's computation graph are different, an informative error should be thrown). I guess I would like to ask if you think that those hash utilities are in a good enough stage to go into one of the main repos?

@ricardoV94
Copy link
Member

I guess I would like to ask if you think that those hash utilities are in a good enough stage to go into one of the main repos?

Not at all :)

Why do you need to know whether the model changed for this specific backend?

@lucianopaz
Copy link
Contributor Author

Why do you need to know whether the model changed for this specific backend?

It's not for the backend. The backend works with what I've written so far. It stores all of the inference data information, and it also stores the step method's sampling state. Maybe I can leave this PR with that functionality only. What I also wanted to include was the ability for pm.sample to grab a ZarrTrace and set the step method's sampling state to whatever value was stored. That would allow:

  1. the ZarrTrace to act as a checkpoint file during sampling (if the machine that is running pm.sample crashes, it might be possible to resume sampling from the partially complete ZarrTrace file)
  2. to run the tuning phase somewhere and then re-use those tuning results to start drawing samples later (I don't think this is good practice, but I've heard people ask for it).

I only need the hash utilities for starting sampling from a previously stored state.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 7, 2024

  • o act as a checkpoint file during sampling (if the machine that is running pm.sample crashes, it might be possible to resume sampling from the partially complete ZarrTrace file)

  • to run the tuning phase somewhere and then re-use those tuning results to start drawing samples later (I don't think this is good practice, but I've heard people ask for it).

The use case sounds great, but why do we need to check for model consistency? That should be the responsibility of the user. I'm sure functional PPLs like blackjax don't try to guess if the density function changed during evals?

@lucianopaz
Copy link
Contributor Author

  • o act as a checkpoint file during sampling (if the machine that is running pm.sample crashes, it might be possible to resume sampling from the partially complete ZarrTrace file)
  • to run the tuning phase somewhere and then re-use those tuning results to start drawing samples later (I don't think this is good practice, but I've heard people ask for it).

The use case sounds great, but why do we need to check for model consistency? That should be the responsibility of the user. I'm sure functional PPLs like blackjax don't try to guess if the density function changed during evals?

The step methods have a reference to the tensor variable they are supposed to step. I wanted to make sure that if we're trying to set the state of a NUTS stepper, it should at least have been defined for the same variable. The problem is that the stored value won't be the exact same object, but a copy of it. That's why I want to check if it's hash representation matches.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 7, 2024

I don't think we should check, the step samplers should be a step function, given a logp/dlogp function and a previous state they propose the next value/state.

They shouldn't need to know anything about a pymc model or pytensor variables? In theory they could work with densities provided by the user that doesn't even involve pymc, like nutpie does.

@ricardoV94
Copy link
Member

Matching should be done by variable names like posterior predictive does.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 10, 2024

Can we keep this PR about adding Zarr trace only?

And not sampling state/ resuming

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 10, 2024

Re flaky test, @bwengals was surprised by the magnitude of the error and there is an open issue for it.

If you're confident about there being no issue please close the associated issue. Perhaps as a separate PR to unblock other contributions, depending on how long this takes to get merged

@bwengals
Copy link
Contributor

Put a fix in for the flaky GP test here #7567

@lucianopaz
Copy link
Contributor Author

Can we keep this PR about adding Zarr trace only?

And not sampling state/ resuming

Yes, the only thing that I still need to figure out is what needs to be changed in the current ZarrTrace and ZarrChain classes to make sure that the sampling state/resuming will be workable without a big refactor of these two classes.

@lucianopaz
Copy link
Contributor Author

Put a fix in for the flaky GP test here #7567

Your fix looks great. I was just trying out stuff to see if the CI would build. I'll drop the commit with the test patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants