Skip to content

Commit

Permalink
API: Add banned_strings
Browse files Browse the repository at this point in the history
From exllamav2: List of strings that the generator will refuse to output. As soon as a partial match happens, a checkpoint is saved that the generator can rewind to if need be. Subsequent tokens are then held until the full string is resolved (match or no match) and either emitted or discarded, accordingly.
  • Loading branch information
DocShotgun committed May 10, 2024
1 parent a1df226 commit c0b631b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
18 changes: 18 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def generate_gen_sync(
)

stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias")
Expand Down Expand Up @@ -960,6 +961,22 @@ def generate_gen_sync(
)
min_tokens = 0

# Check if banned_strings is supported
# TODO: Remove when a new version of ExllamaV2 is released
if banned_strings:
begin_stream_signature = signature(self.generator.begin_stream_ex)

try:
_bound_vars = begin_stream_signature.bind_partial(
banned_strings=[]
)
begin_stream_args["banned_strings"] = banned_strings
except TypeError:
logger.warning(
"banned_strings is not supported by the currently "
"installed ExLlamaV2 version."
)

# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
Expand All @@ -979,6 +996,7 @@ def generate_gen_sync(
logprobs=request_logprobs,
stop_conditions=stop_conditions,
banned_tokens=banned_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
)

Expand Down
9 changes: 9 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("stop", [])
)

banned_strings: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)

token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
Expand Down Expand Up @@ -257,6 +261,10 @@ def to_gen_params(self, **kwargs):
if self.stop and isinstance(self.stop, str):
self.stop = [self.stop]

# Convert banned_strings to an array of strings
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]

# Convert string banned tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [
Expand All @@ -268,6 +276,7 @@ def to_gen_params(self, **kwargs):
"min_tokens": self.min_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
"banned_strings": self.banned_strings,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
Expand Down
4 changes: 4 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ stop:
override: []
force: false
additive: false
banned_strings:
override: []
force: false
additive: false
token_healing:
override: false
force: false
Expand Down

0 comments on commit c0b631b

Please sign in to comment.