Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions .idea/codeStyles/codeStyleConfig.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions .idea/libraries/my_test_package.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions aider/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def get_parser(default_config_files, git_root):
"--openai-api-base",
help="Specify the api base url",
)
group.add_argument(
"--litellm-api-base",
help="Specify the litellm api base url",
)
group.add_argument(
"--litellm-api-key",
help="Specify the litellm api key",
)
group.add_argument(
"--openai-api-type",
help="(deprecated, use --set-env OPENAI_API_TYPE=<value>)",
Expand Down
21 changes: 12 additions & 9 deletions aider/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,17 +2055,20 @@ def format_cost(value):
else:
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"

cost_report = (
f"Cost: ${format_cost(self.message_cost)} message,"
f" ${format_cost(self.total_cost)} session."
)
if cost:
cost_report = (
f"Cost: ${format_cost(self.message_cost)} message,"
f" ${format_cost(self.total_cost)} session."
)

if cache_hit_tokens and cache_write_tokens:
sep = "\n"
else:
sep = " "
if cache_hit_tokens and cache_write_tokens:
sep = "\n"
else:
sep = " "

self.usage_report = tokens_report + sep + cost_report
self.usage_report = tokens_report + sep + cost_report
else:
self.usage_report = tokens_report

def compute_costs_from_tokens(
self, prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens
Expand Down
59 changes: 59 additions & 0 deletions aider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,60 @@ def register_litellm_models(git_root, model_metadata_fname, io, verbose=False):
return 1


def discover_litellm_models(io, verbose=False):
litellm_api_base = os.environ.get("LITELLM_API_BASE")
if not litellm_api_base:
return

try:
import requests

headers = {}
api_key = os.environ.get("LITELLM_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

# First, get the models and their owners
url = litellm_api_base.rstrip("/") + "/models"

response = requests.get(
url, headers=headers, timeout=5, verify=models.model_info_manager.verify_ssl
)
if response.status_code != 200:
io.tool_warning(f"Error fetching models from {url}: {response.status_code}")
return

models_data = response.json()
model_owners = {
model_info.get("id"): model_info.get("owned_by")
for model_info in models_data.get("data", [])
}

# Now, get the model group info
url = litellm_api_base.rstrip("/") + "/model_group/info"
response = requests.get(
url, headers=headers, timeout=5, verify=models.model_info_manager.verify_ssl
)
if response.status_code == 200:
model_group_data = response.json()
for model_info in model_group_data.get("data", []):
model_group = model_info.get("model_group")
if model_group:
models.model_info_manager.local_model_metadata[f"litellm/{model_group}"] = {
"litellm_provider": "litellm",
"mode": "chat",
"owned_by": model_owners.get(model_group),
"input_cost_per_token": model_info.get("input_cost_per_token"),
"output_cost_per_token": model_info.get("output_cost_per_token"),
"max_input_tokens": model_info.get("max_input_tokens"),
"max_output_tokens": model_info.get("max_output_tokens"),
}
if verbose:
io.tool_output(f"Discovered model info from {url}")
except Exception as e:
io.tool_warning(f"Error fetching model info from litellm: {e}")


def sanity_check_repo(repo, io):
if not repo:
return True
Expand Down Expand Up @@ -619,6 +673,10 @@ def get_io(pretty):
handle_deprecated_model_args(args, io)
if args.openai_api_base:
os.environ["OPENAI_API_BASE"] = args.openai_api_base
if args.litellm_api_base:
os.environ["LITELLM_API_BASE"] = args.litellm_api_base
if args.litellm_api_key:
os.environ["LITELLM_API_KEY"] = args.litellm_api_key
if args.openai_api_version:
io.tool_warning(
"--openai-api-version is deprecated, use --set-env OPENAI_API_VERSION=<value>"
Expand Down Expand Up @@ -755,6 +813,7 @@ def get_io(pretty):

register_models(git_root, args.model_settings_file, io, verbose=args.verbose)
register_litellm_models(git_root, args.model_metadata_file, io, verbose=args.verbose)
discover_litellm_models(io, verbose=args.verbose)

if args.list_models:
models.print_matching_models(io, args.list_models)
Expand Down
19 changes: 18 additions & 1 deletion aider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,17 @@ def send_completion(self, messages, functions, stream, temperature=None):
if self.is_deepseek_r1():
messages = ensure_alternating_roles(messages)

model_name = self.name
if self.info.get("litellm_provider") == "litellm":
owned_by = self.info.get("owned_by")
model_id = model_name[len("litellm/") :]
if owned_by:
model_name = f"{owned_by}/{model_id}"
else:
model_name = model_id

kwargs = dict(
model=self.name,
model=model_name,
stream=stream,
)

Expand Down Expand Up @@ -997,6 +1006,14 @@ def send_completion(self, messages, functions, stream, temperature=None):

self.github_copilot_token_to_open_ai_key(kwargs["extra_headers"])

if self.info.get("litellm_provider") == "litellm":
litellm_api_base = os.environ.get("LITELLM_API_BASE")
if litellm_api_base:
kwargs["api_base"] = litellm_api_base
litellm_api_key = os.environ.get("LITELLM_API_KEY")
if litellm_api_key:
kwargs["api_key"] = litellm_api_key

res = litellm.completion(**kwargs)
return hash_object, res

Expand Down
7 changes: 6 additions & 1 deletion aider/onboarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import requests

from aider import urls
from aider import models, urls
from aider.io import InputOutput


Expand Down Expand Up @@ -73,6 +73,11 @@ def try_to_select_default_model():
if api_key_value:
return model_name

if os.environ.get("LITELLM_API_BASE") and os.environ.get("LITELLM_API_KEY"):
for model_name, model_info in models.model_info_manager.local_model_metadata.items():
if model_info.get("litellm_provider") == "litellm":
return model_name

return None


Expand Down
70 changes: 69 additions & 1 deletion tests/basic/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,25 @@
from aider.coders import Coder
from aider.dump import dump # noqa: F401
from aider.io import InputOutput
from aider.main import check_gitignore, load_dotenv_files, main, setup_git
from aider.main import (
check_gitignore,
discover_litellm_models,
load_dotenv_files,
main,
setup_git,
)
from aider.utils import GitTemporaryDirectory, IgnorantTemporaryDirectory, make_repo


class DummyResponse:
def __init__(self, json_data, status_code=200):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data


class TestMain(TestCase):
def setUp(self):
self.original_env = os.environ.copy()
Expand Down Expand Up @@ -45,6 +60,59 @@ def tearDown(self):
self.input_patcher.stop()
self.webbrowser_patcher.stop()

def test_litellm_discover_models(self):
"""
discover_litellm_models should return correct metadata taken from the
downloaded (and locally cached) models JSON payload.
"""
models_payload = {
"data": [
{
"id": "bedrock-claude-opus-4.1",
"object": "model",
"created": 1677610602,
"owned_by": "openai",
}
]
}

model_group_payload = {
"data": [
{
"model_group": "bedrock-claude-opus-4.1",
"providers": ["bedrock"],
"max_input_tokens": 200000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.000015,
"output_cost_per_token": 0.000075,
}
]
}

def mock_get(url, **kwargs):
if "/models" in url:
return DummyResponse(models_payload)
elif "/model_group/info" in url:
return DummyResponse(model_group_payload)
return DummyResponse({}, 404)

with patch("requests.get", mock_get):
os.environ["LITELLM_API_BASE"] = "http://localhost:4000"
os.environ["LITELLM_API_KEY"] = "test-key"

io = MagicMock()
discover_litellm_models(io)

from aider import models

info = models.model_info_manager.get_model_info("litellm/bedrock-claude-opus-4.1")

assert info["max_input_tokens"] == 200000
assert info["input_cost_per_token"] == 0.000015
assert info["output_cost_per_token"] == 0.000075
assert info["litellm_provider"] == "litellm"
assert info["owned_by"] == "openai"

def test_main_with_empty_dir_no_files_on_command(self):
main(["--no-git", "--exit", "--yes"], input=DummyInput(), output=DummyOutput())

Expand Down
Loading