-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support tensor and optimizer serialization #6087
Conversation
wyg1997
commented
Aug 28, 2021
•
edited
Loading
edited
- 支持 Tensor/DType 的 pickle 序列化和反序列化
- Optimizer state_dict/load_state_dict 接口保存/加载训练状态
Optimizer 的训练状态可以通过: import pickle
with open("optim_state.pkl", "wb") as f:
pickle.dump(sgd.state_dict(), f) 来保存 |
这和 pytorch 是不是还不对齐。torch.save/load 是支持 pickle 的。 我们为了支持超大权重,是每个权重保存一个文件的,不像 pytorch 把整个 state_dict 整体序列化。所以支持 pickle 之后 save 逻辑可以改成:遍历传入的 dict,如果元素是 tensor,则走原来的逻辑(为了不引入风险),如果不是 tensor,则生成它的 .pkl 文件保存下来。 |
我们 flow.save/load 接口现在是文件夹保存/加载的状态,这两个接口对齐了,就可以直接 flow.save/load(optim.state_dict)了。但是这是不是和 lazy 的 checkpoint 用法不一致了? |
lazy 的 checkpoint 是指神马,现在保存网络的权重也是通过 flow.save(module.state_dict()) 保存的。 |
好的,我改一下 |
Note: 讨论之后发现对齐 flow.save/load 需要支持嵌套的 dict,对已有机制的改动量比较大,不适合现在改。所以保持这个 pr 目前的方式 |
Speed stats:
|