|
74 | 74 | from serialization import save_predicted_target_candidates
|
75 | 75 | from serialization import find_run_result
|
76 | 76 | from serialization import format_graph_pattern
|
| 77 | +from serialization import load_init_patterns |
77 | 78 | from serialization import load_results
|
78 | 79 | from serialization import pause_if_signaled_by_file
|
79 | 80 | from serialization import print_graph_pattern
|
@@ -1054,13 +1055,25 @@ def generate_init_population(
|
1054 | 1055 | sparql, timeout, gtp_scores,
|
1055 | 1056 | pb_fv=config.INIT_POPPB_FV,
|
1056 | 1057 | n=config.INIT_POPPB_FV_N,
|
| 1058 | + init_patterns=None, |
| 1059 | + pb_init_pattern=config.INIT_POPPB_INIT_PAT, |
1057 | 1060 | ):
|
1058 | 1061 | logger.info('generating init population of seed size %d', config.POPSIZE)
|
1059 | 1062 | population = []
|
1060 | 1063 |
|
1061 | 1064 | # Variable patterns:
|
1062 | 1065 | var_pats = generate_variable_patterns(config.POPSIZE)
|
1063 | 1066 |
|
| 1067 | + if init_patterns: |
| 1068 | + # replace var pats with random ones from init_patterns according to prob |
| 1069 | + var_pats = [ |
| 1070 | + p for p in var_pats |
| 1071 | + if random.random() > pb_init_pattern |
| 1072 | + ] |
| 1073 | + for _ in range(len(var_pats), config.POPSIZE): |
| 1074 | + var_pats.append(random.choice(init_patterns).copy()) |
| 1075 | + random.shuffle(var_pats) |
| 1076 | + |
1064 | 1077 | # initial run of mutate_fix_var to instantiate many of the variable patterns
|
1065 | 1078 | # TODO: maybe loop this? (why only try to fix one var?)
|
1066 | 1079 | to_fix = []
|
@@ -1210,6 +1223,7 @@ def check_quick_stop(
|
1210 | 1223 |
|
1211 | 1224 | def find_graph_patterns(
|
1212 | 1225 | sparql, run, gtp_scores,
|
| 1226 | + init_patterns=None, |
1213 | 1227 | user_callback_per_generation=None,
|
1214 | 1228 | ):
|
1215 | 1229 | timeout = calibrate_query_timeout(sparql)
|
@@ -1237,6 +1251,7 @@ def find_graph_patterns(
|
1237 | 1251 |
|
1238 | 1252 | population = generate_init_population(
|
1239 | 1253 | sparql, timeout, gtp_scores,
|
| 1254 | + init_patterns=init_patterns, |
1240 | 1255 | )
|
1241 | 1256 |
|
1242 | 1257 | # noinspection PyTypeChecker
|
@@ -1270,13 +1285,15 @@ def _find_graph_pattern_coverage_run(
|
1270 | 1285 | coverage_counts,
|
1271 | 1286 | gtp_scores,
|
1272 | 1287 | patterns,
|
| 1288 | + init_patterns=None, |
1273 | 1289 | user_callback_per_generation=None,
|
1274 | 1290 | user_callback_per_run=None,
|
1275 | 1291 | ):
|
1276 | 1292 | min_fitness = calc_min_fitness(gtp_scores, min_score)
|
1277 | 1293 |
|
1278 | 1294 | ngen, res_pop, hall_of_fame, toolbox = find_graph_patterns(
|
1279 | 1295 | sparql, run, gtp_scores,
|
| 1296 | + init_patterns=init_patterns, |
1280 | 1297 | user_callback_per_generation=user_callback_per_generation,
|
1281 | 1298 | )
|
1282 | 1299 |
|
@@ -1384,6 +1401,7 @@ def _find_graph_pattern_coverage_run(
|
1384 | 1401 | def find_graph_pattern_coverage(
|
1385 | 1402 | sparql,
|
1386 | 1403 | ground_truth_pairs,
|
| 1404 | + init_patterns=None, |
1387 | 1405 | min_score=config.MIN_SCORE,
|
1388 | 1406 | min_remaining_gain=config.MIN_REMAINING_GAIN,
|
1389 | 1407 | max_runs=config.NRUNS,
|
@@ -1430,6 +1448,7 @@ def find_graph_pattern_coverage(
|
1430 | 1448 | coverage_counts,
|
1431 | 1449 | gtp_scores,
|
1432 | 1450 | patterns,
|
| 1451 | + init_patterns=init_patterns, |
1433 | 1452 | user_callback_per_generation=user_callback_per_generation,
|
1434 | 1453 | user_callback_per_run=user_callback_per_run,
|
1435 | 1454 | )
|
@@ -1663,6 +1682,7 @@ def main(
|
1663 | 1682 | splitting_variant='random',
|
1664 | 1683 | train_filename=None,
|
1665 | 1684 | test_filename=None,
|
| 1685 | + init_patterns_filename=None, |
1666 | 1686 | print_train_test_sets=True,
|
1667 | 1687 | reset=False,
|
1668 | 1688 | print_topn_raw_patterns=0,
|
@@ -1734,12 +1754,18 @@ def main(
|
1734 | 1754 | # setup node expander
|
1735 | 1755 | sparql = SPARQLWrapper.SPARQLWrapper(sparql_endpoint)
|
1736 | 1756 |
|
| 1757 | + init_patterns = None |
| 1758 | + if init_patterns_filename: |
| 1759 | + init_patterns = load_init_patterns(init_patterns_filename) |
1737 | 1760 |
|
1738 | 1761 | if reset:
|
1739 | 1762 | remove_old_result_files()
|
1740 | 1763 | last_res = find_last_result()
|
1741 | 1764 | if not last_res:
|
1742 |
| - res = find_graph_pattern_coverage(sparql, semantic_associations) |
| 1765 | + res = find_graph_pattern_coverage( |
| 1766 | + sparql, semantic_associations, |
| 1767 | + init_patterns=init_patterns, |
| 1768 | + ) |
1743 | 1769 | result_patterns, coverage_counts, gtp_scores = res
|
1744 | 1770 | sys.stderr.flush()
|
1745 | 1771 |
|
|
0 commit comments