Skip to content

[RSDK-7443] paginate data #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 109 additions & 80 deletions src/viam/app/data_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, List, Mapping, Optional, Sequence, Tuple
Expand Down Expand Up @@ -33,6 +34,7 @@
Filter,
GetDatabaseConnectionRequest,
GetDatabaseConnectionResponse,
Order,
RemoveBinaryDataFromDatasetByIDsRequest,
RemoveBoundingBoxFromImageByIDRequest,
RemoveTagsFromBinaryDataByFilterRequest,
Expand Down Expand Up @@ -109,6 +111,7 @@ async def main():

"""

@dataclass
class TabularData:
"""Class representing a piece of tabular data and associated metadata.

Expand All @@ -119,16 +122,17 @@ class TabularData:
time_received (datetime): the time the requested data was received.
"""

def __init__(self, data: Mapping[str, Any], metadata: CaptureMetadata, time_requested: datetime, time_received: datetime) -> None:
self.data = data
self.metadata = metadata
self.time_requested = time_requested
self.time_received = time_received

data: Mapping[str, Any]
"""The requested data"""

metadata: CaptureMetadata
"""The metadata associated with the data"""

time_requested: datetime
"""The time the data were requested"""

time_received: datetime
"""The time the data were received"""

def __str__(self) -> str:
return f"{self.data}\n{self.metadata}Time requested: {self.time_requested}\nTime received: {self.time_received}\n"
Expand All @@ -139,6 +143,7 @@ def __eq__(self, other: object) -> bool:
return False

# TODO (RSDK-6684): Revisit if this shadow type is necessary
@dataclass
class BinaryData:
"""Class representing a piece of binary data and associated metadata.

Expand All @@ -147,12 +152,11 @@ class BinaryData:
metadata (viam.proto.app.data.BinaryMetadata): the metadata from the request.
"""

def __init__(self, data: bytes, metadata: BinaryMetadata) -> None:
self.data = data
self.metadata = metadata

data: bytes
"""The request data"""

metadata: BinaryMetadata
"""The metadata associated with the data"""

def __str__(self) -> str:
return f"{self.data}\n{self.metadata}"
Expand Down Expand Up @@ -184,47 +188,65 @@ def __init__(self, channel: Channel, metadata: Mapping[str, str]):
async def tabular_data_by_filter(
self,
filter: Optional[Filter] = None,
limit: Optional[int] = None,
sort_order: Optional[Order.ValueType] = None,
last: Optional[str] = None,
count_only: bool = False,
dest: Optional[str] = None,
) -> List[TabularData]:
"""Filter and download tabular data.
) -> Tuple[List[TabularData], int, str]:
"""Filter and download tabular data. The data will be paginated into pages of `limit` items, and the pagination ID will be included
in the returned tuple. If a destination is provided, the data will be saved to that file.
If the file is not empty, it will be overwritten.

::

from viam.proto.app.data import Filter

my_data = []
last = None
my_filter = Filter(component_name="left_motor")
tabular_data = await data_client.tabular_data_by_filter(my_filter)
while True:
tabular_data, last = await data_client.tabular_data_by_filter(my_filter, last)
if not tabular_data:
break
my_data.extend(tabular_data)


Args:
filter (viam.proto.app.data.Filter): Optional `Filter` specifying tabular data to retrieve. No `Filter` implies all tabular
data.
limit (int): The maximum number of entries to include in a page. Defaults to 50.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to match the API undefined fields, can have this default to Optional[int] = None and specify here that if it's unspecified, we'll retrieve 50? That way if the default ever changes in the server code, we don't need to change it in the SDK code as well (and can just update the comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! Didn't realize the server set the limit, thought we had to do it in the client.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, could possibly update the comment to 'default to 50 if unspecified' to make it clear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sort_order (viam.proto.app.data.Order): The desired sort order of the data.
last (str): Optional string indicating the ID of the last-returned data.
If provided, the server will return the next data entries after the `last` ID.
count_only (bool): Whether to return only the total count of entries.
dest (str): Optional filepath for writing retrieved data.

Returns:
List[TabularData]: The tabular data.
int: The count (number of entries)
str: The last-returned page ID.
"""
filter = filter if filter else Filter()
last = ""
data = []

# `DataRequest`s are limited to 100 pieces of data, so we loop through calls until
# we are certain we've received everything.
while True:
data_request = DataRequest(filter=filter, limit=100, last=last)
request = TabularDataByFilterRequest(data_request=data_request, count_only=False)
response: TabularDataByFilterResponse = await self._data_client.TabularDataByFilter(request, metadata=self._metadata)
if not response.data or len(response.data) == 0:
break
data += [
DataClient.TabularData(
struct_to_dict(struct.data),
response.metadata[struct.metadata_index],
struct.time_requested.ToDatetime(),
struct.time_received.ToDatetime(),
)
for struct in response.data
]
last = response.last

