Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/runtime/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from runtime.model.metadata import collect_metadata # noqa: F401
from runtime.model.metadata import load_metadata # noqa: F401
from runtime.model.metadata import save_metadata # noqa: F401
from runtime.model.model import EstimatorType, Model, load # noqa: F401
from runtime.model.model import EstimatorType, Model # noqa: F401
162 changes: 121 additions & 41 deletions python/runtime/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
"""This module saves or loads the SQLFlow model.
"""
import os
import tempfile
from enum import Enum

import runtime.oss as oss
from runtime.model.db import read_with_generator, write_with_generator
from runtime.model.tar import unzip_dir, zip_dir

Expand All @@ -24,10 +26,10 @@
import pickle

# archive the current work director into a tarball
tarball = "model.tar.gz"
TARBALL_NAME = "model.tar.gz"

# serialize the Model object into file
model_obj_file = "sqlflow_model.pkl"
MODEL_OBJ_FILE_NAME = "sqlflow_model.pkl"


class EstimatorType(Enum):
Expand Down Expand Up @@ -67,67 +69,145 @@ def __init__(self, typ, meta):
"""
self._typ = typ
self._meta = meta
self._dump_file = "sqlflow_model.pkl"

def save(self, datasource, table, cwd="./"):
"""This save function would archive all the files on work director
into a tarball, and saved it into DBMS with the specified table name.
def _zip(self, local_dir, tarball):
"""
Zip the model information and all files in local_dir into a tarball.

Args:
datasource: string
the connection string to DBMS.
table: string
the saved table name.
local_dir (str): the local directory.
tarball (str): the tarball path.

Returns:
None.
"""
model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME)
_dump_pkl(self, model_obj_file)
zip_dir(cwd, tarball)
zip_dir(local_dir, tarball, arcname="./")
os.remove(model_obj_file)

def _bytes_reader(filename, buf_size=8 * 32):
def _gen():
with open(filename, "rb") as f:
while True:
data = f.read(buf_size)
if data:
yield data
else:
break
@staticmethod
def _unzip(local_dir, tarball):
"""
Unzip the tarball into local_dir and deserialize the model
information.

return _gen
Args:
local_dir (str): the local directory.
tarball (str): the tarball path.

write_with_generator(datasource, table, _bytes_reader(tarball))
Returns:
Model: a Model object represent the model type and meta
information.
"""
model_obj_file = os.path.join(local_dir, MODEL_OBJ_FILE_NAME)
unzip_dir(tarball, local_dir)
model = _load_pkl(model_obj_file)
os.remove(model_obj_file)
return model

def save_to_db(self, datasource, table, local_dir=os.getcwd()):
"""
This save function would archive all the files on local_dir
into a tarball, and save it into DBMS with the specified table
name.

def load(datasource, table, cwd="./"):
"""Load the saved model from DBMS and unzip it on the work director.
Args:
datasource (str): the connection string to DBMS.
table (str): the saved table name.
local_dir (str): the local directory to save.

Args:
datasource: string
The connection string to DBMS
Returns:
None.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
tarball = os.path.join(tmp_dir, TARBALL_NAME)
self._zip(local_dir, tarball)

def _bytes_reader(filename, buf_size=8 * 32):
def _gen():
with open(filename, "rb") as f:
while True:
data = f.read(buf_size)
if data:
yield data
else:
break

return _gen

write_with_generator(datasource, table, _bytes_reader(tarball))

@staticmethod
def load_from_db(datasource, table, local_dir=os.getcwd()):
"""
Load the saved model from DBMS and unzip it on local_dir.

table: string
The table name which saved in DBMS
Args:
datasource (str): the connection string to DBMS
table (str): the table name which saved in DBMS
local_dir (str): the local directory to load.

Returns:
Model: a Model object represent the model type and meta information.
"""
gen = read_with_generator(datasource, table)
with open(tarball, "wb") as f:
for data in gen():
f.write(bytes(data))
Returns:
Model: a Model object represent the model type and meta
information.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
tarball = os.path.join(tmp_dir, TARBALL_NAME)
gen = read_with_generator(datasource, table)
with open(tarball, "wb") as f:
for data in gen():
f.write(bytes(data))

unzip_dir(tarball, cwd)
return _load_pkl(os.path.join(cwd, model_obj_file))
return Model._unzip(local_dir, tarball)

