Skip to content

Commit b67d233

Browse files
authored
Add Model.save_to_oss and Model.load_from_oss (#2817)
* add save_to_oss/load_from_oss * change pickle protocol * add more explanations on oss_model_dir doc * fix ut
1 parent 4d46f6e commit b67d233

File tree

6 files changed

+229
-85
lines changed

6 files changed

+229
-85
lines changed

python/runtime/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
from runtime.model.metadata import collect_metadata # noqa: F401
1515
from runtime.model.metadata import load_metadata # noqa: F401
1616
from runtime.model.metadata import save_metadata # noqa: F401
17-
from runtime.model.model import EstimatorType, Model, load # noqa: F401
17+
from runtime.model.model import EstimatorType, Model # noqa: F401

python/runtime/model/model.py

Lines changed: 121 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""This module saves or loads the SQLFlow model.
1414
"""
1515
import os
16+
import tempfile
1617
from enum import Enum
1718

19+
import runtime.oss as oss
1820
from runtime.model.db import read_with_generator, write_with_generator
1921
from runtime.model.tar import unzip_dir, zip_dir
2022

@@ -24,10 +26,10 @@
2426
import pickle
2527

2628
# archive the current work director into a tarball
27-
tarball = "model.tar.gz"
29+
TARBALL_NAME = "model.tar.gz"
2830

2931
# serialize the Model object into file
30-
model_obj_file = "sqlflow_model.pkl"
32+
MODEL_OBJ_FILE_NAME = "sqlflow_model.pkl"
3133

3234

3335
class EstimatorType(Enum):
@@ -67,67 +69,145 @@ def __init__(self, typ, meta):
6769
"""
6870
self._typ = typ
6971
self._meta = meta
70-
self._dump_file = "sqlflow_model.pkl"
7172

72-
def save(self, datasource, table, cwd="./"):
73-
"""This save function would archive all the files on work director
74-
into a tarball, and saved it into DBMS with the specified table name.
73+
def _zip(self, local_dir, tarball):
74+
"""
75+
Zip the model information and all files in local_dir into a tarball.
7576
7677
Args:
77-
datasource: string
78-
the connection string to DBMS.
79-
table: string
80-
the saved table name.
78+
local_dir (str): the local directory.
79+
tarball (str): the tarball path.
80+
81+
Returns:
82+
None.
8183
"""
84+
model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME)
8285
_dump_pkl(self, model_obj_file)
83-
zip_dir(cwd, tarball)
86+
zip_dir(local_dir, tarball, arcname="./")
87+
os.remove(model_obj_file)
8488

85-
def _bytes_reader(filename, buf_size=8 * 32):
86-
def _gen():
87-
with open(filename, "rb") as f:
88-
while True:
89-
data = f.read(buf_size)
90-
if data:
91-
yield data
92-
else:
93-
break
89+
@staticmethod
90+
def _unzip(local_dir, tarball):
91+
"""
92+
Unzip the tarball into local_dir and deserialize the model
93+
information.
9494
95-
return _gen
95+
Args:
96+
local_dir (str): the local directory.
97+
tarball (str): the tarball path.
9698
97-
write_with_generator(datasource, table, _bytes_reader(tarball))
99+
Returns:
100+
Model: a Model object represent the model type and meta
101+
information.
102+
"""
103+
model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME)
104+
unzip_dir(tarball, local_dir)
105+
model = _load_pkl(model_obj_file)
106+
os.remove(model_obj_file)
107+
return model
98108

109+
def save_to_db(self, datasource, table, local_dir=os.getcwd()):
110+
"""
111+
This save function would archive all the files on local_dir
112+
into a tarball, and save it into DBMS with the specified table
113+
name.
99114
100-
def load(datasource, table, cwd="./"):
101-
"""Load the saved model from DBMS and unzip it on the work director.
115+
Args:
116+
datasource (str): the connection string to DBMS.
117+
table (str): the saved table name.
118+
local_dir (str): the local directory to save.
102119
103-
Args:
104-
datasource: string
105-
The connection string to DBMS
120+
Returns:
121+
None.
122+
"""
123+
with tempfile.TemporaryDirectory() as tmp_dir:
124+
tarball = os.path.join(tmp_dir, TARBALL_NAME)
125+
self._zip(local_dir, tarball)
126+
127+
def _bytes_reader(filename, buf_size=8 * 32):
128+
def _gen():
129+
with open(filename, "rb") as f:
130+
while True:
131+
data = f.read(buf_size)
132+
if data:
133+
yield data
134+
else:
135+
break
136+
137+
return _gen
138+
139+
write_with_generator(datasource, table, _bytes_reader(tarball))
140+
141+
@staticmethod
142+
def load_from_db(datasource, table, local_dir=os.getcwd()):
143+
"""
144+
Load the saved model from DBMS and unzip it on local_dir.
106145
107-
table: string
108-
The table name which saved in DBMS
146+
Args:
147+
datasource (str): the connection string to DBMS
148+
table (str): the table name which saved in DBMS
149+
local_dir (str): the local directory to load.
109150
110-
Returns:
111-
Model: a Model object represent the model type and meta information.
112-
"""
113-
gen = read_with_generator(datasource, table)
114-
with open(tarball, "wb") as f:
115-
for data in gen():
116-
f.write(bytes(data))
151+
Returns:
152+
Model: a Model object represent the model type and meta
153+
information.
154+
"""
155+
with tempfile.TemporaryDirectory() as tmp_dir:
156+
tarball = os.path.join(tmp_dir, TARBALL_NAME)
157+
gen = read_with_generator(datasource, table)
158+
with open(tarball, "wb") as f:
159+
for data in gen():
160+
f.write(bytes(data))
117161

118-
unzip_dir(tarball, cwd)
119-
return _load_pkl(os.path.join(cwd, model_obj_file))
162+
return Model._unzip(local_dir, tarball)
163+
164+
def save_to_oss(self, oss_model_dir, local_dir=os.getcwd()):
165+
"""
166+
This save function would archive all the files on local_dir
167+
into a tarball, and save it into OSS model directory.
168+
169+
Args:
170+
oss_model_dir (str): the OSS model directory to save.
171+
It is in the format of oss://bucket/path/to/dir/.
172+
local_dir (str): the local directory to save.
173+
174+
Returns:
175+
None.
176+
"""
177+
with tempfile.TemporaryDirectory() as tmp_dir:
178+
tarball = os.path.join(tmp_dir, TARBALL_NAME)
179+
self._zip(local_dir, tarball)
180+
oss.save_file(oss_model_dir, tarball, TARBALL_NAME)
181+
182+
@staticmethod
183+
def load_from_oss(oss_model_dir, local_dir=os.getcwd()):
184+
"""
185+
Load the saved model from OSS and unzip it on local_dir.
186+
187+
Args:
188+
oss_model_dir (str): the OSS model directory to load.
189+
It is in the format of oss://bucket/path/to/dir/.
190+
local_dir (str): the local directory to load.
191+
192+
Returns:
193+
Model: a Model object represent the model type and meta
194+
information.
195+
"""
196+
with tempfile.TemporaryDirectory() as tmp_dir:
197+
tarball = os.path.join(tmp_dir, TARBALL_NAME)
198+
oss.load_file(oss_model_dir, tarball, TARBALL_NAME)
199+
return Model._unzip(local_dir, tarball)
120200

121201

122202
def _dump_pkl(obj, to_file):
123203
"""Dump the Python object to file with Pickle.
124204
"""
125205
with open(to_file, "wb") as f:
126-
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
206+
pickle.dump(obj, f, protocol=2)
127207

128208

129-
def _load_pkl(filename):
209+
def _load_pkl(from_file):
130210
"""Load the Python object from a file with Pickle.
131211
"""
132-
with open(filename, "rb") as f:
212+
with open(from_file, "rb") as f:
133213
return pickle.load(f)

python/runtime/model/model_test.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import tempfile
1616
import unittest
1717

18-
from runtime.model import EstimatorType, Model, load
18+
import runtime.oss as oss
19+
from runtime.model import EstimatorType, Model
1920
from runtime.testing import get_datasource
2021

2122

@@ -26,23 +27,51 @@ def setUp(self):
2627
def tearDown(self):
2728
os.chdir(self.cur_dir)
2829

29-
def test_save(self):
30+
def test_save_load_db(self):
3031
table = "sqlflow_models.test_model"
31-
meta = {"train_params": {"n_classes": 3}}
32+
meta = {"model_params": {"n_classes": 3}}
3233
m = Model(EstimatorType.XGBOOST, meta)
3334
datasource = get_datasource()
3435

3536
# save mode
3637
with tempfile.TemporaryDirectory() as d:
37-
os.chdir(d)
38-
m.save(datasource, table)
38+
m.save_to_db(datasource, table, d)
3939

4040
# load model
4141
with tempfile.TemporaryDirectory() as d:
42-
os.chdir(d)
43-
m = load(datasource, table)
42+
m = Model.load_from_db(datasource, table, d)
4443
self.assertEqual(m._meta, meta)
4544

45+
@unittest.skipUnless(
46+
os.getenv("SQLFLOW_OSS_AK") and os.getenv("SQLFLOW_OSS_SK"),
47+
"skip when SQLFLOW_OSS_AK or SQLFLOW_OSS_SK is not set")
48+
def test_save_load_oss(self):
49+
bucket = oss.get_models_bucket()
50+
meta = {"model_params": {"n_classes": 3}}
51+
m = Model(EstimatorType.XGBOOST, meta)
52+
53+
oss_dir = "unknown/model_test_dnn_classifier/"
54+
oss_model_path = "oss://%s/%s" % (bucket.bucket_name, oss_dir)
55+
56+
oss.delete_oss_dir_recursive(bucket, oss_dir)
57+
58+
# save model
59+
def save_to_oss():
60+
with tempfile.TemporaryDirectory() as d:
61+
m.save_to_oss(oss_model_path, d)
62+
63+
# load model
64+
def load_from_oss():
65+
with tempfile.TemporaryDirectory() as d:
66+
return Model.load_from_oss(oss_model_path, d)
67+
68+
with self.assertRaises(Exception):
69+
load_from_oss()
70+
71+
save_to_oss()
72+
m = load_from_oss()
73+
self.assertEqual(m._meta, meta)
74+
4675

4776
if __name__ == '__main__':
4877
unittest.main()

python/runtime/model/tar.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tarfile
1818

1919

20-
def zip_dir(src_dir, tarball):
20+
def zip_dir(src_dir, tarball, arcname=None):
2121
"""To compress a directory into tarball.
2222
2323
Args:
@@ -26,9 +26,12 @@ def zip_dir(src_dir, tarball):
2626
2727
tarball: string
2828
The output tarball name.
29+
30+
arcname: string
31+
The output name of src_dir in the tarball.
2932
"""
3033
with tarfile.open(tarball, "w:gz") as tar:
31-
tar.add(src_dir, recursive=True)
34+
tar.add(src_dir, arcname=arcname, recursive=True)
3235

3336

3437
def unzip_dir(tarball, dest_dir=os.getcwd()):

0 commit comments

Comments
 (0)