Skip to content

Commit cc10463

Browse files
authored
Add encode and decode for array of array. (#594)
1 parent 08d33c7 commit cc10463

File tree

2 files changed

+127
-5
lines changed

2 files changed

+127
-5
lines changed

gel/protocol/codecs/array.pyx

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,22 @@ cdef class BaseArrayCodec(BaseCodec):
3333
cdef encode(self, WriteBuffer buf, object obj):
3434
cdef:
3535
WriteBuffer elem_data
36+
WriteBuffer tuple_elem_data
3637
int32_t ndims = 1
3738
Py_ssize_t objlen
3839
Py_ssize_t i
3940

4041
if not isinstance(
4142
self.sub_codec,
42-
(ScalarCodec, TupleCodec, NamedTupleCodec, EnumCodec,
43-
RangeCodec, MultiRangeCodec)
43+
(
44+
ScalarCodec,
45+
TupleCodec,
46+
NamedTupleCodec,
47+
EnumCodec,
48+
RangeCodec,
49+
MultiRangeCodec,
50+
ArrayCodec,
51+
)
4452
):
4553
raise TypeError(
4654
'only arrays of scalars are supported (got type {!r})'.format(
@@ -67,7 +75,20 @@ cdef class BaseArrayCodec(BaseCodec):
6775
)
6876
else:
6977
try:
70-
self.sub_codec.encode(elem_data, item)
78+
if isinstance(self.sub_codec, ArrayCodec):
79+
# This is an array of array.
80+
# Wrap the inner array with a tuple.
81+
tuple_elem_data = WriteBuffer.new()
82+
self.sub_codec.encode(tuple_elem_data, item)
83+
84+
elem_data.write_int32(4 + 4 + tuple_elem_data.len()) # buffer length
85+
elem_data.write_int32(1) # tuple_elem_count
86+
elem_data.write_int32(0) # reserved
87+
elem_data.write_buffer(tuple_elem_data)
88+
89+
else:
90+
self.sub_codec.encode(elem_data, item)
91+
7192
except TypeError as e:
7293
raise ValueError(
7394
'invalid array element: {}'.format(
@@ -121,8 +142,27 @@ cdef class BaseArrayCodec(BaseCodec):
121142
if elem_len == -1:
122143
elem = None
123144
else:
124-
frb_slice_from(&elem_buf, buf, elem_len)
125-
elem = self.sub_codec.decode(&elem_buf)
145+
146+
if isinstance(self.sub_codec, ArrayCodec):
147+
# This is an array of array
148+
# Unwrap the tuple from the inner array.
149+
tuple_elem_count = <Py_ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
150+
if tuple_elem_count != 1:
151+
raise RuntimeError(
152+
f'cannot decode inner array: expected 1 '
153+
f'element, got {tuple_elem_count}')
154+
155+
frb_read(buf, 4) # reserved
156+
tuple_elem_len = hton.unpack_int32(frb_read(buf, 4))
157+
158+
elem = self.sub_codec.decode(
159+
frb_slice_from(&elem_buf, buf, tuple_elem_len)
160+
)
161+
162+
else:
163+
frb_slice_from(&elem_buf, buf, elem_len)
164+
elem = self.sub_codec.decode(&elem_buf)
165+
126166
if frb_get_len(&elem_buf):
127167
raise RuntimeError(
128168
f'unexpected trailing data in buffer after '

tests/test_array.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#
2+
# This source file is part of the EdgeDB open source project.
3+
#
4+
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
20+
from gel import _testbase as tb
21+
22+
23+
class TestArrayOfArray(tb.SyncQueryTestCase):
24+
def setUp(self):
25+
super().setUp()
26+
27+
if self.client.query_required_single('''
28+
select sys::get_version().major < 7
29+
'''):
30+
self.skipTest("Test needs nested arrays")
31+
32+
async def test_array_of_array_01(self):
33+
# basic array of array
34+
self.assertEqual(
35+
self.client.query_single(
36+
'select <array<array<int64>>>[]'
37+
),
38+
[[]],
39+
)
40+
self.assertEqual(
41+
self.client.query_single(
42+
'select [[1]]'
43+
),
44+
[[1]],
45+
)
46+
self.assertEqual(
47+
self.client.query_single(
48+
'select [[[1]]]'
49+
),
50+
[[[1]]],
51+
)
52+
self.assertEqual(
53+
self.client.query_single(
54+
'select [[[[1]]]]'
55+
),
56+
[[[[1]]]],
57+
)
58+
self.assertEqual(
59+
self.client.query_single(
60+
'select [[1], [2, 3], [4, 5, 6, 7]]'
61+
),
62+
[[1], [2, 3], [4, 5, 6, 7]],
63+
)
64+
65+
async def test_array_of_array_02(self):
66+
# check that array tuple array still works
67+
self.assertEqual(
68+
self.client.query_single(
69+
'select [([1],)]'
70+
),
71+
[([1],)],
72+
)
73+
74+
async def test_array_of_array_03(self):
75+
# check encoding array of array
76+
self.assertEqual(
77+
self.client.query_single(
78+
'select <array<array<int64>>>$0',
79+
[[1]],
80+
),
81+
[[1]],
82+
)

0 commit comments

Comments
 (0)