Skip to content

Commit 39eec48

Browse files
committed
init_patterns_filename arg to start with "good" init patterns
The file should be a simple json list of graph patterns (dicts) and allows to explore the search space from an already known smarter initialization than just simple var patterns. Also comes with the INIT_POPPB_INIT_PAT=.75 prob var that can be used to balance between old var patterns and the given file.
1 parent 89507c1 commit 39eec48

File tree

4 files changed

+45
-1
lines changed

4 files changed

+45
-1
lines changed

config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
INIT_POP_LEN_BETA = 30. # beta value in a length beta distribution
6262
INIT_POPPB_FV = 0.9 # probability to fix a variable in init population
6363
INIT_POPPB_FV_N = 5 # allow up to n instantiations for each fixed variable
64+
INIT_POPPB_INIT_PAT = .75 # probability to use pattern from init patterns file
6465
VARPAT_REINTRO = 10 # number of variable patterns re-introduced each generation
6566
HOFPAT_REINTRO = 10 # number of hall of fame patterns re-introduced each gen
6667
LOGLVL_EVAL = 10 # loglvl for eval logs (10: DEBUG, 20: INFO)

gp_learner.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from serialization import save_predicted_target_candidates
7575
from serialization import find_run_result
7676
from serialization import format_graph_pattern
77+
from serialization import load_init_patterns
7778
from serialization import load_results
7879
from serialization import pause_if_signaled_by_file
7980
from serialization import print_graph_pattern
@@ -1054,13 +1055,25 @@ def generate_init_population(
10541055
sparql, timeout, gtp_scores,
10551056
pb_fv=config.INIT_POPPB_FV,
10561057
n=config.INIT_POPPB_FV_N,
1058+
init_patterns=None,
1059+
pb_init_pattern=config.INIT_POPPB_INIT_PAT,
10571060
):
10581061
logger.info('generating init population of seed size %d', config.POPSIZE)
10591062
population = []
10601063

10611064
# Variable patterns:
10621065
var_pats = generate_variable_patterns(config.POPSIZE)
10631066

1067+
if init_patterns:
1068+
# replace var pats with random ones from init_patterns according to prob
1069+
var_pats = [
1070+
p for p in var_pats
1071+
if random.random() > pb_init_pattern
1072+
]
1073+
for _ in range(len(var_pats), config.POPSIZE):
1074+
var_pats.append(random.choice(init_patterns).copy())
1075+
random.shuffle(var_pats)
1076+
10641077
# initial run of mutate_fix_var to instantiate many of the variable patterns
10651078
# TODO: maybe loop this? (why only try to fix one var?)
10661079
to_fix = []
@@ -1210,6 +1223,7 @@ def check_quick_stop(
12101223

12111224
def find_graph_patterns(
12121225
sparql, run, gtp_scores,
1226+
init_patterns=None,
12131227
user_callback_per_generation=None,
12141228
):
12151229
timeout = calibrate_query_timeout(sparql)
@@ -1237,6 +1251,7 @@ def find_graph_patterns(
12371251

12381252
population = generate_init_population(
12391253
sparql, timeout, gtp_scores,
1254+
init_patterns=init_patterns,
12401255
)
12411256

12421257
# noinspection PyTypeChecker
@@ -1270,13 +1285,15 @@ def _find_graph_pattern_coverage_run(
12701285
coverage_counts,
12711286
gtp_scores,
12721287
patterns,
1288+
init_patterns=None,
12731289
user_callback_per_generation=None,
12741290
user_callback_per_run=None,
12751291
):
12761292
min_fitness = calc_min_fitness(gtp_scores, min_score)
12771293

12781294
ngen, res_pop, hall_of_fame, toolbox = find_graph_patterns(
12791295
sparql, run, gtp_scores,
1296+
init_patterns=init_patterns,
12801297
user_callback_per_generation=user_callback_per_generation,
12811298
)
12821299

@@ -1384,6 +1401,7 @@ def _find_graph_pattern_coverage_run(
13841401
def find_graph_pattern_coverage(
13851402
sparql,
13861403
ground_truth_pairs,
1404+
init_patterns=None,
13871405
min_score=config.MIN_SCORE,
13881406
min_remaining_gain=config.MIN_REMAINING_GAIN,
13891407
max_runs=config.NRUNS,
@@ -1430,6 +1448,7 @@ def find_graph_pattern_coverage(
14301448
coverage_counts,
14311449
gtp_scores,
14321450
patterns,
1451+
init_patterns=init_patterns,
14331452
user_callback_per_generation=user_callback_per_generation,
14341453
user_callback_per_run=user_callback_per_run,
14351454
)
@@ -1663,6 +1682,7 @@ def main(
16631682
splitting_variant='random',
16641683
train_filename=None,
16651684
test_filename=None,
1685+
init_patterns_filename=None,
16661686
print_train_test_sets=True,
16671687
reset=False,
16681688
print_topn_raw_patterns=0,
@@ -1734,12 +1754,18 @@ def main(
17341754
# setup node expander
17351755
sparql = SPARQLWrapper.SPARQLWrapper(sparql_endpoint)
17361756

1757+
init_patterns = None
1758+
if init_patterns_filename:
1759+
init_patterns = load_init_patterns(init_patterns_filename)
17371760

17381761
if reset:
17391762
remove_old_result_files()
17401763
last_res = find_last_result()
17411764
if not last_res:
1742-
res = find_graph_pattern_coverage(sparql, semantic_associations)
1765+
res = find_graph_pattern_coverage(
1766+
sparql, semantic_associations,
1767+
init_patterns=init_patterns,
1768+
)
17431769
result_patterns, coverage_counts, gtp_scores = res
17441770
sys.stderr.flush()
17451771

run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@
7272
type=config.str_to_bool,
7373
)
7474

75+
parser.add_argument(
76+
"--init_patterns_filename",
77+
help="file with nicer patterns to be used in init population",
78+
action="store",
79+
default=None,
80+
)
81+
7582
parser.add_argument(
7683
"--reset",
7784
help="remove previous training's result files if existing (otherwise "

serialization.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,16 @@ def lazy_print_header():
354354
)
355355

356356

357+
def load_init_patterns(fn):
358+
with open(fn, 'r') as f:
359+
data = json.load(f)
360+
init_patterns = [
361+
GraphPattern.from_dict(d)
362+
for d in data
363+
]
364+
return init_patterns
365+
366+
357367
def save_predicted_target_candidates(gps, gtps, gtp_gp_tcs):
358368
fn = path.join(
359369
config.RESDIR, 'predicted_train_target_candidates.pkl.gz')

0 commit comments

Comments
 (0)