diff --git a/stepcovnet/tf_config.py b/stepcovnet/tf_config.py index cc56e99..e5c7f12 100644 --- a/stepcovnet/tf_config.py +++ b/stepcovnet/tf_config.py @@ -1,9 +1,16 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import logging + import tensorflow as tf from keras import mixed_precision -# tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) +try: + tf.config.experimental.set_memory_growth( + tf.config.list_physical_devices("GPU")[0], enable=True + ) +except Exception: + logging.warning("Failed to set memory growth for GPU.", exc_info=True) MIXED_PRECISION_POLICY = mixed_precision.Policy("mixed_float16") mixed_precision.set_global_policy(MIXED_PRECISION_POLICY)