From 2fa0596e1a139042f50b3d59971c00c06d566e02 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 12 Jun 2019 23:49:10 +0200 Subject: [PATCH] BUG allow to import keras from tensorflow (#532) --- doc/whats_new/v0.5.rst | 4 +++ imblearn/keras/_generator.py | 49 +++++++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v0.5.rst b/doc/whats_new/v0.5.rst index a1a20baa6..0c01d561c 100644 --- a/doc/whats_new/v0.5.rst +++ b/doc/whats_new/v0.5.rst @@ -77,3 +77,7 @@ Bug :class:`imblearn.ensemble.RUSBoostClassifier` to get a decision stump as a weak learner as in the original paper. :pr:`545` by :user:`Christos Aridas `. + +- Allow to import ``keras`` directly from ``tensorflow`` in the + :mod:`imblearn.keras`. + :pr:`531` by :user:`Guillaume Lemaitre `. diff --git a/imblearn/keras/_generator.py b/imblearn/keras/_generator.py index 4ac793b5e..75f371e5e 100644 --- a/imblearn/keras/_generator.py +++ b/imblearn/keras/_generator.py @@ -1,16 +1,43 @@ """Implement generators for ``keras`` which will balance the data.""" -from __future__ import division + # This is a trick to avoid an error during tests collection with pytest. We # avoid the error when importing the package raise the error at the moment of # creating the instance. -try: - import keras - ParentClass = keras.utils.Sequence - HAS_KERAS = True -except ImportError: - ParentClass = object - HAS_KERAS = False +def import_keras(): + """Try to import keras from keras and tensorflow. + + This is possible to import the sequence from keras or tensorflow. + Keras is not ducktyping ``Sequence`` before 2.3 and we need import from + all possible library to ensure that the ``isinstance(...)`` is not going + to fail. This function can be modified when we support Keras 2.3. + """ + + def import_from_keras(): + try: + import keras + return (keras.utils.Sequence,), True + except ImportError: + return tuple(), False + + def import_from_tensforflow(): + try: + from tensorflow import keras + return (keras.utils.Sequence,), True + except ImportError: + return tuple(), False + + ParentClassKeras, has_keras_k = import_from_keras() + ParentClassTensorflow, has_keras_tf = import_from_tensforflow() + has_keras = has_keras_k or has_keras_tf + if has_keras: + ParentClass = (ParentClassKeras + ParentClassTensorflow) + else: + ParentClass = (object,) + return ParentClass, has_keras + + +ParentClass, HAS_KERAS = import_keras() from scipy.sparse import issparse @@ -29,7 +56,7 @@ 'NeighbourhoodCleaningRule', 'TomekLinks') -class BalancedBatchGenerator(ParentClass): +class BalancedBatchGenerator(*ParentClass): """Create balanced batches when training a keras model. Create a keras ``Sequence`` which is given to ``fit_generator``. The @@ -102,6 +129,10 @@ class BalancedBatchGenerator(ParentClass): ... epochs=10, verbose=0) """ + + # flag for keras sequence duck-typing + use_sequence_api = True + def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32, keep_sparse=False, random_state=None): if not HAS_KERAS: