Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ def _all_precisions_supported(enabled_precisions: Set[dtype]) -> bool:

def validate_compile_settings(self) -> None:
if ENABLED_FEATURES.tensorrt_rtx:
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
raise RuntimeError("TensorRT-RTX does not support bfloat16!")
# The below checks are not relevant for TensorRT-RTX
return

if (
Expand Down
4 changes: 0 additions & 4 deletions tests/py/dynamo/conversion/test_binary_ops_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ def forward(self, x, y):
if op[0].__name__ not in ["pow.Tensor_Tensor", "fmod.Tensor"]
]
)
@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"bf16 is not supported for tensorrt_rtx",
)
def test_elementwise_ops_bf16(self, _, orig_op):
class TestModule(nn.Module):
def __init__(self, orig_op):
Expand Down
4 changes: 0 additions & 4 deletions tests/py/dynamo/conversion/test_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ def forward(self, x):
precision=torch.float,
)

@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"bf16 is not supported for tensorrt_rtx",
)
def test_to_copy_bfloat16(self):
class ToCopyBFloat16(nn.Module):
def forward(self, x):
Expand Down
4 changes: 1 addition & 3 deletions tests/py/dynamo/llm/test_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
def test_llm_decoder_layer(precision):
from run_llm import compile_torchtrt
from torchtrt_ext import register_sdpa

if torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx and precision == "BF16":
pytest.skip("TensorRT-RTX does not support bfloat16, skipping test")

with torch.inference_mode():
args = argparse.Namespace()
args.debug = False
Expand Down
4 changes: 0 additions & 4 deletions tests/py/dynamo/models/test_dtype_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ def forward(self, x):
),
"Platform does not have BF16 support",
)
@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"bf16 is not supported for tensorrt_rtx",
)
class TestBF16Support(TestCase):
@unittest.skipIf(
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
Expand Down
3 changes: 0 additions & 3 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def test_resnet_dynamic(ir, dtype):
"""
Tests the Resnet18 model (which is fully convertible) with dynamic shapes
"""
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

import torchvision.models as models

model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype)
Expand Down
21 changes: 0 additions & 21 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ def test_resnet18_torch_exec_ops(ir):
"torchvision is not installed",
)
def test_mobilenet_v2(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype)
input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype)

Expand Down Expand Up @@ -237,9 +234,6 @@ def test_mobilenet_v2(ir, dtype):
"timm or torchvision not installed",
)
def test_efficientnet_b0(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

model = (
timm.create_model("efficientnet_b0", pretrained=True)
.eval()
Expand Down Expand Up @@ -284,9 +278,6 @@ def test_efficientnet_b0(ir, dtype):
"transformers is required to run this test",
)
def test_bert_base_uncased(ir, dtype):
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype)
Expand Down Expand Up @@ -425,10 +416,6 @@ def test_resnet18_half(ir):


@pytest.mark.unit
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"tensorrt_rtx does not support bfloat16",
)
def test_cosmos_true_div(ir):
class CosmosLearnablePositionalEmbed(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -527,10 +514,6 @@ def forward(


@pytest.mark.unit
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"bf16 is not supported for tensorrt_rtx",
)
@pytest.mark.critical
def test_bf16_model(ir):
class MyModule(torch.nn.Module):
Expand Down Expand Up @@ -576,10 +559,6 @@ def forward(self, x):


@pytest.mark.unit
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"bf16 is not supported for tensorrt_rtx",
)
@pytest.mark.critical
def test_bf16_fallback_model(ir):
class MyModule(torch.nn.Module):
Expand Down
6 changes: 0 additions & 6 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,6 @@ def test_base_int8(ir, dtype):
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
Expand Down Expand Up @@ -469,9 +466,6 @@ def test_base_int8_dynamic_shape(ir, dtype):
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
pytest.skip("TensorRT-RTX does not support bfloat16")

class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
Expand Down
Loading