-
Notifications
You must be signed in to change notification settings - Fork 151
Amazon SageMaker AI Model Provider implementation #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Here is a PR for the OpenAI model provider. It also introduces a base class from which we can derive other OpenAI compatible providers. If we ask that customers host their models on SageMaker with OpenAI compatibility, then we can reduce a lot of code duplication (see for example how |
This is very interesting @pgrayy , thank you for bringing it to my attention. What do you think should we do here? Do you want to merge you OpenAI PR first, then I build a class derived from the |
Updated implementation based on the suggestions in this PR. Please re-review and let me know if you need me to change more code. I've also run the tests and everything seems to work just fine :) |
We now have the OpenAI provider PR merged. I think it would be worth updating the SageMaker provider to use the new base class. The code should look similar to: class SageMakerModel(OpenAIModel):
class SageMakerConfig(TypedDict, total=False):
endpoint_name: str
inference_component_name: Optional[str]
model_id: str
params: Optional[dict[str, Any]]
def __init__(
self,
boto_session: Optional[boto3.Session] = None,
boto_client_config: Optional[BotocoreConfig] = None,
region_name: Optional[str] = None,
**model_config: Unpack[SageMakerConfig]
) -> None:
self.config = dict(model_config)
logger.debug("config=<%s> | initializing", self.config)
boto_session = boto_session or boto3.Session(region_name=region_name)
self.client = session.client(
service_name="sagemaker-runtime",
config=boto_client_config,
)
@override
def update_config(self, **model_config: Unpack[SageMakerConfig]) -> None:
self.config.update(model_config)
@override
def get_config(self) -> SageMakerConfig:
return cast(SageMakerModel.SageMakerConfig, self.config)
@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
# yield events in the format that the base OpenAIModel.format_chunk expects. |
Thanks @pgrayy . I've done some tests, and it does not seem like obtaining a response from SageMaker AI in completions format is very straightforward. I'd suggest we go ahead with the current implementation, and I will work with the SageMaker AI Service Team to figure out what is the best approach to improve the implementation and support OpenAI Completions as response format from the endpoint. |
Can you elaborate on this? SageMaker is a custom model hosting solution so you should be able to conform to any format. What challenges are you seeing implementing your handler to return OpenAI compatible payloads? I can help do some testing on my end. |
If you could test, it would be great! The problem is how the streamed response comes back from the DJL container. |
Closed this Pull Request to work on a new implementation based on OpenAI model provider. Refer to #176 . |
Description
Support for Amazon SageMaker AI endpoints as Model Provider
Related Issues
PR #16
Documentation PR
[Link to related associated PR in the agent-docs repo]
Type of Change
New feature
Testing
Yes
Checklist
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.