-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevolutionary_sampler.py
2701 lines (2124 loc) · 133 KB
/
evolutionary_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
from collections import OrderedDict, Counter
from datetime import datetime, timedelta
from functools import wraps
import glob
import os
import platform
import re
import signal
import shutil
import sys
import tempfile
import traceback
import typing
import logging
logging.getLogger('git').setLevel(logging.WARNING)
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
from git.repo import Repo
import numpy as np
import tatsu
import tatsu.ast
import tatsu.grammars
import torch
from tqdm import tqdm, trange
from viztracer import VizTracer
import wandb
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))
import src # type: ignore
# from ast_parser import SETUP, PREFERENCES, TERMINAL, SCORING=
import ast_printer
import ast_parser
from ast_context_fixer import ASTContextFixer
from ast_counter_sampler import *
from ast_counter_sampler import parse_or_load_counter, ASTSampler, RegrowthSampler, SamplingException, MCMC_REGRWOTH, PRIOR_COUNT, LENGTH_PRIOR
from ast_mcmc_regrowth import _load_pickle_gzip, InitialProposalSamplerType, create_initial_proposal_sampler
from ast_utils import *
from evolutionary_sampler_behavioral_features import build_behavioral_features_featurizer, BehavioralFeatureSet, BehavioralFeaturizer, DEFAULT_N_COMPONENTS
from evolutionary_sampler_diversity import *
from evolutionary_sampler_utils import Selector, UCBSelector, ThompsonSamplingSelector
from fitness_energy_utils import load_model_and_feature_columns, load_data_from_path, save_data, get_data_path, DEFAULT_SAVE_MODEL_NAME, evaluate_single_game_energy_contributions
from fitness_features import *
from fitness_ngram_models import *
from fitness_ngram_models import VARIABLE_PATTERN
from latest_model_paths import LATEST_AST_N_GRAM_MODEL_PATH, LATEST_FITNESS_FEATURIZER_PATH,\
LATEST_FITNESS_FUNCTION_DATE_ID, LATEST_REAL_GAMES_PATH
#LATEST_SPECIFIC_OBJECTS_AST_N_GRAM_MODEL_PATH, LATEST_SPECIFIC_OBJECTS_FITNESS_FEATURIZER_PATH, LATEST_SPECIFIC_OBJECTS_FITNESS_FUNCTION_DATE_ID
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'reward-machine')))
from compile_predicate_statistics_full_database import DUCKDB_TMP_FOLDER, DUCKDB_QUERY_LOG_FOLDER # type: ignore
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
import multiprocessing
from multiprocessing import pool as mpp
# import multiprocess as multiprocessing
# from multiprocess import pool as mpp
def istarmap(self, func, iterable, chunksize=1):
"""starmap-version of imap
"""
self._check_running()
if chunksize < 1:
raise ValueError(
"Chunksize must be 1+, not {0:n}".format(
chunksize))
task_batches = mpp.Pool._get_tasks(func, iterable, chunksize) # type: ignore
result = mpp.IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, # type: ignore
mpp.starmapstar, # type: ignore
task_batches),
result._set_length # type: ignore
))
return (item for chunk in result for item in chunk)
mpp.Pool.istarmap = istarmap # type: ignore
parser = argparse.ArgumentParser(description='Evolutionary Sampler')
parser.add_argument('--grammar-file', type=str, default=DEFAULT_GRAMMAR_FILE)
parser.add_argument('--parse-counter', action='store_true')
parser.add_argument('--counter-output-path', type=str, default=DEFAULT_COUNTER_OUTPUT_PATH)
parser.add_argument('--use-specific-objects-models', action='store_true')
DEFAULT_FITNESS_FUNCTION_DATE_ID = LATEST_FITNESS_FUNCTION_DATE_ID
parser.add_argument('--fitness-function-date-id', type=str, default=DEFAULT_FITNESS_FUNCTION_DATE_ID)
DEFAULT_FITNESS_FEATURIZER_PATH = LATEST_FITNESS_FEATURIZER_PATH
parser.add_argument('--fitness-featurizer-path', type=str, default=DEFAULT_FITNESS_FEATURIZER_PATH)
DEFAULT_FITNESS_FUNCTION_MODEL_NAME = DEFAULT_SAVE_MODEL_NAME
parser.add_argument('--fitness-function-model-name', type=str, default=DEFAULT_FITNESS_FUNCTION_MODEL_NAME)
parser.add_argument('--no-flip-fitness-sign', action='store_true')
DEFAULT_POPULATION_SIZE = 100
parser.add_argument('--population-size', type=int, default=DEFAULT_POPULATION_SIZE)
DEFAULT_N_STEPS = 100
parser.add_argument('--n-steps', type=int, default=DEFAULT_N_STEPS)
# TODO: rewrite these arguments to the things this sampler actually needs
# DEFAULT_PLATEAU_PATIENCE_STEPS = 1000
# parser.add_argument('--plateau-patience-steps', type=int, default=DEFAULT_PLATEAU_PATIENCE_STEPS)
# DEFAULT_MAX_STEPS = 20000
# parser.add_argument('--max-steps', type=int, default=DEFAULT_MAX_STEPS)
# DEFAULT_N_SAMPLES_PER_STEP = 1
# parser.add_argument('--n-samples-per-step', type=int, default=DEFAULT_N_SAMPLES_PER_STEP)
# parser.add_argument('--non-greedy', action='store_true')
# DEFAULT_ACCEPTANCE_TEMPERATURE = 1.0
# parser.add_argument('--acceptance-temperature', type=float, default=DEFAULT_ACCEPTANCE_TEMPERATURE)
# MICROBIAL_GA = 'microbial_ga'
# MICROBIAL_GA_WITH_BEAM_SEARCH = 'microbial_ga_with_beam_search'
# WEIGHTED_BEAM_SEARCH = 'weighted_beam_search'
MAP_ELITES = 'map_elites'
# SAMPLER_TYPES = [MICROBIAL_GA, MICROBIAL_GA_WITH_BEAM_SEARCH, WEIGHTED_BEAM_SEARCH, MAP_ELITES]
SAMPLER_TYPES = [MAP_ELITES]
parser.add_argument('--sampler-type', type=str, required=True, choices=SAMPLER_TYPES)
# parser.add_argument('--diversity-scorer-type', type=str, required=False, choices=DIVERSITY_SCORERS)
# parser.add_argument('--diversity-scorer-k', type=int, default=1)
# parser.add_argument('--diversity-score-threshold', type=float, default=0.0)
# parser.add_argument('--diversity-threshold-absolute', action='store_true')
# parser.add_argument('--microbial-ga-crossover-full-sections', action='store_true')
# parser.add_argument('--microbial-ga-crossover-type', type=int, default=2)
# DEFAULT_MICROBIAL_GA_MIN_N_CROSSOVERS = 1
# parser.add_argument('--microbial-ga-n-min-loser-crossovers', type=int, default=DEFAULT_MICROBIAL_GA_MIN_N_CROSSOVERS)
# DEFAULT_MICROBIAL_GA_MAX_N_CROSSOVERS = 5
# parser.add_argument('--microbial-ga-n-max-loser-crossovers', type=int, default=DEFAULT_MICROBIAL_GA_MAX_N_CROSSOVERS)
# DEFAULT_BEAM_SEARCH_K = 10
# parser.add_argument('--beam-search-k', type=int, default=DEFAULT_BEAM_SEARCH_K)
DEFAULT_GENERATION_SIZE = 1024
parser.add_argument('--map-elites-generation-size', type=int, default=DEFAULT_GENERATION_SIZE)
parser.add_argument('--map-elites-key-type', type=int, default=0)
parser.add_argument('--map-elites-weight-strategy', type=int, default=0)
parser.add_argument('--map-elites-initialization-strategy', type=int, default=0)
parser.add_argument('--map-elites-population-seed-path', type=str, default=None)
parser.add_argument('--map-elites-initial-candidate-pool-size', type=int, default=None)
parser.add_argument('--map-elites-use-crossover', action='store_true')
parser.add_argument('--map-elites-use-cognitive-operators', action='store_true')
features_group = parser.add_mutually_exclusive_group(required=True)
features_group.add_argument('--map-elites-behavioral-features-key', type=str, default=None)
features_group.add_argument('--map-elites-custom-behavioral-features-key', type=str, default=None,
choices=[feature_set_enum.value for feature_set_enum in BehavioralFeatureSet])
features_group.add_argument('--map-elites-pca-behavioral-features-indices', nargs='+', type=int, default=None)
parser.add_argument('--map-elites-pca-behavioral-features-ast-file-path', type=str, default=LATEST_REAL_GAMES_PATH)
parser.add_argument('--map-elites-pca-behavioral-features-bins-per-feature', type=int, default=None)
parser.add_argument('--map-elites-pca-behavioral-features-n-components', type=int, default=None)
parser.add_argument('--map-elites-behavioral-feature-exemplar-distance-type', type=str, default=None)
parser.add_argument('--map-elites-behavioral-feature-exemplar-distance-metric', type=str, default=None)
parser.add_argument('--map-elites-good-threshold', type=float, default=None)
parser.add_argument('--map-elites-great-threshold', type=float, default=None)
parser.add_argument('--prior-only-sampling', action='store_true')
DEFAULT_RELATIVE_PATH = '.'
parser.add_argument('--relative-path', type=str, default=DEFAULT_RELATIVE_PATH)
DEFAULT_NGRAM_MODEL_PATH = LATEST_AST_N_GRAM_MODEL_PATH
parser.add_argument('--ngram-model-path', type=str, default=DEFAULT_NGRAM_MODEL_PATH)
DEFUALT_RANDOM_SEED = 33
parser.add_argument('--random-seed', type=int, default=DEFUALT_RANDOM_SEED)
parser.add_argument('--initial-proposal-type', type=int, default=0)
parser.add_argument('--sample-patience', type=int, default=100)
parser.add_argument('--sample-parallel', action='store_true')
DEFAULT_START_METHOD = 'spawn'
parser.add_argument('--parallel-start-method', type=str, default=DEFAULT_START_METHOD)
parser.add_argument('--parallel-n-workers', type=int, default=8)
parser.add_argument('--parallel-chunksize', type=int, default=1)
parser.add_argument('--parallel-maxtasksperchild', type=int, default=None)
parser.add_argument('--parallel-use-plain-map', action='store_true')
parser.add_argument('--verbose', type=int, default=0)
parser.add_argument('--should-tqdm', action='store_true')
parser.add_argument('--within-step-tqdm', action='store_true')
parser.add_argument('--compute-diversity-metrics', action='store_true')
parser.add_argument('--save-interval', type=int, default=0)
parser.add_argument('--omit-rules', type=str, nargs='*')
parser.add_argument('--omit-tokens', type=str, nargs='*')
parser.add_argument('--sampler-prior-count', action='append', type=int, default=[])
parser.add_argument('--sampler-filter-func-key', type=str)
parser.add_argument('--no-weight-insert-delete-nodes-by-length', action='store_true')
DEFAULT_MAX_SAMPLE_TOTAL_SIZE = 1024 * 1024 * 5 # ~20x larger than the largest game in the real dataset
parser.add_argument('--max-sample-total-size', type=int, default=DEFAULT_MAX_SAMPLE_TOTAL_SIZE)
DEFAULT_MAX_SAMPLE_DEPTH = 16 # 24 # deeper than the deepest game, which has depth 23, and this is for a single node regrowth
parser.add_argument('--max-sample-depth', type=int, default=DEFAULT_MAX_SAMPLE_DEPTH)
DEFAULT_MAX_SAMPLE_NODES = 128 # 256 # longer than most games, but limiting a single node regrowth, not an entire game
parser.add_argument('--max-sample-nodes', type=int, default=DEFAULT_MAX_SAMPLE_NODES)
DEFAULT_OUTPUT_NAME = 'evo-sampler'
parser.add_argument('--output-name', type=str, default=DEFAULT_OUTPUT_NAME)
DEFAULT_OUTPUT_FOLDER = './samples'
parser.add_argument('--output-folder', type=str, default=DEFAULT_OUTPUT_FOLDER)
parser.add_argument('--wandb', action='store_true')
DEFAULT_WANDB_PROJECT = 'game-generation-map-elites'
parser.add_argument('--wandb-project', type=str, default=DEFAULT_WANDB_PROJECT)
DEFAULT_WANDB_ENTITY = 'guy'
parser.add_argument('--wandb-entity', type=str, default=DEFAULT_WANDB_ENTITY)
parser.add_argument('--profile', action='store_true')
parser.add_argument('--profile-output-file', type=str, default='tracer.json')
parser.add_argument('--profile-output-folder', type=str, default=tempfile.gettempdir())
parser.add_argument('--resume', action='store_true')
parser.add_argument('--resume-max-days-back', type=int, default=1)
parser.add_argument('--start-step', type=int, default=0)
class CrossoverType(Enum):
SAME_RULE = 0
SAME_PARENT_INITIAL_SELECTOR = 1
SAME_PARENT_FULL_SELECTOR = 2
SAME_PARENT_RULE = 3
SAME_PARENT_RULE_INITIAL_SELECTOR = 4
SAME_PARENT_RULE_FULL_SELECTOR = 5
def _get_node_key(node: typing.Any):
if isinstance(node, tatsu.ast.AST):
if node.parseinfo.rule is None: # type: ignore
raise ValueError('Node has no rule')
return node.parseinfo.rule # type: ignore
else:
return type(node).__name__
def node_info_to_key(crossover_type: CrossoverType, node_info: ast_parser.ASTNodeInfo):
if crossover_type == CrossoverType.SAME_RULE:
return _get_node_key(node_info[0])
elif crossover_type == CrossoverType.SAME_PARENT_INITIAL_SELECTOR:
return '_'.join([_get_node_key(node_info[1]), str(node_info[2][0])])
elif crossover_type == CrossoverType.SAME_PARENT_FULL_SELECTOR:
return '_'.join([_get_node_key(node_info[1]), *[str(s) for s in node_info[2]]])
elif crossover_type == CrossoverType.SAME_PARENT_RULE:
return '_'.join([_get_node_key(node_info[1]), _get_node_key(node_info[0])])
elif crossover_type == CrossoverType.SAME_PARENT_RULE_INITIAL_SELECTOR:
return '_'.join([_get_node_key(node_info[1]), str(node_info[2][0]), _get_node_key(node_info[0])])
elif crossover_type == CrossoverType.SAME_PARENT_RULE_FULL_SELECTOR:
return '_'.join([_get_node_key(node_info[1]), *[str(s) for s in node_info[2]], _get_node_key(node_info[0])])
else:
raise ValueError(f'Invalid crossover type {crossover_type}')
ASTType: typing.TypeAlias = typing.Union[tuple, tatsu.ast.AST]
T = typing.TypeVar('T')
class SingleStepResults(typing.NamedTuple):
samples: typing.List[ASTType]
fitness_scores: typing.List[float]
parent_infos: typing.List[typing.Dict[str, typing.Any]]
diversity_scores: typing.List[float]
sample_features: typing.List[typing.Dict[str, typing.Any]]
operators: typing.List[str]
def __len__(self):
return len(self.samples)
def accumulate(self, other: 'SingleStepResults'):
self.samples.extend(other.samples)
self.fitness_scores.extend(other.fitness_scores)
if other.parent_infos is not None: self.parent_infos.extend(other.parent_infos)
if other.diversity_scores is not None: self.diversity_scores.extend(other.diversity_scores)
if other.sample_features is not None: self.sample_features.extend(other.sample_features)
if other.operators is not None: self.operators.extend(other.operators)
def no_op_operator(games: typing.Union[ASTType, typing.List[ASTType]], rng=None):
return games
def handle_multiple_inputs(operator):
@wraps(operator)
def wrapped_operator(self, games: typing.Union[ASTType, typing.List[ASTType]], rng: np.random.Generator, *args, **kwargs):
if not isinstance(games, list):
return operator(self, games, rng=rng, *args, **kwargs)
if len(games) == 1:
return operator(self, games[0], rng=rng, *args, **kwargs)
else:
operator_outputs = [operator(self, game, rng=rng, *args, **kwargs) for game in games]
outputs = []
for out in operator_outputs:
if isinstance(out, list):
outputs.extend(out)
else:
outputs.append(out)
return outputs
return wrapped_operator
# def msgpack_function_outputs(function):
# @wraps(function)
# def wrapped_function(*args, **kwargs):
# outputs = function(*args, **kwargs)
# return msgpack.packb(outputs)
# return wrapped_function
PARENT_INDEX = 'parent_index'
class PopulationBasedSampler():
args: argparse.Namespace
candidates: SingleStepResults
context_fixer: ASTContextFixer
counter: ASTRuleValueCounter
diversity_scorer: typing.Optional[DiversityScorer]
diversity_scorer_type: typing.Optional[str]
feature_names: typing.List[str]
first_sampler_key: str
fitness_featurizer: ASTFitnessFeaturizer
fitness_featurizer_path: str
fitness_function: typing.Callable[[torch.Tensor], float]
fitness_function_date_id: str
fitness_function_model_name: str
flip_fitness_sign: bool
generation_diversity_scores: np.ndarray
generation_diversity_scores_index: int
generation_index: int
grammar: str
grammar_parser: tatsu.grammars.Grammar # type: ignore
initial_samplers: typing.Dict[str, typing.Callable[[], ASTType]]
max_sample_depth: int
max_sample_nodes: int
max_sample_total_size: int
n_processes: int
n_workers: int
output_folder: str
output_name: str
postprocessor: ast_parser.ASTSamplePostprocessor
population: typing.List[ASTType]
population_size: int
random_seed: int
regrowth_sampler: RegrowthSampler
resume: bool
relative_path: str
rng: np.random.Generator
sample_filter_func: typing.Optional[typing.Callable[[ASTType, typing.Dict[str, typing.Any], float], bool]]
sample_parallel: bool
sampler_keys: typing.List[str]
sampler_kwargs: typing.Dict[str, typing.Any]
sampler_prior_count: typing.List[int]
samplers: typing.Dict[str, ASTSampler]
saving: bool
signal_received: bool
success_by_generation_and_operator: typing.List[typing.Dict[str, int]]
verbose: int
weight_insert_delete_nodes_by_length: bool
'''
This is a type of game sampler which uses an evolutionary strategy to climb a
provided fitness function. It's a population-based alternative to the MCMC samper
# TODO: store statistics about which locations are more likely to receive beneficial mutations?
# TODO: keep track of 'lineages'
'''
def __init__(self,
args: argparse.Namespace,
population_size: int = DEFAULT_POPULATION_SIZE,
verbose: int = 0,
initial_proposal_type: InitialProposalSamplerType = InitialProposalSamplerType.MAP,
fitness_featurizer_path: str = DEFAULT_FITNESS_FEATURIZER_PATH,
fitness_function_date_id: str = DEFAULT_FITNESS_FUNCTION_DATE_ID,
fitness_function_model_name: str = DEFAULT_SAVE_MODEL_NAME,
flip_fitness_sign: bool = True,
relative_path: str = DEFAULT_RELATIVE_PATH,
output_folder: str = DEFAULT_OUTPUT_FOLDER,
output_name: str = DEFAULT_OUTPUT_NAME,
ngram_model_path: str = DEFAULT_NGRAM_MODEL_PATH,
sampler_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
section_sampler_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
sample_patience: int = 100,
sample_parallel: bool = False,
n_workers: int = 1,
diversity_scorer_type: typing.Optional[str] = None,
diversity_scorer_k: int = 1,
diversity_score_threshold: float = 0.0,
diversity_threshold_absolute: bool = False,
sample_filter_func: typing.Optional[typing.Callable[[ASTType, typing.Dict[str, typing.Any], float], bool]] = None,
sampler_prior_count: typing.List[int] = [PRIOR_COUNT],
weight_insert_delete_nodes_by_length: bool = True,
max_sample_depth: int = DEFAULT_MAX_SAMPLE_DEPTH,
max_sample_nodes: int = DEFAULT_MAX_SAMPLE_NODES,
max_sample_total_size: int = DEFAULT_MAX_SAMPLE_TOTAL_SIZE,
resume: bool = False,
):
self.args = args
self.population_size = population_size
self.verbose = verbose
self.sample_patience = sample_patience
self.sample_parallel = sample_parallel
self.n_workers = n_workers
self.n_processes = n_workers + 1 # including the main process
self.diversity_scorer_type = diversity_scorer_type
self.grammar = open(args.grammar_file).read()
self.grammar_parser = typing.cast(tatsu.grammars.Grammar, tatsu.compile(self.grammar))
self.counter = parse_or_load_counter(args, self.grammar_parser)
self.relative_path = relative_path
self.output_folder = output_folder
self.output_name = output_name
self.fitness_featurizer_path = fitness_featurizer_path
self.fitness_featurizer = _load_pickle_gzip(fitness_featurizer_path)
self.fitness_function_date_id = fitness_function_date_id
self.fitness_function_model_name = fitness_function_model_name
self.fitness_function, self.feature_names = load_model_and_feature_columns(fitness_function_date_id, name=fitness_function_model_name, relative_path=relative_path) # type: ignore
self.flip_fitness_sign = flip_fitness_sign
self.diversity_scorer_type = diversity_scorer_type
self.diversity_scorer_k = diversity_scorer_k
self.diversity_score_threshold = diversity_score_threshold
self.diversity_threshold_absolute = diversity_threshold_absolute
self.diversity_scorer = None
if self.diversity_scorer_type is not None:
self.diversity_scorer = create_diversity_scorer(self.diversity_scorer_type, k=diversity_scorer_k, featurizer=self.fitness_featurizer, feature_names=self.feature_names)
self.sample_filter_func = sample_filter_func
self.sampler_prior_count = sampler_prior_count
self.weight_insert_delete_nodes_by_length = weight_insert_delete_nodes_by_length
self.max_sample_depth = max_sample_depth
self.max_sample_nodes = max_sample_nodes
self.max_sample_total_size = max_sample_total_size
self.random_seed = args.random_seed + self._process_index()
self.rng = np.random.default_rng(self.random_seed)
# Used to generate the initial population of complete games
if sampler_kwargs is None:
sampler_kwargs = {}
self.sampler_kwargs = sampler_kwargs
self.samplers = {f'prior{pc}': ASTSampler(self.grammar_parser, self.counter,
max_sample_depth=self.max_sample_depth,
max_sample_nodes=self.max_sample_nodes,
seed=self.random_seed + pc,
prior_rule_count=pc, prior_token_count=pc,
length_prior={n: pc for n in LENGTH_PRIOR},
**sampler_kwargs) for pc in sampler_prior_count}
self.sampler_keys = list(self.samplers.keys())
self.first_sampler_key = self.sampler_keys[0]
# Used to fix the AST context after crossover / mutation
self.context_fixer = ASTContextFixer(self.samplers[self.first_sampler_key], rng=np.random.default_rng(self.random_seed), strict=False)
self.initial_samplers = { # type: ignore
key: create_initial_proposal_sampler(initial_proposal_type, self.samplers[key], self.context_fixer,
ngram_model_path, section_sampler_kwargs)
for key in self.sampler_keys
}
# Used as the mutation operator to modify existing games
self.regrowth_sampler = RegrowthSampler(self.samplers, seed=self.random_seed, rng=np.random.default_rng(self.random_seed))
# Initialize the candidate pools in each genera
self.candidates = SingleStepResults([], [], [], [], [], [])
self.postprocessor = ast_parser.ASTSamplePostprocessor()
self.generation_index = 0
self.fitness_metrics_history = []
self.diversity_metrics_history = []
self.success_by_generation_and_operator = []
self.generation_diversity_scores = np.zeros(self.population_size)
self.generation_diversity_scores_index = -1
self.saving = False
self.population_initialized = False
self.signal_received = False
self.resume = resume
def initialize_population(self):
"""
Separated to a second function that must be alled seaprately to allow for subclasses to initialize further
"""
self.population_initialized = True
# Do any preliminary initialization
self._pre_population_sample_setup()
# Generate the initial population
self._inner_initialize_population()
pop = self.population
if isinstance(pop, dict):
pop = pop.values()
# logger.debug(f'Mean initial population_size: {np.mean([object_total_size(p) for p in pop]):.3f}')
def _pre_population_sample_setup(self):
pass
def _inner_initialize_population(self):
self.set_population([self._gen_init_sample(idx) for idx in trange(self.population_size, desc='Generating initial population')])
def _proposal_to_features(self, proposal: ASTType) -> typing.Dict[str, typing.Any]:
return typing.cast(dict, self.fitness_featurizer.parse(proposal, return_row=True)) # type: ignore
def _features_to_tensor(self, features: typing.Dict[str, typing.Any]) -> torch.Tensor:
return torch.tensor([features[name] for name in self.feature_names], dtype=torch.float32) # type: ignore
def _evaluate_fitness(self, features: torch.Tensor) -> float:
fitness_function = self.fitness_function
if 'wrapper' in fitness_function.named_steps: # type: ignore
fitness_function.named_steps['wrapper'].eval() # type: ignore
score = fitness_function.transform(features).item()
return -score if self.flip_fitness_sign else score
def _score_proposal(self, proposal: ASTType, return_features: bool = False):
proposal_features = self._proposal_to_features(proposal)
proposal_tensor = self._features_to_tensor(proposal_features)
proposal_fitness = self._evaluate_fitness(proposal_tensor)
if return_features:
return proposal_fitness, proposal_features
return proposal_fitness
def _process_index(self):
identity = multiprocessing.current_process()._identity # type: ignore
if identity is None or len(identity) == 0:
return 0
return identity[0] % self.n_processes
def _sampler(self, rng: np.random.Generator) -> ASTSampler:
return self.samplers[self._choice(self.sampler_keys, rng=rng)] # type: ignore
def _initial_sampler(self, rng: np.random.Generator):
return self.initial_samplers[self._choice(self.sampler_keys, rng=rng)] # type: ignore
def _rename_game(self, game: ASTType, name: str) -> None:
replace_child(game[1], ['game_name'], name) # type: ignore
def __getstate__(self):
# Copy the object's state from self.__dict__ which contains
# all our instance attributes. Always use the dict.copy()
# method to avoid modifying the original state.
state = self.__dict__.copy()
# Remove the unpicklable entries when spawning a new process, rather than saving
if not self.saving:
state['population'] = {}
state['fitness_values'] = {}
state['archive_cell_first_occupied'] = {}
# Make sure we're not marking the saved model as being saved
state['saving'] = False
return state
def __setstate__(self, state):
self.__dict__.update(state)
# Set unique random seed per process X generation index
self.random_seed = self.args.random_seed + (self._process_index() * (self.generation_index + 1))
self.rng = np.random.default_rng(self.random_seed)
self.saving = False
for sampler_key in self.samplers:
self.samplers[sampler_key].rng = np.random.default_rng(self.random_seed + self.samplers[sampler_key].prior_rule_count)
self.regrowth_sampler.seed = self.random_seed
self.regrowth_sampler.rng = np.random.default_rng(self.random_seed)
self.context_fixer.rng = np.random.default_rng(self.random_seed)
# trying to hard-code these as dicts to see which might be changing during iteration?!
# self.population = dict(self.population)
# self.fitness_values = dict(self.fitness_values)
# self.archive_cell_first_occupied = dict(self.archive_cell_first_occupied)
# print(f'Set state, population type: {type(self.population)} | fitness_values type: {type(self.fitness_values)} | archive_cell_first_occupied type: {type(self.archive_cell_first_occupied)}')
def save(self, suffix: typing.Optional[str] = None, log_message: bool = True):
self.saving = True
output_name = self.output_name
if suffix is not None:
output_name += f'_{suffix}'
save_data(self, self.output_folder, output_name, self.relative_path, log_message=log_message)
self.saving = False
def set_population(self, population: typing.List[typing.Any], fitness_values: typing.Optional[typing.List[float]] = None):
'''
Set the initial population of the sampler
'''
self.population = population
self.population_size = len(population)
if fitness_values is None:
fitness_values = typing.cast(typing.List[float], [self._score_proposal(game, return_features=False) for game in self.population])
self.fitness_values = fitness_values
self.best_fitness = max(self.fitness_values)
self.mean_fitness = np.mean(self.fitness_values)
self.std_fitness = np.std(self.fitness_values)
if self.diversity_scorer is not None:
self.diversity_scorer.set_population(self.population)
def _best_individual(self):
return self.population[np.argmax(self.fitness_values)]
def _print_game(self, game):
print(ast_printer.ast_to_string(game, "\n"))
def _choice(self, iterable: typing.Sequence[T], n: int = 1, rng: typing.Optional[np.random.Generator] = None,
weights: typing.Optional[typing.Sequence[float]] = None) -> typing.Union[T, typing.List[T]]:
'''
Small hack to get around the rng invalid __array_struct__ error
'''
if rng is None:
rng = self.rng
# try:
if n == 1:
idx = rng.choice(len(iterable), p=weights)
return iterable[idx]
else:
idxs = rng.choice(len(iterable), size=n, replace=False, p=weights)
return [iterable[idx] for idx in idxs]
# except ValueError as e:
# logger.error(f'Error in choice: len = {len(iterable)}, {n} = n, weights shape = {weights.shape}: {e}') # type: ignore
# logger.error(traceback.format_exc())
# raise e
def _gen_init_sample(self, idx, rng=None):
'''
Helper function for generating an initial sample (repeating until one is generated
without errors)
'''
sample = None
if rng is None:
rng = self.rng
while sample is None:
try:
sample = typing.cast(tuple, self._initial_sampler(self.rng).sample(global_context=dict(original_game_id=f'evo-{idx}')))
if self.sample_filter_func is not None:
sample_fitness, sample_features = self._score_proposal(sample, return_features=True) # type: ignore
if not self.sample_filter_func(sample, sample_features, sample_fitness):
sample = None
except RecursionError:
if self.verbose >= 2: logger.info(f'Recursion error in sample {idx} -- skipping')
except SamplingException:
if self.verbose >= 2: logger.info(f'Sampling exception in sample {idx} -- skipping')
except ValueError:
if self.verbose >= 2: logger.info(f'Value error in sample {idx} -- skipping')
return sample
def _sample_mutation(self, rng: np.random.Generator, operators: typing.Optional[typing.List[typing.Callable]] = None) -> typing.Callable[[typing.Union[ASTType, typing.List[ASTType]], np.random.Generator], typing.Union[ASTType, typing.List[ASTType]]]:
if operators is None:
operators = [self._gen_regrowth_sample, self._insert, self._delete]
return self._choice(operators, rng=rng) # type: ignore
def _randomly_mutate_game(self, game: typing.Union[ASTType, typing.List[ASTType]], rng: np.random.Generator,
operators: typing.Optional[typing.List[typing.Callable]] = None) -> typing.Union[ASTType, typing.List[ASTType]]:
return self._sample_mutation(rng, operators)(game, rng)
@handle_multiple_inputs
def _gen_regrowth_sample(self, game: ASTType, rng: np.random.Generator):
# Set the source AST of the regrowth sampler to the current game
self.regrowth_sampler.set_source_ast(game)
return self._regrowth(rng)
def _regrowth(self, rng: np.random.Generator, node_key_to_regrow: typing.Optional[typing.Hashable] = None) -> ASTType:
'''
Helper function for generating a new sample from an existing game (repeating until one is generated
without errors)
'''
new_proposal = None
sample_generated = False
while not sample_generated:
try:
new_proposal = self.regrowth_sampler.sample(sample_index=0, update_game_id=False, rng=rng, node_key_to_regrow=node_key_to_regrow)
self.context_fixer.fix_contexts(new_proposal)
sample_generated = True
# In this context I don't need this expensive check for identical samples, as it's just a noop
# if ast_printer.ast_to_string(new_proposal) == ast_printer.ast_to_string(game): # type: ignore
# if self.verbose >= 2: print('Regrowth generated identical games, repeating')
# else:
# sample_generated = True
except RecursionError as e:
if self.verbose >= 2: logger.info(f'Recursion error in regrowth, skipping sample: {e.args}')
except SamplingException as e:
if self.verbose >= 2: logger.info(f'Sampling exception in regrowth, skipping sample: {e.args}')
except ValueError:
if self.verbose >= 2: logger.info(f'Value error in sample -- skipping')
return new_proposal # type: ignore
def _get_valid_insert_or_delete_nodes(
self, game: ASTType, insert: bool = True,
weigh_nodes_by_length: bool = True, shortest_weight_maximal: bool = False, return_keys: bool = False,
) -> typing.Tuple[typing.List[typing.Tuple[tatsu.ast.AST, typing.List[typing.Union[str, int]], str, typing.Dict[str, typing.Any], typing.Dict[str, typing.Any]]], np.ndarray]:
'''
Returns a list of every node in the game which is a valid candidate for insertion or deletion
(i.e. can have more than one child). Each entry in the list is of the form:
(parent, selector, section, global_context, local_context)
'''
self.regrowth_sampler.set_source_ast(game)
# Collect all nodes whose final selector is an integet (i.e. an index into a list) and whose parent
# yields a list when its first selector is applied. Also make sure that the list has a minimum length
valid_nodes = []
for node, parent, selector, _, section, global_context, local_context in self.regrowth_sampler.parent_mapping.values():
first_parent = parent[selector[0]]
if isinstance(selector[-1], int) and isinstance(first_parent, list):
parent_length = len(first_parent)
if insert and parent_length >= 1:
valid_nodes.append((node, parent, selector[0], section, global_context, local_context))
elif not insert:
min_length = self.samplers[self.first_sampler_key].rules[parent.parseinfo.rule][selector[0]][MIN_LENGTH] # type: ignore
if parent_length >= min_length + 1:
valid_nodes.append((node, parent, selector[0], section, global_context, local_context))
if len(valid_nodes) == 0:
raise SamplingException('No valid nodes found for insertion or deletion')
# Dedupe valid nodes based on their parent and selector
valid_node_keys = set()
output_valid_nodes = []
output_node_keys = []
output_node_weights = []
for node, parent, selector, section, global_context, local_context in valid_nodes:
key = (*self.regrowth_sampler._ast_key(parent), selector)
if key not in valid_node_keys:
valid_node_keys.add(key)
output_valid_nodes.append((parent, selector, section, global_context, local_context))
output_node_keys.append(self.regrowth_sampler._ast_key(node))
output_node_weights.append(len(parent[selector]))
if len(output_valid_nodes) > 0:
if not weigh_nodes_by_length:
output_node_weights = np.ones(len(output_valid_nodes)) / len(output_valid_nodes)
else:
output_node_weights = np.array(output_node_weights, dtype=float)
if shortest_weight_maximal:
output_node_weights = output_node_weights.max() + output_node_weights.min() - output_node_weights
output_node_weights /= output_node_weights.sum()
else:
output_node_weights = np.array([])
if return_keys:
return output_valid_nodes, output_node_weights, output_node_keys # type: ignore
return output_valid_nodes, output_node_weights
@handle_multiple_inputs
def _insert(self, game: ASTType, rng: np.random.Generator):
'''
Attempt to insert a new node into the provided game by identifying a node which can have multiple
children and inserting a new node into it. The new node is selected using the initial sampler
'''
# Make a copy of the game
new_game = deepcopy_ast(game)
valid_nodes, valid_node_weights = self._get_valid_insert_or_delete_nodes(
new_game, insert=True, weigh_nodes_by_length=self.weight_insert_delete_nodes_by_length, shortest_weight_maximal=True)
if len(valid_nodes) == 0:
raise SamplingException('No valid nodes found for insertion')
# Select a random node from the list of valid nodes
parent, selector, section, global_context, local_context = self._choice(valid_nodes, rng=rng, weights=valid_node_weights) # type: ignore
parent_rule = parent.parseinfo.rule # type: ignore
parent_rule_posterior_dict = self._sampler(rng).rules[parent_rule][selector]
assert "length_posterior" in parent_rule_posterior_dict, f"Rule {parent_rule} does not have a length posterior"
# Sample a new rule from the parent rule posterior (parent_rule_posterior_dict['rule_posterior'])
new_rule = posterior_dict_sample(self.rng, parent_rule_posterior_dict['rule_posterior'])
sample_global_context = global_context.copy() # type: ignore
sample_global_context['rng'] = rng
new_node = None
while new_node is None:
try:
new_node = self._sampler(rng).sample(new_rule, global_context=sample_global_context, local_context=local_context) # type: ignore
except RecursionError as e:
if self.verbose >= 2: logger.info(f'Recursion error in insert, skipping sample: {e.args}')
except SamplingException as e:
if self.verbose >= 2: logger.info(f'Sampling exception in insert, skipping sample: {e.args}')
except ValueError:
if self.verbose >= 2: logger.info(f'Value error in sample -- skipping')
if isinstance(new_node, tuple):
new_node = new_node[0]
# Insert the new node into the parent at a random index
parent[selector].insert(rng.integers(len(parent[selector]) + 1), new_node) # type: ignore
# Do any necessary context-fixing
self.context_fixer.fix_contexts(new_game, crossover_child=new_node) # type: ignore
return new_game
@handle_multiple_inputs
def _delete(self, game: ASTType, rng: np.random.Generator):
'''
Attempt to deleting a new node into the provided game by identifying a node which can have multiple
children and deleting one of them
'''
# Make a copy of the game
new_game = deepcopy_ast(game)
valid_nodes, valid_node_weights = self._get_valid_insert_or_delete_nodes(
new_game, insert=False, weigh_nodes_by_length=self.weight_insert_delete_nodes_by_length, shortest_weight_maximal=False)
if len(valid_nodes) == 0:
raise SamplingException('No valid nodes found for deletion')
# Select a random node from the list of valid nodes
parent, selector, section, global_context, local_context = self._choice(valid_nodes, rng=rng, weights=valid_node_weights) # type: ignore
parent_rule = parent.parseinfo.rule # type: ignore
parent_rule_posterior_dict = self._sampler(rng).rules[parent_rule][selector]
assert "length_posterior" in parent_rule_posterior_dict, f"Rule {parent_rule} does not have a length posterior"
# Delete a random node from the parent
delete_index = rng.integers(len(parent[selector])) # type: ignore
child_to_delete = parent[selector][delete_index] # type: ignore
del parent[selector][delete_index] # type: ignore
# Do any necessary context-fixing
self.context_fixer.fix_contexts(new_game, original_child=child_to_delete) # type: ignore
return new_game
def _crossover(self, games: typing.Union[ASTType, typing.List[ASTType]],
rng: typing.Optional[np.random.Generator] = None,
crossover_type: typing.Optional[CrossoverType] = None,
crossover_first_game: bool = True, crossover_second_game: bool = True):
'''
Attempts to perform a crossover between the two given games. The crossover type determines
how nodes in the game are categorized (i.e. by rule, by parent rule, etc.). The crossover
is performed by finding the set of 'categories' that are present in both games, and then
selecting a random category from which to sample the nodes that will be exchanged. If no
categories are shared between the two games, then no crossover is performed
'''
if not crossover_first_game and not crossover_second_game:
raise ValueError("At least one of crossover_first_game and crossover_second_game must be True")
if rng is None:
rng = self.rng
if crossover_type is None:
crossover_type = typing.cast(CrossoverType, self._choice(list(CrossoverType), rng=rng))
game_2 = None
if isinstance(games, list):
game_1 = games[0]
if len(games) > 1:
game_2 = games[1]
else:
game_1 = games
if game_2 is None:
game_2 = typing.cast(ASTType, self._choice(self.population, rng=rng))
if crossover_first_game:
game_1 = deepcopy_ast(game_1)
if crossover_second_game:
game_2 = deepcopy_ast(game_2)
# Create a map from crossover_type keys to lists of nodeinfos for each game
self.regrowth_sampler.set_source_ast(game_1)
game_1_crossover_map = defaultdict(list)
for node_key in self.regrowth_sampler.node_keys:
node_info = self.regrowth_sampler.parent_mapping[node_key]
game_1_crossover_map[node_info_to_key(crossover_type, node_info)].append(node_info)
self.regrowth_sampler.set_source_ast(game_2)
game_2_crossover_map = defaultdict(list)
for node_key in self.regrowth_sampler.node_keys:
node_info = self.regrowth_sampler.parent_mapping[node_key]
game_2_crossover_map[node_info_to_key(crossover_type, node_info)].append(node_info)
# Find the set of crossover_type keys that are shared between the two games
shared_crossover_keys = set(game_1_crossover_map.keys()).intersection(set(game_2_crossover_map.keys()))
# If there are no shared crossover keys, then throw an exception
if len(shared_crossover_keys) == 0:
raise SamplingException("No crossover keys shared between the two games")
# Select a random crossover key and a nodeinfo for each game with that key
crossover_key = self._choice(list(shared_crossover_keys), rng=rng)
game_1_selected_node_info = self._choice(game_1_crossover_map[crossover_key], rng=rng)
game_2_selected_node_info = self._choice(game_2_crossover_map[crossover_key], rng=rng)
# Create new copies of the nodes to be crossed over
g1_node, g1_parent, g1_selector = game_1_selected_node_info[:3]
g2_node, g2_parent, g2_selector = game_2_selected_node_info[:3]
# Perform the crossover and fix the contexts of the new games
if crossover_first_game:
game_2_crossover_node = deepcopy_ast(g2_node, copy_type=ASTCopyType.NODE)
replace_child(g1_parent, g1_selector, game_2_crossover_node) # type: ignore
self.context_fixer.fix_contexts(game_1, g1_node, game_2_crossover_node) # type: ignore
if crossover_second_game:
game_1_crossover_node = deepcopy_ast(g1_node, copy_type=ASTCopyType.NODE)
replace_child(g2_parent, g2_selector, game_1_crossover_node) # type: ignore
self.context_fixer.fix_contexts(game_2, g2_node, game_1_crossover_node) # type: ignore
return [game_1, game_2]
def _crossover_insert(self, games: typing.Union[ASTType, typing.List[ASTType]],
rng: typing.Optional[np.random.Generator] = None,
crossover_first_game: bool = True, crossover_second_game: bool = True):
if rng is None:
rng = self.rng
crossover_type = CrossoverType.SAME_PARENT_INITIAL_SELECTOR
game_2 = None
if isinstance(games, list):
game_1 = games[0]
if len(games) > 1:
game_2 = games[1]
else:
game_1 = games
if game_2 is None:
game_2 = typing.cast(ASTType, self._choice(self.population, rng=rng))
if crossover_first_game:
game_1 = deepcopy_ast(game_1)
if crossover_second_game:
game_2 = deepcopy_ast(game_2)
# Create a map from crossover_type keys to lists of nodeinfos for each game
_, _, game_1_insertion_node_keys = self._get_valid_insert_or_delete_nodes( # type: ignore
game_1, insert=True, weigh_nodes_by_length=self.weight_insert_delete_nodes_by_length,
shortest_weight_maximal=True, return_keys=True)
game_1_crossover_map = defaultdict(list)
for node_key in game_1_insertion_node_keys:
node_info = self.regrowth_sampler.parent_mapping[node_key]
game_1_crossover_map[node_info_to_key(crossover_type, node_info)].append(node_info)
_, _, game_2_insertion_node_keys = self._get_valid_insert_or_delete_nodes( # type: ignore
game_2, insert=True, weigh_nodes_by_length=self.weight_insert_delete_nodes_by_length,
shortest_weight_maximal=True, return_keys=True)
game_2_crossover_map = defaultdict(list)