|
13 | 13 | """This module saves or loads the SQLFlow model. |
14 | 14 | """ |
15 | 15 | import os |
| 16 | +import tempfile |
16 | 17 | from enum import Enum |
17 | 18 |
|
| 19 | +import runtime.oss as oss |
18 | 20 | from runtime.model.db import read_with_generator, write_with_generator |
19 | 21 | from runtime.model.tar import unzip_dir, zip_dir |
20 | 22 |
|
|
24 | 26 | import pickle |
25 | 27 |
|
26 | 28 | # archive the current work director into a tarball |
27 | | -tarball = "model.tar.gz" |
| 29 | +TARBALL_NAME = "model.tar.gz" |
28 | 30 |
|
29 | 31 | # serialize the Model object into file |
30 | | -model_obj_file = "sqlflow_model.pkl" |
| 32 | +MODEL_OBJ_FILE_NAME = "sqlflow_model.pkl" |
31 | 33 |
|
32 | 34 |
|
33 | 35 | class EstimatorType(Enum): |
@@ -67,67 +69,145 @@ def __init__(self, typ, meta): |
67 | 69 | """ |
68 | 70 | self._typ = typ |
69 | 71 | self._meta = meta |
70 | | - self._dump_file = "sqlflow_model.pkl" |
71 | 72 |
|
72 | | - def save(self, datasource, table, cwd="./"): |
73 | | - """This save function would archive all the files on work director |
74 | | - into a tarball, and saved it into DBMS with the specified table name. |
| 73 | + def _zip(self, local_dir, tarball): |
| 74 | + """ |
| 75 | + Zip the model information and all files in local_dir into a tarball. |
75 | 76 |
|
76 | 77 | Args: |
77 | | - datasource: string |
78 | | - the connection string to DBMS. |
79 | | - table: string |
80 | | - the saved table name. |
| 78 | + local_dir (str): the local directory. |
| 79 | + tarball (str): the tarball path. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + None. |
81 | 83 | """ |
| 84 | + model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME) |
82 | 85 | _dump_pkl(self, model_obj_file) |
83 | | - zip_dir(cwd, tarball) |
| 86 | + zip_dir(local_dir, tarball, arcname="./") |
| 87 | + os.remove(model_obj_file) |
84 | 88 |
|
85 | | - def _bytes_reader(filename, buf_size=8 * 32): |
86 | | - def _gen(): |
87 | | - with open(filename, "rb") as f: |
88 | | - while True: |
89 | | - data = f.read(buf_size) |
90 | | - if data: |
91 | | - yield data |
92 | | - else: |
93 | | - break |
| 89 | + @staticmethod |
| 90 | + def _unzip(local_dir, tarball): |
| 91 | + """ |
| 92 | + Unzip the tarball into local_dir and deserialize the model |
| 93 | + information. |
94 | 94 |
|
95 | | - return _gen |
| 95 | + Args: |
| 96 | + local_dir (str): the local directory. |
| 97 | + tarball (str): the tarball path. |
96 | 98 |
|
97 | | - write_with_generator(datasource, table, _bytes_reader(tarball)) |
| 99 | + Returns: |
| 100 | + Model: a Model object represent the model type and meta |
| 101 | + information. |
| 102 | + """ |
| 103 | + model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME) |
| 104 | + unzip_dir(tarball, local_dir) |
| 105 | + model = _load_pkl(model_obj_file) |
| 106 | + os.remove(model_obj_file) |
| 107 | + return model |
98 | 108 |
|
| 109 | + def save_to_db(self, datasource, table, local_dir=os.getcwd()): |
| 110 | + """ |
| 111 | + This save function would archive all the files on local_dir |
| 112 | + into a tarball, and save it into DBMS with the specified table |
| 113 | + name. |
99 | 114 |
|
100 | | -def load(datasource, table, cwd="./"): |
101 | | - """Load the saved model from DBMS and unzip it on the work director. |
| 115 | + Args: |
| 116 | + datasource (str): the connection string to DBMS. |
| 117 | + table (str): the saved table name. |
| 118 | + local_dir (str): the local directory to save. |
102 | 119 |
|
103 | | - Args: |
104 | | - datasource: string |
105 | | - The connection string to DBMS |
| 120 | + Returns: |
| 121 | + None. |
| 122 | + """ |
| 123 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 124 | + tarball = os.path.join(tmp_dir, TARBALL_NAME) |
| 125 | + self._zip(local_dir, tarball) |
| 126 | + |
| 127 | + def _bytes_reader(filename, buf_size=8 * 32): |
| 128 | + def _gen(): |
| 129 | + with open(filename, "rb") as f: |
| 130 | + while True: |
| 131 | + data = f.read(buf_size) |
| 132 | + if data: |
| 133 | + yield data |
| 134 | + else: |
| 135 | + break |
| 136 | + |
| 137 | + return _gen |
| 138 | + |
| 139 | + write_with_generator(datasource, table, _bytes_reader(tarball)) |
| 140 | + |
| 141 | + @staticmethod |
| 142 | + def load_from_db(datasource, table, local_dir=os.getcwd()): |
| 143 | + """ |
| 144 | + Load the saved model from DBMS and unzip it on local_dir. |
106 | 145 |
|
107 | | - table: string |
108 | | - The table name which saved in DBMS |
| 146 | + Args: |
| 147 | + datasource (str): the connection string to DBMS |
| 148 | + table (str): the table name which saved in DBMS |
| 149 | + local_dir (str): the local directory to load. |
109 | 150 |
|
110 | | - Returns: |
111 | | - Model: a Model object represent the model type and meta information. |
112 | | - """ |
113 | | - gen = read_with_generator(datasource, table) |
114 | | - with open(tarball, "wb") as f: |
115 | | - for data in gen(): |
116 | | - f.write(bytes(data)) |
| 151 | + Returns: |
| 152 | + Model: a Model object represent the model type and meta |
| 153 | + information. |
| 154 | + """ |
| 155 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 156 | + tarball = os.path.join(tmp_dir, TARBALL_NAME) |
| 157 | + gen = read_with_generator(datasource, table) |
| 158 | + with open(tarball, "wb") as f: |
| 159 | + for data in gen(): |
| 160 | + f.write(bytes(data)) |
117 | 161 |
|
118 | | - unzip_dir(tarball, cwd) |
119 | | - return _load_pkl(os.path.join(cwd, model_obj_file)) |
| 162 | + return Model._unzip(local_dir, tarball) |
| 163 | + |
| 164 | + def save_to_oss(self, oss_model_dir, local_dir=os.getcwd()): |
| 165 | + """ |
| 166 | + This save function would archive all the files on local_dir |
| 167 | + into a tarball, and save it into OSS model directory. |
| 168 | +
|
| 169 | + Args: |
| 170 | + oss_model_dir (str): the OSS model directory to save. |
| 171 | + It is in the format of oss://bucket/path/to/dir/. |
| 172 | + local_dir (str): the local directory to save. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + None. |
| 176 | + """ |
| 177 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 178 | + tarball = os.path.join(tmp_dir, TARBALL_NAME) |
| 179 | + self._zip(local_dir, tarball) |
| 180 | + oss.save_file(oss_model_dir, tarball, TARBALL_NAME) |
| 181 | + |
| 182 | + @staticmethod |
| 183 | + def load_from_oss(oss_model_dir, local_dir=os.getcwd()): |
| 184 | + """ |
| 185 | + Load the saved model from OSS and unzip it on local_dir. |
| 186 | +
|
| 187 | + Args: |
| 188 | + oss_model_dir (str): the OSS model directory to load. |
| 189 | + It is in the format of oss://bucket/path/to/dir/. |
| 190 | + local_dir (str): the local directory to load. |
| 191 | +
|
| 192 | + Returns: |
| 193 | + Model: a Model object represent the model type and meta |
| 194 | + information. |
| 195 | + """ |
| 196 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 197 | + tarball = os.path.join(tmp_dir, TARBALL_NAME) |
| 198 | + oss.load_file(oss_model_dir, tarball, TARBALL_NAME) |
| 199 | + return Model._unzip(local_dir, tarball) |
120 | 200 |
|
121 | 201 |
|
122 | 202 | def _dump_pkl(obj, to_file): |
123 | 203 | """Dump the Python object to file with Pickle. |
124 | 204 | """ |
125 | 205 | with open(to_file, "wb") as f: |
126 | | - pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) |
| 206 | + pickle.dump(obj, f, protocol=2) |
127 | 207 |
|
128 | 208 |
|
129 | | -def _load_pkl(filename): |
| 209 | +def _load_pkl(from_file): |
130 | 210 | """Load the Python object from a file with Pickle. |
131 | 211 | """ |
132 | | - with open(filename, "rb") as f: |
| 212 | + with open(from_file, "rb") as f: |
133 | 213 | return pickle.load(f) |
0 commit comments