Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Remove stored references to X and y in BasePredictor. #195

Merged
merged 2 commits into from
Jul 18, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/python/nimbusml/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from sklearn.base import BaseEstimator
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_is_fitted

from . import Pipeline
from .internal.core.base_pipeline_item import BasePipelineItem
Expand Down Expand Up @@ -49,8 +48,6 @@ def fit(self, X, y=None, **params):
"Classifier can't train when only one class is "
"present.")
self.classes_ = unique_classes
self.X_ = X
self.y_ = y

# Clear cached summary since it should not
# retain its value after a new call to fit
Expand All @@ -69,13 +66,24 @@ def fit(self, X, y=None, **params):
set_shape(self, X)
return self

@property
def _is_fitted(self):
"""
Tells if the predictor was trained.
"""
return (hasattr(self, 'model_') and
self.model_ and
os.path.isfile(self.model_))

@trace
def _invoke_inference_method(self, method, X, **params):
"""
Returns predictions. Can be predicted labels, probabilities
or else decision values.
"""
check_is_fitted(self, ["X_", "y_"])
if not self._is_fitted:
raise ValueError("Model is not fitted. "
"fit() must be called before {}.".format(method))

# Check that the input is of the same shape as the one passed
# during
Expand Down