@@ -112,12 +112,13 @@ def test_init_with_all_params(self, boto_session):
112112 "endpoint_name" : "test-endpoint" ,
113113 "inference_component_name" : "test-component" ,
114114 "region_name" : "us-west-2" ,
115- "additional_args" : {"test_arg_name " : "test_arg_value " },
115+ "additional_args" : {"test_req_arg_name " : "test_req_arg_value " },
116116 }
117117 payload_config = {
118118 "stream" : False ,
119119 "max_tokens" : 1024 ,
120120 "temperature" : 0.7 ,
121+ "additional_args" : {"test_payload_arg_name" : "test_payload_arg_value" },
121122 }
122123 client_config = BotocoreConfig (user_agent_extra = "test-agent" )
123124
@@ -130,10 +131,11 @@ def test_init_with_all_params(self, boto_session):
130131
131132 assert model .endpoint_config ["endpoint_name" ] == "test-endpoint"
132133 assert model .endpoint_config ["inference_component_name" ] == "test-component"
133- assert model .endpoint_config ["additional_args" ]["test_arg_name " ] == "test_arg_value "
134+ assert model .endpoint_config ["additional_args" ]["test_req_arg_name " ] == "test_req_arg_value "
134135 assert model .payload_config ["stream" ] is False
135136 assert model .payload_config ["max_tokens" ] == 1024
136137 assert model .payload_config ["temperature" ] == 0.7
138+ assert model .payload_config ["additional_args" ]["test_payload_arg_name" ] == "test_payload_arg_value"
137139
138140 boto_session .client .assert_called_once_with (
139141 service_name = "sagemaker-runtime" ,
@@ -246,16 +248,24 @@ def test_format_request_with_additional_args(self, boto_session, endpoint_config
246248 endpoint_config_ext = {
247249 ** endpoint_config ,
248250 "additional_args" : {
249- "extra_key" : "extra_value" ,
251+ "extra_request_key" : "extra_request_value" ,
252+ },
253+ }
254+ payload_config_ext = {
255+ ** payload_config ,
256+ "additional_args" : {
257+ "extra_payload_key" : "extra_payload_value" ,
250258 },
251259 }
252260 model = SageMakerAIModel (
253261 boto_session = boto_session ,
254262 endpoint_config = endpoint_config_ext ,
255- payload_config = payload_config ,
263+ payload_config = payload_config_ext ,
256264 )
257265 request = model .format_request (messages )
258- assert request .get ("extra_key" ) == "extra_value"
266+ assert request .get ("extra_request_key" ) == "extra_request_value"
267+ payload = json .loads (request ["Body" ])
268+ assert payload .get ("extra_payload_key" ) == "extra_payload_value"
259269
260270 @pytest .mark .asyncio
261271 async def test_stream_with_streaming_enabled (self , sagemaker_client , model , messages ):
0 commit comments