diff --git a/tensorflow_hub/keras_layer.py b/tensorflow_hub/keras_layer.py index 9ed6bb1ce..78bdc8d3c 100644 --- a/tensorflow_hub/keras_layer.py +++ b/tensorflow_hub/keras_layer.py @@ -116,12 +116,12 @@ def __init__(self, handle, trainable=False, arguments=None, **kwargs): if hasattr(self._func, "trainable_variables"): for v in self._func.trainable_variables: self._add_existing_weight(v, trainable=True) - trainable_variables = set(self._func.trainable_variables) + trainable_variables = {id(v) for v in self._func.trainable_variables} else: - trainable_variables = set() + trainable_variables = {} if hasattr(self._func, "variables"): for v in self._func.variables: - if v not in trainable_variables: + if id(v) not in trainable_variables: self._add_existing_weight(v, trainable=False) # Forward the callable's regularization losses (if any).