@@ -431,9 +431,66 @@ def __init__(self, **kwargs):
431
431
)
432
432
433
433
434
+ class TextGenModelWrapper :
435
+ def __init__ (self , model ):
436
+ self .model = model
437
+
438
+ def parameters (self ):
439
+ return self .model .parameters ()
440
+
441
+ def __call__ (
442
+ self ,
443
+ input_ids ,
444
+ past_key_values ,
445
+ attention_mask ,
446
+ position_ids ,
447
+ return_dict ,
448
+ use_cache ,
449
+ ):
450
+ return self .model (input_ids , attention_mask , position_ids , past_key_values )
451
+
452
+
453
+ class TG_Pipeline (Pipeline ):
454
+ def __init__ (self , ** kwargs ):
455
+ if self .device != torch .device ("cuda" ):
456
+ raise ValueError (f"Textgen does not support device { self .device } " )
457
+
458
+ super ().__init__ (** kwargs )
459
+
460
+ def _get_config (
461
+ self ,
462
+ model_type : Optional [str ],
463
+ pretrained_config : Optional [str ],
464
+ config_args : Dict [str , Any ],
465
+ ) -> Optional [PretrainedConfig ]:
466
+ return None
467
+
468
+ def _create_model (self ) -> PreTrainedModel :
469
+ raise NotImplementedError ()
470
+
471
+ def _reload_model (self ):
472
+ raise NotImplementedError ()
473
+
474
+ def _save_pretrained (self , pretrained_model : str ):
475
+ raise NotImplementedError ()
476
+
477
+ def _load_pretrained (self , pretrained_model : str ):
478
+ from text_generation_server import get_model
479
+
480
+ pretrained_model , revision = parse_revision (pretrained_model )
481
+ return TextGenModelWrapper (get_model (pretrained_model , revision , False , False ))
482
+
483
+ def _generate_hf (self , inputs : Dict , max_new_tokens : int , use_cache : bool ):
484
+ raise NotImplementedError ()
485
+
486
+ def _allocate_mock_cache (self , past_key_length : int , batch_size : int ):
487
+ raise NotImplementedError ()
488
+
489
+
434
490
_PIPELINE_CLASS_MAP = {
435
491
"HF_Pipeline" : HF_Pipeline ,
436
492
"DS_Pipeline" : DS_Pipeline ,
493
+ "TG_Pipeline" : TG_Pipeline ,
437
494
}
438
495
439
496
0 commit comments