Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions litellm/llms/bedrock/base_aws_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ def get_bedrock_model_id(
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="openai"
)
elif provider == "qwen2" and "qwen2/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="qwen2"
)
elif provider == "qwen3" and "qwen3/" in model_id:
model_id = BaseAWSLLM._get_model_id_from_model_with_spec(
model_id, spec="qwen3"
)
return model_id

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,72 @@ def test_qwen2_provider_detection():
assert config is not None
assert isinstance(config, AmazonQwen2Config)


def test_qwen2_model_id_extraction_with_arn():
"""Test that model ID is correctly extracted from bedrock/qwen2/arn... paths"""
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM

# Test case: bedrock/qwen2/arn:aws:bedrock:us-east-1:123456789012:imported-model/test-qwen2
# The qwen2/ prefix should be stripped, leaving only the ARN for encoding
model = "qwen2/arn:aws:bedrock:us-east-1:123456789012:imported-model/test-qwen2"
provider = "qwen2"

result = BaseAWSLLM.get_bedrock_model_id(
optional_params={},
provider=provider,
model=model
)

# The result should NOT contain "qwen2/" - it should be stripped
assert "qwen2/" not in result
# The result should be URL-encoded ARN
assert "arn%3Aaws%3Abedrock" in result or "arn:aws:bedrock" in result


def test_qwen2_model_id_extraction_without_qwen2_prefix():
"""Test that model ID extraction doesn't strip qwen2/ when provider is not qwen2"""
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM

# Test case: just a model name without qwen2/ prefix
model = "arn:aws:bedrock:us-east-1:123456789012:imported-model/test-qwen2"
provider = "qwen2"

result = BaseAWSLLM.get_bedrock_model_id(
optional_params={},
provider=provider,
model=model
)

# Result should be encoded ARN
assert "arn" in result.lower() or "aws" in result.lower()


def test_qwen2_get_bedrock_model_id_with_various_formats():
"""Test get_bedrock_model_id with various Qwen2 model path formats"""
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM

test_cases = [
{
"model": "qwen2/arn:aws:bedrock:us-east-1:123456789012:imported-model/test-qwen2",
"provider": "qwen2",
"should_not_contain": "qwen2/",
"description": "Qwen2 imported model ARN"
},
{
"model": "bedrock/qwen2/arn:aws:bedrock:us-east-1:123456789012:imported-model/test-qwen2",
"provider": "qwen2",
"should_not_contain": "qwen2/",
"description": "Bedrock prefixed Qwen2 ARN"
}
]

for test_case in test_cases:
result = BaseAWSLLM.get_bedrock_model_id(
optional_params={},
provider=test_case["provider"],
model=test_case["model"]
)

assert test_case["should_not_contain"] not in result, \
f"Failed for {test_case['description']}: {test_case['should_not_contain']} found in {result}"

Loading