Skip to content

Commit b799994

Browse files
committed
fix: patch up stream object type and other bugs
1 parent 0654465 commit b799994

File tree

4 files changed

+229
-83
lines changed

4 files changed

+229
-83
lines changed

btrdb/conn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
import os
2121
import re
22+
from typing import List
2223
import uuid as uuidlib
2324
from concurrent.futures import ThreadPoolExecutor
2425

@@ -255,8 +256,7 @@ def streams(self, *identifiers, versions=None, is_collection_prefix=False):
255256

256257
if versions and len(versions) != len(identifiers):
257258
raise ValueError("number of versions does not match identifiers")
258-
259-
streams = []
259+
streams: List[Stream] = []
260260
for ident in identifiers:
261261
if isinstance(ident, uuidlib.UUID):
262262
streams.append(self.stream_from_uuid(ident))
@@ -277,15 +277,17 @@ def streams(self, *identifiers, versions=None, is_collection_prefix=False):
277277
is_collection_prefix=is_collection_prefix,
278278
tags={"name": parts[-1]},
279279
)
280-
if len(found) == 1:
280+
if isinstance(found, Stream):
281+
streams.append(found)
282+
continue
283+
if isinstance(found, list) and len(found) == 1:
281284
streams.append(found[0])
282285
continue
283286
raise StreamNotFoundError(f"Could not identify stream `{ident}`")
284287

285288
raise ValueError(
286289
f"Could not identify stream based on `{ident}`. Identifier must be UUID or collection/name."
287290
)
288-
289291
obj = StreamSet(streams)
290292

291293
if versions:

btrdb/stream.py

Lines changed: 76 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import logging
2020
import re
21+
from typing import List
2122
import uuid as uuidlib
2223
import warnings
2324
from collections import deque
@@ -532,9 +533,7 @@ def arrow_insert(self, data: pa.Table, merge: str = "never") -> int:
532533
version = []
533534
for tab in table_slices:
534535
version.append(
535-
self._btrdb.ep.arrowInsertValues(
536-
uu=self.uuid, values=tab, policy=merge
537-
)
536+
self._btrdb.ep.arrowInsertValues(uu=self.uuid, values=tab, policy=merge)
538537
)
539538
return max(version)
540539

@@ -768,10 +767,12 @@ def arrow_values(self, start, end, version: int = 0) -> pa.Table:
768767
if len(tables) > 0:
769768
return pa.concat_tables(tables)
770769
else:
771-
schema = pa.schema([
772-
pa.field('time', pa.timestamp('ns', tz='UTC'), nullable=False),
773-
pa.field('value', pa.float64(), nullable=False),
774-
])
770+
schema = pa.schema(
771+
[
772+
pa.field("time", pa.timestamp("ns", tz="UTC"), nullable=False),
773+
pa.field("value", pa.float64(), nullable=False),
774+
]
775+
)
775776
return pa.Table.from_arrays([pa.array([]), pa.array([])], schema=schema)
776777