def save_to_oss(self, oss_model_dir, local_dir=os.getcwd()):
"""
This save function would archive all the files on local_dir
into a tarball, and save it into OSS model directory.

Args:
oss_model_dir (str): the OSS model directory to save.
It is in the format of oss://bucket/path/to/dir/.
local_dir (str): the local directory to save.

Returns:
None.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
tarball = os.path.join(tmp_dir, TARBALL_NAME)
self._zip(local_dir, tarball)
oss.save_file(oss_model_dir, tarball, TARBALL_NAME)

@staticmethod
def load_from_oss(oss_model_dir, local_dir=os.getcwd()):
"""
Load the saved model from OSS and unzip it on local_dir.

Args:
oss_model_dir (str): the OSS model directory to load.
It is in the format of oss://bucket/path/to/dir/.
local_dir (str): the local directory to load.

Returns:
Model: a Model object represent the model type and meta
information.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
tarball = os.path.join(tmp_dir, TARBALL_NAME)
oss.load_file(oss_model_dir, tarball, TARBALL_NAME)
return Model._unzip(local_dir, tarball)


def _dump_pkl(obj, to_file):
"""Dump the Python object to file with Pickle.
"""
with open(to_file, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
pickle.dump(obj, f, protocol=2)


def _load_pkl(filename):
def _load_pkl(from_file):
"""Load the Python object from a file with Pickle.
"""
with open(filename, "rb") as f:
with open(from_file, "rb") as f:
return pickle.load(f)
43 changes: 36 additions & 7 deletions python/runtime/model/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import tempfile
import unittest

from runtime.model import EstimatorType, Model, load
import runtime.oss as oss
from runtime.model import EstimatorType, Model
from runtime.testing import get_datasource


Expand All @@ -26,23 +27,51 @@ def setUp(self):
def tearDown(self):
os.chdir(self.cur_dir)

def test_save(self):
def test_save_load_db(self):
table = "sqlflow_models.test_model"
meta = {"train_params": {"n_classes": 3}}
meta = {"model_params": {"n_classes": 3}}
m = Model(EstimatorType.XGBOOST, meta)
datasource = get_datasource()

# save mode
with tempfile.TemporaryDirectory() as d:
os.chdir(d)
m.save(datasource, table)
m.save_to_db(datasource, table, d)

# load model
with tempfile.TemporaryDirectory() as d:
os.chdir(d)
m = load(datasource, table)
m = Model.load_from_db(datasource, table, d)
self.assertEqual(m._meta, meta)

@unittest.skipUnless(
os.getenv("SQLFLOW_OSS_AK") and os.getenv("SQLFLOW_OSS_SK"),
"skip when SQLFLOW_OSS_AK or SQLFLOW_OSS_SK is not set")
def test_save_load_oss(self):
bucket = oss.get_models_bucket()
meta = {"model_params": {"n_classes": 3}}
m = Model(EstimatorType.XGBOOST, meta)

oss_dir = "unknown/model_test_dnn_classifier/"
oss_model_path = "oss://%s/%s" % (bucket.bucket_name, oss_dir)

oss.delete_oss_dir_recursive(bucket, oss_dir)

# save model
def save_to_oss():
with tempfile.TemporaryDirectory() as d:
m.save_to_oss(oss_model_path, d)

# load model
def load_from_oss():
with tempfile.TemporaryDirectory() as d:
return Model.load_from_oss(oss_model_path, d)

with self.assertRaises(Exception):
load_from_oss()

save_to_oss()
m = load_from_oss()
self.assertEqual(m._meta, meta)


if __name__ == '__main__':
unittest.main()
7 changes: 5 additions & 2 deletions python/runtime/model/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tarfile


def zip_dir(src_dir, tarball):
def zip_dir(src_dir, tarball, arcname=None):
"""To compress a directory into tarball.

Args:
Expand All @@ -26,9 +26,12 @@ def zip_dir(src_dir, tarball):

tarball: string
The output tarball name.

arcname: string
The output name of src_dir in the tarball.
"""
with tarfile.open(tarball, "w:gz") as tar:
tar.add(src_dir, recursive=True)
tar.add(src_dir, arcname=arcname, recursive=True)


def unzip_dir(tarball, dest_dir=os.getcwd()):
Expand Down
Loading