@@ -409,6 +409,60 @@ def register_litellm_models(git_root, model_metadata_fname, io, verbose=False):
409
409
return 1
410
410
411
411
412
+ def discover_litellm_models (io , verbose = False ):
413
+ litellm_api_base = os .environ .get ("LITELLM_API_BASE" )
414
+ if not litellm_api_base :
415
+ return
416
+
417
+ try :
418
+ import requests
419
+
420
+ headers = {}
421
+ api_key = os .environ .get ("LITELLM_API_KEY" )
422
+ if api_key :
423
+ headers ["Authorization" ] = f"Bearer { api_key } "
424
+
425
+ # First, get the models and their owners
426
+ url = litellm_api_base .rstrip ("/" ) + "/models"
427
+
428
+ response = requests .get (
429
+ url , headers = headers , timeout = 5 , verify = models .model_info_manager .verify_ssl
430
+ )
431
+ if response .status_code != 200 :
432
+ io .tool_warning (f"Error fetching models from { url } : { response .status_code } " )
433
+ return
434
+
435
+ models_data = response .json ()
436
+ model_owners = {
437
+ model_info .get ("id" ): model_info .get ("owned_by" )
438
+ for model_info in models_data .get ("data" , [])
439
+ }
440
+
441
+ # Now, get the model group info
442
+ url = litellm_api_base .rstrip ("/" ) + "/model_group/info"
443
+ response = requests .get (
444
+ url , headers = headers , timeout = 5 , verify = models .model_info_manager .verify_ssl
445
+ )
446
+ if response .status_code == 200 :
447
+ model_group_data = response .json ()
448
+ for model_info in model_group_data .get ("data" , []):
449
+ model_group = model_info .get ("model_group" )
450
+ if model_group :
451
+ models .model_info_manager .local_model_metadata [f"litellm/{ model_group } " ] = {
452
+ "litellm_provider" : "litellm" ,
453
+ "mode" : "chat" ,
454
+ "owned_by" : model_owners .get (model_group ),
455
+ "input_cost_per_token" : model_info .get ("input_cost_per_token" ),
456
+ "output_cost_per_token" : model_info .get ("output_cost_per_token" ),
457
+ "max_input_tokens" : model_info .get ("max_input_tokens" ),
458
+ "max_output_tokens" : model_info .get ("max_output_tokens" ),
459
+ }
460
+ if verbose :
461
+ io .tool_output (f"Discovered model info from { url } " )
462
+ except Exception as e :
463
+ io .tool_warning (f"Error fetching model info from litellm: { e } " )
464
+
465
+
412
466
def sanity_check_repo (repo , io ):
413
467
if not repo :
414
468
return True
@@ -619,6 +673,10 @@ def get_io(pretty):
619
673
handle_deprecated_model_args (args , io )
620
674
if args .openai_api_base :
621
675
os .environ ["OPENAI_API_BASE" ] = args .openai_api_base
676
+ if args .litellm_api_base :
677
+ os .environ ["LITELLM_API_BASE" ] = args .litellm_api_base
678
+ if args .litellm_api_key :
679
+ os .environ ["LITELLM_API_KEY" ] = args .litellm_api_key
622
680
if args .openai_api_version :
623
681
io .tool_warning (
624
682
"--openai-api-version is deprecated, use --set-env OPENAI_API_VERSION=<value>"
@@ -755,6 +813,7 @@ def get_io(pretty):
755
813
756
814
register_models (git_root , args .model_settings_file , io , verbose = args .verbose )
757
815
register_litellm_models (git_root , args .model_metadata_file , io , verbose = args .verbose )
816
+ discover_litellm_models (io , verbose = args .verbose )
758
817
759
818
if args .list_models :
760
819
models .print_matching_models (io , args .list_models )
0 commit comments