Skip to content

Commit 034e8da

Browse files
committed
added --select-pair method to choose how to select a representative pair
1 parent 6bd39fd commit 034e8da

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

mmpdblib/cli/generate.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,18 @@ def get_num_available_cpus():
364364
help = "Only consider rules with at least N matched molecular pairs",
365365
)
366366

367+
@click.option(
368+
"--select-pair",
369+
"select_pair_method",
370+
type = click.Choice(("first", "quadratic", "min")),
371+
default = "first",
372+
help = (
373+
"If 'first' (fastest), select a representative pair arbitrarily. "
374+
"If 'quadratic', minimize sum of num_heavies**2. "
375+
"If 'min', use the minimum num_heavies for either side."
376+
),
377+
)
378+
367379
@click.option(
368380
"--output",
369381
"-o",
@@ -424,6 +436,7 @@ def generate(
424436
subqueries,
425437
radius,
426438
min_pairs,
439+
select_pair_method,
427440
output_file,
428441
columns,
429442
headers,
@@ -516,6 +529,7 @@ def generate(
516529
from_smiles_list = from_smiles_list,
517530
radius = radius,
518531
min_pairs = min_pairs,
532+
select_pair_method = select_pair_method,
519533
reporter = reporter,
520534
):
521535
if to_process is None:
@@ -532,7 +546,29 @@ def generate(
532546
output_line = output_format_str.format_map(result)
533547
output_file.write(output_line)
534548

535-
549+
550+
SELECT_PAIR_SQL = """
551+
SELECT cmpd1.public_id, cmpd1.clean_smiles, cmpd2.public_id, cmpd2.clean_smiles
552+
FROM pair,
553+
compound AS cmpd1,
554+
compound AS cmpd2
555+
WHERE pair.rule_environment_id = ?
556+
AND pair.compound1_id = cmpd1.id
557+
AND pair.compound2_id = cmpd2.id
558+
<ORDER>
559+
LIMIT 1
560+
"""
561+
_select_pair_sql_table = {
562+
"first": SELECT_PAIR_SQL.replace(
563+
"<ORDER>\n",
564+
""),
565+
"quadratic": SELECT_PAIR_SQL.replace(
566+
"<ORDER>",
567+
"ORDER BY cmpd1.clean_num_heavies * cmpd1.clean_num_heavies + cmpd2.clean_num_heavies * cmpd2.clean_num_heavies"),
568+
"min": SELECT_PAIR_SQL.replace(
569+
"<ORDER>",
570+
"ORDER BY MIN(cmpd1.clean_num_heavies, cmpd2.clean_num_heavies)"),
571+
}
536572

537573

538574
def generate_unwelded_from_constant(
@@ -543,6 +579,7 @@ def generate_unwelded_from_constant(
543579
from_smiles_list,
544580
radius,
545581
min_pairs,
582+
select_pair_method,
546583
reporter,
547584
):
548585
# I need to get the database.execute() so "?" is handled portably
@@ -642,17 +679,14 @@ def generate_unwelded_from_constant(
642679

643680
if pair_cursor is not None:
644681
# Pick a representative pair
645-
pair_result = db.execute("""
646-
SELECT cmpd1.public_id, cmpd1.clean_smiles, cmpd2.public_id, cmpd2.clean_smiles
647-
FROM pair,
648-
compound AS cmpd1,
649-
compound AS cmpd2
650-
WHERE pair.rule_environment_id = ?
651-
AND pair.compound1_id = cmpd1.id
652-
AND pair.compound2_id = cmpd2.id
653-
--ORDER BY cmpd1.clean_num_heavies * cmpd1.clean_num_heavies + cmpd2.clean_num_heavies * cmpd2.clean_num_heavies
654-
LIMIT 1
655-
""", (rule_environment_id,), cursor = pair_cursor)
682+
select_pair_sql = _select_pair_sql_table[select_pair_method]
683+
684+
pair_result = db.execute(
685+
select_pair_sql,
686+
(rule_environment_id,),
687+
cursor = pair_cursor,
688+
)
689+
656690
have_one = False
657691
for pair_from_id, pair_from_smiles, pair_to_id, pair_to_smiles in pair_result:
658692
have_one = True

0 commit comments

Comments
 (0)