Skip to content

Commit 0f7c501

Browse files
committed
Update the model for the Python notebook
1 parent 53c435f commit 0f7c501

File tree

7 files changed

+74
-26
lines changed

7 files changed

+74
-26
lines changed

notebooks/example_python.ipynb

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,40 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 13,
5+
"execution_count": 12,
66
"id": "23530ae1",
77
"metadata": {},
88
"outputs": [
99
{
1010
"name": "stdout",
1111
"output_type": "stream",
1212
"text": [
13-
"Loading Lark base parser from cache: cache/parsers/python_lalr_parser.pkl\n"
13+
"[2025-06-03 16:10:30,358-root] - Loading model meta-llama/Llama-3.2-1B with device:cuda, device_map:auto, torch_dtype:torch.bfloat16\n",
14+
"[2025-06-03 16:10:31,670-root] - Loading model meta-llama/Llama-3.2-1B with device:cuda, device_map:auto, torch_dtype:torch.bfloat16\n"
1415
]
1516
}
1617
],
1718
"source": [
19+
"import sys\n",
20+
"sys.path.append('..') # Assuming we are in the root directory\n",
1821
"from syncode import Syncode\n",
1922
"import warnings\n",
2023
"warnings.filterwarnings('ignore')\n",
2124
"\n",
22-
"model_name = \"WizardLM/WizardCoder-1B-V1.0\"\n",
25+
"model_name = \"meta-llama/Llama-3.2-1B\"\n",
2326
"\n",
2427
"# Load the unconstrained original model\n",
2528
"llm = Syncode(model = model_name, mode='original', max_new_tokens=200)\n",
2629
"\n",
2730
"# Load the Syncode augmented model\n",
28-
"syn_llm = Syncode(model = model_name, mode='grammar_mask', grammar='python', parse_output_only=False)"
31+
"syn_llm = Syncode(\n",
32+
" model = model_name, \n",
33+
" mode='grammar_mask', \n",
34+
" grammar='python', \n",
35+
" parse_output_only=False,\n",
36+
" indent=True,\n",
37+
" opp=False\n",
38+
" )"
2939
]
3040
},
3141
{
@@ -39,10 +49,17 @@
3949
},
4050
{
4151
"cell_type": "code",
42-
"execution_count": 16,
52+
"execution_count": 8,
4353
"id": "490cddb3",
4454
"metadata": {},
4555
"outputs": [
56+
{
57+
"name": "stderr",
58+
"output_type": "stream",
59+
"text": [
60+
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
61+
]
62+
},
4663
{
4764
"name": "stdout",
4865
"output_type": "stream",
@@ -51,10 +68,11 @@
5168
" '''Return if prime'''\n",
5269
" if n < 2:\n",
5370
" return False\n",
54-
" for i in range(2, int(n**0.5)+1):\n",
71+
" for i in range(2, n):\n",
5572
" if n % i == 0:\n",
5673
" return False\n",
57-
" return True\n"
74+
" return True\n",
75+
"\n"
5876
]
5977
},
6078
{
@@ -64,7 +82,7 @@
6482
"traceback": [
6583
"Traceback \u001b[0;36m(most recent call last)\u001b[0m:\n",
6684
"\u001b[0m File \u001b[1;32m~/anaconda3/envs/codex/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577\u001b[0m in \u001b[1;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\u001b[0m\n",
67-
"\u001b[0;36m Cell \u001b[0;32mIn[16], line 4\u001b[0;36m\n\u001b[0;31m exec(output)\u001b[0;36m\n",
85+
"\u001b[0;36m Cell \u001b[0;32mIn[8], line 4\u001b[0;36m\n\u001b[0;31m exec(output)\u001b[0;36m\n",
6886
"\u001b[0;36m File \u001b[0;32m<string>:3\u001b[0;36m\u001b[0m\n\u001b[0;31m if n < 2:\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unindent does not match any outer indentation level\n"
6987
]
7088
}
@@ -78,22 +96,25 @@
7896
},
7997
{
8098
"cell_type": "code",
81-
"execution_count": 15,
99+
"execution_count": 13,
82100
"id": "76cd93f5",
83101
"metadata": {},
84102
"outputs": [
103+
{
104+
"name": "stderr",
105+
"output_type": "stream",
106+
"text": [
107+
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
108+
]
109+
},
85110
{
86111
"name": "stdout",
87112
"output_type": "stream",
88113
"text": [
89114
"def is_prime(n):\n",
90115
" '''Return if prime'''\n",
91-
" if n < 2:\n",
92-
" return False\n",
93-
" for i in range(2, int(n**0.5) + 1):\n",
94-
" if n % i == 0:\n",
95-
" return False\n",
96-
" return True\n"
116+
" return n > 1 and all(n % i!= 0 for i in range(2, int(n**0.5) + 1))\n",
117+
"\n"
97118
]
98119
}
99120
],

syncode/grammar_mask/grammar_constrainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
2-
import syncode.common as common
3-
from transformers import LogitsProcessor, PreTrainedTokenizer
2+
from transformers import PreTrainedTokenizer
43
from syncode.mask_store.byte_tokenizer import ByteTokenizer
54
from syncode.parse_result import AcceptSequence, RemainderState
65
from syncode.parsers.incremental_parser import IncrementalParser, ParseResult
@@ -53,7 +52,9 @@ def __init__(self,
5352
batch_size=1,
5453
dev_mode=False,
5554
parser='lalr',
56-
mode='grammar_mask'):
55+
mode='grammar_mask',
56+
indent=False
57+
):
5758

