Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 3e86d7f

Browse files
authored
Merge 967a5f7 into 98ca286
2 parents 98ca286 + 967a5f7 commit 3e86d7f

File tree

3 files changed

+133
-6
lines changed

3 files changed

+133
-6
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import pathlib
3+
4+
import pytest
5+
import requests
6+
import test_utils
7+
import torch
8+
9+
CURR_FILE_PATH = os.path.dirname(os.path.realpath(__file__))
10+
REPO_ROOT = os.path.normpath(os.path.join(CURR_FILE_PATH, "..", ".."))
11+
MODELSTORE_DIR = os.path.join(REPO_ROOT, "model_store")
12+
data_file_kitten = os.path.join(REPO_ROOT, "examples/image_classifier/kitten.jpg")
13+
HF_TRANSFORMERS_EXAMPLE_DIR = os.path.join(
14+
REPO_ROOT, "examples/Huggingface_Transformers/"
15+
)
16+
17+
18+
def test_no_model_loaded():
19+
"""
20+
Validates that TorchServe returns reponse code 404 if no model is loaded.
21+
"""
22+
23+
os.makedirs(MODELSTORE_DIR, exist_ok=True) # Create modelstore directory
24+
test_utils.start_torchserve(model_store=MODELSTORE_DIR)
25+
26+
response = requests.post(
27+
url="http://localhost:8080/models/alexnet/invoke",
28+
data=open(data_file_kitten, "rb"),
29+
)
30+
assert response.status_code == 404, "Model not loaded error expected"
31+
32+
33+
@pytest.mark.skipif(
34+
not ((torch.cuda.device_count() > 0) and torch.cuda.is_available()),
35+
reason="Test to be run on GPU only",
36+
)
37+
def test_oom_on_model_load():
38+
"""
39+
Validates that TorchServe returns reponse code 507 if there is OOM on model loading.
40+
"""
41+
42+
# Create model store directory
43+
pathlib.Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)
44+
45+
# Start TorchServe
46+
test_utils.start_torchserve(no_config_snapshots=True)
47+
48+
# Register model
49+
params = {
50+
"model_name": "BERTSeqClassification",
51+
"url": "https://torchserve.pytorch.org/mar_files/BERTSeqClassification.mar",
52+
"batch_size": 1,
53+
"initial_workers": 16,
54+
}
55+
response = test_utils.register_model_with_params(params)
56+
57+
assert response.status_code == 507, "OOM Error expected"
58+
59+
test_utils.stop_torchserve()
60+
61+
62+
@pytest.mark.skipif(
63+
not ((torch.cuda.device_count() > 0) and torch.cuda.is_available()),
64+
reason="Test to be run on GPU only",
65+
)
66+
def test_oom_on_invoke():
67+
# Create model store directory
68+
pathlib.Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)
69+
70+
# Start TorchServe
71+
test_utils.start_torchserve(no_config_snapshots=True)
72+
73+
# Register model
74+
params = {
75+
"model_name": "BERTSeqClassification",
76+
"url": "https://torchserve.pytorch.org/mar_files/BERTSeqClassification.mar",
77+
"batch_size": 8,
78+
"initial_workers": 12,
79+
}
80+
response = test_utils.register_model_with_params(params)
81+
82+
input_text = os.path.join(
83+
REPO_ROOT,
84+
"examples",
85+
"Huggingface_Transformers",
86+
"Seq_classification_artifacts",
87+
"sample_text_captum_input.txt",
88+
)
89+
90+
# Make 8 curl requests in parallel with &
91+
# Send multiple requests to make sure to hit OOM
92+
for i in range(10):
93+
response = os.popen(
94+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
95+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
96+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
97+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
98+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
99+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
100+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} && "
101+
f"curl http://127.0.0.1:8080/models/BERTSeqClassification/invoke -T {input_text} "
102+
)
103+
response = response.read()
104+
105+
# If OOM is hit, we expect code 507 to be present in the response string
106+
lines = response.split("\n")
107+
output = ""
108+
for line in lines:
109+
if "code" in line:
110+
line = line.strip()
111+
output = line
112+
break
113+
assert output == '"code": 507,', "OOM Error expected"

ts/model_service_worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ def load_model(self, load_model_request):
144144
return service, "loaded model {}".format(model_name), 200
145145
except MemoryError:
146146
return None, "System out of memory", 507
147+
except RuntimeError as ex: # pylint: disable=broad-except
148+
if "CUDA" in str(ex):
149+
# Handles Case A: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED (Close to OOM) &
150+
# Case B: CUDA out of memory (OOM)
151+
return None, "System out of memory", 507
152+
else:
153+
# Sanity testcases fail without this
154+
return None, "Unknown exception", 500
147155

148156
def handle_connection(self, cl_socket):
149157
"""

ts/service.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,21 @@ def predict(self, batch):
132132
# noinspection PyBroadException
133133
try:
134134
ret = self._entry_point(input_batch, self.context)
135-
except PredictionException as e:
136-
logger.error("Prediction error", exc_info=True)
137-
return create_predict_response(None, req_id_map, e.message, e.error_code)
138135
except MemoryError:
139136
logger.error("System out of memory", exc_info=True)
140137
return create_predict_response(None, req_id_map, "Out of resources", 507)
141-
except Exception: # pylint: disable=broad-except
142-
logger.warning("Invoking custom service failed.", exc_info=True)
143-
return create_predict_response(None, req_id_map, "Prediction failed", 503)
138+
except PredictionException as e:
139+
logger.error("Prediction error", exc_info=True)
140+
return create_predict_response(None, req_id_map, e.message, e.error_code)
141+
except Exception as ex: # pylint: disable=broad-except
142+
if "CUDA" in str(ex):
143+
# Handles Case A: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED (Close to OOM) &
144+
# Case B: CUDA out of memory (OOM)
145+
logger.error("CUDA out of memory", exc_info=True)
146+
return create_predict_response(None, req_id_map, "Out of resources", 507)
147+
else:
148+
logger.warning("Invoking custom service failed.", exc_info=True)
149+
return create_predict_response(None, req_id_map, "Prediction failed", 503)
144150

145151
if not isinstance(ret, list):
146152
logger.warning(

0 commit comments

Comments
 (0)