Skip to content

Commit 2bd81bf

Browse files
author
drajnic
committed
feat: Add support for litellm as a provider
- Allow users to configure a litellm API base and key to use a litellm instance as a model provider. - Implement model discovery from the /models and /model_group/info endpoints of the litellm instance. - Use the discovered model information, including pricing, to enable cost calculation for litellm models. - Improve the cost display to show token usage even when the cost is zero.
1 parent e4fc2f5 commit 2bd81bf

File tree

8 files changed

+236
-12
lines changed

8 files changed

+236
-12
lines changed

aider/args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def get_parser(default_config_files, git_root):
7777
"--openai-api-base",
7878
help="Specify the api base url",
7979
)
80+
group.add_argument(
81+
"--litellm-api-base",
82+
help="Specify the litellm api base url",
83+
)
84+
group.add_argument(
85+
"--litellm-api-key",
86+
help="Specify the litellm api key",
87+
)
8088
group.add_argument(
8189
"--openai-api-type",
8290
help="(deprecated, use --set-env OPENAI_API_TYPE=<value>)",

aider/coders/base_coder.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,17 +2055,20 @@ def format_cost(value):
20552055
else:
20562056
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
20572057

2058-
cost_report = (
2059-
f"Cost: ${format_cost(self.message_cost)} message,"
2060-
f" ${format_cost(self.total_cost)} session."
2061-
)
2058+
if cost:
2059+
cost_report = (
2060+
f"Cost: ${format_cost(self.message_cost)} message,"
2061+
f" ${format_cost(self.total_cost)} session."
2062+
)
20622063

2063-
if cache_hit_tokens and cache_write_tokens:
2064-
sep = "\n"
2065-
else:
2066-
sep = " "
2064+
if cache_hit_tokens and cache_write_tokens:
2065+
sep = "\n"
2066+
else:
2067+
sep = " "
20672068

2068-
self.usage_report = tokens_report + sep + cost_report
2069+
self.usage_report = tokens_report + sep + cost_report
2070+
else:
2071+
self.usage_report = tokens_report
20692072

20702073
def compute_costs_from_tokens(
20712074
self, prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens

aider/main.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,60 @@ def register_litellm_models(git_root, model_metadata_fname, io, verbose=False):
409409
return 1
410410

411411

412+
def discover_litellm_models(io, verbose=False):
413+
litellm_api_base = os.environ.get("LITELLM_API_BASE")
414+
if not litellm_api_base:
415+
return
416+
417+
try:
418+
import requests
419+
420+
headers = {}
421+
api_key = os.environ.get("LITELLM_API_KEY")
422+
if api_key:
423+
headers["Authorization"] = f"Bearer {api_key}"
424+
425+
# First, get the models and their owners
426+
url = litellm_api_base.rstrip("/") + "/models"
427+
428+
response = requests.get(
429+
url, headers=headers, timeout=5, verify=models.model_info_manager.verify_ssl
430+
)
431+
if response.status_code != 200:
432+
io.tool_warning(f"Error fetching models from {url}: {response.status_code}")
433+
return
434+
435+
models_data = response.json()
436+
model_owners = {
437+
model_info.get("id"): model_info.get("owned_by")
438+
for model_info in models_data.get("data", [])
439+
}
440+
441+
# Now, get the model group info
442+
url = litellm_api_base.rstrip("/") + "/model_group/info"
443+
response = requests.get(
444+
url, headers=headers, timeout=5, verify=models.model_info_manager.verify_ssl
445+
)
446+
if response.status_code == 200:
447+
model_group_data = response.json()
448+
for model_info in model_group_data.get("data", []):
449+
model_group = model_info.get("model_group")
450+
if model_group:
451+
models.model_info_manager.local_model_metadata[f"litellm/{model_group}"] = {
452+
"litellm_provider": "litellm",
453+
"mode": "chat",
454+
"owned_by": model_owners.get(model_group),
455+
"input_cost_per_token": model_info.get("input_cost_per_token"),
456+
"output_cost_per_token": model_info.get("output_cost_per_token"),
457+
"max_input_tokens": model_info.get("max_input_tokens"),
458+
"max_output_tokens": model_info.get("max_output_tokens"),
459+
}
460+
if verbose:
461+
io.tool_output(f"Discovered model info from {url}")
462+
except Exception as e:
463+
io.tool_warning(f"Error fetching model info from litellm: {e}")
464+
465+
412466
def sanity_check_repo(repo, io):
413467
if not repo:
414468
return True
@@ -619,6 +673,10 @@ def get_io(pretty):
619673
handle_deprecated_model_args(args, io)
620674
if args.openai_api_base:
621675
os.environ["OPENAI_API_BASE"] = args.openai_api_base
676+
if args.litellm_api_base:
677+
os.environ["LITELLM_API_BASE"] = args.litellm_api_base
678+
if args.litellm_api_key:
679+
os.environ["LITELLM_API_KEY"] = args.litellm_api_key
622680
if args.openai_api_version:
623681
io.tool_warning(
624682
"--openai-api-version is deprecated, use --set-env OPENAI_API_VERSION=<value>"
@@ -755,6 +813,7 @@ def get_io(pretty):
755813

756814
register_models(git_root, args.model_settings_file, io, verbose=args.verbose)
757815
register_litellm_models(git_root, args.model_metadata_file, io, verbose=args.verbose)
816+
discover_litellm_models(io, verbose=args.verbose)
758817

759818
if args.list_models:
760819
models.print_matching_models(io, args.list_models)

aider/models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,8 +953,17 @@ def send_completion(self, messages, functions, stream, temperature=None):
953953
if self.is_deepseek_r1():
954954
messages = ensure_alternating_roles(messages)
955955

956+
model_name = self.name
957+
if self.info.get("litellm_provider") == "litellm":
958+
owned_by = self.info.get("owned_by")
959+
model_id = model_name[len("litellm/") :]
960+
if owned_by:
961+
model_name = f"{owned_by}/{model_id}"
962+
else:
963+
model_name = model_id
964+
956965
kwargs = dict(
957-
model=self.name,
966+
model=model_name,
958967
stream=stream,
959968
)
960969

@@ -997,6 +1006,14 @@ def send_completion(self, messages, functions, stream, temperature=None):
9971006

9981007
self.github_copilot_token_to_open_ai_key(kwargs["extra_headers"])
9991008

1009+
if self.info.get("litellm_provider") == "litellm":
1010+
litellm_api_base = os.environ.get("LITELLM_API_BASE")
1011+
if litellm_api_base:
1012+
kwargs["api_base"] = litellm_api_base
1013+
litellm_api_key = os.environ.get("LITELLM_API_KEY")
1014+
if litellm_api_key:
1015+
kwargs["api_key"] = litellm_api_key
1016+
10001017
res = litellm.completion(**kwargs)
10011018
return hash_object, res
10021019

aider/onboarding.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import requests
1313

14-
from aider import urls
14+
from aider import models, urls
1515
from aider.io import InputOutput
1616

1717

@@ -73,6 +73,11 @@ def try_to_select_default_model():
7373
if api_key_value:
7474
return model_name
7575

76+
if os.environ.get("LITELLM_API_BASE") and os.environ.get("LITELLM_API_KEY"):
77+
for model_name, model_info in models.model_info_manager.local_model_metadata.items():
78+
if model_info.get("litellm_provider") == "litellm":
79+
return model_name
80+
7681
return None
7782

7883

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
name = "aider-chat"
44
description = "Aider is AI pair programming in your terminal"
55
readme = "README.md"
6+
version = "1.0.0-dev"
67
classifiers = [
78
"Development Status :: 4 - Beta",
89
"Environment :: Console",

tests/basic/test_main.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,25 @@
1414
from aider.coders import Coder
1515
from aider.dump import dump # noqa: F401
1616
from aider.io import InputOutput
17-
from aider.main import check_gitignore, load_dotenv_files, main, setup_git
17+
from aider.main import (
18+
check_gitignore,
19+
discover_litellm_models,
20+
load_dotenv_files,
21+
main,
22+
setup_git,
23+
)
1824
from aider.utils import GitTemporaryDirectory, IgnorantTemporaryDirectory, make_repo
1925

2026

27+
class DummyResponse:
28+
def __init__(self, json_data, status_code=200):
29+
self.json_data = json_data
30+
self.status_code = status_code
31+
32+
def json(self):
33+
return self.json_data
34+
35+
2136
class TestMain(TestCase):
2237
def setUp(self):
2338
self.original_env = os.environ.copy()
@@ -45,6 +60,59 @@ def tearDown(self):
4560
self.input_patcher.stop()
4661
self.webbrowser_patcher.stop()
4762

63+
def test_litellm_discover_models(self):
64+
"""
65+
discover_litellm_models should return correct metadata taken from the
66+
downloaded (and locally cached) models JSON payload.
67+
"""
68+
models_payload = {
69+
"data": [
70+
{
71+
"id": "bedrock-claude-opus-4.1",
72+
"object": "model",
73+
"created": 1677610602,
74+
"owned_by": "openai",
75+
}
76+
]
77+
}
78+
79+
model_group_payload = {
80+
"data": [
81+
{
82+
"model_group": "bedrock-claude-opus-4.1",
83+
"providers": ["bedrock"],
84+
"max_input_tokens": 200000,
85+
"max_output_tokens": 32000,
86+
"input_cost_per_token": 0.000015,
87+
"output_cost_per_token": 0.000075,
88+
}
89+
]
90+
}
91+
92+
def mock_get(url, **kwargs):
93+
if "/models" in url:
94+
return DummyResponse(models_payload)
95+
elif "/model_group/info" in url:
96+
return DummyResponse(model_group_payload)
97+
return DummyResponse({}, 404)
98+
99+
with patch("requests.get", mock_get):
100+
os.environ["LITELLM_API_BASE"] = "http://localhost:4000"
101+
os.environ["LITELLM_API_KEY"] = "test-key"
102+
103+
io = MagicMock()
104+
discover_litellm_models(io)
105+
106+
from aider import models
107+
108+
info = models.model_info_manager.get_model_info("litellm/bedrock-claude-opus-4.1")
109+
110+
assert info["max_input_tokens"] == 200000
111+
assert info["input_cost_per_token"] == 0.000015
112+
assert info["output_cost_per_token"] == 0.000075
113+
assert info["litellm_provider"] == "litellm"
114+
assert info["owned_by"] == "openai"
115+
48116
def test_main_with_empty_dir_no_files_on_command(self):
49117
main(["--no-git", "--exit", "--yes"], input=DummyInput(), output=DummyOutput())
50118

tests/basic/test_models.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import os
12
import unittest
23
from unittest.mock import ANY, MagicMock, patch
34

5+
from aider.coders.base_coder import Coder
46
from aider.models import (
57
ANTHROPIC_BETA_HEADER,
68
Model,
@@ -17,13 +19,16 @@ def setUp(self):
1719
from aider.models import MODEL_SETTINGS
1820

1921
self._original_settings = MODEL_SETTINGS.copy()
22+
self.original_env = os.environ.copy()
2023

2124
def tearDown(self):
2225
"""Restore original MODEL_SETTINGS after each test"""
2326
from aider.models import MODEL_SETTINGS
2427

2528
MODEL_SETTINGS.clear()
2629
MODEL_SETTINGS.extend(self._original_settings)
30+
os.environ.clear()
31+
os.environ.update(self.original_env)
2732

2833
def test_get_model_info_nonexistent(self):
2934
manager = ModelInfoManager()
@@ -558,6 +563,64 @@ def test_use_temperature_in_send_completion(self, mock_completion):
558563
timeout=600,
559564
)
560565

566+
@patch("aider.models.litellm.completion")
567+
def test_litellm_send_completion(self, mock_completion):
568+
"""
569+
Model.send_completion should call litellm.completion with the correct arguments.
570+
"""
571+
os.environ["LITELLM_API_BASE"] = "http://localhost:4000"
572+
os.environ["LITELLM_API_KEY"] = "test-key"
573+
574+
model = Model("litellm/my-model")
575+
model.info["litellm_provider"] = "litellm"
576+
model.info["owned_by"] = "my-provider"
577+
578+
messages = [{"role": "user", "content": "Hello"}]
579+
model.send_completion(messages, None, True)
580+
581+
mock_completion.assert_called_once_with(
582+
model="my-provider/my-model",
583+
messages=messages,
584+
stream=True,
585+
temperature=0,
586+
api_base="http://localhost:4000",
587+
api_key="test-key",
588+
timeout=600,
589+
)
590+
591+
@patch("aider.coders.base_coder.litellm.completion_cost")
592+
def test_litellm_cost_calculation(self, mock_completion_cost):
593+
"""
594+
Test that the cost is calculated correctly for a litellm model.
595+
"""
596+
os.environ["LITELLM_API_BASE"] = "http://localhost:4000"
597+
os.environ["LITELLM_API_KEY"] = "test-key"
598+
599+
model = Model("litellm/my-model")
600+
model.info["litellm_provider"] = "litellm"
601+
model.info["input_cost_per_token"] = 0.00001
602+
model.info["output_cost_per_token"] = 0.00002
603+
604+
messages = [{"role": "user", "content": "Hello"}]
605+
completion = MagicMock()
606+
completion.usage.prompt_tokens = 10
607+
completion.usage.completion_tokens = 20
608+
completion.usage.prompt_cache_hit_tokens = 0
609+
completion.usage.cache_read_input_tokens = 0
610+
completion.usage.cache_creation_input_tokens = 0
611+
612+
mock_completion_cost.return_value = (10 * 0.00001) + (20 * 0.00002)
613+
614+
coder = Coder.create(main_model=model, io=MagicMock())
615+
coder.message_tokens_sent = 0
616+
coder.message_tokens_received = 0
617+
coder.total_cost = 0
618+
coder.message_cost = 0
619+
coder.calculate_and_show_tokens_and_cost(messages, completion)
620+
621+
self.assertIsNotNone(coder.usage_report)
622+
self.assertIn("Cost: $0.0005", coder.usage_report)
623+
561624

562625
if __name__ == "__main__":
563626
unittest.main()

0 commit comments

Comments
 (0)