11import logging
2- import re
3- from typing import Any , Optional
2+ from typing import Optional
3+
4+ from duckdb import DuckDBPyConnection , DuckDBPyRelation
45
56from countess import VERSION
67from countess .core .parameters import (
1112 MultiParam ,
1213 StringParam ,
1314)
14- from countess .core .plugins import DuckdbThreadedTransformPlugin
15+ from countess .core .plugins import DuckdbSimplePlugin
16+ from countess .utils .duckdb import duckdb_escape_identifier , duckdb_escape_literal
1517
1618logger = logging .getLogger (__name__ )
1719
@@ -21,7 +23,7 @@ class OutputColumnsMultiParam(MultiParam):
2123 datatype = DataTypeChoiceParam ("Column Type" , "STRING" )
2224
2325
24- class RegexToolPlugin (DuckdbThreadedTransformPlugin ):
26+ class RegexToolPlugin (DuckdbSimplePlugin ):
2527 name = "Regex Tool"
2628 description = "Apply regular expressions to a column to make new column(s)"
2729 link = "https://countess-project.github.io/CountESS/included-plugins/#regex-tool"
@@ -33,37 +35,34 @@ class RegexToolPlugin(DuckdbThreadedTransformPlugin):
3335 drop_column = BooleanParam ("Drop Column" , False )
3436 drop_unmatch = BooleanParam ("Drop Unmatched Rows" , False )
3537
36- compiled_re = None
37-
38- def prepare (self , * a ) -> None :
39- super ().prepare (* a )
40- self .compiled_re = re .compile (self .regex .value )
41-
42- def add_fields (self ):
43- return {op .name .value : op .datatype .get_selected_type () for op in self .output }
38+ def execute (
39+ self , ddbc : DuckDBPyConnection , source : DuckDBPyRelation , row_limit : Optional [int ] = None
40+ ) -> Optional [DuckDBPyRelation ]:
41+ column_id = duckdb_escape_identifier (self .column .value )
42+ regexp_value = duckdb_escape_literal (self .regex .value )
43+ output_ids = [duckdb_escape_literal (o .name .value ) for o in self .output if o .name .value ]
44+ output_types = [
45+ duckdb_escape_identifier (o .name .value ) + " " + o .datatype .value for o in self .output if o .name .value
46+ ]
4447
45- def remove_fields (self , field_names ):
4648 if self .drop_column :
47- return [ self .column .value ]
49+ proj = "" . join ( duckdb_escape_identifier ( c ) + ", " for c in source . columns if c != self .column .value )
4850 else :
49- return []
51+ proj = "*, "
5052
51- def transform (self , data : dict [str , Any ]) -> Optional [dict [str , Any ]]:
52- assert self .compiled_re is not None
53- value = data [self .column .value ]
54- if value is not None :
55- try :
56- if match := self .compiled_re .match (str (value )):
57- data .update (
58- {op .name .value : op .datatype .cast_value (val ) for op , val in zip (self .output , match .groups ())}
59- )
60- return data
61- else :
62- logger .info ("%s didn't match" , repr (value ))
63- except (TypeError , ValueError ) as exc :
64- logger .warning ("Exception" , exc_info = exc )
53+ proj += f"""
54+ unnest(try_cast(
55+ regexp_extract({ column_id } , { regexp_value } , [{ ',' .join (output_ids )} ])
56+ as struct({ ',' .join (output_types )} )
57+ ))
58+ """
59+
60+ logger .debug ("VampseqScorePlugin proj %s" , proj )
6561
6662 if self .drop_unmatch :
67- return None
63+ filt = f"regexp_matches({ column_id } , { regexp_value } )"
64+ logger .debug ("VampseqScorePlugin filt %s" , filt )
65+ return source .filter (filt ).project (proj )
66+
6867 else :
69- return data
68+ return source . project ( proj )
0 commit comments