Skip to content

Commit 8586e97

Browse files
author
Stephen Hoover
committed
ENH BYO trained model
If you train a scikit-learn compatible estimator outside of Civis Platform, you can use this to upload it to Civis Platform and prepare it for scoring with CivisML. There's a new Custom Script which will introspect metadata necessary for CivisML and make itself appear sufficiently like a CivisML training job that it can be used as input to a scoring job.
1 parent 664f6b3 commit 8586e97

File tree

2 files changed

+104
-3
lines changed

2 files changed

+104
-3
lines changed

civis/ml/_model.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@
4242
9112: 9113, # v1.1
4343
8387: 9113, # v1.0
4444
7020: 7021, # v0.5
45+
11028: 10583, # v2.2 registration CHANGE ME
4546
}
4647
_CIVISML_TEMPLATE = None # CivisML training template to use
48+
REGISTRATION_TEMPLATES = [11028, # v2.2
49+
]
4750

4851

4952
class ModelError(RuntimeError):
@@ -713,6 +716,8 @@ def _get_template_ids(self, client):
713716
global _CIVISML_TEMPLATE
714717
if _CIVISML_TEMPLATE is None:
715718
for t_id in sorted(_PRED_TEMPLATES)[::-1]:
719+
if t_id in REGISTRATION_TEMPLATES:
720+
continue
716721
try:
717722
# Check that we can access the template
718723
client.templates.get_scripts(id=t_id)
@@ -783,6 +788,98 @@ def __setstate__(self, state):
783788
template_ids = self._get_template_ids(self._client)
784789
self.train_template_id, self.predict_template_id = template_ids
785790

791+
@classmethod
792+
def register_pretrained_model(cls, model, dependent_variable=None,
793+
features=None, model_name=None,
794+
dependencies=None, git_token_name=None,
795+
skip_model_check=False, verbose=False,
796+
client=None):
797+
"""Make a scikit-learn Estimator usable with CivisML scoring
798+
799+
Parameters
800+
----------
801+
model : sklearn.base.BaseEstimator or int
802+
The model object. This must be a fitted scikit-learn compatible
803+
Estimator object, or else the integer Civis File ID of a
804+
pickle which stores such an object.
805+
dependent_variable
806+
features
807+
model_name
808+
dependencies
809+
git_token_name
810+
skip_model_check
811+
verbose
812+
client
813+
814+
Returns
815+
-------
816+
:class:`~civis.ml.ModelPipeline`
817+
"""
818+
client = client or APIClient(resources='all')
819+
820+
if isinstance(dependent_variable, six.string_types):
821+
dependent_variable = [dependent_variable]
822+
if isinstance(features, six.string_types):
823+
features = [features]
824+
if isinstance(dependencies, six.string_types):
825+
dependencies = [dependencies]
826+
if not model_name:
827+
model_name = ("Pretrained {} model for "
828+
"CivisML".format(model.__class__.__name__))
829+
model_name = model_name[:255] # Max size is 255 characters
830+
831+
if isinstance(model, (int, float, six.string_types)):
832+
model_file_id = int(model)
833+
else:
834+
try:
835+
tempdir = tempfile.mkdtemp()
836+
fout = os.path.join(tempdir, 'model_for_civisml.pkl')
837+
joblib.dump(model, fout, compress=3)
838+
with open(fout, 'rb') as _fout:
839+
# NB: Using the name "estimator.pkl" means that
840+
# CivisML doesn't need to copy this input to a file
841+
# with a different name.
842+
model_file_id = cio.file_to_civis(_fout, 'estimator.pkl',
843+
client=client)
844+
finally:
845+
shutil.rmtree(tempdir)
846+
847+
args = {'MODEL_FILE_ID': str(model_file_id),
848+
'SKIP_MODEL_CHECK': skip_model_check,
849+
'DEBUG': verbose}
850+
if dependent_variable is not None:
851+
args['TARGET_COLUMN'] = ' '.join(dependent_variable)
852+
if features is not None:
853+
args['FEATURE_COLUMNS'] = ' '.join(features)
854+
if dependencies is not None:
855+
args['DEPENDENCIES'] = ' '.join(dependencies)
856+
if git_token_name:
857+
creds = find(client.credentials.list(),
858+
name=git_token_name,
859+
type='Custom')
860+
if len(creds) > 1:
861+
raise ValueError("Unique credential with name '{}' for "
862+
"remote git hosting service not found!"
863+
.format(git_token_name))
864+
args['GIT_CRED'] = creds[0].id
865+
866+
template_id = max(REGISTRATION_TEMPLATES)
867+
container = client.scripts.post_custom(
868+
from_template_id=template_id,
869+
name=model_name,
870+
arguments=args)
871+
log.info('Created custom script %s.', container.id)
872+
873+
run = client.scripts.post_custom_runs(container.id)
874+
log.debug('Started job %s, run %s.', container.id, run.id)
875+
876+
fut = ModelFuture(container.id, run.id, client=client,
877+
poll_on_creation=False)
878+
fut.result()
879+
log.info('Model registration complete.')
880+
881+
return ModelPipeline.from_existing(fut.job_id, fut.run_id, client)
882+
786883
@classmethod
787884
def from_existing(cls, train_job_id, train_run_id='latest', client=None):
788885
"""Create a :class:`ModelPipeline` object from existing model IDs

civis/ml/tests/test_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
from civis.ml import _model
4040

4141

42+
LATEST_TRAIN_TEMPLATE = 10582
43+
LATEST_PRED_TEMPLATE = 10583
44+
45+
4246
def setup_client_mock(script_id=-10, run_id=100, state='succeeded',
4347
run_outputs=None):
4448
"""Return a Mock set up for use in testing container scripts
@@ -682,7 +686,7 @@ def test_modelpipeline_init_newest():
682686
mp = _model.ModelPipeline(LogisticRegression(), 'test', etl=etl,
683687
client=mock_client)
684688
assert mp.etl == etl
685-
assert mp.train_template_id == max(_model._PRED_TEMPLATES)
689+
assert mp.train_template_id == LATEST_TRAIN_TEMPLATE
686690
# clean up
687691
_model._CIVISML_TEMPLATE = None
688692

@@ -787,7 +791,7 @@ def test_modelpipeline_classmethod_constructor_defaults(
787791
def test_modelpipeline_classmethod_constructor_future_train_version():
788792
# Test handling attempts to restore a model created with a newer
789793
# version of CivisML.
790-
current_max_template = max(_model._PRED_TEMPLATES)
794+
current_max_template = LATEST_TRAIN_TEMPLATE
791795
cont = container_response_stub(current_max_template + 1000)
792796
mock_client = mock.Mock()
793797
mock_client.scripts.get_containers.return_value = cont
@@ -892,7 +896,7 @@ def test_modelpipeline_train_df(mock_ccr, mock_stash, mp_setup):
892896
train_data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
893897
assert 'res' == mp.train(train_data)
894898
mock_stash.assert_called_once_with(
895-
train_data, max(_model._PRED_TEMPLATES.keys()), client=mock.ANY)
899+
train_data, LATEST_TRAIN_TEMPLATE, client=mock.ANY)
896900
assert mp.train_result_ == 'res'
897901

898902

0 commit comments

Comments
 (0)