Skip to content

Commit 966e316

Browse files
authored
[Bugfix] Fix pickle of input when async output processing is on (#9931)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
1 parent 43300bd commit 966e316

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,29 @@ def test_model_with_failure(vllm_runner) -> None:
156156
ModelInputForGPUWithSamplingMetadata)
157157
finally:
158158
os.remove(filename)
159+
160+
161+
def test_failure_with_async_out_proc(vllm_runner) -> None:
162+
163+
filename = None
164+
try:
165+
with vllm_runner("facebook/opt-125m",
166+
dtype="half",
167+
enforce_eager=False,
168+
gpu_memory_utilization=0.7) as vllm_model,\
169+
patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
170+
side_effect=ValueError()):
171+
model_config = vllm_model.model.llm_engine.model_config
172+
assert model_config.use_async_output_proc
173+
with pytest.raises(ValueError) as exc_info:
174+
vllm_model.generate_greedy('how to make pizza?', 250)
175+
matches = re.search(r"input dumped to (.+).pkl",
176+
str(exc_info.value))
177+
assert matches is not None
178+
179+
filename = f"{matches.group(1)}.pkl"
180+
finally:
181+
# Clean up
182+
if filename is not None:
183+
os.remove(filename)
184+
pass

vllm/worker/model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ def from_broadcasted_tensor_dict(
136136
attn_backend, tensor_dict)
137137
return cls(**tensor_dict)
138138

139+
# Exclude `async_callback` to be able to pickle this object
140+
def __getstate__(self):
141+
state = self.__dict__.copy()
142+
del state["async_callback"]
143+
return state
144+
145+
# TODO: What happens when we depickle this object?
146+
# How can we update this callback to properly pass it to the engine?
147+
def __setstate__(self, state):
148+
self.__dict__.update(state)
149+
self.__dict__.update({'async_callback': None})
150+
139151

140152
@dataclass(frozen=True)
141153
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):

0 commit comments

Comments
 (0)