1- from typing import Optional
21import logging
3-
4- from duckdb import DuckDBPyConnection , DuckDBPyRelation
2+ from typing import Iterable , Optional
53
64from countess import VERSION
7- from countess .core .parameters import PerNumericColumnArrayParam , BooleanParam , ColumnOrNoneChoiceParam
8- from countess .core .plugins import DuckdbSimplePlugin
5+ from countess .core .parameters import BooleanParam , ColumnOrNoneChoiceParam , PerNumericColumnArrayParam
6+ from countess .core .plugins import DuckdbSqlPlugin
97from countess .utils .duckdb import duckdb_escape_identifier , duckdb_escape_literal
108
119logger = logging .getLogger (__name__ )
1210
13- class CorrelationPlugin (DuckdbSimplePlugin ):
11+
12+ class CorrelationPlugin (DuckdbSqlPlugin ):
1413 """Correlations"""
1514
1615 name = "Correlation Tool"
@@ -21,28 +20,24 @@ class CorrelationPlugin(DuckdbSimplePlugin):
2120 columns = PerNumericColumnArrayParam ("Columns" , BooleanParam ("Correlate?" , False ))
2221 group = ColumnOrNoneChoiceParam ("Group" )
2322
24- def execute (
25- self , ddbc : DuckDBPyConnection , source : DuckDBPyRelation , row_limit : Optional [int ] = None
26- ) -> Optional [DuckDBPyRelation ]:
27-
23+ def sql (self , table_name : str , columns : Iterable [str ]) -> Optional [str ]:
2824 grp = duckdb_escape_identifier (self .group .value ) if self .group .is_not_none () else None
2925
3026 if sum (1 for c in self .columns .params if c .value ) < 2 :
3127 return None
3228
33- sql = " union all " .join (f"""
34- select { (grp + ", " ) if grp else "" }
35- { duckdb_escape_literal (c1 .label )} as column_x,
36- { duckdb_escape_literal (c2 .label )} as column_y,
37- corr({ duckdb_escape_identifier (c2 .label )} ,{ duckdb_escape_identifier (c1 .label )} ) as correlation_coefficient,
38- covar_pop({ duckdb_escape_identifier (c2 .label )} ,{ duckdb_escape_identifier (c1 .label )} ) as covariance_population,
39- regr_r2({ duckdb_escape_identifier (c2 .label )} ,{ duckdb_escape_identifier (c1 .label )} ) as pearsons_r2
40- from { source .alias }
41- { ("group by " + grp ) if grp else "" }
42- """
29+ return " union all " .join (
30+ f"""
31+ select { (grp + ", " ) if grp else "" }
32+ { duckdb_escape_literal (c1 .label )} as column_x,
33+ { duckdb_escape_literal (c2 .label )} as column_y,
34+ corr({ duckdb_escape_identifier (c2 .label )} ,{ duckdb_escape_identifier (c1 .label )} ) as correlation_coefficient,
35+ covar_pop({ duckdb_escape_identifier (c2 .label )} ,{ duckdb_escape_identifier (c1 .label )} ) as covariance_population,
36+ regr_r2({ duckdb_escape_identifier (c2 .label )} ,{ duckdb_escape_identifier (c1 .label )} ) as pearsons_r2
37+ from { table_name }
38+ { ("group by " + grp ) if grp else "" }
39+ """
4340 for c1 in self .columns .params
4441 for c2 in self .columns .params
4542 if c1 .value and c2 .value and c1 .label < c2 .label
4643 )
47- logger .debug ("CorrelationPlugin.execute sql %s" , sql )
48- return ddbc .sql (sql )
0 commit comments