Skip to content

Commit 4e25392

Browse files
authored
Fix serialization doc display (#525)
Summary: att Test Plan: checkout rendered doc Reviewers: Subscribers: Tasks: Tags:
1 parent c61aa44 commit 4e25392

File tree

2 files changed

+85
-91
lines changed

2 files changed

+85
-91
lines changed

docs/source/index.rst

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,6 @@ with more content coming soon.
6868
overview
6969
getting-started
7070

71-
.. toctree::
72-
:glob:
73-
:maxdepth: 1
74-
:caption: Concepts
75-
:hidden:
76-
77-
dtypes
78-
quantization
79-
sparsity
80-
performant_kernels
81-
serialization
82-
8371
.. toctree::
8472
:glob:
8573
:maxdepth: 1
@@ -99,3 +87,12 @@ with more content coming soon.
9987
api_ref_dtypes
10088
..
10189
api_ref_kernel
90+
91+
.. toctree::
92+
:glob:
93+
:maxdepth: 1
94+
:caption: Tutorials
95+
:hidden:
96+
97+
serialization
98+

docs/source/serialization.rst

Lines changed: 76 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,101 +3,98 @@ Serialization
33

44
Serialization and deserialization is an important question that people care about especially when we integrate torchao with other libraries. Here we want to describe how serialization and deserialization works for torchao optimized (quantized or sparsified) models.
55

6-
High level serialization and deserialization flow
7-
=================================================
8-
9-
```python
10-
import copy
11-
import tempfile
12-
import torch
13-
from torchao.quantization.quant_api import (
14-
quantize_,
15-
int4_weight_only,
16-
)
17-
18-
class ToyLinearModel(torch.nn.Module):
19-
def __init__(self, m=64, n=32, k=64):
20-
super().__init__()
21-
self.linear1 = torch.nn.Linear(m, n, bias=False)
22-
self.linear2 = torch.nn.Linear(n, k, bias=False)
23-
24-
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
25-
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
6+
Serialization and deserialization flow
7+
======================================
8+
9+
Here is the serialization and deserialization flow::
10+
11+
import copy
12+
import tempfile
13+
import torch
14+
from torchao.quantization.quant_api import (
15+
quantize_,
16+
int4_weight_only,
17+
)
18+
19+
class ToyLinearModel(torch.nn.Module):
20+
def __init__(self, m=64, n=32, k=64):
21+
super().__init__()
22+
self.linear1 = torch.nn.Linear(m, n, bias=False)
23+
self.linear2 = torch.nn.Linear(n, k, bias=False)
24+
25+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
26+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
27+
28+
def forward(self, x):
29+
x = self.linear1(x)
30+
x = self.linear2(x)
31+
return x
32+
33+
dtype = torch.bfloat16
34+
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
35+
print(f"original model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
36+
37+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
38+
quantize_(m, int4_weight_only())
39+
print(f"quantized model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
40+
41+
ref = m(*example_inputs)
42+
with tempfile.NamedTemporaryFile() as f:
43+
torch.save(m.state_dict(), f)
44+
f.seek(0)
45+
state_dict = torch.load(f)
46+
47+
with torch.device("meta"):
48+
m_loaded = ToyLinearModel(1024, 1024, 1024).eval().to(dtype)
49+
50+
# `linear.weight` is nn.Parameter, so we check the type of `linear.weight.data`
51+
print(f"type of weight before loading: {type(m_loaded.linear1.weight.data), type(m_loaded.linear2.weight.data)}")
52+
m_loaded.load_state_dict(state_dict, assign=True)
53+
print(f"type of weight after loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
54+
55+
res = m_loaded(*example_inputs)
56+
assert torch.equal(res, ref)
2657

27-
def forward(self, x):
28-
x = self.linear1(x)
29-
x = self.linear2(x)
30-
return x
3158

32-
dtype = torch.bfloat16
33-
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
34-
print(f"original model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
35-
36-
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
37-
quantize_(m, int4_weight_only())
38-
print(f"quantized model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
39-
40-
ref = m(*example_inputs)
41-
with tempfile.NamedTemporaryFile() as f:
42-
torch.save(m.state_dict(), f)
43-
f.seek(0)
44-
state_dict = torch.load(f)
59+
What happens when serializing an optimized model?
60+
=================================================
61+
To serialize an optimized model, we just need to call ``torch.save(m.state_dict(), f)``, because in torchao, we use tensor subclass to represent different dtypes or support different optimization techniques like quantization and sparsity. So after optimization, the only thing change is the weight Tensor is changed to an optimized weight Tensor, and the model structure is not changed at all. For example:
4562

46-
with torch.device("meta"):
47-
m_loaded = ToyLinearModel(1024, 1024, 1024).eval().to(dtype)
63+
original floating point model ``state_dict``::
64+
65+
{"linear1.weight": float_weight1, "linear2.weight": float_weight2}
4866

49-
# `linear.weight` is nn.Parameter, so we check the type of `linear.weight.data`
50-
print(f"type of weight before loading: {type(m_loaded.linear1.weight.data), type(m_loaded.linear2.weight.data)}")
51-
m_loaded.load_state_dict(state_dict, assign=True)
52-
print(f"type of weight after loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
67+
quantized model ``state_dict``::
5368

54-
res = m_loaded(*example_inputs)
55-
assert torch.equal(res, ref)
69+
{"linear1.weight": quantized_weight1, "linear2.weight": quantized_weight2, ...}
5670

57-
```
5871

59-
What happens when serializing an optimized model?
60-
=================================================
61-
To serialize an optimized model, we just need to call `torch.save(m.state_dict(), f)`, because in torchao, we use tensor subclass to represent different dtypes or support different optimization techniques like quantization and sparsity. So after optimization, the only thing change is the weight Tensor is changed to an optimized weight Tensor, and the model structure is not changed at all. For example:
72+
The size of the quantized model is typically going to be smaller to the original floating point model, but it also depends on the specific techinque and implementation you are using. You can print the model size with ``torchao.utils.get_model_size_in_bytes`` utility function, specifically for the above example using int4_weight_only quantization, we can see the size reduction is around 4x::
6273

63-
original floating point model `state_dict`:
64-
```
65-
{"linear1.weight": float_weight1, "linear2.weight": float_weight2}
66-
```
74+
original model size: 4.0 MB
75+
quantized model size: 1.0625 MB
6776

68-
quantized model `state_dict`:
69-
```
70-
{"linear1.weight": quantized_weight1, "linear2.weight": quantized_weight2, ...}
71-
```
77+
78+
What happens when deserializing an optimized model?
79+
===================================================
80+
To deserialize an optimized model, we can initialize the floating point model in `meta <https://pytorch.org/docs/stable/meta.html>`__ device and then load the optimized ``state_dict`` with ``assign=True`` using `model.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`__::
7281

73-
The size of the quantized model is typically going to be smaller to the original floating point model, but it also depends on the specific techinque and implementation you are using. You can print the model size with `torchao.utils.get_model_size_in_bytes` utility function, specifically for the above example using int4_weight_only quantization, we can see the size reduction is around 4x:
7482

75-
```
76-
original model size: 4.0 MB
77-
quantized model size: 1.0625 MB
78-
```
83+
with torch.device("meta"):
84+
m_loaded = ToyLinearModel(1024, 1024, 1024).eval().to(dtype)
7985

80-
What happens when deserializing an optimized model?
81-
===================================================
82-
To deserialize an optimized model, we can initialize the floating point model in `meta <https://pytorch.org/docs/stable/meta.html>`__ device and then load the optimized `state_dict` with `assign=True` using `model.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`__:
86+
print(f"type of weight before loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
87+
m_loaded.load_state_dict(state_dict, assign=True)
88+
print(f"type of weight after loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
8389

84-
```
85-
with torch.device("meta"):
86-
m_loaded = ToyLinearModel(1024, 1024, 1024).eval().to(dtype)
8790

88-
print(f"type of weight before loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
89-
m_loaded.load_state_dict(state_dict, assign=True)
90-
print(f"type of weight after loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
91-
```
91+
The reason we initialize the model in ``meta`` device is to avoid initializing the original floating point model since original floating point model may not fit into the device that we want to use for inference.
9292

93-
The reason we initialize the model in `meta` device is to avoid initializing the original floating point model since original floating point model may not fit into the device that we want to use for inference.
93+
What happens in ``m_loaded.load_state_dict(state_dict, assign=True)`` is that the corresponding weights (e.g. m_loaded.linear1.weight) are updated with the Tensors in ``state_dict``, which is an optimized tensor subclass instance (e.g. int4 ``AffineQuantizedTensor``). No dependency on torchao is needed for this to work.
9494

95-
What happens in `m_loaded.load_state_dict(state_dict, assign=True)` is that the corresponding weights (e.g. m_loaded.linear1.weight) are updated with the Tensors in `state_dict`, which is an optimized tensor subclass instance (e.g. int4 `AffineQuantizedTensor`). No dependency on torchao is needed for this to work.
95+
We can also verify that the weight is properly loaded by checking the type of weight tensor::
9696

97-
We can also verify that the weight is properly loaded by checking the type of weight tensor:
98-
```
99-
type of weight before loading: (<class 'torch.Tensor'>, <class 'torch.Tensor'>)
100-
type of weight after loading: (<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>)
97+
type of weight before loading: (<class 'torch.Tensor'>, <class 'torch.Tensor'>)
98+
type of weight after loading: (<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>)
10199

102-
```
103100

0 commit comments

Comments
 (0)