777778
def aligned_windows(self, start, end, pointwidth, version=0):
@@ -879,20 +880,24 @@ def arrow_aligned_windows(
879880
logger.debug(f"For stream - {self.uuid} - {self.name}")
880881
start = to_nanoseconds(start)
881882
end = to_nanoseconds(end)
882-
tables = list(self._btrdb.ep.arrowAlignedWindows(
883-
self.uuid, start=start, end=end, pointwidth=pointwidth, version=version
884-
))
883+
tables = list(
884+
self._btrdb.ep.arrowAlignedWindows(
885+
self.uuid, start=start, end=end, pointwidth=pointwidth, version=version
886+
)
887+
)
885888
if len(tables) > 0:
886889
return pa.concat_tables(tables)
887890
else:
888-
schema = pa.schema([
889-
pa.field('time', pa.timestamp('ns', tz='UTC'), nullable=False),
890-
pa.field('mean', pa.float64(), nullable=False),
891-
pa.field('min', pa.float64(), nullable=False),
892-
pa.field('max', pa.float64(), nullable=False),
893-
pa.field('count', pa.uint64(), nullable=False),
894-
pa.field('stddev', pa.float64(), nullable=False),
895-
])
891+
schema = pa.schema(
892+
[
893+
pa.field("time", pa.timestamp("ns", tz="UTC"), nullable=False),
894+
pa.field("mean", pa.float64(), nullable=False),
895+
pa.field("min", pa.float64(), nullable=False),
896+
pa.field("max", pa.float64(), nullable=False),
897+
pa.field("count", pa.uint64(), nullable=False),
898+
pa.field("stddev", pa.float64(), nullable=False),
899+
]
900+
)
896901
return pa.Table.from_arrays([pa.array([]) for _ in range(5)], schema=schema)
897902

898903
def windows(self, start, end, width, depth=0, version=0):
@@ -986,25 +991,29 @@ def arrow_windows(
986991
raise NotImplementedError(_arrow_not_impl_str.format("arrow_windows"))
987992
start = to_nanoseconds(start)
988993
end = to_nanoseconds(end)
989-
tables = list(self._btrdb.ep.arrowWindows(
990-
self.uuid,
991-
start=start,
992-
end=end,
993-
width=width,
994-
depth=0,
995-
version=version,
996-
))
994+
tables = list(
995+
self._btrdb.ep.arrowWindows(
996+
self.uuid,
997+
start=start,
998+
end=end,
999+
width=width,
1000+
depth=0,
1001+
version=version,
1002+
)
1003+
)
9971004
if len(tables) > 0:
9981005
return pa.concat_tables(tables)
9991006
else:
1000-
schema = pa.schema([
1001-
pa.field('time', pa.timestamp('ns', tz='UTC'), nullable=False),
1002-
pa.field('mean', pa.float64(), nullable=False),
1003-
pa.field('min', pa.float64(), nullable=False),
1004-
pa.field('max', pa.float64(), nullable=False),
1005-
pa.field('count', pa.uint64(), nullable=False),
1006-
pa.field('stddev', pa.float64(), nullable=False),
1007-
])
1007+
schema = pa.schema(
1008+
[
1009+
pa.field("time", pa.timestamp("ns", tz="UTC"), nullable=False),
1010+
pa.field("mean", pa.float64(), nullable=False),
1011+
pa.field("min", pa.float64(), nullable=False),
1012+
pa.field("max", pa.float64(), nullable=False),
1013+
pa.field("count", pa.uint64(), nullable=False),
1014+
pa.field("stddev", pa.float64(), nullable=False),
1015+
]
1016+
)
10081017
return pa.Table.from_arrays([pa.array([]) for _ in range(5)], schema=schema)
10091018

10101019
def nearest(self, time, version, backward=False):
@@ -1085,8 +1094,14 @@ class StreamSetBase(Sequence):
10851094
A lighweight wrapper around a list of stream objects
10861095
"""
10871096

1088-
def __init__(self, streams):
1089-
self._streams = streams
1097+
def __init__(self, streams: List[Stream]):
1098+
self._streams: List[Stream] = []
1099+
for stream in streams:
1100+
if not isinstance(stream, Stream):
1101+
raise BTRDBTypeError(
1102+
f"streams must be of type Stream {stream}, {type(stream)}"
1103+
)
1104+
self._streams.append(stream)
10901105
if len(self._streams) < 1:
10911106
raise ValueError(
10921107
f"Trying to create streamset with an empty list of streams {self._streams}."
@@ -1541,7 +1556,7 @@ def _streamset_data(self, as_iterators=False):
15411556
_ = params.pop("sampling_frequency", None)
15421557
versions = self._pinned_versions
15431558
if versions == None:
1544-
versions = {s.uuid : 0 for s in self}
1559+
versions = {s.uuid: 0 for s in self}
15451560

15461561
if self.pointwidth is not None:
15471562
# create list of stream.aligned_windows data
@@ -1734,12 +1749,12 @@ def values(self):
17341749
result.append([point[0] for point in stream_data])
17351750
return result
17361751

1737-
def arrow_values(self, name_callable=lambda s : s.collection + '/' + s.name):
1752+
def arrow_values(self, name_callable=lambda s: s.collection + "/" + s.name):
17381753
"""Return a pyarrow table of stream values based on the streamset parameters."""
17391754
params = self._params_from_filters()
17401755
versions = self._pinned_versions
17411756
if versions == None:
1742-
versions = {s.uuid : 0 for s in self}
1757+
versions = {s.uuid: 0 for s in self}
17431758

17441759
if params.get("sampling_frequency", None) is None:
17451760
_ = params.pop("sampling_frequency", None)
@@ -1797,13 +1812,20 @@ def arrow_values(self, name_callable=lambda s : s.collection + '/' + s.name):
17971812
table = list(self._btrdb.ep.arrowMultiValues(**params))
17981813
if len(table) > 0:
17991814
data = pa.concat_tables(table)
1800-
data = data.rename_columns(["time"] + [name_callable(s) for s in self._streams])
1815+
data = data.rename_columns(
1816+
["time"] + [name_callable(s) for s in self._streams]
1817+
)
18011818
else:
18021819
schema = pa.schema(
1803-
[pa.field('time', pa.timestamp('ns', tz='UTC'), nullable=False)]
1804-
+ [pa.field(name_callable(s), pa.float64(), nullable=False) for s in self._streams],
1820+
[pa.field("time", pa.timestamp("ns", tz="UTC"), nullable=False)]
1821+
+ [
1822+
pa.field(name_callable(s), pa.float64(), nullable=False)
1823+
for s in self._streams
1824+
],
1825+
)
1826+
data = pa.Table.from_arrays(
1827+
[pa.array([]) for i in range(1 + len(self._streams))], schema=schema
18051828
)
1806-
data = pa.Table.from_arrays([pa.array([]) for i in range(1+len(self._streams))], schema=schema)
18071829
return data
18081830

18091831
def __repr__(self):
@@ -1828,6 +1850,15 @@ def __getitem__(self, item):
18281850

18291851
return self._streams[item]
18301852

1853+
def __contains__(self, item):
1854+
if isinstance(item, str):
1855+
for stream in self._streams:
1856+
if str(stream.uuid()) == item:
1857+
return True
1858+
return False
1859+
1860+
return item in self._streams
1861+
18311862
def __len__(self):
18321863
return len(self._streams)
18331864

@@ -1865,12 +1896,14 @@ def __init__(
18651896
if self.start is not None and self.end is not None and self.start >= self.end:
18661897
raise BTRDBValueError("`start` must be strictly less than `end` argument")
18671898

1899+
18681900
def _to_period_ns(fs: int):
18691901
"""Convert sampling rate to sampling period in ns."""
18701902
period = 1 / fs
18711903
period_ns = period * 1e9
18721904
return int(period_ns)
18731905

1906+
18741907
def _coalesce_table_deque(tables: deque):
18751908
main_table = tables.popleft()
18761909
idx = 0

tests/btrdb/test_conn.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,61 @@
1919
from unittest.mock import Mock, PropertyMock, call, patch
2020

2121
import pytest
22+
import uuid
23+
2224

2325
from btrdb.conn import BTrDB, Connection
2426
from btrdb.endpoint import Endpoint
2527
from btrdb.exceptions import *
2628
from btrdb.grpcinterface import btrdb_pb2
29+
from btrdb.stream import Stream
30+
31+
##########################################################################
32+
## Fixtures
33+
##########################################################################
34+
35+
36+
@pytest.fixture
37+
def stream1():
38+
uu = uuid.UUID("0d22a53b-e2ef-4e0a-ab89-b2d48fb2592a")
39+
stream = Mock(Stream)
40+
stream.version = Mock(return_value=11)
41+
stream.uuid = Mock(return_value=uu)
42+
type(stream).collection = PropertyMock(return_value="fruits/apple")
43+
type(stream).name = PropertyMock(return_value="gala")
44+
stream.tags = Mock(return_value={"name": "gala", "unit": "volts"})
45+
stream.annotations = Mock(return_value=({"owner": "ABC", "color": "red"}, 11))
46+
stream._btrdb = Mock()
47+
return stream
48+
49+
50+
@pytest.fixture
51+
def stream2():
52+
uu = uuid.UUID("17dbe387-89ea-42b6-864b-f505cdb483f5")
53+
stream = Mock(Stream)
54+
stream.version = Mock(return_value=22)
55+
stream.uuid = Mock(return_value=uu)
56+
type(stream).collection = PropertyMock(return_value="fruits/orange")
57+
type(stream).name = PropertyMock(return_value="blood")
58+
stream.tags = Mock(return_value={"name": "blood", "unit": "amps"})
59+
stream.annotations = Mock(return_value=({"owner": "ABC", "color": "orange"}, 22))
60+
stream._btrdb = Mock()
61+
return stream
62+
63+
64+
@pytest.fixture
65+
def stream3():
66+
uu = uuid.UUID("17dbe387-89ea-42b6-864b-e2ef0d22a53b")
67+
stream = Mock(Stream)
68+
stream.version = Mock(return_value=33)
69+
stream.uuid = Mock(return_value=uu)
70+
type(stream).collection = PropertyMock(return_value="fruits/banana")
71+
type(stream).name = PropertyMock(return_value="yellow")
72+
stream.tags = Mock(return_value={"name": "yellow", "unit": "watts"})
73+
stream.annotations = Mock(return_value=({"owner": "ABC", "color": "yellow"}, 33))
74+
stream._btrdb = Mock()
75+
return stream
76+
2777

2878
##########################################################################
2979
## Connection Tests
@@ -91,7 +141,7 @@ def test_streams_recognizes_uuid(self, mock_func):
91141
"""
92142
db = BTrDB(None)
93143
uuid1 = uuidlib.UUID("0d22a53b-e2ef-4e0a-ab89-b2d48fb2592a")
94-
mock_func.return_value = [1]
144+
mock_func.return_value = Stream(db, uuid1)
95145
db.streams(uuid1)
96146

97147
mock_func.assert_called_once()
@@ -104,7 +154,7 @@ def test_streams_recognizes_uuid_string(self, mock_func):
104154
"""
105155
db = BTrDB(None)
106156
uuid1 = "0d22a53b-e2ef-4e0a-ab89-b2d48fb2592a"
107-
mock_func.return_value = [1]
157+
mock_func.return_value = Stream(db, uuid1)
108158
db.streams(uuid1)
109159

110160
mock_func.assert_called_once()
@@ -117,7 +167,9 @@ def test_streams_handles_path(self, mock_func):
117167
"""
118168
db = BTrDB(None)
119169
ident = "zoo/animal/dog"
120-
mock_func.return_value = [1]
170+
mock_func.return_value = [
171+
Stream(db, "0d22a53b-e2ef-4e0a-ab89-b2d48fb2592a"),
172+
]
121173
db.streams(ident, "0d22a53b-e2ef-4e0a-ab89-b2d48fb2592a")
122174

123175
mock_func.assert_called_once()
@@ -139,12 +191,10 @@ def test_streams_raises_err(self, mock_func):
139191
with pytest.raises(StreamNotFoundError) as exc:
140192
db.streams(ident)
141193

142-
mock_func.return_value = [1, 2]
143-
with pytest.raises(StreamNotFoundError) as exc:
144-
db.streams(ident)
145-
146194
# check that does not raise if one returned
147-
mock_func.return_value = [1]
195+
mock_func.return_value = [
196+
Stream(db, ident),
197+
]
148198
db.streams(ident)
149199

150200
def test_streams_raises_valueerror(self):

0 commit comments

Comments
 (0)