1818from unittest import mock
1919
2020from google .protobuf import text_format
21+ import numpy as np
2122import tensorflow as tf
2223
2324from tensorboard .plugins .hparams import _keras
2425from tensorboard .plugins .hparams import metadata
2526from tensorboard .plugins .hparams import plugin_data_pb2
2627from tensorboard .plugins .hparams import summary_v2 as hp
2728
28- # Stay on Keras 2 for now: https://github.com/keras-team/keras/issues/18467.
29- version_fn = getattr (tf .keras , "version" , None )
30- if version_fn and version_fn ().startswith ("3." ):
31- import tf_keras as keras # Keras 2
32- else :
33- keras = tf .keras # Keras 2
34-
35- tf .compat .v1 .enable_eager_execution ()
36-
37-
3829class CallbackTest (tf .test .TestCase ):
3930 def setUp (self ):
4031 super ().setUp ()
@@ -46,12 +37,12 @@ def _initialize_model(self, writer):
4637 "optimizer" : "adam" ,
4738 HP_DENSE_NEURONS : 8 ,
4839 }
49- self .model = keras .models .Sequential (
40+ self .model = tf . keras .models .Sequential (
5041 [
51- keras .layers .Dense (
42+ tf . keras .layers .Dense (
5243 self .hparams [HP_DENSE_NEURONS ], input_shape = (1 ,)
5344 ),
54- keras .layers .Dense (1 , activation = "sigmoid" ),
45+ tf . keras .layers .Dense (1 , activation = "sigmoid" ),
5546 ]
5647 )
5748 self .model .compile (loss = "mse" , optimizer = self .hparams ["optimizer" ])
@@ -69,7 +60,7 @@ def mock_time():
6960 initial_time = mock_time .time
7061 with mock .patch ("time.time" , mock_time ):
7162 self ._initialize_model (writer = self .logdir )
72- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
63+ self .model .fit (x = tf . constant ( [(1 ,)]) , y = tf . constant ( [(2 ,)]) , callbacks = [self .callback ])
7364 final_time = mock_time .time
7465
7566 files = os .listdir (self .logdir )
@@ -142,7 +133,7 @@ def test_explicit_writer(self):
142133 filename_suffix = ".magic" ,
143134 )
144135 self ._initialize_model (writer = writer )
145- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
136+ self .model .fit (x = tf . constant ( [(1 ,)]) , y = tf . constant ( [(2 ,)]) , callbacks = [self .callback ])
146137
147138 files = os .listdir (self .logdir )
148139 self .assertEqual (len (files ), 1 , files )
@@ -158,15 +149,15 @@ def test_non_eager_failure(self):
158149 with self .assertRaisesRegex (
159150 RuntimeError , "only supported in TensorFlow eager mode"
160151 ):
161- self .model .fit (x = [( 1 ,)] , y = [( 2 ,)] , callbacks = [self .callback ])
152+ self .model .fit (x = np . ones (( 10 , 10 )) , y = np . ones (( 10 , 10 )) , callbacks = [self .callback ])
162153
163154 def test_reuse_failure (self ):
164155 self ._initialize_model (writer = self .logdir )
165- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
156+ self .model .fit (x = tf . constant ( [(1 ,)]) , y = tf . constant ( [(2 ,)]) , callbacks = [self .callback ])
166157 with self .assertRaisesRegex (
167158 RuntimeError , "cannot be reused across training sessions"
168159 ):
169- self .model .fit (x = [(1 ,)], y = [(2 ,)], callbacks = [self .callback ])
160+ self .model .fit (x = tf . constant ( [(1 ,)]) , y = tf . constant ( [(2 ,)]) , callbacks = [self .callback ])
170161
171162 def test_invalid_writer (self ):
172163 with self .assertRaisesRegex (
0 commit comments