This repository has been archived by the owner on Dec 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_training_script.py
191 lines (141 loc) · 5.37 KB
/
generate_training_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import re
import json
from sys import argv, exit
# if the first line of a code cell corresponds to this,
# then it is considered a training cell
ARG_CELL_PATTERN = re.compile(r'^# arguments cell$')
TRAIN_CELL_PATTERN = re.compile(r'^# training cell$')
IMPORT_PATTERN = re.compile(r'^import .*$')
TRAIN_FILE_NAME = 'train.py'
INDENTATION = 2
def is_marked_cell(cell, marker):
return cell['cell_type'] == 'code' and cell['source'] and re.match(marker, cell['source'][0].strip())
def get_code_cells_with_matching_first_line(notebook, marker):
return [cell for cell in notebook['cells'] if is_marked_cell(cell, marker)]
def is_comment(line):
return line.strip().startswith('#')
def extract_args(arg_cells):
args = []
for cell in arg_cells:
for line in cell['source']:
if is_comment(line):
continue
if line.count('=') == 1: # assignment found
clean_line = strip_trailing_comment(line)
pair = tuple([s.strip() for s in clean_line.split('=')])
args.append(pair)
return args
def strip_trailing_comment(line):
# this is difficult to do reliably because a # can be nested in a string
# and there are escape characters and stuff, a compiler backend is needed
# but I'll keep it simple and unreliable instead!
pos = line.find('#')
if pos != -1:
return line[:pos].strip()
return line.strip()
_pattern_lookups = [
('int', re.compile(r'^[-+]?\d+$')), # int
('float', re.compile(r'^[-+]?\d+.\d*|\d+/\d+$')), # rational or float
('str', re.compile(r'^\'.*\'|".*"]$')), # string
]
def detect_type(str_value):
for typ, pattern in _pattern_lookups:
if re.match(pattern, str_value):
return typ
raise ValueError(f'Could not match type of: {str_value}')
def create_parser_arg(parser, arg_name, arg_value):
arg_type = detect_type(arg_value)
return f'{parser}.add_argument(\'--{arg_name}\', default={arg_value}, type={arg_type})'
class CodeBuilder:
def __init__(self):
self.level = 0
self.code = []
def indent(self):
self.level += 1
def unindent(self):
self.level = max(self.level-1, 0)
def add(self, s):
if not s.endswith('\n'):
s += '\n'
fmt_s = self.level * INDENTATION * ' ' + s
self.code.append(fmt_s)
def build(self):
return ''.join(self.code)
def make_parser_code(args):
parser = 'parser'
builder = CodeBuilder()
builder.add('def parse_arguments():')
builder.indent()
builder.add(f'{parser} = ArgumentParser()')
for name, value in args:
builder.add(create_parser_arg(parser, name, value))
builder.add(f'return {parser}.parse_args()')
return builder.build()
def make_arg_pattern(args):
or_pattern_string = r'|'.join([name for name, _ in args])
combined_pattern = re.compile(r'(\W)(' + or_pattern_string + r')(\W)')
return combined_pattern
_remove_pattern = re.compile(r'^\s*(?:plt.*)?(?:#.*)?\s*$')
def substitute_arg(m):
return m.group(1) + 'args.' + m.group(2).lower() + m.group(3)
def make_train_code(args, training_cells):
# extract and combine code cells
notebook_code = []
for cell in training_cells:
for line in cell['source']:
# filter the code, removing empties, comments and plt stuff
if re.match(_remove_pattern, line):
continue
notebook_code.append(line)
# substitute ARG with args.ARG
arg_pattern = make_arg_pattern(args)
for i in range(len(notebook_code)):
notebook_code[i] = re.sub(arg_pattern, substitute_arg, notebook_code[i])
# build the code
builder = CodeBuilder()
builder.add('def train(args):')
builder.indent()
for line in notebook_code:
builder.add(line)
return builder.build()
def make_main_code():
builder = CodeBuilder()
builder.add('if __name__ == \'__main__\':')
builder.indent()
builder.add('args = parse_arguments()')
builder.add('train(args)')
return builder.build()
def make_import_code(import_cells):
builder = CodeBuilder()
builder.add('from argparse import ArgumentParser')
for cell in import_cells:
for line in cell['source']:
builder.add(line)
return builder.build()
def convert_notebook(path):
with open(path, 'r') as notebook_file:
notebook = json.load(notebook_file)
code_parts = []
# imports
import_cells = get_code_cells_with_matching_first_line(notebook, IMPORT_PATTERN)
code_parts.append(make_import_code(import_cells))
# parser
arg_cells = get_code_cells_with_matching_first_line(notebook, ARG_CELL_PATTERN)
args = extract_args(arg_cells)
lower_args = [(n.lower(), v) for n, v in args]
code_parts.append(make_parser_code(lower_args))
# training code
training_cells = get_code_cells_with_matching_first_line(notebook, TRAIN_CELL_PATTERN)
code_parts.append(make_train_code(args, training_cells))
# main_code
code_parts.append(make_main_code())
with open(TRAIN_FILE_NAME, 'w') as train_file:
for code_part in code_parts:
train_file.write(code_part)
train_file.write('\n\n')
if __name__ == '__main__':
if len(argv) != 2:
print('usage: python generate_training_script.py notebook_path')
exit()
_, notebook_path = argv
convert_notebook(notebook_path)