Skip to content

Commit 121ec39

Browse files
committed
Add in a DuckdbSqlPlugin for all the plugins which just do some SQL
1 parent 5861f61 commit 121ec39

File tree

7 files changed

+78
-83
lines changed

7 files changed

+78
-83
lines changed

countess/core/plugins.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,23 @@ def execute(
156156
raise NotImplementedError(f"{self.__class__}.execute")
157157

158158

159+
class DuckdbSqlPlugin(DuckdbSimplePlugin):
160+
def execute(
161+
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
162+
) -> Optional[DuckDBPyRelation]:
163+
sql = self.sql(source.alias, source.columns)
164+
logger.debug(f"{self.__class__}.execute sql %s", sql)
165+
if sql:
166+
try:
167+
return ddbc.sql(sql)
168+
except duckdb.duckdb.DatabaseError as exc:
169+
logger.warning(exc)
170+
return None
171+
172+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
173+
raise NotImplementedError(f"{self.__class__}.sql")
174+
175+
159176
class DuckdbInputPlugin(DuckdbPlugin):
160177
num_inputs = 0
161178

countess/plugins/correlation.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from typing import Optional
21
import logging
3-
4-
from duckdb import DuckDBPyConnection, DuckDBPyRelation
2+
from typing import Iterable, Optional
53

64
from countess import VERSION
7-
from countess.core.parameters import PerNumericColumnArrayParam, BooleanParam, ColumnOrNoneChoiceParam
8-
from countess.core.plugins import DuckdbSimplePlugin
5+
from countess.core.parameters import BooleanParam, ColumnOrNoneChoiceParam, PerNumericColumnArrayParam
6+
from countess.core.plugins import DuckdbSqlPlugin
97
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal
108

119
logger = logging.getLogger(__name__)
1210

13-
class CorrelationPlugin(DuckdbSimplePlugin):
11+
12+
class CorrelationPlugin(DuckdbSqlPlugin):
1413
"""Correlations"""
1514

1615
name = "Correlation Tool"
@@ -21,28 +20,24 @@ class CorrelationPlugin(DuckdbSimplePlugin):
2120
columns = PerNumericColumnArrayParam("Columns", BooleanParam("Correlate?", False))
2221
group = ColumnOrNoneChoiceParam("Group")
2322

24-
def execute(
25-
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
26-
) -> Optional[DuckDBPyRelation]:
27-
23+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
2824
grp = duckdb_escape_identifier(self.group.value) if self.group.is_not_none() else None
2925

3026
if sum(1 for c in self.columns.params if c.value) < 2:
3127
return None
3228

33-
sql = " union all ".join(f"""
34-
select {(grp + ", ") if grp else ""}
35-
{duckdb_escape_literal(c1.label)} as column_x,
36-
{duckdb_escape_literal(c2.label)} as column_y,
37-
corr({duckdb_escape_identifier(c2.label)},{duckdb_escape_identifier(c1.label)}) as correlation_coefficient,
38-
covar_pop({duckdb_escape_identifier(c2.label)},{duckdb_escape_identifier(c1.label)}) as covariance_population,
39-
regr_r2({duckdb_escape_identifier(c2.label)},{duckdb_escape_identifier(c1.label)}) as pearsons_r2
40-
from {source.alias}
41-
{("group by "+grp) if grp else ""}
42-
"""
29+
return " union all ".join(
30+
f"""
31+
select {(grp + ", ") if grp else ""}
32+
{duckdb_escape_literal(c1.label)} as column_x,
33+
{duckdb_escape_literal(c2.label)} as column_y,
34+
corr({duckdb_escape_identifier(c2.label)},{duckdb_escape_identifier(c1.label)}) as correlation_coefficient,
35+
covar_pop({duckdb_escape_identifier(c2.label)},{duckdb_escape_identifier(c1.label)}) as covariance_population,
36+
regr_r2({duckdb_escape_identifier(c2.label)},{duckdb_escape_identifier(c1.label)}) as pearsons_r2
37+
from {table_name}
38+
{("group by "+grp) if grp else ""}
39+
"""
4340
for c1 in self.columns.params
4441
for c2 in self.columns.params
4542
if c1.value and c2.value and c1.label < c2.label
4643
)
47-
logger.debug("CorrelationPlugin.execute sql %s", sql)
48-
return ddbc.sql(sql)

countess/plugins/group_by.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2-
from typing import Optional
2+
from typing import Iterable, Optional
33

44
from duckdb import DuckDBPyConnection, DuckDBPyRelation
55

66
from countess import VERSION
77
from countess.core.parameters import BooleanParam, PerColumnArrayParam, TabularMultiParam
8-
from countess.core.plugins import DuckdbSimplePlugin
8+
from countess.core.plugins import DuckdbSqlPlugin
99
from countess.utils.duckdb import duckdb_escape_identifier
1010

1111
logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ def _op(op_name, col_name):
3030
return f"{op_call}{col_ident}) AS {col_output}"
3131

3232

33-
class GroupByPlugin(DuckdbSimplePlugin):
33+
class GroupByPlugin(DuckdbSqlPlugin):
3434
"""Groups by an arbitrary column and rolls up rows"""
3535

3636
name = "Group By"
@@ -42,9 +42,7 @@ class GroupByPlugin(DuckdbSimplePlugin):
4242
columns = PerColumnArrayParam("Columns", ColumnMultiParam("Column"))
4343
join = BooleanParam("Join Back?")
4444

45-
def execute(
46-
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
47-
) -> Optional[DuckDBPyRelation]:
45+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
4846
column_params = list(self.columns.get_column_params())
4947
columns = (
5048
", ".join(
@@ -60,16 +58,16 @@ def execute(
6058
for col_name, col_param in column_params
6159
if col_param.params["index"].value
6260
)
63-
if group_by:
64-
sql = f"SELECT {group_by}, {columns} FROM {source.alias} GROUP BY {group_by}"
65-
else:
66-
sql = f"SELECT {columns} FROM {source.alias}"
67-
6861
if self.join:
6962
if group_by:
70-
sql = f"SELECT * FROM {source.alias} JOIN ({sql}) USING ({group_by})"
63+
return (
64+
f"SELECT * FROM {table_name} JOIN (SELECT {group_by}, {columns} "
65+
"FROM {table_name} GROUP BY {group_by}) USING ({group_by})"
66+
)
7167
else:
72-
sql = f"SELECT * FROM {source.alias} CROSS JOIN ({sql})"
73-
74-
logger.debug("GroupByPlugin.execute sql %s", sql)
75-
return ddbc.sql(sql)
68+
return "SELECT * FROM {table_name} CROSS JOIN (SELECT {columns} " "FROM {table_name}"
69+
else:
70+
if group_by:
71+
return f"SELECT {group_by}, {columns} FROM {table_name} GROUP BY {group_by}"
72+
else:
73+
return f"SELECT {columns} FROM {table_name}"

countess/plugins/join.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,15 @@ def execute_multi(
8585
)
8686
if row_limit is not None:
8787
query += f" LIMIT {row_limit}"
88-
logger.debug(query)
88+
89+
logger.debug("JoinPlugin.execute_multi tables[0] %s %d", tables[0].alias, len(tables[0]))
90+
logger.debug("JoinPlugin.execute_multi tables[1] %s %d", tables[1].alias, len(tables[1]))
91+
logger.debug("JoinPlugin.execute_multi query %s", query)
8992

9093
try:
91-
return ddbc.sql(query)
94+
rel = ddbc.sql(query)
95+
logger.debug("JoinPlugin.execute_multi output %d", len(rel))
96+
return rel
9297
except duckdb.ConversionException as exc:
9398
logger.info(exc)
9499
return None

countess/plugins/score_scale.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import logging
2-
from typing import Optional
3-
4-
from duckdb import DuckDBPyConnection, DuckDBPyRelation
2+
from typing import Iterable, Optional
53

64
from countess import VERSION
75
from countess.core.parameters import (
@@ -13,7 +11,7 @@
1311
StringParam,
1412
TabularMultiParam,
1513
)
16-
from countess.core.plugins import DuckdbSimplePlugin
14+
from countess.core.plugins import DuckdbSqlPlugin
1715
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal
1816

1917
logger = logging.getLogger(__name__)
@@ -41,7 +39,7 @@ def filter(self):
4139
raise NotImplementedError()
4240

4341

44-
class ScoreScalingPlugin(DuckdbSimplePlugin):
42+
class ScoreScalingPlugin(DuckdbSqlPlugin):
4543
name = "Score Scaling"
4644
description = "Scaled Scores using variant classification"
4745
version = VERSION
@@ -58,13 +56,11 @@ def __init__(self, *a, **k):
5856
self.classifiers[0].label = "Scale to 0.0"
5957
self.classifiers[1].label = "Scale to 1.0"
6058

61-
def execute(
62-
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
63-
) -> Optional[DuckDBPyRelation]:
59+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
6460
score_col_id = duckdb_escape_identifier(self.score_col.value)
6561
scaled_col_id = duckdb_escape_identifier(self.scaled_col.value)
6662

67-
all_columns = ",".join("T0." + duckdb_escape_identifier(c) for c in source.columns if self.scaled_col != c)
63+
all_columns = ",".join("T0." + duckdb_escape_identifier(c) for c in columns if self.scaled_col != c)
6864

6965
if self.group_col.is_not_none():
7066
group_col_id = "T0." + duckdb_escape_identifier(self.group_col.value)
@@ -73,17 +69,13 @@ def execute(
7369

7470
c0, c1 = self.classifiers
7571

76-
sql = f"""
72+
return f"""
7773
select {all_columns}, ({score_col_id} - T1.score_0) / (T1.score_1 - T1.score_0) as {scaled_col_id}
78-
from {source.alias} T0 join (
74+
from {table_name} T0 join (
7975
select {group_col_id} as score_group,
8076
median({score_col_id}) filter ({c0.filter()}) as score_0,
8177
median({score_col_id}) filter ({c1.filter()}) as score_1
82-
from {source.alias} T0
78+
from {table_name} T0
8379
group by score_group
8480
) T1 on ({group_col_id} = T1.score_group)
8581
"""
86-
87-
logger.debug("ScoreScalingPlugin sql %s", sql)
88-
89-
return ddbc.sql(sql)

countess/plugins/vampseq.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2-
from typing import Optional
2+
from typing import Iterable, Optional
33

44
from duckdb import DuckDBPyConnection, DuckDBPyRelation
55

66
from countess import VERSION
77
from countess.core.parameters import ColumnOrNoneChoiceParam, FloatParam, PerNumericColumnArrayParam, TabularMultiParam
8-
from countess.core.plugins import DuckdbSimplePlugin
8+
from countess.core.plugins import DuckdbSqlPlugin
99
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal
1010

1111
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ class CountColumnParam(TabularMultiParam):
1515
weight = FloatParam("Weight")
1616

1717

18-
class VampSeqScorePlugin(DuckdbSimplePlugin):
18+
class VampSeqScorePlugin(DuckdbSqlPlugin):
1919
name = "VAMP-seq Scoring"
2020
description = "Calculate scores from weighed bin counts"
2121
version = VERSION
@@ -32,17 +32,15 @@ def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation])
3232
for n, c in enumerate(count_cols):
3333
c.weight.value = (n + 1) / len(count_cols)
3434

35-
def execute(
36-
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
37-
) -> Optional[DuckDBPyRelation]:
35+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
3836
weighted_columns = {
3937
duckdb_escape_identifier(name): duckdb_escape_literal(param.weight.value)
4038
for name, param in self.columns.get_column_params()
4139
if param.weight.value is not None
4240
}
4341

4442
if not weighted_columns:
45-
return source
43+
return None
4644

4745
if self.group_col.is_not_none():
4846
group_col_id = "T0." + duckdb_escape_identifier(self.group_col.value)
@@ -53,14 +51,11 @@ def execute(
5351
weighted_counts = " + ".join(f"T0.{k} * {v} / T1.{k}" for k, v in weighted_columns.items())
5452
total_counts = " + ".join(f"T0.{k} / T1.{k}" for k in weighted_columns.keys())
5553

56-
sql = f"""
54+
return f"""
5755
select T0.*, ({weighted_counts}) / ({total_counts}) as score
58-
from {source.alias} T0 join (
56+
from {table_name} T0 join (
5957
select {group_col_id} as score_group, {sums}
60-
from {source.alias} T0
58+
from {table_name} T0
6159
group by score_group
6260
) T1 on ({group_col_id} = T1.score_group)
6361
"""
64-
65-
logger.debug("VampseqScorePlugin sql %s", sql)
66-
return ddbc.sql(sql)

countess/plugins/variant.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import logging
22
import string
33
from functools import lru_cache
4-
from typing import Any, Optional
5-
6-
from duckdb import DuckDBPyConnection, DuckDBPyRelation
4+
from typing import Any, Iterable, Optional
75

86
from countess import VERSION
97
from countess.core.parameters import (
@@ -17,7 +15,7 @@
1715
StringCharacterSetParam,
1816
StringParam,
1917
)
20-
from countess.core.plugins import DuckdbParallelTransformPlugin, DuckdbSimplePlugin
18+
from countess.core.plugins import DuckdbParallelTransformPlugin, DuckdbSqlPlugin
2119
from countess.utils.duckdb import duckdb_escape_identifier
2220
from countess.utils.variant import TooManyVariationsException, find_variant_string
2321

@@ -135,16 +133,14 @@ def transform(self, data: dict[str, Any]) -> Optional[dict[str, Any]]:
135133
return data
136134

137135

138-
class VariantClassifier(DuckdbSimplePlugin):
136+
class VariantClassifier(DuckdbSqlPlugin):
139137
name = "Protein Variant Classifier"
140138
description = "Classifies protein variants into simple types"
141139
version = VERSION
142140

143141
variant_col = ColumnChoiceParam("Protein variant Column", "variant")
144142

145-
def execute(
146-
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
147-
) -> Optional[DuckDBPyRelation]:
143+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
148144
variant_col_id = duckdb_escape_identifier(self.variant_col.value)
149145
output_col_id = duckdb_escape_identifier(self.variant_col + "_type")
150146

@@ -154,23 +150,20 @@ def execute(
154150
# once for each distinct variant string. Then the cases
155151
# in the outer select use the parts of the regex match to
156152
# classify the variant.
157-
sql = rf"""
153+
return rf"""
158154
select S.*, case when T.a != '' or T.c == '' and T.e == '=' then 'W'
159155
when T.c != '' and (T.c = T.e or T.e = '=') then 'S'
160156
when T.e = 'Ter' or T.e = '*' then 'N'
161157
when T.c != '' and T.d != '' and T.e != '' then 'M'
162158
else '?'
163159
end as {output_col_id}
164-
from {source.alias} S join (
160+
from {table_name} S join (
165161
select {variant_col_id} as z, unnest(regexp_extract(
166162
{variant_col_id},
167163
'(_?[Ww][Tt])|(p.)?([A-Z][a-z]*)?(\d+)?([A-Z][a-z]*|[=*])?',
168164
['a','b','c','d','e']
169165
))
170-
from {source.alias}
166+
from {table_name}
171167
group by z
172168
) T on S.{variant_col_id} = T.z
173169
"""
174-
175-
logger.debug("VariantClassifier sql %s", sql)
176-
return ddbc.sql(sql)

0 commit comments

Comments
 (0)