Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Update chat properties #1709

Merged
merged 1 commit into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# the specific language governing permissions and limitations under the License.
from typing import Optional, Union, List, Dict

from pydantic.v1 import BaseModel, Field, validator, root_validator
from pydantic import BaseModel, Field, validator, root_validator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we ready to upgrade pydantic?

Copy link
Contributor Author

@xyang16 xyang16 Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frankfliu Sorry I just saw your comment, we are already installing pydantic 2.6 in the docker, I didn't upgrade it.



class ChatProperties(BaseModel):
Expand All @@ -22,25 +22,28 @@ class ChatProperties(BaseModel):
"""

messages: List[Dict[str, str]]
model: Optional[str] # UNUSED
frequency_penalty: Optional[float] = 0
logit_bias: Optional[dict] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] # Currently only support 1
max_new_tokens: Optional[int] = Field(alias="max_tokens")
n: Optional[int] = 1 # Currently only support 1
presence_penalty: Optional[float] = 0
seed: Optional[int]
stop_sequences: Optional[Union[str, list]] = Field(alias="stop")
temperature: Optional[int] = 1
top_p: Optional[int] = 1
user: Optional[str]
model: Optional[str] = Field(default=None, exclude=True) # Unused
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = Field(default=None, exclude=True)
logprobs: Optional[bool] = Field(default=False, exclude=True)
top_logprobs: Optional[int] = Field(default=None,
serialization_alias="logprobs")
max_tokens: Optional[int] = Field(default=None,
serialization_alias="max_new_tokens")
n: Optional[int] = Field(default=1, exclude=True)
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = Field(default=None, exclude=True)

@validator('messages', pre=True)
def validate_messages(
cls, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
if messages is None:
return messages
return None

for message in messages:
if not ("role" in message and "content" in message):
Expand All @@ -52,17 +55,31 @@ def validate_messages(
@validator('frequency_penalty', pre=True)
def validate_frequency_penalty(cls, frequency_penalty: float) -> float:
if frequency_penalty is None:
return frequency_penalty
return None

frequency_penalty = float(frequency_penalty)
if frequency_penalty < -2.0 or frequency_penalty > 2.0:
raise ValueError("frequency_penalty must be between -2.0 and 2.0.")
return frequency_penalty

@validator('top_logprobs', pre=True)
def validate_top_logprobs(cls, top_logprobs: float) -> float:
@validator('logit_bias', pre=True)
def validate_logit_bias(cls, logit_bias: Dict[str, float]):
if logit_bias is None:
return None

for token_id, bias in logit_bias.items():
if bias < -100.0 or bias > 100.0:
raise ValueError(
"logit_bias value must be between -100 and 100.")
return logit_bias

@validator('top_logprobs')
def validate_top_logprobs(cls, top_logprobs: int, values):
if top_logprobs is None:
return top_logprobs
return None

if not values.get('logprobs'):
return None

top_logprobs = int(top_logprobs)
if top_logprobs < 0 or top_logprobs > 20:
Expand All @@ -72,7 +89,7 @@ def validate_top_logprobs(cls, top_logprobs: float) -> float:
@validator('presence_penalty', pre=True)
def validate_presence_penalty(cls, presence_penalty: float) -> float:
if presence_penalty is None:
return presence_penalty
return None

presence_penalty = float(presence_penalty)
if presence_penalty < -2.0 or presence_penalty > 2.0:
Expand All @@ -82,9 +99,9 @@ def validate_presence_penalty(cls, presence_penalty: float) -> float:
@validator('temperature', pre=True)
def validate_temperature(cls, temperature: float) -> float:
if temperature is None:
return temperature
return None

temperature = float(temperature)
if temperature < 0 or temperature > 2:
if temperature < 0.0 or temperature > 2.0:
raise ValueError("temperature must be between 0 and 2.")
return temperature
16 changes: 6 additions & 10 deletions engines/python/setup/djl_python/chat_completions/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@ def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")
chat_params = ChatProperties(**inputs)
_inputs = tokenizer.apply_chat_template(chat_params.messages,
tokenize=False)
_param = chat_params.dict(exclude_unset=True,
exclude={
'messages', 'model', 'logit_bias',
'top_logprobs', 'n', 'user'
})
_param["details"] = True
_param["output_formatter"] = "jsonlines_chat" if inputs.get(
"stream", False) else "json_chat"
_param = chat_params.model_dump(by_alias=True, exclude_unset=True)
_messages = _param.pop("messages")
_inputs = tokenizer.apply_chat_template(_messages, tokenize=False)
_param["details"] = True # Enable details for chat completions
_param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

return _inputs, _param
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def translate_triton_params(self, parameters: dict) -> dict:
parameters["temperature"] = parameters.get("temperature", 0.8)
if "length_penalty" in parameters.keys():
parameters['len_penalty'] = parameters.pop('length_penalty')
parameters["streaming"] = parameters.get("streaming", True)
parameters["streaming"] = parameters.pop(
"stream", parameters.get("streaming", True))
return parameters

@stop_on_any_exception
Expand Down
60 changes: 36 additions & 24 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,32 +617,34 @@ def test_chat_min_configs():
chat_configs = ChatProperties(**properties)
self.assertEqual(chat_configs.messages, properties["messages"])
self.assertIsNone(chat_configs.model)
self.assertEqual(chat_configs.frequency_penalty, 0)
self.assertEqual(chat_configs.frequency_penalty, 0.0)
self.assertIsNone(chat_configs.logit_bias)
self.assertFalse(chat_configs.logprobs)
self.assertIsNone(chat_configs.top_logprobs)
self.assertIsNone(chat_configs.max_new_tokens)
self.assertIsNone(chat_configs.max_tokens)
self.assertEqual(chat_configs.n, 1)
self.assertEqual(chat_configs.presence_penalty, 0)
self.assertEqual(chat_configs.presence_penalty, 0.0)
self.assertIsNone(chat_configs.seed)
self.assertIsNone(chat_configs.stop_sequences)
self.assertEqual(chat_configs.temperature, 1)
self.assertEqual(chat_configs.top_p, 1)
self.assertIsNone(chat_configs.stop)
self.assertFalse(chat_configs.stream)
self.assertEqual(chat_configs.temperature, 1.0)
self.assertEqual(chat_configs.top_p, 1.0)
self.assertIsNone(chat_configs.user)

def test_chat_all_configs():
properties["model"] = "model"
properties["frequency_penalty"] = "1"
properties["logit_bias"] = {"2435": -100, "640": -100}
properties["frequency_penalty"] = "1.0"
properties["logit_bias"] = {"2435": -100.0, "640": -100.0}
properties["logprobs"] = "false"
properties["top_logprobs"] = "3"
properties["max_tokens"] = "256"
properties["n"] = "1"
properties["presence_penalty"] = "1"
properties["presence_penalty"] = "1.0"
properties["seed"] = "123"
properties["stop"] = "stop"
properties["temperature"] = "1"
properties["top_p"] = "3"
properties["stop"] = ["stop"]
properties["stream"] = "true"
properties["temperature"] = "1.0"
properties["top_p"] = "3.0"
properties["user"] = "user"

chat_configs = ChatProperties(**properties)
Expand All @@ -652,18 +654,18 @@ def test_chat_all_configs():
float(properties['frequency_penalty']))
self.assertEqual(chat_configs.logit_bias, properties['logit_bias'])
self.assertFalse(chat_configs.logprobs)
self.assertEqual(chat_configs.top_logprobs,
int(properties['top_logprobs']))
self.assertEqual(chat_configs.max_new_tokens,
self.assertIsNone(chat_configs.top_logprobs)
self.assertEqual(chat_configs.max_tokens,
int(properties['max_tokens']))
self.assertEqual(chat_configs.n, int(properties['n']))
self.assertEqual(chat_configs.presence_penalty,
float(properties['presence_penalty']))
self.assertEqual(chat_configs.seed, int(properties['seed']))
self.assertEqual(chat_configs.stop_sequences, properties['stop'])
self.assertEqual(chat_configs.stop, properties['stop'])
self.assertTrue(chat_configs.stream)
self.assertEqual(chat_configs.temperature,
int(properties['temperature']))
self.assertEqual(chat_configs.top_p, int(properties['top_p']))
float(properties['temperature']))
self.assertEqual(chat_configs.top_p, float(properties['top_p']))
self.assertEqual(chat_configs.user, properties['user'])

def test_invalid_configs():
Expand All @@ -678,34 +680,44 @@ def test_invalid_configs():
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["frequency_penalty"] = "-3"
test_properties["frequency_penalty"] = "-3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["frequency_penalty"] = "3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["logit_bias"] = {"2435": -100.0, "640": 200.0}
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["frequency_penalty"] = "3"
test_properties["logit_bias"] = {"2435": -200.0, "640": 100.0}
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["logprobs"] = "true"
test_properties["top_logprobs"] = "-1"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["logprobs"] = "true"
test_properties["top_logprobs"] = "30"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["presence_penalty"] = "-3"
test_properties["presence_penalty"] = "-3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["presence_penalty"] = "3"
test_properties["presence_penalty"] = "3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["temperature"] = "-1"
test_properties["temperature"] = "-1.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["temperature"] = "3"
test_properties["temperature"] = "3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

Expand Down
Loading