Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZarrTrace and fix non-determinism with Generators in sample #7540

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Oct 16, 2024

Description

This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).

To be honest, the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them. McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:

  1. xarray can read zarr stores directly making it possible to write InferenceData objects to disk directly almost without even having to call a converter.
  2. zarr works with hierarchically structured data. It's possible to store arrays for each variable inside of a group (e.g. posterior, observed_data) directly.
  3. zarr arrays handle numpy arrays nicely. Fixed sized binary data can be fit into zarr arrays seemlessly.
  4. zarr arrays also have the possibility of storing object dtyped arrays using the numcodec package. This makes it possible use the same store to hold sample stats warning objects and step methods sampling_state in the same place as the actual samples from the posterior.
  5. zarr hierarchies can use many different kinds of storage: they can be held in memory, saved as a directory structure, inside of a zip file, or even remotely on s3 buckets.
  6. It's also possible to write to the same zarr object concurrently from different processes or threads, as long as a synchronization object is provided.
  7. zarr also stores the data using a compressed binary representation. The actual compressor can be customized.
  8. zarr arrays are chunked. This means that they don't need to be loaded entirely onto memory, making it possible to leave a smaller memory footprint while sampling. Another benefit of chunking is that write operations on different chunks should be completely independent from each other.

Having stated all of these considerations I intend to:

  • Build a zarr trace backend
  • Have ZarrTrace integrate well with pymc.sample
  • WONT DO NOW Replace the MultiTrace and NDArray backend defaults with their Zarr counterparts
  • Handle sampling state information in the zarr backend
  • Document ZarrTrace
  • Buffer write operations in ZarrChain.record
  • Record sampling state information periodically during sampling
  • NEED TO DO IN SEPARATE PR Make it possible to load the zarr trace backend and resume sampling from it.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/

@lucianopaz lucianopaz added enhancements trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Oct 16, 2024
Copy link

codecov bot commented Oct 16, 2024

Codecov Report

Attention: Patch coverage is 92.61745% with 33 lines in your changes missing coverage. Please review.

Project coverage is 92.83%. Comparing base (5d51953) to head (482bb9f).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pymc/backends/zarr.py 95.66% 13 Missing ⚠️
pymc/sampling/population.py 33.33% 10 Missing ⚠️
pymc/sampling/parallel.py 78.78% 7 Missing ⚠️
pymc/sampling/mcmc.py 95.91% 2 Missing ⚠️
pymc/step_methods/state.py 94.11% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7540      +/-   ##
==========================================
- Coverage   92.83%   92.83%   -0.01%     
==========================================
  Files         106      107       +1     
  Lines       17748    18177     +429     
==========================================
+ Hits        16477    16875     +398     
- Misses       1271     1302      +31     
Files with missing lines Coverage Δ
pymc/backends/__init__.py 92.68% <100.00%> (+0.79%) ⬆️
pymc/model/core.py 92.89% <100.00%> (+0.04%) ⬆️
pymc/step_methods/compound.py 97.58% <100.00%> (+0.02%) ⬆️
pymc/step_methods/hmc/quadpotential.py 84.63% <100.00%> (ø)
pymc/util.py 82.86% <100.00%> (+1.01%) ⬆️
pymc/step_methods/state.py 95.52% <94.11%> (-2.60%) ⬇️
pymc/sampling/mcmc.py 87.45% <95.91%> (+1.32%) ⬆️
pymc/sampling/parallel.py 87.61% <78.78%> (-1.12%) ⬇️
pymc/sampling/population.py 70.83% <33.33%> (-3.85%) ⬇️
pymc/backends/zarr.py 95.66% <95.66%> (ø)

@lucianopaz
Copy link
Contributor Author

This is an important issue to keep track of when we'll eventually want to read the zarr store and create an InferenceData object using xarray and arviz

Comment on lines 95 to 109
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0, "Y"), _dtype, None)
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta64(0, "Y"), _dtype, None)
else:
return (None, _dtype, numcodecs.Pickle())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?

If so, this seems especially perilous for bool 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes a lot more sense now, thanks for the explanation!

In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec, or make some little comment that the fill value is used for initialization?

@michaelosthege
Copy link
Member

the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.

Yes, therefore I would recommend not to use them for any new implementation.

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

  • NumPyBackend is the go-to for in memory situations
  • ClickHouseBackend is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend!
I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.
There are two things which McBackend does not integrate tightly:

  1. xarray because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)
  2. InferenceData groups other than .posterior

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default.
I see that you added "the other" groups as properties to the ZarrTrace. Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

  • How do prior/posterior/log_likelihood data points arrive?
  • Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?
  • If yes, should the storage backend even make a difference between MCMC and forward sampling?

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language?
The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Is that tight integration?

