Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Map type decoding broken #125

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions aiochclient/_types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ RE_NULLABLE = re.compile(r"^Nullable\((.*)\)$")
RE_LOW_CARDINALITY = re.compile(r"^LowCardinality\((.*)\)$")
RE_MAP = re.compile(r"^Map\((.*)\)$")
RE_REPLACE_QUOTE = re.compile(r"(?<!\\)'")

RE_SIMPLE_AGG = re.compile(r"^SimpleAggregateFunction\([^,]+,\s*(.+)\)$")

cdef str remove_single_quotes(str string):
if string[0] == string[-1] == "'":
Expand Down Expand Up @@ -172,6 +172,8 @@ cdef class StrType:
def __cinit__(self, str name, bint container):
self.name = name
self.container = container
inner_type_str = RE_SIMPLE_AGG.findall(name)[0]
self.inner_type = what_py_type(inner_type_str, container=False)

cdef str _convert(self, str string):
string = decode(string.encode())
Expand Down Expand Up @@ -758,6 +760,27 @@ cdef class DecimalType:
cpdef object convert(self, bytes value):
return Decimal(value.decode())

cdef class SimpleAggregateFunctionType:
cdef:
str name
bint container
object inner_type

def __cinit__(self, str name, bint container):
self.name = name
self.container = container
inner_type_str = RE_SIMPLE_AGG.findall(name)[0]
self.inner_type = what_py_type(inner_type_str, container=False)

cdef object _convert(self, str string):
return self.inner_type.p_type(string)

cpdef object p_type(self, str string):
return self._convert(string)

cpdef object convert(self, bytes value):
return self._convert(decode(value))


cdef dict CH_TYPES_MAPPING = {
"Bool": BoolType,
Expand Down Expand Up @@ -796,22 +819,24 @@ cdef dict CH_TYPES_MAPPING = {
"IPv4": IPv4Type,
"IPv6": IPv6Type,
"Nested": NestedType,
"SimpleAggregateFunction": SimpleAggregateFunctionType,
}


cdef what_py_type(str name, bint container = False):
""" Returns needed type class from clickhouse type name """
name = name.strip()
try:
if name.startswith('SimpleAggregateFunction') or name.startswith('AggregateFunction'):
if name.startswith('SimpleAggregateFunction'):
return SimpleAggregateFunctionType(name, container=container)
elif name.startswith('AggregateFunction'):
ch_type = re.findall(r',(.*)\)', name)[0].strip()
else:
ch_type = name.split("(")[0]
return CH_TYPES_MAPPING[ch_type](name, container=container)
except KeyError:
raise ChClientError(f"Unrecognized type name: '{name}'")


cpdef what_py_converter(str name, bint container = False):
""" Returns needed type class from clickhouse type name """
return what_py_type(name, container).convert
Expand Down
32 changes: 21 additions & 11 deletions aiochclient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def convert(self, value: bytes) -> IPv4Address:
return self.p_type(value.decode())

@staticmethod
def unconvert(value: UUID) -> bytes:
def unconvert(value: IPv4Address) -> bytes:
return b"%a" % str(value)


Expand Down Expand Up @@ -323,14 +323,27 @@ def __init__(self, name: str, **kwargs):
self.value_type = what_py_type(tps[comma_index + 1 :], container=True)

def p_type(self, string: str) -> dict:
key, value = string[1:-1].split(':', 1)
return {
self.key_type.p_type(key): self.value_type.p_type(value)

}
"""Parse a TSV-encoded Map string into a dictionary."""
# Remove surrounding brackets or quotes if present
string = string.strip("[]'\"")
if not string:
return {}

# Split by tabs (TSV format for Map in ClickHouse)
parts = string.split('\t')
if len(parts) % 2 != 0:
raise ChClientError(f"Invalid Map TSV format: {string}")

# Convert pairs into a dictionary
result = {}
for i in range(0, len(parts), 2):
key = self.key_type.p_type(parts[i])
value = self.value_type.p_type(parts[i + 1])
result[key] = value
return result

def convert(self, value: bytes) -> dict:
return self.p_type(value.decode())
return self.p_type(self.decode(value))

@staticmethod
def unconvert(value) -> bytes:
Expand All @@ -349,10 +362,7 @@ def __init__(self, name: str, **kwargs):
self.type = what_py_type(RE_ARRAY.findall(name)[0], container=True)

def p_type(self, string: str) -> list:
return [
self.type.p_type(val)
for val in self.seq_parser(string[1:-1])
]
return [self.type.p_type(val) for val in self.seq_parser(string[1:-1])]

def convert(self, value: bytes) -> list:
return self.p_type(value.decode())
Expand Down
59 changes: 59 additions & 0 deletions tests/test_map_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
from aiohttp import ClientSession
from aiochclient import ChClient


@pytest.mark.asyncio
async def test_map_decoding_multiple_pairs():
async with ClientSession() as session:
client = ChClient(session, url="http://localhost:8123", user="default", password="")
result = await client.fetch(
"SELECT map('a', '1', 'b', '2', 'c', '3') AS data",
decode=True
)
assert len(result) == 1
assert result[0]['data'] == {'a': '1', 'b': '2', 'c': '3'}, "Multi-pair map decoding failed"


@pytest.mark.asyncio
async def test_map_decoding_single_pair():
async with ClientSession() as session:
client = ChClient(session, url="http://localhost:8123")
result = await client.fetch(
"SELECT map('x', 'y') AS data",
decode=True
)
assert len(result) == 1
assert result[0]['data'] == {'x': 'y'}, "Single-pair map decoding failed"


@pytest.mark.asyncio
async def test_map_decoding_empty():
async with ClientSession() as session:
client = ChClient(session, url="http://localhost:8123")
result = await client.fetch(
"SELECT map() AS data",
decode=True
)
assert len(result) == 1
assert result[0]['data'] == {}, "Empty map decoding failed"


@pytest.mark.asyncio
async def test_map_table_insert_and_fetch():
async with ClientSession() as session:
client = ChClient(session, url="http://localhost:8123")
await client.execute(
"CREATE TABLE IF NOT EXISTS test_map (id UInt8, data Map(String, String)) ENGINE = Memory"
)
await client.execute(
"INSERT INTO test_map VALUES",
(1, {'a': '1', 'b': '2', 'c': '3'})
)
result = await client.fetch(
"SELECT data FROM test_map WHERE id = 1",
decode=True
)
assert len(result) == 1
assert result[0]['data'] == {'a': '1', 'b': '2', 'c': '3'}, "Table map decoding failed"
await client.execute("DROP TABLE test_map")