Skip to content

Commit 60a8a7e

Browse files
committed
...
1 parent d9431d1 commit 60a8a7e

File tree

8 files changed

+140
-180
lines changed

8 files changed

+140
-180
lines changed

countess/core/parameters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,10 @@ def copy(self) -> "ChoiceParam":
471471

472472
class DataTypeChoiceParam(ChoiceParam):
473473
DATA_TYPES: Mapping[str, tuple[type, Any, Type[ScalarParam]]] = {
474-
"string": (str, "", StringParam),
475-
"number": (float, math.nan, FloatParam),
476-
"integer": (int, 0, IntegerParam),
477-
"boolean": (bool, False, BooleanParam),
474+
"VARCHAR": (str, "", StringParam),
475+
"FLOAT": (float, math.nan, FloatParam),
476+
"INTEGER": (int, 0, IntegerParam),
477+
"BOOLEAN": (bool, False, BooleanParam),
478478
}
479479

480480
def __init__(self, label: str, value: Optional[str] = None, choices: Optional[Iterable[str]] = None):

countess/core/pipeline.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def run(self, ddbc):
8383

8484
assert isinstance(self.plugin, DuckdbPlugin)
8585
if self.is_dirty:
86-
sources = [pn.run(ddbc) for pn in self.parent_nodes]
86+
sources = {pn.name: pn.run(ddbc) for pn in self.parent_nodes }
8787
ddbc.sql(f"DROP TABLE IF EXISTS n_{self.uuid}")
8888
self.plugin.execute_multi(ddbc, sources).to_table(f"n_{self.uuid}")
8989
self.result = ddbc.table(f"n_{self.uuid}")
@@ -176,9 +176,7 @@ def run(self):
176176
start_time = time.time()
177177
for node in self.traverse_nodes():
178178
node.load_config()
179-
node.result = node.plugin.execute_multi(ddbc, [pn.result for pn in node.parent_nodes])
180-
logger.debug("Got result ...")
181-
logger.debug("... %d", len(node.result))
179+
node.result = node.plugin.execute_multi(ddbc, {pn.name: pn.result for pn in node.parent_nodes})
182180

183181
logger.info("Finished, elapsed time: %d", time.time() - start_time)
184182

countess/core/plugins.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import importlib
2020
import importlib.metadata
2121
import logging
22-
from typing import Any, Iterable, List, Optional, Sequence, Type, Union
22+
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Type, Union, Mapping
2323

24+
import duckdb
2425
from duckdb import DuckDBPyConnection, DuckDBPyRelation
2526

2627
from countess.core.parameters import BaseParam, FileArrayParam, FileParam, HasSubParametersMixin, MultiParam
@@ -109,16 +110,17 @@ class DuckdbPlugin(BasePlugin):
109110
# XXX expand this, or find in library somewhere
110111
ALLOWED_TYPES = {"INTEGER", "VARCHAR", "FLOAT"}
111112

112-
def execute_multi(self, ddbc: DuckDBPyConnection, sources: List[DuckDBPyRelation]) -> Optional[DuckDBPyRelation]:
113+
def execute_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]) -> Optional[DuckDBPyRelation]:
113114
raise NotImplementedError(f"{self.__class__}.execute_multi")
114115

115116

116117
class DuckdbSimplePlugin(DuckdbPlugin):
117-
def execute_multi(self, ddbc: DuckDBPyConnection, sources: List[DuckDBPyRelation]) -> Optional[DuckDBPyRelation]:
118+
def execute_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation]) -> Optional[DuckDBPyRelation]:
119+
tables = list(sources.values())
118120
if len(sources) > 1:
119-
return self.execute(ddbc, duckdb_concatenate(sources))
121+
return self.execute(ddbc, duckdb_concatenate(tables))
120122
elif len(sources) == 1:
121-
return self.execute(ddbc, sources[0])
123+
return self.execute(ddbc, tables[0])
122124
else:
123125
return self.execute(ddbc, None)
124126