The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:

  1. Determine it before starting the costly MCMC
  2. Serialize it to its own "blob" of data. Think of {constant_data, observed_data, coords, names dtypes, ...} as the "header" section of a trace.

@lucianopaz
Copy link
Contributor Author

Thanks @michaelosthege for the feedback!

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

* `NumPyBackend` is the go-to for _in memory_ situations

* `ClickHouseBackend` is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend! I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.

The way I see this is that McBackend offers a signature to convert from a kind of storage (like MultiTrace) into another one (arviz.InferenceData). I understand that with this method, you guarantee that there should always be a method to go from an McBackend Run to arviz.InferenceData, but you have to handle a lot of transformation logic in this conversion (just like the extra conversion logic that's already in pymc.backends.arviz). In my opinion, this isn't tight integration. Having the data stored in native zarr makes it possible to generate xarray.Dataset objects with a simple xr.open_zarr(store, group) calls, and then these can be wrapped into an InferenceData object with a simple InferenceData(posterior=zarr_posterior, ...) (and potentially even into the future DataTree objects, since zarr hierarchies are already very much tree-like).

There are two things which McBackend does not integrate tightly:

1. `xarray` because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)

xarray does not support sparse arrays. At the moment, the posterior samples are initialized as "empty" zarr arrays (in practice, filled arrays with a fill_value). The nice thing about zarr arrays is that these filled, uninitialized places, don't take up almost any space because they aren't actually stored. If queried, their value gets set from the fill_value attribute. xarray still needs to figure out pydata/xarray#5475 though.

2. `InferenceData` groups other than `.posterior`

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. I see that you added "the other" groups as properties to the ZarrTrace.

The key thing is that I added these groups to the zarr hierarchy, having them as ZarrTrace properties is not necessary. By having them in a single shared zarr entity, they are stored almost like an InferenceData from zarr. I need to actually check if arviz has a from_zarr method, because that would be the direct conversion method from a ZarrTrace to an InferenceData object without having to add any extra conversion code.

Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

* How do prior/posterior/log_likelihood data points arrive?

* Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?

* If yes, should the storage backend even make a difference between MCMC and forward sampling?

I decided to only focus on MCMC for now, and I'm trying to make ZarrTrace handle concurrent writes to the zarr store from multiple processes during sampling. Having said that, it's almost effortless to add other groups to a zarr hierarchy, and the same store could house prior, prior_predictive, posterior_predictive and predictions as well without having to handle almost any extra logic.

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language? The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with object dtype. At the moment, this is limited to two things:

  1. SamplerWarning
  2. StepMethodState

The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since SamplerWarning is a dataclass, it could potentially be converted into a dictionary and then represented as a json object, which zarr can serialize without problems.

Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:

  • Offloading the maintenance cost of the storage backend code
  • Growing set of features that will become available to us as time goes by
  • Seamless compression of the arrays to save storage space
  • Integration with multiple on disk storage options that range from directory structure, zipfiles and multiple SQL and no-SQL databases
  • Integration with distributed or cloud storage like S3, Hadoop, Google Cloud Storage and Azure storage blob, and also fsspec.

I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend.

@maresb
Copy link
Contributor

maresb commented Oct 17, 2024

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

@lucianopaz lucianopaz force-pushed the zarr branch 2 times, most recently from 3206597 to 69bb2ac Compare October 17, 2024 12:54
@lucianopaz
Copy link
Contributor Author

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

No, I haven't. But I've made the chunksize customizable now via the draws_per_chunk parameter. @aseyboldt said that we could try to use a different chunk size depending on the dimensionality of the RV.

Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished).

@lucianopaz
Copy link
Contributor Author

By the way, I've added a to_inferencedata method to the zarr trace. I had to do it because I wanted to ensure that the zarr store had consolidated metadata (if it didn't, xarray would complain) and because I needed to pass mask_and_scale=False to xarray.open_zarr (which arviz doesn't allow in from_zarr). Anyway, you can see for yourselves that the conversion code is extremely short because the stored data is already aligned with what arviz wants.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this direction. I left a comment on ArviZ integration.

I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160

Also, anything on ArviZ side that can help with this let me know


Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw dimensions in all its variables. the mass matrix could also go in there and a divergence_id even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging variable with chain, draw dimension.

samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8

pymc/backends/zarr.py Outdated Show resolved Hide resolved
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely love this! :D

def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override]
self.chain = chain

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not check the source, but I think zarr will write the whole chunk each time we set a draw here, even if that chunk is not full yet. If that is indeed the case, we should be able to speed this up a lot if draws_per_chunk is >1 if we buffer draws_per_chunk draws, and the set values in one go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it would be a shame that zarr itself doesn't buffer the write until the chunk is filled.

