Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ArrowTypes.jl dependency to serialize optimizers? #77

Open
ericphanson opened this issue May 20, 2022 · 15 comments
Open

Add ArrowTypes.jl dependency to serialize optimizers? #77

ericphanson opened this issue May 20, 2022 · 15 comments
Labels
enhancement New feature or request

Comments

@ericphanson
Copy link

ArrowTypes is a light package used for defining how to serialize objects to arrow format with Arrow.jl. Arrow is a heavy dependency that actually does the serialization.

We could add a few ArrowTypes definitions in order to serialize optimizers to Arrow. Ref beacon-biosignals/LegolasFlux.jl#17 (comment)

Would be interested to know if PRs for that would be accepted here.

@CarloLucibello
Copy link
Member

ArrowTypes.jl is super light indeed. Also Optimisers.jl is very light, but I don't see any problem in taking this new dependence since it seems useful to people and doing it somewhere else would be piracy.

@mcabbott
Copy link
Member

Can you sketch what this does, for people who don't know anything about Arrow?

At the moment calling destructure on the tree of states doesn't work, as Leaf is a leaf... but it could trivially be made to store the momentum etc. arrays as a flat vector. And slightly less trivially be made to store the learning rates etc. too. Would this serve the same purpose or are they quite different?

@ericphanson
Copy link
Author

ericphanson commented May 20, 2022

Can you sketch what this does, for people who don't know anything about Arrow?

The purpose is similar to StructTypes.jl if you are familiar with that. It is to define a couple methods to describe precisely how to serialize an object into the arrow format. For example: https://arrow.juliadata.org/dev/manual/#Custom-types.

Another way to say it is Arrow has some primitive types
https://arrow.apache.org/docs/format/Columnar.html?highlight=type#physical-memory-layout and ArrowTypes describes how to map Julia types to these Arrow primitive types, as well as what kind of metadata to attach. (This metadata is important; it's just something like "JuliaLang.Optimisers.ADAM" to indicate the thing it's attached to should be reconstructed back as an Optimisers.ADAM object. That string is not eval'd though; instead ArrowTypes methods are defined to do this matching for you, to map strings back to types (using Val types as you’d probably expect).

My interest in Arrow + Optimisers is for serializing optimisers out in order to e.g. restart training after a crash. Using ArrowTypes to map Julia types to Arrow types allows one to use Arrow.jl to write a complicated Julia object like the nested optimiser state into a vector of bytes, which can be saved somewhere to reload later.

At the moment calling destructure on the tree of states doesn't work, as Leaf is a leaf... but it could trivially be made to store the momentum etc. arrays as a flat vector. And slightly less trivially be made to store the learning rates etc. too. Would this serve the same purpose or are they quite different?

If there's a canonical way to map to a flat vector of primitive types and back, that would serve the purpose as well. As @ToucheSir has pointed out though, being able to serialize the nested structure has the advantage that the result is a bit more standalone and doesn't rely on having the code to reconstruct the nested structure from the flat vector.

@ToucheSir
Copy link
Member

On the topic of StructTypes, is there not some intermediate interface we can implement such that this functionality is not tied to the Arrow format? I'm envisioning an equivalent to Serde in Rust.

@ericphanson
Copy link
Author

We don’t; StructTypes is designed around JSON afaik and it not expressive enough in that it can’t capture the metadata that makes roundtripping nested objects work smoothly.

I don’t know how serde works but I imagine it’s challenging to have a very general intermediary. It would be nice though. Perhaps @quinnj has thought about it.

@quinnj
Copy link

quinnj commented May 21, 2022

Yeah, it's been low on my priority list to try and see if there's some way to evolve StructTypes.jl/ArrowTypes.jl so we don't need both. They overlap quite a bit, so it's unfortunate when you have certain cases that have to overload both to get both JSON and Arrow compat. I'll try to find some time next week to start sketching out a plan for the future. I'll try to take a look at serde as well and see if we can get some inspiration from there as well.

@ToucheSir
Copy link
Member

I found out today that LightBSON.jl uses StructTypes. @ancapdev how has that worked out for you?

@ancapdev
Copy link

I found out today that LightBSON.jl uses StructTypes. @ancapdev how has that worked out for you?

I haven't actually used the StructTypes API myself, it just seemed like a low effort compatibility layer to put in. My core use cases with LightBSON are high performance and/or long term persistence, wherein I don't want any surprises, no runtime type lookups / method dispatch, and full control over how to handle versioning. LightBSON has its own extension points for that, and it's a lot more procedural than the declarative approaches of StructTypes and ArrowTypes. If you write types that are naturally representable in BSON then it's trivial to read and write them, and if you don't, then you need to put in a bit of work. Fully generic serialization was never the goal, BSON.jl does that.

@femtomc
Copy link

femtomc commented Jun 3, 2022

@quinnj did you happen to do any thinking the last weeks or so?

@ToucheSir ToucheSir added the enhancement New feature or request label Nov 16, 2022
@ericphanson
Copy link
Author

One thought it that weak dependencies / package extensions might make adding arrowtypes definitions 0-cost here. That would be 1.9-only or need Requires.jl for pre-1.9 support.

@ToucheSir
Copy link
Member

I still feel there is a need for some plan that will generalize this to working with other serialization formats, but I would be fine with a PR adding a package extension.

@ericphanson
Copy link
Author

I took a look at it, but with mutable Leaf’s with meaningful object identity, I’m not sure there’s a reliable way to do it at the level of an individual object.

In LegolasFlux we just use fcollect on the model to grab the arrays and that works because fcollect only outputs each individual array once no matter how many times it shows up (since we fixed it, that is). So at the level of an entire model it can work.

But if we add serialization at the level of an individual Leaf, then we actually don’t want serializing the whole state to “just work” by Arrow recursing through the state, because that will serialize separate objects even if a leaf is repeated leading to possible correctness issues if at deserialization time they are out of sync.

So I think really what we want is a serialization api that serializes a whole state (possibly by fmaping or something) in a way that respect object identity, and not individual Leaf’s. Maybe it could work in two steps, an Optimisers-level function serializable that maps a state to something relocatable/independent of Julia session (object identity), and then maybe package-extensions or up-to-thr-user to further take that and actually serialize it. Then another function to undo serializable. That’s basically what MLJ does, iiuc.

Adding a level of indirection is also nice to prevent changing internals of eg Leaf from breaking deserialization.

@ToucheSir
Copy link
Member

So I think really what we want is a serialization api that serializes a whole state (possibly by fmaping or something) in a way that respect object identity, and not individual Leaf’s

I think both are required because Leaf stores a whole lot of useful data besides its parameters. A two-step process is almost certainly necessary though. Probably some kind of walk that replaces unique leaves with ordinal IDs, and a separate dict/list of leaves which can be indexed by ID. Either way, both the transformed state structure and the collection of leaves will have to implement some serialization interface, and it would be nice if said interface was generic.

@ericphanson
Copy link
Author

What other serialization backends would be useful BTW? LightBSON via StructTypes? Something else also?

@ToucheSir
Copy link
Member

Something BSON-ish, something HDF5-ish, maybe exploring newer DL-focused formats like https://github.com/huggingface/safetensors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants