Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 4922e66

Browse files
authored
Merge b3b48dd into 2fa042b
2 parents 2fa042b + b3b48dd commit 4922e66

File tree

12 files changed

+237
-119
lines changed

12 files changed

+237
-119
lines changed

examples/pt2/README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
## PyTorch 2.x integration
22

3-
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the official release and while we are relying on the nightly builds.
3+
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental given that most public benchmarks have focused on training instead of inference.
44

55
We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture.
66

77
## Get started
88

9-
Install torchserve with nightly torch binaries
9+
Install torchserve and ensure that you're using at least `torch>=2.0.0`
1010

11-
```
12-
python ts_scripts/install_dependencies.py --cuda=cu117 --nightly_torch
11+
```sh
12+
python ts_scripts/install_dependencies.py --cuda=cu117
1313
pip install torchserve torch-model-archiver
1414
```
1515

1616
## Package your model
1717

18-
PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `compile.json` during your model packaging
18+
PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging
1919

20-
`{"pt2" : "inductor"}`
20+
`pt2: "inductor"`
2121

22-
As an example let's expand our getting started guide with the only difference being passing in the extra `compile.json` file
22+
As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file
2323

2424
```
2525
mkdir model_store
26-
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json,./serve/examples/image_classifier/compile.json --handler image_classifier
27-
torchserve --start --ncs --model-store model_store --models densenet161.mar
26+
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier
27+
torchserve --start --ncs --model-store model_store --models densenet161.mar --config-file model_config.yaml
2828
```
2929

3030
The exact same approach works with any other model, what's going on is the below
@@ -35,7 +35,7 @@ opt_mod = torch.compile(mod)
3535
# 2. Train the optimized module
3636
# ....
3737
# 3. Save the original module (weights are shared)
38-
torch.save(model, "model.pt")
38+
torch.save(model, "model.pt")
3939

4040
# 4. Load the non optimized model
4141
mod = torch.load(model)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
3+
from ts.torch_handler.base_handler import BaseHandler
4+
5+
6+
class CompileHandler(BaseHandler):
7+
def __init__(self):
8+
super().__init__()
9+
10+
def initialize(self, context):
11+
super().initialize(context)
12+
13+
def preprocess(self, data):
14+
instances = data[0]["body"]["instances"]
15+
input_tensor = torch.as_tensor(instances, dtype=torch.float32)
16+
return input_tensor
17+
18+
def postprocess(self, data):
19+
# Convert the output tensor to a list and return
20+
return data.tolist()[2]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pt2 : "inductor"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pt2 : "torchxla_trace_once"

test/pytest/test_data/torch_xla/compile.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

test/pytest/test_torch_compile.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import glob
2+
import json
3+
import os
4+
import subprocess
5+
import time
6+
from pathlib import Path
7+
8+
import pytest
9+
import torch
10+
from pkg_resources import packaging
11+
12+
PT_2_AVAILABLE = (
13+
True
14+
if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0")
15+
else False
16+
)
17+
18+
CURR_FILE_PATH = Path(__file__).parent
19+
TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile")
20+
21+
MODEL_FILE = os.path.join(TEST_DATA_DIR, "model.py")
22+
HANDLER_FILE = os.path.join(TEST_DATA_DIR, "compile_handler.py")
23+
YAML_CONFIG = os.path.join(TEST_DATA_DIR, "pt2.yaml")
24+
25+
26+
SERIALIZED_FILE = os.path.join(TEST_DATA_DIR, "model.pt")
27+
MODEL_STORE_DIR = os.path.join(TEST_DATA_DIR, "model_store")
28+
MODEL_NAME = "half_plus_two"
29+
30+
31+
@pytest.mark.skipif(PT_2_AVAILABLE == False, reason="torch version is < 2.0.0")
32+
class TestTorchCompile:
33+
def teardown_class(self):
34+
subprocess.run("torchserve --stop", shell=True, check=True)
35+
time.sleep(10)
36+
37+
def test_archive_model_artifacts(self):
38+
assert len(glob.glob(MODEL_FILE)) == 1
39+
assert len(glob.glob(YAML_CONFIG)) == 1
40+
subprocess.run(f"cd {TEST_DATA_DIR} && python model.py", shell=True, check=True)
41+
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
42+
subprocess.run(
43+
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
44+
shell=True,
45+
check=True,
46+
)
47+
assert len(glob.glob(SERIALIZED_FILE)) == 1
48+
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1
49+
50+
def test_start_torchserve(self):
51+
cmd = f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR}"
52+
subprocess.run(
53+
cmd,
54+
shell=True,
55+
check=True,
56+
)
57+
time.sleep(10)
58+
assert len(glob.glob("logs/access_log.log")) == 1
59+
assert len(glob.glob("logs/model_log.log")) == 1
60+
assert len(glob.glob("logs/ts_log.log")) == 1
61+
62+
def test_server_status(self):
63+
result = subprocess.run(
64+
"curl http://localhost:8080/ping",
65+
shell=True,
66+
capture_output=True,
67+
check=True,
68+
)
69+
expected_server_status_str = '{"status": "Healthy"}'
70+
expected_server_status = json.loads(expected_server_status_str)
71+
assert json.loads(result.stdout) == expected_server_status
72+
73+
def test_registered_model(self):
74+
result = subprocess.run(
75+
"curl http://localhost:8081/models",
76+
shell=True,
77+
capture_output=True,
78+
check=True,
79+
)
80+
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
81+
expected_registered_model = json.loads(expected_registered_model_str)
82+
assert json.loads(result.stdout) == expected_registered_model
83+
84+
def test_serve_inference(self):
85+
request_data = {"instances": [[1.0], [2.0], [3.0]]}
86+
request_json = json.dumps(request_data)
87+
88+
result = subprocess.run(
89+
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/half_plus_two -d '{request_json}'",
90+
shell=True,
91+
capture_output=True,
92+
check=True,
93+
)
94+
95+
string_result = result.stdout.decode("utf-8")
96+
float_result = float(string_result)
97+
expected_result = 3.5
98+
99+
assert float_result == expected_result
100+
101+
model_log_path = glob.glob("logs/model_log.log")[0]
102+
with open(model_log_path, "rt") as model_log_file:
103+
model_log = model_log_file.read()
104+
assert "Compiled model with backend inductor" in model_log

test/pytest/test_torch_xla.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
TORCHXLA_AVAILABLE = False
2222

2323
CURR_FILE_PATH = Path(__file__).parent
24-
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data")
24+
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile")
2525

2626
MODEL_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.py")
27-
EXTRA_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "compile.json")
27+
YAML_CONFIG = os.path.join(TORCH_XLA_TEST_DATA_DIR, "xla.yaml")
2828
CONFIG_PROPERTIES = os.path.join(TORCH_XLA_TEST_DATA_DIR, "config.properties")
2929

3030
SERIALIZED_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.pt")
@@ -40,14 +40,14 @@ def teardown_class(self):
4040

4141
def test_archive_model_artifacts(self):
4242
assert len(glob.glob(MODEL_FILE)) == 1
43-
assert len(glob.glob(EXTRA_FILE)) == 1
43+
assert len(glob.glob(YAML_CONFIG)) == 1
4444
assert len(glob.glob(CONFIG_PROPERTIES)) == 1
4545
subprocess.run(
4646
f"cd {TORCH_XLA_TEST_DATA_DIR} && python model.py", shell=True, check=True
4747
)
4848
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
4949
subprocess.run(
50-
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --extra-files {EXTRA_FILE} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
50+
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
5151
shell=True,
5252
check=True,
5353
)

test/pytest/test_utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,31 +118,41 @@ def model_archiver_command_builder(
118118
handler=None,
119119
extra_files=None,
120120
force=False,
121+
config_file=None,
121122
):
122-
cmd = "torch-model-archiver"
123+
# Initialize a list to store the command-line arguments
124+
cmd_parts = ["torch-model-archiver"]
123125

126+
# Append arguments to the list
124127
if model_name:
125-
cmd += " --model-name {0}".format(model_name)
128+
cmd_parts.append(f"--model-name {model_name}")
126129

127130
if version:
128-
cmd += " --version {0}".format(version)
131+
cmd_parts.append(f"--version {version}")
129132

130133
if model_file:
131-
cmd += " --model-file {0}".format(model_file)
134+
cmd_parts.append(f"--model-file {model_file}")
132135

133136
if serialized_file:
134-
cmd += " --serialized-file {0}".format(serialized_file)
137+
cmd_parts.append(f"--serialized-file {serialized_file}")
135138

136139
if handler:
137-
cmd += " --handler {0}".format(handler)
140+
cmd_parts.append(f"--handler {handler}")
138141

139142
if extra_files:
140-
cmd += " --extra-files {0}".format(extra_files)
143+
cmd_parts.append(f"--extra-files {extra_files}")
144+
145+
if config_file:
146+
cmd_parts.append(f"--config-file {config_file}")
141147

142148
if force:
143-
cmd += " --force"
149+
cmd_parts.append("--force")
150+
151+
# Append the export-path argument to the list
152+
cmd_parts.append(f"--export-path {MODEL_STORE}")
144153

145-
cmd += " --export-path {0}".format(MODEL_STORE)
154+
# Convert the list into a string to represent the complete command
155+
cmd = " ".join(cmd_parts)
146156

147157
return cmd
148158

0 commit comments

Comments
 (0)