diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index ff2481c..66f1ad2 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -122,7 +122,7 @@ def augmentation( ): # assert that plugin requires mixed precision to be set assert ( - train_args.bf16 or train_args.bf16 + train_args.bf16 is True or train_args.fp16 is True ), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`" # This is designed to be a passthrough if training scenario is