Skip to content

[ENH] Implement load_model function for ensemble classifiers #2631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
40 changes: 40 additions & 0 deletions aeon/classification/deep_learning/_inception_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,46 @@ def _predict_proba(self, X) -> np.ndarray:

return probs

@classmethod
def load_model(self, model_path, classes):
"""Load pre-trained classifiers instead of fitting.

When calling this function, all funcationalities can be used
such as predict, predict_proba, etc. with the loaded models.

Parameters
----------
model_path : list of str (list of paths including the model names and extension)
The directory where the models will be saved including the model
names with a ".keras" extension.
classes : np.ndarray
The set of unique classes the pre-trained loaded model is trained
to predict during the classification task.

Returns
-------
None
"""
assert (
type(model_path) is list
), "model_path should be a list of paths to the models"

classifier = self()
classifier.classifiers_ = []

for i in range(len(model_path)):
clf = IndividualInceptionClassifier()
clf.load_model(model_path[i], classes)
classifier.classifiers_.append(clf)

classifier.n_classifiers = len(classifier.classifiers_)

classifier.classes_ = classes
classifier.n_classes_ = len(classes)
classifier.is_fitted = True

return classifier

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
40 changes: 40 additions & 0 deletions aeon/classification/deep_learning/_lite_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,46 @@ def _predict_proba(self, X) -> np.ndarray:

return probs

@classmethod
def load_model(self, model_path, classes):
"""Load pre-trained classifiers instead of fitting.

When calling this function, all funcationalities can be used
such as predict, predict_proba, etc. with the loaded models.

Parameters
----------
model_path : list of str (list of paths including the model names and extension)
The director where the models will be saved including the model
names with a ".keras" extension.
classes : np.ndarray
The set of unique classes the pre-trained loaded model is trained
to predict during the classification task.

Returns
-------
None
"""
assert (
type(model_path) is list
), "model_path should be a list of paths to the models"

classifier = self()
classifier.classifiers_ = []

for i in range(len(model_path)):
clf = IndividualLITEClassifier()
clf.load_model(model_path=model_path[i], classes=classes)
classifier.classifiers_.append(clf)

classifier.n_classifiers = len(classifier.classifiers_)

classifier.classes_ = classes
classifier.n_classes_ = len(classes)
classifier.is_fitted = True

return classifier

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
47 changes: 47 additions & 0 deletions aeon/classification/deep_learning/tests/test_inception_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Tests for save/load functionality of InceptionTimeClassifier."""

import glob
import os
import tempfile

import numpy as np
import pytest

from aeon.classification.deep_learning import InceptionTimeClassifier
from aeon.testing.data_generation import make_example_3d_numpy
from aeon.utils.validation._dependencies import _check_soft_dependencies


@pytest.mark.skipif(
not _check_soft_dependencies("tensorflow", severity="none"),
reason="skip test if required soft dependency not available",
)
def test_save_load_inceptiontime():
"""Test saving and loading for InceptionTimeClassifier."""
with tempfile.TemporaryDirectory() as temp:
temp_dir = os.path.join(temp, "")

X, y = make_example_3d_numpy(
n_cases=10, n_channels=1, n_timepoints=12, return_y=True
)

model = InceptionTimeClassifier(
n_epochs=1, random_state=42, save_best_model=True, file_path=temp_dir
)
model.fit(X, y)

y_pred_orig = model.predict(X)

model_file = glob.glob(os.path.join(temp_dir, f"{model.best_file_name}*.keras"))

loaded_model = InceptionTimeClassifier.load_model(
model_path=model_file, classes=model.classes_
)

assert isinstance(loaded_model, InceptionTimeClassifier)

preds = loaded_model.predict(X)
assert isinstance(preds, np.ndarray)

assert len(preds) == len(y)
np.testing.assert_array_equal(preds, y_pred_orig)
49 changes: 49 additions & 0 deletions aeon/classification/deep_learning/tests/test_lite_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for save/load functionality of LiteTimeClassifier."""

import glob
import os
import tempfile

import numpy as np
import pytest

from aeon.classification.deep_learning import LITETimeClassifier
from aeon.testing.data_generation import make_example_3d_numpy
from aeon.utils.validation._dependencies import _check_soft_dependencies


@pytest.mark.skipif(
not _check_soft_dependencies("tensorflow", severity="none"),
reason="skip test if required soft dependency not available",
)
def test_save_load_litetim():
"""Test saving and loading for LiteTimeClassifier."""
with tempfile.TemporaryDirectory() as temp:
temp_dir = os.path.join(temp, "")

X, y = make_example_3d_numpy(
n_cases=10, n_channels=1, n_timepoints=12, return_y=True
)

model = LITETimeClassifier(
n_epochs=1, random_state=42, save_best_model=True, file_path=temp_dir
)
model.fit(X, y)

y_pred_orig = model.predict(X)

model_files = glob.glob(
os.path.join(temp_dir, f"{model.best_file_name}*.keras")
)

loaded_model = LITETimeClassifier.load_model(
model_path=model_files, classes=model.classes_
)

assert isinstance(loaded_model, LITETimeClassifier)

preds = loaded_model.predict(X)
assert isinstance(preds, np.ndarray)

assert len(preds) == len(y)
np.testing.assert_array_equal(preds, y_pred_orig)