Skip to content

Commit 8384ad4

Browse files
authored
Add --uv flag (#27)
* use uv * add --uv flag * undo readme changes for PR * add test * fix readme again
1 parent 4fcc2fd commit 8384ad4

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

tests/test_installer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ def test_get_pip_commands_valid():
9292
assert result == expected
9393

9494

95+
def test_get_pip_commands_with_uv():
96+
cmds = [["package1"], ["package2", "--upgrade"]]
97+
expected = [
98+
["uv", "pip", "install", "package1"],
99+
["uv", "pip", "install", "package2", "--upgrade"],
100+
]
101+
102+
result = get_pip_commands(cmds, use_uv=True)
103+
assert result == expected
104+
105+
95106
def test_get_pip_commands_none_input():
96107
cmds = [["package1"], None]
97108
with pytest.raises(AssertionError):

torchruntime/__main__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ def print_usage(entry_command: str):
1515
1616
Examples:
1717
{entry_command} install
18+
{entry_command} install --uv
1819
{entry_command} install torch==2.2.0 torchvision==0.17.0
19-
{entry_command} install torch>=2.0.0 torchaudio
20+
{entry_command} install --uv torch>=2.0.0 torchaudio
2021
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
2122
2223
{entry_command} test # Runs all tests (import, devices, math, functions)
@@ -31,6 +32,9 @@ def print_usage(entry_command: str):
3132
If no packages are specified, the latest available versions
3233
of torch, torchaudio and torchvision will be installed.
3334
35+
Options:
36+
--uv Use uv instead of pip for installation
37+
3438
Version specification formats (follows pip format):
3539
package==2.1.0 Exact version
3640
package>=2.0.0 Minimum version
@@ -56,16 +60,21 @@ def main():
5660
command = sys.argv[1]
5761

5862
if command == "install":
59-
package_versions = sys.argv[2:] if len(sys.argv) > 2 else None
60-
install(package_versions)
63+
args = sys.argv[2:] if len(sys.argv) > 2 else []
64+
use_uv = "--uv" in args
65+
# Remove --uv from args to get package list
66+
package_versions = [arg for arg in args if arg != "--uv"] if args else None
67+
install(package_versions, use_uv=use_uv)
6168
elif command == "test":
6269
subcommand = sys.argv[2] if len(sys.argv) > 2 else "all"
6370
test(subcommand)
6471
elif command == "info":
6572
info()
6673
else:
6774
print(f"Unknown command: {command}")
68-
print_usage()
75+
entry_path = sys.argv[0]
76+
cli = "python -m torchruntime" if "__main__.py" in entry_path else "torchruntime"
77+
print_usage(cli)
6978

7079

7180
if __name__ == "__main__":

torchruntime/installer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,13 @@ def get_install_commands(torch_platform, packages):
7676
raise ValueError(f"Unsupported platform: {torch_platform}")
7777

7878

79-
def get_pip_commands(cmds):
79+
def get_pip_commands(cmds, use_uv=False):
8080
assert not any(cmd is None for cmd in cmds)
81-
return [PIP_PREFIX + cmd for cmd in cmds]
81+
if use_uv:
82+
pip_prefix = ["uv", "pip", "install"]
83+
else:
84+
pip_prefix = [sys.executable, "-m", "pip", "install"]
85+
return [pip_prefix + cmd for cmd in cmds]
8286

8387

8488
def run_commands(cmds):
@@ -87,13 +91,14 @@ def run_commands(cmds):
8791
subprocess.run(cmd)
8892

8993

90-
def install(packages=[]):
94+
def install(packages=[], use_uv=False):
9195
"""
9296
packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"].
97+
use_uv: bool, whether to use uv for installation. Defaults to False.
9398
"""
9499

95100
gpu_infos = get_gpus()
96101
torch_platform = get_torch_platform(gpu_infos)
97102
cmds = get_install_commands(torch_platform, packages)
98-
cmds = get_pip_commands(cmds)
103+
cmds = get_pip_commands(cmds, use_uv=use_uv)
99104
run_commands(cmds)

0 commit comments

Comments
 (0)