|
42 | 42 | 9112: 9113, # v1.1 |
43 | 43 | 8387: 9113, # v1.0 |
44 | 44 | 7020: 7021, # v0.5 |
| 45 | + 11028: 10583, # v2.2 registration CHANGE ME |
45 | 46 | } |
46 | 47 | _CIVISML_TEMPLATE = None # CivisML training template to use |
| 48 | +REGISTRATION_TEMPLATES = [11028, # v2.2 |
| 49 | + ] |
47 | 50 |
|
48 | 51 |
|
49 | 52 | class ModelError(RuntimeError): |
@@ -713,6 +716,8 @@ def _get_template_ids(self, client): |
713 | 716 | global _CIVISML_TEMPLATE |
714 | 717 | if _CIVISML_TEMPLATE is None: |
715 | 718 | for t_id in sorted(_PRED_TEMPLATES)[::-1]: |
| 719 | + if t_id in REGISTRATION_TEMPLATES: |
| 720 | + continue |
716 | 721 | try: |
717 | 722 | # Check that we can access the template |
718 | 723 | client.templates.get_scripts(id=t_id) |
@@ -783,6 +788,98 @@ def __setstate__(self, state): |
783 | 788 | template_ids = self._get_template_ids(self._client) |
784 | 789 | self.train_template_id, self.predict_template_id = template_ids |
785 | 790 |
|
| 791 | + @classmethod |
| 792 | + def register_pretrained_model(cls, model, dependent_variable=None, |
| 793 | + features=None, model_name=None, |
| 794 | + dependencies=None, git_token_name=None, |
| 795 | + skip_model_check=False, verbose=False, |
| 796 | + client=None): |
| 797 | + """Make a scikit-learn Estimator usable with CivisML scoring |
| 798 | +
|
| 799 | + Parameters |
| 800 | + ---------- |
| 801 | + model : sklearn.base.BaseEstimator or int |
| 802 | + The model object. This must be a fitted scikit-learn compatible |
| 803 | + Estimator object, or else the integer Civis File ID of a |
| 804 | + pickle which stores such an object. |
| 805 | + dependent_variable |
| 806 | + features |
| 807 | + model_name |
| 808 | + dependencies |
| 809 | + git_token_name |
| 810 | + skip_model_check |
| 811 | + verbose |
| 812 | + client |
| 813 | +
|
| 814 | + Returns |
| 815 | + ------- |
| 816 | + :class:`~civis.ml.ModelPipeline` |
| 817 | + """ |
| 818 | + client = client or APIClient(resources='all') |
| 819 | + |
| 820 | + if isinstance(dependent_variable, six.string_types): |
| 821 | + dependent_variable = [dependent_variable] |
| 822 | + if isinstance(features, six.string_types): |
| 823 | + features = [features] |
| 824 | + if isinstance(dependencies, six.string_types): |
| 825 | + dependencies = [dependencies] |
| 826 | + if not model_name: |
| 827 | + model_name = ("Pretrained {} model for " |
| 828 | + "CivisML".format(model.__class__.__name__)) |
| 829 | + model_name = model_name[:255] # Max size is 255 characters |
| 830 | + |
| 831 | + if isinstance(model, (int, float, six.string_types)): |
| 832 | + model_file_id = int(model) |
| 833 | + else: |
| 834 | + try: |
| 835 | + tempdir = tempfile.mkdtemp() |
| 836 | + fout = os.path.join(tempdir, 'model_for_civisml.pkl') |
| 837 | + joblib.dump(model, fout, compress=3) |
| 838 | + with open(fout, 'rb') as _fout: |
| 839 | + # NB: Using the name "estimator.pkl" means that |
| 840 | + # CivisML doesn't need to copy this input to a file |
| 841 | + # with a different name. |
| 842 | + model_file_id = cio.file_to_civis(_fout, 'estimator.pkl', |
| 843 | + client=client) |
| 844 | + finally: |
| 845 | + shutil.rmtree(tempdir) |
| 846 | + |
| 847 | + args = {'MODEL_FILE_ID': str(model_file_id), |
| 848 | + 'SKIP_MODEL_CHECK': skip_model_check, |
| 849 | + 'DEBUG': verbose} |
| 850 | + if dependent_variable is not None: |
| 851 | + args['TARGET_COLUMN'] = ' '.join(dependent_variable) |
| 852 | + if features is not None: |
| 853 | + args['FEATURE_COLUMNS'] = ' '.join(features) |
| 854 | + if dependencies is not None: |
| 855 | + args['DEPENDENCIES'] = ' '.join(dependencies) |
| 856 | + if git_token_name: |
| 857 | + creds = find(client.credentials.list(), |
| 858 | + name=git_token_name, |
| 859 | + type='Custom') |
| 860 | + if len(creds) > 1: |
| 861 | + raise ValueError("Unique credential with name '{}' for " |
| 862 | + "remote git hosting service not found!" |
| 863 | + .format(git_token_name)) |
| 864 | + args['GIT_CRED'] = creds[0].id |
| 865 | + |
| 866 | + template_id = max(REGISTRATION_TEMPLATES) |
| 867 | + container = client.scripts.post_custom( |
| 868 | + from_template_id=template_id, |
| 869 | + name=model_name, |
| 870 | + arguments=args) |
| 871 | + log.info('Created custom script %s.', container.id) |
| 872 | + |
| 873 | + run = client.scripts.post_custom_runs(container.id) |
| 874 | + log.debug('Started job %s, run %s.', container.id, run.id) |
| 875 | + |
| 876 | + fut = ModelFuture(container.id, run.id, client=client, |
| 877 | + poll_on_creation=False) |
| 878 | + fut.result() |
| 879 | + log.info('Model registration complete.') |
| 880 | + |
| 881 | + return ModelPipeline.from_existing(fut.job_id, fut.run_id, client) |
| 882 | + |
786 | 883 | @classmethod |
787 | 884 | def from_existing(cls, train_job_id, train_run_id='latest', client=None): |
788 | 885 | """Create a :class:`ModelPipeline` object from existing model IDs |
|
0 commit comments