Skip to content

Commit e4cdf8f

Browse files
committed
Fix some issue and bump version to v0.4.13
1 parent 26b00f6 commit e4cdf8f

File tree

6 files changed

+32
-25
lines changed

6 files changed

+32
-25
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ SynCode depends on HuggingFace [transformers](https://github.com/huggingface/tra
6969

7070
| SynCode version | Required transformers version | Python version |
7171
| -------------- | ----------------------------- | -------------- |
72-
| `v0.4.13` (latest) | `v4.51.0` | 3.6 - 3.12 |
72+
| `v0.4.14` (latest) | `v4.51.3` | 3.6 - 3.12 |
7373

7474
**Note:** Python 3.13 is not currently supported due to dependency constraints.
7575

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "syncode"
7-
version="0.4.13"
7+
version="0.4.14"
88
requires-python = ">=3.6,<3.13"
99
description = "Grammar-guided code generation tool"
1010
readme = "README.md"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
setuptools.setup(
2020
name="syncode",
21-
version="0.4.13",
21+
version="0.4.14",
2222
author="Shubham Ugare",
2323
author_email="shubhamugare@gmail.com",
2424
description="This package provides the tool for grammar augmented LLM generation.",

syncode/grammar_mask/grammar_constrainer.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
120120
self._set_start_from(input_ids)
121121

122122
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=-1)
123-
partial_code, remainder_bytes = self._get_partial_codes(input_ids)[0]
123+
partial_output, remainder_bytes = self._get_partial_outputs(input_ids)[0]
124124

125-
res, skip = self._parse_partial_code(
125+
res, skip = self._parse_partial_output(
126126
idx=0,
127-
partial_code=partial_code,
127+
partial_output=partial_output,
128128
remainder_bytes=remainder_bytes,
129129
accepted_generation=False
130130
)
@@ -142,7 +142,7 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
142142
is_valid = self.dfa_mask_store.is_valid_prefix(res)
143143

144144
if is_valid:
145-
self._update_valid_state(partial_code, 0, res)
145+
self._update_valid_state(partial_output, 0, res)
146146

147147
return is_valid
148148

@@ -163,11 +163,11 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
163163
torch.FloatTensor: The masked scores.
164164
"""
165165
self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start
166-
partial_codes = self._get_partial_codes(input_ids)
166+
partial_outputs = self._get_partial_outputs(input_ids)
167167

168-
for idx, (partial_code, remainder_bytes) in enumerate(partial_codes):
168+
for idx, (partial_output, remainder_bytes) in enumerate(partial_outputs):
169169
# 1. Parsing
170-
res, skip = self._parse_partial_code(idx, partial_code, remainder_bytes, accepted_generation=True)
170+
res, skip = self._parse_partial_output(idx, partial_output, remainder_bytes, accepted_generation=True)
171171
if skip: continue
172172

173173
# 2. Computing the accept mask
@@ -187,23 +187,29 @@ def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) ->
187187

188188
return scores
189189

190-
def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: bytes, accepted_generation=True) -> tuple[ParseResult, bool]:
190+
def _parse_partial_output(
191+
self,
192+
idx: int,
193+
partial_output: str,
194+
remainder_bytes: bytes,
195+
accepted_generation=True
196+
) -> tuple[ParseResult, bool]:
191197
"""
192198
Parse the partial code and return the result.
193199
"""
194200
skip = False
195201
res = None
196202

197203
try:
198-
res = self.inc_parser.get_acceptable_next_terminals(partial_code)
204+
res = self.inc_parser.get_acceptable_next_terminals(partial_output)
199205

200206
if len(remainder_bytes) > 0:
201207
res.remainder_state = RemainderState.INCOMPLETE
202208
res.remainder = res.remainder.encode('utf-8') + remainder_bytes
203209
else:
204210
res.remainder = res.remainder.encode('utf-8')
205211

206-
self._update_valid_state(partial_code, idx, res)
212+
self._update_valid_state(partial_output, idx, res)
207213
except Exception as e:
208214
if self.dev_mode == True and accepted_generation:
209215
logger.info("-"*50)
@@ -213,45 +219,45 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
213219
elif self.parse_failed == False and accepted_generation:
214220
self.parse_failed = True
215221
logger.info("-"*50)
216-
logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_code}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
222+
logger.info(f"Parsing failed! Falling back to unconstrained decoding.\nException: {e}\nPartial code: {partial_output}\nParsed lexical tokens: {self.inc_parser.parsed_lexer_tokens}")
217223
logger.info("-"*50)
218224
skip = True
219225
return res, skip
220226

221-
def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
227+
def _get_partial_outputs(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
222228
"""
223229
Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
224230
"""
225231
output = []
226232
for idx in range(len(input_ids)):
227233
if self.parse_output_only:
228-
partial_code, remainder_bytes = self._bytes_to_string(
234+
partial_output, remainder_bytes = self._bytes_to_string(
229235
self.byte_tokenizer.decode(
230236
input_ids[idx, self.start_from:].tolist(), skip_special_tokens=True)
231237
)
232238
else:
233-
partial_code, remainder_bytes = self._bytes_to_string(
239+
partial_output, remainder_bytes = self._bytes_to_string(
234240
self.byte_tokenizer.decode(
235241
input_ids[idx].tolist(), skip_special_tokens=True)
236242
)
237-
output.append((partial_code, remainder_bytes))
243+
output.append((partial_output, remainder_bytes))
238244
return output
239245

240-
def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
246+
def _update_valid_state(self, partial_output: str, idx: int, r: ParseResult):
241247
"""
242248
This a simple heuristic to cut off the generated output at the end of the function.
243249
TODO: Put this under a flag to enable/disable this heuristic.
244250
"""
245251
if idx < len(self.function_ends):
246252
if r.function_end: # If the function end is not None, then the last valid state is the function end
247253
if self.function_ends[idx] is None: self.function_ends[idx] = []
248-
self.function_ends[idx].append(len(partial_code) - len(r.remainder))
254+
self.function_ends[idx].append(len(partial_output) - len(r.remainder))
249255

250256
if idx < len(self.last_valid_state):
251257
for accept_seq in r.accept_sequences:
252258
# 'EOF' is special terminal since $END does not work with python
253259
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
254-
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)
260+
self.last_valid_state[idx] = len(partial_output) - len(r.remainder)
255261

256262
@staticmethod
257263
def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:

syncode/mask_store/byte_tokenizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,12 @@ def __init__(self, tokenizer, vocab_type=None):
188188
# Cache special token IDs as a set for faster lookups
189189
self.special_token_ids = set(getattr(tokenizer, "all_special_ids", []))
190190

191+
# NOTE: This seems to be problematic in some cases where regular tokens like "\t" are treated as special tokens
191192
# Added tokens are typically special tokens
192193
# if added_tokens_decoder is not None self.tokenizer.added_tokens_decoder.keys()
193194
# to special_token_ids
194-
if hasattr(tokenizer, "added_tokens_decoder"):
195-
self.special_token_ids.update(tokenizer.added_tokens_decoder.keys())
195+
# if hasattr(tokenizer, "added_tokens_decoder"):
196+
# self.special_token_ids.update(tokenizer.added_tokens_decoder.keys())
196197

197198

198199
@classmethod

syncode/parsers/python_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def _get_indentation(self, partial_code) -> int:
3535
return tab_len
3636

3737
def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
38-
# Stores the sequence of tokens that the parser has seen in the order
39-
interactive = self.interactive
4038
lexer_tokens, lexing_incomplete = self._lex_code(partial_code)
39+
self.next_ac_terminals = self._accepts(self.interactive)
4140

4241
# Restore the previous state of the parser
4342
self._restore_recent_parser_state(lexer_tokens)
43+
interactive = self.interactive
4444

4545
next_ac_indents = None
4646

0 commit comments

Comments
 (0)