Skip to content

Commit d591be1

Browse files
committed
pipeline
1 parent 81d322c commit d591be1

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

src/pipeline.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,66 @@ def __init__(self, **kwargs):
431431
)
432432

433433

434+
class TextGenModelWrapper:
435+
def __init__(self, model):
436+
self.model = model
437+
438+
def parameters(self):
439+
return self.model.parameters()
440+
441+
def __call__(
442+
self,
443+
input_ids,
444+
past_key_values,
445+
attention_mask,
446+
position_ids,
447+
return_dict,
448+
use_cache,
449+
):
450+
return self.model(input_ids, attention_mask, position_ids, past_key_values)
451+
452+
453+
class TG_Pipeline(Pipeline):
454+
def __init__(self, **kwargs):
455+
if self.device != torch.device("cuda"):
456+
raise ValueError(f"Textgen does not support device {self.device}")
457+
458+
super().__init__(**kwargs)
459+
460+
def _get_config(
461+
self,
462+
model_type: Optional[str],
463+
pretrained_config: Optional[str],
464+
config_args: Dict[str, Any],
465+
) -> Optional[PretrainedConfig]:
466+
return None
467+
468+
def _create_model(self) -> PreTrainedModel:
469+
raise NotImplementedError()
470+
471+
def _reload_model(self):
472+
raise NotImplementedError()
473+
474+
def _save_pretrained(self, pretrained_model: str):
475+
raise NotImplementedError()
476+
477+
def _load_pretrained(self, pretrained_model: str):
478+
from text_generation_server import get_model
479+
480+
pretrained_model, revision = parse_revision(pretrained_model)
481+
return TextGenModelWrapper(get_model(pretrained_model, revision, False, False))
482+
483+
def _generate_hf(self, inputs: Dict, max_new_tokens: int, use_cache: bool):
484+
raise NotImplementedError()
485+
486+
def _allocate_mock_cache(self, past_key_length: int, batch_size: int):
487+
raise NotImplementedError()
488+
489+
434490
_PIPELINE_CLASS_MAP = {
435491
"HF_Pipeline": HF_Pipeline,
436492
"DS_Pipeline": DS_Pipeline,
493+
"TG_Pipeline": TG_Pipeline,
437494
}
438495

439496

0 commit comments

Comments
 (0)