@@ -175,6 +177,10 @@ def load_file(
175177
raise NotImplementedError(f"{self.__class__}.load_file")
176178

177179

180+
class DuckdbSaveFilePlugin(DuckdbSimplePlugin):
181+
num_outputs = 0
182+
183+
178184
class DuckdbFilterPlugin(DuckdbSimplePlugin):
179185
def input_columns(self) -> dict[str, str]:
180186
raise NotImplementedError(f"{self.__class__}.input_columns")
@@ -267,27 +273,30 @@ def execute(self, ddbc, source):
267273
logger.debug("DuckDbTransformPlugin.query output_type %s", output_type)
268274
logger.debug("DuckDbTransformPlugin.query project_fields %s", project_fields)
269275

276+
# if the function already exists, remove it
270277
try:
271-
ddbc.create_function(
272-
name=function_name,
273-
function=self.transform_tuple,
274-
parameters=input_types,
275-
return_type=output_type,
276-
null_handling="special",
277-
side_effects=False,
278-
)
279-
return source.project(project_fields)
280-
finally:
278+
ddbc.remove_function(function_name)
279+
except duckdb.InvalidInputException:
280+
# it didn't exist
281281
pass
282-
# ddbc.remove_function(function_name)
282+
283+
ddbc.create_function(
284+
name=function_name,
285+
function=self.transform_tuple,
286+
parameters=input_types,
287+
return_type=output_type,
288+
null_handling="special",
289+
side_effects=False,
290+
)
291+
return source.project(project_fields)
283292

284293
def transform_tuple(self, *data):
285294
logger.debug("DuckDbTransformPlugin.transform_tuple %s", data)
286295
r = self.transform(dict(zip([k for k in self.input_columns().keys() if k is not None], data)))
287296
logger.debug("DuckDbTransformPlugin.transform_tuple %s", r)
288297
return r
289298

290-
def transform(self, data: dict[str, Any]):
299+
def transform(self, data: dict[str, Any]) -> Union[dict[str, Any], Tuple[Any], None]:
291300
"""This will be called for each row, with the columns nominated in
292301
`self.input_columns` as parameters. Return a tuple with the same
293302
value types as (or a dictionary with the same keys and value types as)

countess/gui/tabular.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
import tkinter as tk
44
from functools import partial
5+
import logging
56
from math import ceil, floor, isinf, isnan
67
from tkinter import ttk
78
from typing import Callable, Optional, Union
@@ -11,6 +12,9 @@
1112
from countess.gui.widgets import ResizingFrame, copy_to_clipboard, get_icon
1213
from countess.utils.duckdb import duckdb_dtype_is_integer, duckdb_dtype_is_numeric, duckdb_escape_identifier
1314

15+
16+
logger = logging.getLogger(__name__)
17+
1418
# XXX columns should automatically resize based on information
1519
# from _column_xscrollcommand which can tell if they're
1620
# overflowing. Or maybe use
@@ -19,12 +23,21 @@
1923

2024

2125
def column_format_for(table: DuckDBPyRelation, column: str) -> str:
22-
dtype = table[column].dtypes[0]
26+
#logger.debug("column_format_for column %s %s", column, table.columns)
27+
28+
# XXX https://github.com/duckdb/duckdb/issues/15267
29+
dtype = table.project(duckdb_escape_identifier(column)).dtypes[0]
30+
31+
logger.debug("column_format_for dtype %s", dtype)
32+
2333
if duckdb_dtype_is_numeric(dtype):
2434
# Work out the maximum width required to represent the integer part in this
2535
# column, so we can pad values to that width.
2636
column_esc = duckdb_escape_identifier(column)
27-
column_min, column_max = table.aggregate(f"min({column_esc}), max({column_esc})").fetchone()
37+
column_min_max = table.aggregate(f"min({column_esc}), max({column_esc})").fetchone()
38+
if column_min_max is None:
39+
return "%s"
40+
column_min, column_max = column_min_max
2841
if column_min is None or isnan(column_min) or isinf(column_min):
2942
column_min = -100
3043
if column_max is None or isnan(column_max) or isinf(column_max):
@@ -191,12 +204,9 @@ def refresh(self, new_offset=0):
191204
# with some window managers. Needs refactoring.
192205

193206
new_offset = max(0, min(self.length - self.height, int(new_offset)))
194-
offset_diff = new_offset - self.offset
195207

196208
rows = self.table.limit(self.height, offset=new_offset).fetchall()
197-
for column_num, (column_name, column_widget, column_format) in enumerate(
198-
zip(self.table.columns, self.columns, self.column_formats)
199-
):
209+
for column_num, (column_widget, column_format) in enumerate(zip(self.columns, self.column_formats)):
200210
column_widget["state"] = tk.NORMAL
201211
column_widget.delete("1.0", tk.END)
202212
for row in rows:
@@ -217,6 +227,7 @@ def set_click_callback(self, click_callback) -> None:
217227
self.click_callback = click_callback
218228

219229
def set_sort_order(self, column_num: int, descending: Optional[bool] = None):
230+
assert self.ddbc is not None
220231
assert self.table is not None
221232

222233
if descending is None and column_num == self.sort_by_col:

countess/plugins/csv.py

Lines changed: 42 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from io import BufferedWriter, BytesIO
66
from typing import Any, List, Optional, Sequence, Tuple, Union
77

8-
import pandas as pd
8+
import duckdb
9+
from duckdb import DuckDBPyConnection, DuckDBPyRelation
910

1011
from countess import VERSION
1112
from countess.core.parameters import (
@@ -18,9 +19,9 @@
1819
MultiParam,
1920
StringParam,
2021
)
21-
from countess.core.plugins import PandasInputFilesPlugin, PandasOutputPlugin
22+
from countess.core.plugins import DuckdbLoadFilePlugin, DuckdbSaveFilePlugin
2223
from countess.utils.files import clean_filename
23-
from countess.utils.pandas import flatten_columns
24+
from countess.utils.duckdb import duckdb_escape_literal, duckdb_escape_identifier
2425

2526
CSV_FILE_TYPES: Sequence[Tuple[str, Union[str, List[str]]]] = [
2627
("CSV", [".csv", ".csv.gz"]),
@@ -34,10 +35,18 @@
3435
class ColumnsMultiParam(MultiParam):
3536
name = StringParam("Column Name", "")
3637
type = DataTypeOrNoneChoiceParam("Column Type")
37-
index = BooleanParam("Index?", False)
3838

3939

40-
class LoadCsvPlugin(PandasInputFilesPlugin):
40+
CSV_DELIMITER_CHOICES = {
41+
',': ',',
42+
';': ';',
43+
'|': '|',
44+
'TAB': '\t',
45+
'SPACE': ' ',
46+
'NONE': None
47+
}
48+
49+
class LoadCsvPlugin(DuckdbLoadFilePlugin):
4150
"""Load CSV files"""
4251

4352
name = "CSV Load"
@@ -46,78 +55,42 @@ class LoadCsvPlugin(PandasInputFilesPlugin):
4655
version = VERSION
4756
file_types = CSV_FILE_TYPES
4857

49-
delimiter = ChoiceParam("Delimiter", ",", choices=[",", ";", "TAB", "|", "WHITESPACE"])
50-
quoting = ChoiceParam("Quoting", "None", choices=["None", "Double-Quote", "Quote with Escape"])
51-
comment = ChoiceParam("Comment", "None", choices=["None", "#", ";"])
58+
delimiter = ChoiceParam("Delimiter", ",", choices=CSV_DELIMITER_CHOICES.keys())
5259
header = BooleanParam("CSV file has header row?", True)
5360
filename_column = StringParam("Filename Column", "")
5461
columns = ArrayParam("Columns", ColumnsMultiParam("Column"))
5562

56-
def read_file_to_dataframe(self, filename: str, file_param: BaseParam, row_limit=None):
57-
options: dict[str, Any] = {
58-
"header": 0 if self.header else None,
59-
}
60-
if row_limit is not None:
61-
options["nrows"] = row_limit
62-
63-
index_col_numbers = []
64-
65-
if len(self.columns):
66-
options["names"] = []
67-
options["usecols"] = []
68-
options["converters"] = {}
69-
70-
for n, pp in enumerate(self.columns):
71-
options["names"].append(str(pp.name) or f"column_{n}")
72-
if pp.type.is_not_none():
73-
if pp.index:
74-
index_col_numbers.append(len(options["usecols"]))
75-
options["usecols"].append(n)
76-
options["converters"][n] = pp["type"].cast_value
77-
78-
if self.delimiter == "TAB":
79-
options["delimiter"] = "\t"
80-
elif self.delimiter == "WHITESPACE":
81-
options["delim_whitespace"] = True
63+
def load_file(
64+
self, cursor: DuckDBPyConnection, filename: str, file_param: BaseParam, file_number: int
65+
) -> duckdb.DuckDBPyRelation:
66+
if self.header and len(self.columns) == 0:
67+
table = cursor.read_csv(
68+
filename,
69+
header = True,
70+
delimiter = CSV_DELIMITER_CHOICES[self.delimiter.value],
71+
)
72+
for column_name, column_dtype in zip(table.columns, table.dtypes):
73+
column_param = self.columns.add_row()
74+
column_param.name.value = column_name
75+
column_param.type.value = str(column_dtype)
8276
else:
83-
options["delimiter"] = str(self.delimiter)
84-
85-
if self.quoting == "None":
86-
options["quoting"] = csv.QUOTE_NONE
87-
elif self.quoting == "Double-Quote":
88-
options["quotechar"] = '"'
89-
options["doublequote"] = True
90-
elif self.quoting == "Quote with Escape":
91-
options["quotechar"] = '"'
92-
options["doublequote"] = False
93-
options["escapechar"] = "\\"
94-
95-
if self.comment.value != "None":
96-
options["comment"] = str(self.comment)
97-
98-
# XXX pd.read_csv(index_col=) is half the speed of pd.read_csv().set_index()
99-
100-
df = pd.read_csv(filename, **options)
101-
102-
while len(df.columns) > len(self.columns):
103-
self.columns.add_row()
104-
105-
if self.header:
106-
for n, col in enumerate(df.columns):
107-
if not self.columns[n].name:
108-
self.columns[n].name = str(col)
109-
self.columns[n].type = "string"
77+
table = cursor.read_csv(
78+
filename,
79+
header = False,
80+
skiprows = 1 if self.header else 0,
81+
delimiter = CSV_DELIMITER_CHOICES[self.delimiter.value],
82+
columns = { str(c.name): str(c.type) for c in self.columns } if self.columns else None
83+
)
11084

11185
if self.filename_column:
112-
df[str(self.filename_column)] = clean_filename(filename)
113-
114-
if index_col_numbers:
115-
df = df.set_index([df.columns[n] for n in index_col_numbers])
86+
escaped_filename = duckdb_escape_literal(clean_filename(filename))
87+
escaped_column = duckdb_escape_identifier(self.filename_column.value)
88+
table = table.project(f"*, {escaped_filename} AS {escaped_column}")
11689

117-
return df
90+
return table
11891

11992

120-
class SaveCsvPlugin(PandasOutputPlugin):
93+
class SaveCsvPlugin(DuckdbSaveFilePlugin):
12194
name = "CSV Save"
12295
description = "Save data as CSV or similar delimited text files"
12396
link = "https://countess-project.github.io/CountESS/included-plugins/#csv-writer"
@@ -135,61 +108,5 @@ class SaveCsvPlugin(PandasOutputPlugin):
135108
SEPARATORS = {",": ",", ";": ";", "SPACE": " ", "TAB": "\t"}
136109
QUOTING = {False: csv.QUOTE_MINIMAL, True: csv.QUOTE_NONNUMERIC}
137110

138-
def prepare(self, sources: list[str], row_limit: Optional[int] = None):
139-
if row_limit is None:
140-
logger.debug("SaveCsvPlugin.process %s prepare %s", self.name, self.filename)
141-
filename = str(self.filename)
142-
if filename.endswith(".gz"):
143-
self.filehandle = gzip.open(filename, "wb")
144-
elif filename.endswith(".bz2"):
145-
self.filehandle = bz2.open(filename, "wb")
146-
else:
147-
self.filehandle = open(filename, "wb")
148-
else:
149-
logger.debug("SaveCsvPlugin.process %s prepare BytesIO", self.name)
150-
self.filehandle = BytesIO()
151-
152-
self.csv_columns = None
153-
154-
def process(self, data: pd.DataFrame, source: str):
155-
# reset indexes so we can treat all columns equally.
156-
# if there's just a nameless index then we don't care about it, drop it.
157-
drop_index = data.index.name is None and data.index.names[0] is None
158-
dataframe = flatten_columns(data.reset_index(drop=drop_index))
159-
160-
# if this is our first dataframe to write then decide whether to
161-
# include the header or not.
162-
if self.csv_columns is None:
163-
self.csv_columns = list(dataframe.columns)
164-
emit_header = bool(self.header)
165-
else:
166-
# add in any columns we haven't seen yet in previous dataframes.
167-
for c in dataframe.columns:
168-
if c not in self.csv_columns:
169-
self.csv_columns.append(c)
170-
logger.warning("Added CSV Column %s with no header", repr(c))
171-
# fill in blanks for any columns which are in previous dataframes but not
172-
# in this one.
173-
dataframe = dataframe.assign(**{c: None for c in self.csv_columns if c not in dataframe.columns})
174-
emit_header = False
175-
176-
logger.debug(
177-
"SaveCsvPlugin.process %s writing rows %d columns %d", self.name, len(dataframe), len(self.csv_columns)
178-
)
179-
180-
dataframe.to_csv(
181-
self.filehandle,
182-
header=emit_header,
183-
columns=self.csv_columns,
184-
index=False,
185-
sep=self.SEPARATORS[str(self.delimiter)],
186-
quoting=self.QUOTING[bool(self.quoting)],
187-
) # type: ignore [call-overload]
188-
return []
189-
190-
def finalize(self):
191-
logger.debug("SaveCsvPlugin.process %s finalize", self.name)
192-
if isinstance(self.filehandle, BytesIO):
193-
yield self.filehandle.getvalue().decode("utf-8")
194-
else:
195-
self.filehandle.close()
111+
def execute(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> Optional [DuckDBPyRelation]:
112+
pass

0 commit comments

Comments
 (0)