12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Optional
15
+ import copy
16
+ from typing import Any , Optional
16
17
17
18
from typing_extensions import override
18
19
20
+ from oumi .core .configs import GenerationParams , ModelParams
21
+ from oumi .core .types .conversation import Conversation
19
22
from oumi .inference .remote_inference_engine import RemoteInferenceEngine
20
23
21
24
@@ -33,3 +36,33 @@ def base_url(self) -> Optional[str]:
33
36
def api_key_env_varname (self ) -> Optional [str ]:
34
37
"""Return the default environment variable name for the OpenAI API key."""
35
38
return "OPENAI_API_KEY"
39
+
40
+ @override
41
+ def _convert_conversation_to_api_input (
42
+ self ,
43
+ conversation : Conversation ,
44
+ generation_params : GenerationParams ,
45
+ model_params : ModelParams ,
46
+ ) -> dict [str , Any ]:
47
+ """Converts a conversation to an OpenAI input.
48
+
49
+ Documentation: https://platform.openai.com/docs/api-reference/chat/create
50
+
51
+ Args:
52
+ conversation: The conversation to convert.
53
+ generation_params: Parameters for generation during inference.
54
+ model_params: Model parameters to use during inference.
55
+
56
+ Returns:
57
+ Dict[str, Any]: A dictionary representing the OpenAI input.
58
+ """
59
+ # o1-preview does NOT support logit_bias.
60
+ if model_params .model_name == "o1-preview" :
61
+ generation_params = copy .deepcopy (generation_params )
62
+ generation_params .logit_bias = {}
63
+
64
+ return super ()._convert_conversation_to_api_input (
65
+ conversation = conversation ,
66
+ generation_params = generation_params ,
67
+ model_params = model_params ,
68
+ )
0 commit comments