Skip to content

Commit 2d1f95f

Browse files
authored
Merge pull request #190 from structuredllm/gemma
Fix Phi-4 issue
2 parents 1162fac + 61fafe8 commit 2d1f95f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

syncode/language_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def generate_grammar_constrained_completion(
111111
stop_criteria = []
112112

113113
# Generate completions
114-
if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1: # Use our own implementation for greedy search and sampling
114+
if self.opp and (gen_mode == GenerationMode.SAMPLE or gen_mode == GenerationMode.GREEDY_SEARCH) and batch_size == 1:
115+
# Use our own implementation for greedy search and sampling
115116
generated_ids = self._generate(
116117
inputs,
117118
gen_config,
@@ -239,6 +240,11 @@ def _generate(
239240
# (the clone itself is always small)
240241
next_token_scores = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=token_ids.device)
241242

243+
if len(next_token_scores.shape) == 3:
244+
# FIXME: This is a strange behaviour for some models like Phi-4
245+
# We expect next_token_scores to be of shape (batch_size, vocab_size)
246+
next_token_scores = next_token_scores[:, -1, :]
247+
242248
if grammar_decoder is not None:
243249
next_token = self._get_next_token(gen_mode, token_ids, logits_processor, next_token_scores)
244250
is_valid = grammar_decoder.is_valid(token_ids, next_token)

syncode/mask_store/byte_tokenizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ def __init__(self, tokenizer, vocab_type=None):
188188
# Cache special token IDs as a set for faster lookups
189189
self.special_token_ids = set(getattr(tokenizer, "all_special_ids", []))
190190

191+
# Added tokens are typically special tokens
192+
# if added_tokens_decoder is not None self.tokenizer.added_tokens_decoder.keys()
193+
# to special_token_ids
194+
if hasattr(tokenizer, "added_tokens_decoder"):
195+
self.special_token_ids.update(tokenizer.added_tokens_decoder.keys())
196+
197+
191198
@classmethod
192199
def from_pretrained(cls, model_id, vocab_type=None):
193200
"""

0 commit comments

Comments
 (0)