data_request = DataRequest(filter=filter)
if limit:
data_request.limit = limit
if sort_order:
data_request.sort_order = sort_order
if last:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for updating here and in BinaryDataByFilter! Similar to my PR-level question about testing, could we confirm that keeping these unspecified gives back what we expect in an SDK call?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't manually tested this code -- I don't have a robot with a ton of data. But also, this is just client code. I added tests to make sure that the client is forwarding values properly to the server, but beyond that the client is simply going to trust that the server returns back the appropriate data for the request

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, so confirming that if these are unspecified, the proto field is unspecified. If so, that matches what we'd want.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMed, but I just added you into the Data/ML Dev Org so that you can test against this, just so we can be extra sure that this is following the behavior we expect before giving customer updates.

data_request.last = last
request = TabularDataByFilterRequest(data_request=data_request, count_only=count_only)
response: TabularDataByFilterResponse = await self._data_client.TabularDataByFilter(request, metadata=self._metadata)
data = [
DataClient.TabularData(
struct_to_dict(struct.data),
response.metadata[struct.metadata_index],
struct.time_requested.ToDatetime(),
struct.time_received.ToDatetime(),
)
for struct in response.data
]

if dest:
try:
Expand All @@ -233,59 +255,72 @@ async def tabular_data_by_filter(
file.flush()
except Exception as e:
LOGGER.error(f"Failed to write tabular data to file {dest}", exc_info=e)
return data
return data, response.count, response.last

async def binary_data_by_filter(
self, filter: Optional[Filter] = None, dest: Optional[str] = None, include_file_data: bool = True, num_files: Optional[int] = None
) -> List[BinaryData]:
"""Filter and download binary data.
self,
filter: Optional[Filter] = None,
limit: Optional[int] = None,
sort_order: Optional[Order.ValueType] = None,
last: Optional[str] = None,
include_binary_data: bool = True,
count_only: bool = False,
include_internal_data: bool = False,
dest: Optional[str] = None,
) -> Tuple[List[BinaryData], int, str]:
"""Filter and download binary data. The data will be paginated into pages of `limit` items, and the pagination ID will be included
in the returned tuple. If a destination is provided, the data will be saved to that file.
If the file is not empty, it will be overwritten.

::

from viam.proto.app.data import Filter

my_filter = Filter(component_type="camera")
binary_data = await data_client.binary_data_by_filter(my_filter)

Args:
filter (Optional[viam.proto.app.data.Filter]): Optional `Filter` specifying binary data to retrieve. No `Filter` implies all
binary data.
dest (Optional[str]): Optional filepath for writing retrieved data.
include_file_data (bool): Boolean specifying whether to actually include the binary file data with each retrieved file. Defaults
to true (i.e., both the files' data and metadata are returned).
num_files (Optional[str]): Number of binary data to return. Passing 0 returns all binary data matching the filter no matter.
Defaults to 100 if no binary data is requested, otherwise 10. All binary data or the first `num_files` will be returned,
whichever comes first.
my_data = []
last = None
my_filter = Filter(component_name="camera")
while True:
data, last = await data_client.binary_data_by_filter(my_filter, last)
if not data:
break
my_data.extend(data)

Raises:
ValueError: If `num_files` is less than 0.
Args:
filter (viam.proto.app.data.Filter): Optional `Filter` specifying tabular data to retrieve. No `Filter` implies all binary
data.
limit (int): The maximum number of entries to include in a page. Defaults to 50.
sort_order (viam.proto.app.data.Order): The desired sort order of the data.
last (str): Optional string indicating the ID of the last-returned data.
If provided, the server will return the next data entries after the `last` ID.
include_binary_data (bool): Boolean specifying whether to actually include the binary file data with each retrieved file.
Defaults to true (i.e., both the files' data and metadata are returned).
count_only (bool): Whether to return only the total count of entries.
include_internal_data (bool): Whether to return the internal data. Internal data is used for Viam-specific data ingestion,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're exposing include_internal_data exposed in the SDK, would great to add the corresponding param for TabularDataByFilter for consistency

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be included?? I can remove it. I saw it in the proto so I added it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok added just to make things quick

like cloud SLAM. Defaults to `False`
dest (str): Optional filepath for writing retrieved data.

Returns:
List[BinaryData]: The binary data.
int: The count (number of entries)
str: The last-returned page ID.
"""
num_files = num_files if num_files else 10 if include_file_data else 100
if num_files < 0:
raise ValueError("num_files must be at least 0.")
filter = filter if filter else Filter()
limit = 1 if include_file_data else 100
last = ""
data = []

# `DataRequest`s are limited in pieces of data, so we loop through calls until
# we are certain we've received everything.
while True:
new_data, last = await self._binary_data_by_filter(filter=filter, limit=limit, include_binary=include_file_data, last=last)
if not new_data or len(new_data) == 0:
break
elif num_files != 0 and len(new_data) > num_files:
data += new_data[0:num_files]
break
else:
data += new_data
num_files -= len(new_data)
if num_files == 0:
break

data_request = DataRequest(filter=filter)
if limit:
data_request.limit = limit
if sort_order:
data_request.sort_order = sort_order
if last:
data_request.last = last
request = BinaryDataByFilterRequest(
data_request=data_request,
include_binary=include_binary_data,
count_only=count_only,
include_internal_data=include_internal_data,
)
response: BinaryDataByFilterResponse = await self._data_client.BinaryDataByFilter(request, metadata=self._metadata)
data = [DataClient.BinaryData(data.binary, data.metadata) for data in response.data]
if dest:
try:
file = open(dest, "w")
Expand All @@ -294,13 +329,7 @@ async def binary_data_by_filter(
except Exception as e:
LOGGER.error(f"Failed to write binary data to file {dest}", exc_info=e)

return data

async def _binary_data_by_filter(self, filter: Filter, limit: int, include_binary: bool, last: str) -> Tuple[List[BinaryData], str]:
data_request = DataRequest(filter=filter, limit=limit, last=last)
request = BinaryDataByFilterRequest(data_request=data_request, count_only=False, include_binary=include_binary)
response: BinaryDataByFilterResponse = await self._data_client.BinaryDataByFilter(request, metadata=self._metadata)
return [DataClient.BinaryData(data.binary, data.metadata) for data in response.data], response.last
return data, response.count, response.last

async def binary_data_by_ids(
self,
Expand Down
15 changes: 13 additions & 2 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,14 @@ async def TabularDataByFilter(self, stream: Stream[TabularDataByFilterRequest, T
time_received=datetime_to_timestamp(tabular_data.time_received),
)
)
await stream.send_message(TabularDataByFilterResponse(data=tabular_response_structs, metadata=tabular_metadata))
await stream.send_message(
TabularDataByFilterResponse(
data=tabular_response_structs,
metadata=tabular_metadata,
count=len(tabular_response_structs),
last="LAST_TABULAR_DATA_PAGE_ID",
)
)
self.was_tabular_data_requested = True

async def BinaryDataByFilter(self, stream: Stream[BinaryDataByFilterRequest, BinaryDataByFilterResponse]) -> None:
Expand All @@ -798,7 +805,11 @@ async def BinaryDataByFilter(self, stream: Stream[BinaryDataByFilterRequest, Bin
self.filter = request.data_request.filter
self.include_binary = request.include_binary
await stream.send_message(
BinaryDataByFilterResponse(data=[BinaryData(binary=data.data, metadata=data.metadata) for data in self.binary_response])
BinaryDataByFilterResponse(
data=[BinaryData(binary=data.data, metadata=data.metadata) for data in self.binary_response],
count=len(self.binary_response),
last="LAST_BINARY_DATA_PAGE_ID",
)
)
self.was_binary_data_requested = True

Expand Down
8 changes: 6 additions & 2 deletions tests/test_data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,21 @@ class TestClient:
async def test_tabular_data_by_filter(self, service: MockData):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
tabular_data = await client.tabular_data_by_filter(filter=FILTER)
tabular_data, count, last = await client.tabular_data_by_filter(filter=FILTER)
assert tabular_data == TABULAR_RESPONSE
assert count == len(tabular_data)
assert last != ""
self.assert_filter(filter=service.filter)

@pytest.mark.asyncio
async def test_binary_data_by_filter(self, service: MockData):
async with ChannelFor([service]) as channel:
client = DataClient(channel, DATA_SERVICE_METADATA)
binary_data = await client.binary_data_by_filter(filter=FILTER, include_file_data=INCLUDE_BINARY)
binary_data, count, last = await client.binary_data_by_filter(filter=FILTER, include_binary_data=INCLUDE_BINARY)
assert service.include_binary == INCLUDE_BINARY
assert binary_data == BINARY_RESPONSE
assert count == len(binary_data)
assert last != ""
self.assert_filter(filter=service.filter)

@pytest.mark.asyncio
Expand Down