@lucianopaz
Copy link
Contributor Author

I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160

Great! Yes, let's try to address those in other PRs.

Also, anything on ArviZ side that can help with this let me know

I don't know if from_zarr might have to be updated a bit? How did you handle the fill_value from zarr? I ran into problems on my side and had to add mask_and_scale=False when I opened the group.

Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw dimensions in all its variables. the mass matrix could also go in there and a divergence_id even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging variable with chain, draw dimension.

I tried to stay close to how pymc is returning things right now. I agree that this could be greatly improved, but maybe we can do so in follow up iterations.

samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8

I'll try to add this. It doesn't seem to be difficult.

@lucianopaz lucianopaz force-pushed the zarr branch 2 times, most recently from 0c5e73b to ee0a36d Compare October 23, 2024 09:25
@aseyboldt
Copy link
Member

I just did a little experiment with strace, to see what zarr does when we write a value to a chunk.

I wrote a little script that creates the array, writes two chunks and then closes the file. With strace python the-script.py we can see all the syscalls it uses in between

import zarr
import numpy as np
import sys

store = zarr.DirectoryStore("zarr-test.zarr")
#store = zarr.LRUStoreCache(store, max_size=2**28)
data = zarr.open(store)
foo = data.array("foo", np.zeros((0, 0, 0)), chunks=(10, 10, 1))
foo.resize((1000, 10, 1))

# Mark the position in the code to make it easier to find the correct part
print("start--", flush=True, file=sys.stderr)

foo[0, 0, 0] = 1.0
print("mid--", flush=True, file=sys.stderr)
foo[1, 0, 0] = 2.0

print("done--", flush=True, file=sys.stderr)

The first write triggers this:

write(2, "start--\n", 8start--
)                = 8
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", 0x7ffd56fa6d50) = -1 ENOENT (No such file or directory)
getpid()                                = 414823
futex(0x7b79166b0a88, FUTEX_WAKE_PRIVATE, 2147483647) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", 0x7ffd56fa6cd0) = -1 ENOENT (No such file or directory)
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=14, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=14, ...}) = 0
fstat(3, {st_mode=S_IFCHR|0666, st_rdev=makedev(0x1, 0x9), ...}) = 0
read(3, "\231;\177Y\rp\267\323W?\16\212n\372\346\372", 16) = 16
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, 0666) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=0, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c10)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
write(4, "\2\0011\10 \3\0\0 \3\0\0/\0\0\0\24\0\0\0\27\0\0\0\37\0\1\0\377\377F\37"..., 47) = 47
close(4)                                = 0
rename("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0") = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", 0x7ffd56fa6d30) = -1 ENOENT (No such file or directory)

The second write triggers this:

write(2, "mid--\n", 6mid--
)                  = 6
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", O_RDONLY|O_CLOEXEC) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c30)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
lseek(4, 0, SEEK_CUR)                   = 0
fstat(4, {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
read(4, "\2\0011\10 \3\0\0 \3\0\0/\0\0\0\24\0\0\0\27\0\0\0\37\0\1\0\377\377F\37"..., 48) = 47
read(4, "", 1)                          = 0
close(4)                                = 0
getpid()                                = 414823
getpid()                                = 414823
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=24, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=24, ...}) = 0
fstat(3, {st_mode=S_IFCHR|0666, st_rdev=makedev(0x1, 0x9), ...}) = 0
read(3, "q\313\267\231t\364\276E\235\223J4\354}\372\240", 16) = 16
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, 0666) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=0, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c10)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
write(4, "\2\0011\10 \3\0\0 \3\0\0006\0\0\0\24\0\0\0\36\0\0\0\37\0\1\0\377\377F\37"..., 54) = 54
close(4)                                = 0
rename("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0") = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", 0x7ffd56fa6d30) = -1 ENOENT (No such file or directory)
write(2, "done--\n", 7done--
)                 = 7

So there's a lot going on for each write to the array. If I read this correctly, for the second store it actually reads the chunk from the disc, then modifies the chunk with the indexing update, writes the new chunk to a temporary file and then replaces that with the original chunk file.

For the first write it skips reading in the chunk data, because there still is nothing there to read.

So I think if we want to get good performance from this, we should try to combine writes, by buffering draws_per_chunk items, and then writing them in one go.

