Skip to content

Commit

Permalink
[inference] check for unsupported model generate args (#2627)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Dec 19, 2022
1 parent df985fa commit 5676f5e
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self, model, config):

self._get_model_config_generate(config) # keep for weird backward compatibility

# patch model generate with ours if model uses it
if hasattr(self.module, "generate"):
self.generate = self._generate

if hasattr(self.module, "config"):
DSPolicy.hf_model_config = self.module.config

Expand Down Expand Up @@ -148,8 +152,6 @@ def _get_model_config_generate(self, config):
self.config = getattr(self.module,
'config',
None) if config.config is None else config.config
# todo: clarify with Reza if this gets used anywhere
self.generate = getattr(self.module, 'generate', None)

def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'):
Expand Down Expand Up @@ -518,3 +520,19 @@ def forward(self, *inputs, **kwargs):
self._model_times.append(duration)

return outputs

def _generate(self, *inputs, **kwargs):
num_beams = 1
if "generation_config" in kwargs:
gen_config = kwargs["generation_config"]
num_beams = getattr(gen_config, "num_beams", 1)
if "num_beams" in kwargs:
num_beams = kwargs["num_beams"]

if num_beams > 1:
raise NotImplementedError(
"DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
)

return self.module.generate(*inputs, **kwargs)

0 comments on commit 5676f5e

Please sign in to comment.