Skip to content

Commit b66ff54

Browse files
authored
[Remote Inference] Supported params are ignored (#1562)
1 parent 7af96df commit b66ff54

21 files changed

+219
-100
lines changed

src/oumi/core/inference/base_inference_engine.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,20 @@ def infer(
7979
"Only one of input or inference_config.input_path should be provided."
8080
)
8181

82-
if inference_config and inference_config.generation:
83-
generation_params = inference_config.generation
84-
self._check_unsupported_params(generation_params)
85-
else:
86-
generation_params = self._generation_params
82+
# Ensure the inference config has up-to-date generation parameters.
83+
if inference_config:
84+
if inference_config.generation:
85+
self._check_unsupported_params(inference_config.generation)
86+
elif self._generation_params:
87+
inference_config = copy.deepcopy(inference_config)
88+
inference_config.generation = self._generation_params
89+
90+
# Warn the user: They provided an inference config without generation
91+
# params, so what was the point of providing it in the first place?
92+
logger.warning(
93+
"No generation parameters provided in the inference config. Using "
94+
"the generation parameters that the engine was initialized with."
95+
)
8796

8897
if input is not None:
8998
return self.infer_online(input, inference_config)

src/oumi/inference/anthropic_inference_engine.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing_extensions import override
1818

19-
from oumi.core.configs import GenerationParams, RemoteParams
19+
from oumi.core.configs import GenerationParams, ModelParams, RemoteParams
2020
from oumi.core.types.conversation import Conversation, Message, Role
2121
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
2222
from oumi.utils.logging import logger
@@ -55,7 +55,10 @@ def api_key_env_varname(self) -> Optional[str]:
5555

5656
@override
5757
def _convert_conversation_to_api_input(
58-
self, conversation: Conversation, generation_params: GenerationParams
58+
self,
59+
conversation: Conversation,
60+
generation_params: GenerationParams,
61+
model_params: ModelParams,
5962
) -> dict[str, Any]:
6063
"""Converts a conversation to an Anthropic API input.
6164
@@ -68,6 +71,7 @@ def _convert_conversation_to_api_input(
6871
Args:
6972
conversation: The Oumi Conversation object to convert.
7073
generation_params: Parameters for text generation.
74+
model_params: Model parameters to use during inference.
7175
7276
Returns:
7377
Dict[str, Any]: A dictionary containing the formatted input for the
@@ -98,7 +102,7 @@ def _convert_conversation_to_api_input(
98102
# Build request body
99103
# See https://docs.anthropic.com/claude/reference/messages_post
100104
body = {
101-
"model": self._model,
105+
"model": model_params.model_name,
102106
"messages": self._get_list_of_message_json_dicts(
103107
messages, group_adjacent_same_role_turns=True
104108
),

src/oumi/inference/gcp_inference_engine.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pydantic
1919
from typing_extensions import override
2020

21-
from oumi.core.configs import GenerationParams, RemoteParams
21+
from oumi.core.configs import GenerationParams, ModelParams, RemoteParams
2222
from oumi.core.configs.params.guided_decoding_params import GuidedDecodingParams
2323
from oumi.core.types.conversation import Conversation
2424
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
@@ -73,7 +73,10 @@ def _get_request_headers(
7373

7474
@override
7575
def _convert_conversation_to_api_input(
76-
self, conversation: Conversation, generation_params: GenerationParams
76+
self,
77+
conversation: Conversation,
78+
generation_params: GenerationParams,
79+
model_params: ModelParams,
7780
) -> dict[str, Any]:
7881
"""Converts a conversation to an OpenAI input.
7982
@@ -82,12 +85,13 @@ def _convert_conversation_to_api_input(
8285
Args:
8386
conversation: The conversation to convert.
8487
generation_params: Parameters for generation during inference.
88+
model_params: Model parameters to use during inference.
8589
8690
Returns:
8791
Dict[str, Any]: A dictionary representing the Vertex input.
8892
"""
8993
api_input = {
90-
"model": self._model,
94+
"model": model_params.model_name,
9195
"messages": self._get_list_of_message_json_dicts(
9296
conversation.messages, group_adjacent_same_role_turns=True
9397
),

src/oumi/inference/gemini_inference_engine.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing_extensions import override
1818

19-
from oumi.core.configs import GenerationParams
19+
from oumi.core.configs import GenerationParams, ModelParams
2020
from oumi.core.types.conversation import Conversation
2121
from oumi.inference.gcp_inference_engine import (
2222
_convert_guided_decoding_config_to_api_input,
@@ -37,7 +37,10 @@ class GoogleGeminiInferenceEngine(RemoteInferenceEngine):
3737

3838
@override
3939
def _convert_conversation_to_api_input(
40-
self, conversation: Conversation, generation_params: GenerationParams
40+
self,
41+
conversation: Conversation,
42+
generation_params: GenerationParams,
43+
model_params: ModelParams,
4144
) -> dict[str, Any]:
4245
"""Converts a conversation to an Gemini API input.
4346
@@ -46,12 +49,13 @@ def _convert_conversation_to_api_input(
4649
Args:
4750
conversation: The conversation to convert.
4851
generation_params: Parameters for generation during inference.
52+
model_params: Model parameters to use during inference.
4953
5054
Returns:
5155
Dict[str, Any]: A dictionary representing the Gemini input.
5256
"""
5357
api_input = {
54-
"model": self._model,
58+
"model": model_params.model_name,
5559
"messages": self._get_list_of_message_json_dicts(
5660
conversation.messages, group_adjacent_same_role_turns=True
5761
),

src/oumi/inference/openai_inference_engine.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
import copy
16+
from typing import Any, Optional
1617

1718
from typing_extensions import override
1819

20+
from oumi.core.configs import GenerationParams, ModelParams
21+
from oumi.core.types.conversation import Conversation
1922
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
2023

2124

@@ -33,3 +36,33 @@ def base_url(self) -> Optional[str]:
3336
def api_key_env_varname(self) -> Optional[str]:
3437
"""Return the default environment variable name for the OpenAI API key."""
3538
return "OPENAI_API_KEY"
39+
40+
@override
41+
def _convert_conversation_to_api_input(
42+
self,
43+
conversation: Conversation,
44+
generation_params: GenerationParams,
45+
model_params: ModelParams,
46+
) -> dict[str, Any]:
47+
"""Converts a conversation to an OpenAI input.
48+
49+
Documentation: https://platform.openai.com/docs/api-reference/chat/create
50+
51+
Args:
52+
conversation: The conversation to convert.
53+
generation_params: Parameters for generation during inference.
54+
model_params: Model parameters to use during inference.
55+
56+
Returns:
57+
Dict[str, Any]: A dictionary representing the OpenAI input.
58+
"""
59+
# o1-preview does NOT support logit_bias.
60+
if model_params.model_name == "o1-preview":
61+
generation_params = copy.deepcopy(generation_params)
62+
generation_params.logit_bias = {}
63+
64+
return super()._convert_conversation_to_api_input(
65+
conversation=conversation,
66+
generation_params=generation_params,
67+
model_params=model_params,
68+
)

0 commit comments

Comments
 (0)