5859
self.tokenizer = tokenizer
5960
self.byte_tokenizer = byte_tokenizer
@@ -82,6 +83,7 @@ def __init__(self,
8283
tokenizer=self.tokenizer,
8384
use_cache=use_cache,
8485
mode=mode, # Controls approximation strategy for token masking
86+
indent=indent
8587
)
8688

8789

syncode/grammar_mask/logits_processor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def __init__(self,
2929
num_samples=1,
3030
dev_mode=False,
3131
parser='lalr',
32-
mode='grammar_mask'):
32+
mode='grammar_mask',
33+
indent=False
34+
):
3335

3436
self.tokenizer = tokenizer
3537
self.byte_tokenizer = ByteTokenizer(tokenizer)
@@ -44,7 +46,8 @@ def __init__(self,
4446
batch_size=num_samples,
4547
dev_mode=dev_mode,
4648
parser=parser,
47-
mode=mode
49+
mode=mode,
50+
indent=indent
4851
)
4952

5053
def reset(self):

syncode/infer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
seed: Optional[int] = None,
5252
opp: bool = True,
5353
device_map: Optional[str] = None,
54+
indent: bool = False,
5455
**kwargs
5556
):
5657
# Check inputs
@@ -102,6 +103,7 @@ def __init__(
102103
dev_mode=dev_mode,
103104
parser=parser,
104105
mode=mode,
106+
indent=indent
105107
)
106108

107109
# Set default max new tokens if not provided

syncode/mask_store/mask_store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def __init__(self,
5959
self._vocab,
6060
eos_token_id=self.eos_token_id,
6161
special_token_ids=self.special_token_ids,
62-
indent=indent, mode=mode
62+
indent=indent,
63+
mode=mode
6364
)
6465
terminal_names = [terminal.name for terminal in terminals]
6566

@@ -106,8 +107,10 @@ def init_mask_store(
106107
if use_cache and os.path.exists(fsm_path):
107108
try:
108109
with open(fsm_path, 'rb') as f:
109-
mask_store = pickle.load(f)
110-
return mask_store
110+
mask_store: MaskStore = pickle.load(f)
111+
if mask_store.indentation == indent:
112+
return mask_store
113+
111114
except Exception as e:
112115
logger.warning(f"Error loading mask store: {e}")
113116

syncode/parsers/incremental_parser.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,20 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
166166
self._handle_parsing_error(lexer_tokens, token, e)
167167

168168
# Compute current terminal string
169-
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)
169+
remainder_state, current_term_str, final_terminal = self._get_remainder(
170+
partial_code,
171+
lexing_incomplete=lexing_incomplete,
172+
parse_incomplete=parse_incomplete
173+
)
170174

171-
return ParseResult.from_accept_terminals(self.cur_ac_terminals, self.next_ac_terminals, current_term_str, remainder_state, final_terminal=final_terminal, ignore_terminals=self.base_parser.lexer_conf.ignore)
175+
return ParseResult.from_accept_terminals(
176+
self.cur_ac_terminals,
177+
self.next_ac_terminals,
178+
current_term_str,
179+
remainder_state,
180+
final_terminal=final_terminal,
181+
ignore_terminals=self.base_parser.lexer_conf.ignore)
182+
172183

173184
def _get_remainder(self, code, lexing_incomplete=False, parse_incomplete=False):
174185
final_terminal = None

syncode/parsers/python_parser.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,21 @@ def get_acceptable_next_terminals(self, partial_code) -> ParseResult:
8585
remainder_state, final_terminal = None, None
8686

8787
# Compute current terminal string
88-
remainder_state, current_term_str, final_terminal = self._get_remainder(partial_code, lexing_incomplete=lexing_incomplete, parse_incomplete=parse_incomplete)
88+
remainder_state, current_term_str, final_terminal = self._get_remainder(
89+
partial_code,
90+
lexing_incomplete=lexing_incomplete,
91+
parse_incomplete=parse_incomplete
92+
)
8993

9094
cur_ac_terminals = self.cur_ac_terminals
9195
next_ac_terminals = self.next_ac_terminals
9296
next_ac_indents = None
9397

9498
if remainder_state == RemainderState.MAYBE_COMPLETE or remainder_state == RemainderState.COMPLETE:
9599
if len(self.parsed_lexer_tokens) > 0 and self.parsed_lexer_tokens[-1].type == '_NL':
100+
# Calculate the last indetation level
96101
last_indent_str = self.parsed_lexer_tokens[-1].value.split('\n')[-1]
102+
97103
last_indent = last_indent_str.count(' ') + last_indent_str.count('\t') * self.tab_len
98104
next_ac_indents = [indent-last_indent for indent in self.indent_level if indent >= last_indent]
99105

0 commit comments

Comments
 (0)