1+ from __future__ import annotations
2+
13from abc import ABC , abstractmethod
2- from typing import Optional , List , Dict , Any
4+ from typing import List , Dict , Any
35
46from ai21 .models import Penalty , CompletionsResponse
7+ from ai21 .types import NOT_GIVEN , NotGiven
8+ from ai21 .utils .typing import remove_not_given
59
610
711class Completion (ABC ):
@@ -13,18 +17,18 @@ def create(
1317 model : str ,
1418 prompt : str ,
1519 * ,
16- max_tokens : int = 64 ,
17- num_results : int = 1 ,
18- min_tokens = 0 ,
19- temperature = 0.7 ,
20- top_p = 1 ,
21- top_k_return = 0 ,
22- custom_model : Optional [ str ] = None ,
23- stop_sequences : Optional [ List [str ]] = () ,
24- frequency_penalty : Optional [ Penalty ] = None ,
25- presence_penalty : Optional [ Penalty ] = None ,
26- count_penalty : Optional [ Penalty ] = None ,
27- epoch : Optional [ int ] = None ,
20+ max_tokens : int | NotGiven = NOT_GIVEN ,
21+ num_results : int | NotGiven = NOT_GIVEN ,
22+ min_tokens : int | NotGiven = NOT_GIVEN ,
23+ temperature : float | NOT_GIVEN = NOT_GIVEN ,
24+ top_p : float | NotGiven = NOT_GIVEN ,
25+ top_k_return : int | NotGiven = NOT_GIVEN ,
26+ custom_model : str | NotGiven = NOT_GIVEN ,
27+ stop_sequences : List [str ] | NotGiven = NOT_GIVEN ,
28+ frequency_penalty : Penalty | NotGiven = NOT_GIVEN ,
29+ presence_penalty : Penalty | NotGiven = NOT_GIVEN ,
30+ count_penalty : Penalty | NotGiven = NOT_GIVEN ,
31+ epoch : int | NotGiven = NOT_GIVEN ,
2832 ** kwargs ,
2933 ) -> CompletionsResponse :
3034 """
@@ -54,32 +58,34 @@ def _create_body(
5458 self ,
5559 model : str ,
5660 prompt : str ,
57- max_tokens : Optional [ int ] ,
58- num_results : Optional [ int ] ,
59- min_tokens : Optional [ int ] ,
60- temperature : Optional [ float ] ,
61- top_p : Optional [ int ] ,
62- top_k_return : Optional [ int ] ,
63- custom_model : Optional [ str ] ,
64- stop_sequences : Optional [ List [str ]] ,
65- frequency_penalty : Optional [ Penalty ] ,
66- presence_penalty : Optional [ Penalty ] ,
67- count_penalty : Optional [ Penalty ] ,
68- epoch : Optional [ int ] ,
61+ max_tokens : int | NotGiven ,
62+ num_results : int | NotGiven ,
63+ min_tokens : int | NotGiven ,
64+ temperature : float | NotGiven ,
65+ top_p : float | NotGiven ,
66+ top_k_return : int | NotGiven ,
67+ custom_model : str | NotGiven ,
68+ stop_sequences : List [str ] | NotGiven ,
69+ frequency_penalty : Penalty | NotGiven ,
70+ presence_penalty : Penalty | NotGiven ,
71+ count_penalty : Penalty | NotGiven ,
72+ epoch : int | NotGiven ,
6973 ):
70- return {
71- "model" : model ,
72- "customModel" : custom_model ,
73- "prompt" : prompt ,
74- "maxTokens" : max_tokens ,
75- "numResults" : num_results ,
76- "minTokens" : min_tokens ,
77- "temperature" : temperature ,
78- "topP" : top_p ,
79- "topKReturn" : top_k_return ,
80- "stopSequences" : stop_sequences or [],
81- "frequencyPenalty" : None if frequency_penalty is None else frequency_penalty .to_dict (),
82- "presencePenalty" : None if presence_penalty is None else presence_penalty .to_dict (),
83- "countPenalty" : None if count_penalty is None else count_penalty .to_dict (),
84- "epoch" : epoch ,
85- }
74+ return remove_not_given (
75+ {
76+ "model" : model ,
77+ "customModel" : custom_model ,
78+ "prompt" : prompt ,
79+ "maxTokens" : max_tokens ,
80+ "numResults" : num_results ,
81+ "minTokens" : min_tokens ,
82+ "temperature" : temperature ,
83+ "topP" : top_p ,
84+ "topKReturn" : top_k_return ,
85+ "stopSequences" : stop_sequences ,
86+ "frequencyPenalty" : NOT_GIVEN if frequency_penalty is NOT_GIVEN else frequency_penalty .to_dict (),
87+ "presencePenalty" : NOT_GIVEN if presence_penalty is NOT_GIVEN else presence_penalty .to_dict (),
88+ "countPenalty" : NOT_GIVEN if count_penalty is NOT_GIVEN else count_penalty .to_dict (),
89+ "epoch" : epoch ,
90+ }
91+ )
0 commit comments