Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions py/torch_tensorrt/_TensorRTProxyModule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import importlib
import importlib.util
import importlib.metadata
import logging
import os
import platform
Expand Down Expand Up @@ -142,14 +143,29 @@ def alias_tensorrt() -> None:
if package_imported:
return

# in order not to break or change the existing behavior, we only build and run with tensorrt by default, tensorrt-rtx is for experiment only
# if we want to test with tensorrt-rtx, we have to build the wheel with --use-rtx and test with USE_TRT_RTX=true
# eg: USE_TRT_RTX=true python test.py
# in future, we can do dynamic linking either to tensorrt or tensorrt-rtx based on the gpu type
# Determine if this installation is the RTX variant based on the installed wheel name.
# This checks which distribution provides the `torch_tensorrt` package:
# - 'torch-tensorrt-rtx' => use tensorrt_rtx
# - 'torch-tensorrt' => use tensorrt
use_rtx = False
if (use_rtx_env_var := os.environ.get("USE_TRT_RTX")) is not None:
if use_rtx_env_var.lower() == "true":
try:
pkg_map = importlib.metadata.packages_distributions()
dist_names = pkg_map.get("torch_tensorrt", []) or []
normalized = {name.replace("_", "-").lower() for name in dist_names}
if "torch-tensorrt-rtx" in normalized:
use_rtx = True
except Exception:
# Best-effort fallback: prefer standard tensorrt unless only tensorrt_rtx is available
try:
importlib.import_module("tensorrt")
use_rtx = False
except Exception:
try:
importlib.import_module("tensorrt_rtx")
use_rtx = True
except Exception:
use_rtx = False

package_name = "tensorrt_rtx" if use_rtx else "tensorrt"

if not use_rtx:
Expand Down
36 changes: 35 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import sys
import warnings
import atexit
from dataclasses import dataclass
from datetime import datetime
from distutils.cmd import Command
Expand Down Expand Up @@ -138,6 +139,33 @@ def load_dep_info():
if use_rtx_env_var == "1" or use_rtx_env_var.lower() == "true":
USE_TRT_RTX = True

# Distribution name: keep import package as `torch_tensorrt`, but vary project
# name so wheels for RTX vs standard TensorRT are distinct.
PROJECT_NAME = "torch_tensorrt_rtx" if USE_TRT_RTX else "torch_tensorrt"
#
# Ensure METADATA Name matches RTX vs standard build even with static pyproject.
# We temporarily rewrite [project].name in pyproject.toml for the current build.
#
_pyproject_path = get_root_dir() / "pyproject.toml"
_pyproject_backup = None
if USE_TRT_RTX:
try:
_pyproject_backup = _pyproject_path.read_text(encoding="utf-8")
updated = []
for line in _pyproject_backup.splitlines(keepends=True):
if line.strip().startswith("name = ") and '"torch_tensorrt"' in line:
updated.append('name = "torch-tensorrt-rtx"\n')
else:
updated.append(line)
_new_content = "".join(updated)
if _new_content != _pyproject_backup:
_pyproject_path.write_text(_new_content, encoding="utf-8")
atexit.register(
lambda: _pyproject_path.write_text(_pyproject_backup, encoding="utf-8")
)
except Exception:
# Non-fatal; filename will still distinguish variants.
pass
if (release_env_var := os.environ.get("RELEASE")) is not None:
if release_env_var == "1":
RELEASE = True
Expand Down Expand Up @@ -340,6 +368,12 @@ def finalize_options(self):
self.root_is_pure = False

def run(self):
# Ensure wheel metadata/project name reflects RTX vs standard build,
# even when pyproject.toml provides static [project].name.
try:
self.distribution.metadata.name = PROJECT_NAME
except Exception:
pass
if not PY_ONLY:
build_libtorchtrt_cxx11_abi(develop=False, rt_only=NO_TS)
copy_libtorchtrt(rt_only=NO_TS)
Expand Down Expand Up @@ -806,7 +840,7 @@ def get_requirements():


setup(
name="torch_tensorrt",
name=PROJECT_NAME,
ext_modules=ext_modules,
version=__version__,
cmdclass={
Expand Down
Loading