Skip to content

Commit

Permalink
BUG allow to import keras from tensorflow (scikit-learn-contrib#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jun 12, 2019
1 parent 85422e8 commit 2fa0596
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ Bug
:class:`imblearn.ensemble.RUSBoostClassifier` to get a decision stump as a
weak learner as in the original paper.
:pr:`545` by :user:`Christos Aridas <chkoar>`.

- Allow to import ``keras`` directly from ``tensorflow`` in the
:mod:`imblearn.keras`.
:pr:`531` by :user:`Guillaume Lemaitre <glemaitre>`.
49 changes: 40 additions & 9 deletions imblearn/keras/_generator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
"""Implement generators for ``keras`` which will balance the data."""
from __future__ import division


# This is a trick to avoid an error during tests collection with pytest. We
# avoid the error when importing the package raise the error at the moment of
# creating the instance.
try:
import keras
ParentClass = keras.utils.Sequence
HAS_KERAS = True
except ImportError:
ParentClass = object
HAS_KERAS = False
def import_keras():
"""Try to import keras from keras and tensorflow.
This is possible to import the sequence from keras or tensorflow.
Keras is not ducktyping ``Sequence`` before 2.3 and we need import from
all possible library to ensure that the ``isinstance(...)`` is not going
to fail. This function can be modified when we support Keras 2.3.
"""

def import_from_keras():
try:
import keras
return (keras.utils.Sequence,), True
except ImportError:
return tuple(), False

def import_from_tensforflow():
try:
from tensorflow import keras
return (keras.utils.Sequence,), True
except ImportError:
return tuple(), False

ParentClassKeras, has_keras_k = import_from_keras()
ParentClassTensorflow, has_keras_tf = import_from_tensforflow()
has_keras = has_keras_k or has_keras_tf
if has_keras:
ParentClass = (ParentClassKeras + ParentClassTensorflow)
else:
ParentClass = (object,)
return ParentClass, has_keras


ParentClass, HAS_KERAS = import_keras()

from scipy.sparse import issparse

Expand All @@ -29,7 +56,7 @@
'NeighbourhoodCleaningRule', 'TomekLinks')


class BalancedBatchGenerator(ParentClass):
class BalancedBatchGenerator(*ParentClass):
"""Create balanced batches when training a keras model.
Create a keras ``Sequence`` which is given to ``fit_generator``. The
Expand Down Expand Up @@ -102,6 +129,10 @@ class BalancedBatchGenerator(ParentClass):
... epochs=10, verbose=0)
"""

# flag for keras sequence duck-typing
use_sequence_api = True

def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32,
keep_sparse=False, random_state=None):
if not HAS_KERAS:
Expand Down

0 comments on commit 2fa0596

Please sign in to comment.