-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathMutate.jl
431 lines (398 loc) · 13.9 KB
/
Mutate.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
module MutateModule
using DynamicExpressions:
AbstractExpressionNode,
Node,
preserve_sharing,
copy_node,
count_nodes,
count_constants,
simplify_tree!,
combine_operators
using ..CoreModule:
Options, MutationWeights, Dataset, RecordType, sample_mutation, DATA_TYPE, LOSS_TYPE
using ..ComplexityModule: compute_complexity
using ..LossFunctionsModule: score_func, score_func_batched
using ..CheckConstraintsModule: check_constraints
using ..AdaptiveParsimonyModule: RunningSearchStatistics
using ..PopMemberModule: PopMember
using ..MutationFunctionsModule:
gen_random_tree_fixed_size,
mutate_constant,
mutate_operator,
swap_operands,
append_random_op,
prepend_random_op,
insert_random_op,
delete_random_op!,
crossover_trees,
form_random_connection!,
break_random_connection!
using ..ConstantOptimizationModule: optimize_constants
using ..RecorderModule: @recorder
function condition_mutation_weights!(
weights::MutationWeights, member::PopMember, options::Options, curmaxsize::Int
)
if !preserve_sharing(typeof(member.tree))
weights.form_connection = 0.0
weights.break_connection = 0.0
end
if member.tree.degree == 0
# If equation is too small, don't delete operators
# or simplify
weights.mutate_operator = 0.0
weights.swap_operands = 0.0
weights.delete_node = 0.0
weights.simplify = 0.0
if !member.tree.constant
weights.optimize = 0.0
weights.mutate_constant = 0.0
end
return nothing
end
if !any(node -> node.degree == 2, member.tree)
# swap is implemented only for binary ops
weights.swap_operands = 0.0
end
#More constants => more likely to do constant mutation
n_constants = count_constants(member.tree)
weights.mutate_constant *= min(8, n_constants) / 8.0
complexity = compute_complexity(member, options)
if complexity >= curmaxsize
# If equation is too big, don't add new operators
weights.add_node = 0.0
weights.insert_node = 0.0
end
if !options.should_simplify
weights.simplify = 0.0
end
return nothing
end
# Go through one simulated options.annealing mutation cycle
# exp(-delta/T) defines probability of accepting a change
function next_generation(
dataset::D,
member::P,
temperature,
curmaxsize::Int,
running_search_statistics::RunningSearchStatistics,
options::Options;
tmp_recorder::RecordType,
)::Tuple{
P,Bool,Float64
} where {T,L,D<:Dataset{T,L},N<:AbstractExpressionNode{T},P<:PopMember{T,L,N}}
parent_ref = member.ref
mutation_accepted = false
num_evals = 0.0
#TODO - reconsider this
beforeScore, beforeLoss = if options.batching
num_evals += (options.batch_size / dataset.n)
score_func_batched(dataset, member, options)
else
member.score, member.loss
end
nfeatures = dataset.nfeatures
weights = copy(options.mutation_weights)
condition_mutation_weights!(weights, member, options, curmaxsize)
mutation_choice = sample_mutation(weights)
successful_mutation = false
#TODO: Currently we dont take this \/ into account
is_success_always_possible = true
attempts = 0
max_attempts = 10
#############################################
# Mutations
#############################################
local tree
while (!successful_mutation) && attempts < max_attempts
tree = copy_node(member.tree)
successful_mutation = true
if mutation_choice == :mutate_constant
tree = mutate_constant(tree, temperature, options)
@recorder tmp_recorder["type"] = "constant"
is_success_always_possible = true
# Mutating a constant shouldn't invalidate an already-valid function
elseif mutation_choice == :mutate_operator
tree = mutate_operator(tree, options)
@recorder tmp_recorder["type"] = "operator"
is_success_always_possible = true
# Can always mutate to the same operator
elseif mutation_choice == :swap_operands
tree = swap_operands(tree)
@recorder tmp_recorder["type"] = "swap_operands"
is_success_always_possible = true
elseif mutation_choice == :add_node
if rand() < 0.5
tree = append_random_op(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "append_op"
else
tree = prepend_random_op(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "prepend_op"
end
is_success_always_possible = false
# Can potentially have a situation without success
elseif mutation_choice == :insert_node
tree = insert_random_op(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "insert_op"
is_success_always_possible = false
elseif mutation_choice == :delete_node
tree = delete_random_op!(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "delete_op"
is_success_always_possible = true
elseif mutation_choice == :simplify
@assert options.should_simplify
simplify_tree!(tree, options.operators)
if tree isa Node
tree = combine_operators(tree, options.operators)
end
@recorder tmp_recorder["type"] = "partial_simplify"
mutation_accepted = true
return (
PopMember(
tree,
beforeScore,
beforeLoss,
options;
parent=parent_ref,
deterministic=options.deterministic,
),
mutation_accepted,
num_evals,
)
is_success_always_possible = true
# Simplification shouldn't hurt complexity; unless some non-symmetric constraint
# to commutative operator...
elseif mutation_choice == :randomize
# We select a random size, though the generated tree
# may have fewer nodes than we request.
tree_size_to_generate = rand(1:curmaxsize)
tree = gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T)
@recorder tmp_recorder["type"] = "regenerate"
is_success_always_possible = true
elseif mutation_choice == :optimize
cur_member = PopMember(
tree,
beforeScore,
beforeLoss,
options,
compute_complexity(member, options);
parent=parent_ref,
deterministic=options.deterministic,
)
cur_member, new_num_evals = optimize_constants(dataset, cur_member, options)
num_evals += new_num_evals
@recorder tmp_recorder["type"] = "optimize"
mutation_accepted = true
return (cur_member, mutation_accepted, num_evals)
is_success_always_possible = true
elseif mutation_choice == :do_nothing
@recorder begin
tmp_recorder["type"] = "identity"
tmp_recorder["result"] = "accept"
tmp_recorder["reason"] = "identity"
end
mutation_accepted = true
return (
PopMember(
tree,
beforeScore,
beforeLoss,
options,
compute_complexity(member, options);
parent=parent_ref,
deterministic=options.deterministic,
),
mutation_accepted,
num_evals,
)
elseif mutation_choice == :form_connection
tree = form_random_connection!(tree)
@recorder tmp_recorder["type"] = "form_connection"
is_success_always_possible = true
elseif mutation_choice == :break_connection
tree = break_random_connection!(tree)
@recorder tmp_recorder["type"] = "break_connection"
is_success_always_possible = true
else
error("Unknown mutation choice: $mutation_choice")
end
successful_mutation =
successful_mutation && check_constraints(tree, options, curmaxsize)
attempts += 1
end
#############################################
if !successful_mutation
@recorder begin
tmp_recorder["result"] = "reject"
tmp_recorder["reason"] = "failed_constraint_check"
end
mutation_accepted = false
return (
PopMember(
copy_node(member.tree),
beforeScore,
beforeLoss,
options,
compute_complexity(member, options);
parent=parent_ref,
deterministic=options.deterministic,
),
mutation_accepted,
num_evals,
)
end
if options.batching
afterScore, afterLoss = score_func_batched(dataset, tree, options)
num_evals += (options.batch_size / dataset.n)
else
afterScore, afterLoss = score_func(dataset, tree, options)
num_evals += 1
end
if isnan(afterScore)
@recorder begin
tmp_recorder["result"] = "reject"
tmp_recorder["reason"] = "nan_loss"
end
mutation_accepted = false
return (
PopMember(
copy_node(member.tree),
beforeScore,
beforeLoss,
options,
compute_complexity(member, options);
parent=parent_ref,
deterministic=options.deterministic,
),
mutation_accepted,
num_evals,
)
end
probChange = 1.0
if options.annealing
delta = afterScore - beforeScore
probChange *= exp(-delta / (temperature * options.alpha))
end
newSize = -1
if options.use_frequency
oldSize = compute_complexity(member, options)
newSize = compute_complexity(tree, options)
old_frequency = if (0 < oldSize <= options.maxsize)
running_search_statistics.normalized_frequencies[oldSize]
else
1e-6
end
new_frequency = if (0 < newSize <= options.maxsize)
running_search_statistics.normalized_frequencies[newSize]
else
1e-6
end
probChange *= old_frequency / new_frequency
end
if probChange < rand()
@recorder begin
tmp_recorder["result"] = "reject"
tmp_recorder["reason"] = "annealing_or_frequency"
end
mutation_accepted = false
return (
PopMember(
copy_node(member.tree),
beforeScore,
beforeLoss,
options,
compute_complexity(member, options);
parent=parent_ref,
deterministic=options.deterministic,
),
mutation_accepted,
num_evals,
)
else
@recorder begin
tmp_recorder["result"] = "accept"
tmp_recorder["reason"] = "pass"
end
mutation_accepted = true
return (
PopMember(
tree,
afterScore,
afterLoss,
options,
newSize;
parent=parent_ref,
deterministic=options.deterministic,
),
mutation_accepted,
num_evals,
)
end
end
"""Generate a generation via crossover of two members."""
function crossover_generation(
member1::P, member2::P, dataset::D, curmaxsize::Int, options::Options
)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},P<:PopMember{T,L}}
tree1 = member1.tree
tree2 = member2.tree
crossover_accepted = false
# We breed these until constraints are no longer violated:
child_tree1, child_tree2 = crossover_trees(tree1, tree2)
num_tries = 1
max_tries = 10
num_evals = 0.0
afterSize1 = -1
afterSize2 = -1
while true
afterSize1 = compute_complexity(child_tree1, options)
afterSize2 = compute_complexity(child_tree2, options)
# Both trees satisfy constraints
if check_constraints(child_tree1, options, curmaxsize, afterSize1) &&
check_constraints(child_tree2, options, curmaxsize, afterSize2)
break
end
if num_tries > max_tries
crossover_accepted = false
return member1, member2, crossover_accepted, num_evals # Fail.
end
child_tree1, child_tree2 = crossover_trees(tree1, tree2)
num_tries += 1
end
if options.batching
afterScore1, afterLoss1 = score_func_batched(
dataset, child_tree1, options; complexity=afterSize1
)
afterScore2, afterLoss2 = score_func_batched(
dataset, child_tree2, options; complexity=afterSize2
)
num_evals += 2 * (options.batch_size / dataset.n)
else
afterScore1, afterLoss1 = score_func(
dataset, child_tree1, options; complexity=afterSize1
)
afterScore2, afterLoss2 = score_func(
dataset, child_tree2, options; complexity=afterSize2
)
num_evals += options.batch_size / dataset.n
end
baby1 = PopMember(
child_tree1,
afterScore1,
afterLoss1,
options,
afterSize1;
parent=member1.ref,
deterministic=options.deterministic,
)
baby2 = PopMember(
child_tree2,
afterScore2,
afterLoss2,
options,
afterSize2;
parent=member2.ref,
deterministic=options.deterministic,
)
crossover_accepted = true
return baby1, baby2, crossover_accepted, num_evals
end
end