diff --git a/cleverhans/tf2/attacks/fast_gradient_method.py b/cleverhans/tf2/attacks/fast_gradient_method.py index ae3210f20..4f5279fb4 100644 --- a/cleverhans/tf2/attacks/fast_gradient_method.py +++ b/cleverhans/tf2/attacks/fast_gradient_method.py @@ -55,6 +55,9 @@ def fast_gradient_method( if clip_max is not None: asserts.append(tf.math.less_equal(x, clip_max)) + # cast to tensor if provided as numpy array + x = tf.cast(x, tf.float32) + if y is None: # Using model predictions as ground truth to avoid label leaking y = tf.argmax(model_fn(x), 1)