pymc/sampling/mcmc.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

@lucianopaz PR looks good. Left some small comments

@lucianopaz
Copy link
Contributor Author

lucianopaz commented Dec 21, 2024

@ricardoV94, I addressed the last couple of comments. Can you check if you think this is now ready to merge? Once this is merged, we could check whether #7612 is still an issue or not. I actually went ahead and checked the example from the issue and it works fine now.

@@ -1611,6 +1611,7 @@ def compile_fn(
inputs: Sequence[Variable] | None = None,
mode=None,
point_fn: bool = True,
borrow_vars: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why an option and with this default?

It shouldn't need to be an option. The point is that if input=output there's no need for a deepcopy. Step samplers always copy the point before mutating it, it's not supposed to be mutated inplace

Also, no reason to treat any differently than NDarray backend

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah this is happening inside Model. I didn't see the file header...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these changes, as the BaseTrace class is the one compiling the function and it doesn't go through here. So no code duplication. Was there one in a previous version when I left my comment?

I would revert these changes in model core.py. They are not being used. The BaseTrace sidesteps Model altogether to create the point function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this wrong then? I left it as False by default to keep backwards compatibility.

pymc/model/core.py Show resolved Hide resolved
draws_per_chain = total_draws_per_chain - tuning_steps_per_chain

total_n_tune = tuning_steps_per_chain.sum()
total_draws = draws_per_chain.sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this. I think what MultiTrace is doing is more sensible. You won't get junk draws if you interrupt earlier. Here you'll get zeros and false for integers and booleans, for chains that are incomplete.

requirements.txt Outdated
@@ -8,3 +8,4 @@ rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
typing-extensions>=3.7.4
zarr>=2.5.0,<3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you make it optional?

@ricardoV94
Copy link
Member

@lucianopaz can you add a regression test for #7612 ?

@ricardoV94 ricardoV94 changed the title Add ZarrTrace Add ZarrTrace and fix non-determinism with Generators in sample Dec 21, 2024
pymc/sampling/mcmc.py Outdated Show resolved Hide resolved
Comment on lines 456 to 462
self.fn = model.compile_fn(
self.vars,
inputs=model.value_vars,
on_unused_input="ignore",
borrow_vars=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I mess it up. Doesn't BaseTrace.__init__ already create the fn you need?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All my comments about borrow and trust_input were based on this assumption. Note the base class does not call model.compile_fn, it sidesteps it completely, hence why I was surprised you were changing it previously

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the BaseTrace for each chain (ZarrChain) is initialized in ZarrTrace. Now I'm sending off the fn that is compiled in ZarrTrace over to every chain instance to use it without having to recompile it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do that for MultiTrace as well. I had that in mind as a follow up in #7578 where a lot of the wall time was just avoiding needless function compilations.

My early comment still stands you shouldn't call model.compile_fn and you should (unless you don't want the speedup for some reason) have trust_input and borrow inputs and outputs to avoid deep copies needlessly. AFAICT the same logic should be applicable to the new and old traces.

Copy link
Member

@ricardoV94 ricardoV94 Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the reason I didn't do it for multiprocessing yet is that we have to be careful about point functions with RNGs (say with stochastic Deterministics). We actually need to copy the function and provide new shared RNGs for each function. If you're sharing self.fn across multiple chains (and they are called in different processes) the same concern may apply here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a related issue: #7588

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't copy the function and do the RNG swap it's better not to share for now, as it may mess up the RNG updates and you get crappy random draws.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the problem that the random generators inside self.fn are shared across chains? What I mean is that, is the problem that when different chains call the same fn, the random generator state will advance mixing up the different chains? I can see that this could happen in the context of a single process that runs multiple chains, either in threads or sequentially. But I don't see this happening in different processes.
Or is the problem that we don't manage to control the random generator state that gets called in fn as we do with step methods? If this is the real problem, we could address it using the RandomGeneratorState, and random_generator_from_state and get_state_from_generator to create copies of random generators that adhere to whatever state was inputed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I think that #7588 really captures the core problem with what might be going on, so I don't think that it's worth opening a different issue. This PR doesn't close the issue though, but I don't think that it's intention should be to do so. I have an idea to try and tackle 7588 using the RandomGeneratorState stuff that I developed here.

@lucianopaz lucianopaz force-pushed the zarr branch 2 times, most recently from 069ae75 to 9f4c013 Compare December 23, 2024 06:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug enhancements major Include in major changes release notes section samplers trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Sampling with generators as seeding no longer deterministic
7 participants