Skip to content

RSDK-7192 - Provisioning wrappers #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/viam/app/provisioning_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Mapping, List, Optional

from grpclib.client import Channel

from viam import logging
from viam.proto.provisioning import (
CloudConfig,
GetNetworkListRequest,
GetNetworkListResponse,
GetSmartMachineStatusRequest,
GetSmartMachineStatusResponse,
NetworkInfo,
ProvisioningServiceStub,
SetNetworkCredentialsRequest,
SetSmartMachineCredentialsRequest,
)

LOGGER = logging.getLogger(__name__)


class ProvisioningClient:
"""gRPC client for getting and setting smart machine info.

Constructor is used by `ViamClient` to instantiate relevant service stubs. Calls to
`ProvisioningClient` methods should be made through `ViamClient`.

Establish a connection::

import asyncio

from viam.rpc.dial import DialOptions, Credentials
from viam.app.viam_client import ViamClient


async def connect() -> ViamClient:
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
return await ViamClient.create_from_dial_options(dial_options)


async def main():

# Make a ViamClient
viam_client = await connect()
# Instantiate a ProvisioningClient to run provisioning client API methods on
provisioning_client = viam_client.provisioning_client
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you have to actually add the provisioning_client as a property on viam_client for this to work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 yep you're totally right! Thanks for catching that.


viam_client.close()

if __name__ == '__main__':
asyncio.run(main())

"""

def __init__(self, channel: Channel, metadata: Mapping[str, str]):
"""Create a `ProvisioningClient` that maintains a connection to app.

Args:
channel (grpclib.client.Channel): Connection to app.
metadata (Mapping[str, str]): Required authorization token to send requests to app.
"""
self._metadata = metadata
self._provisioning_client = ProvisioningServiceStub(channel)
self._channel = channel

_provisioning_client: ProvisioningServiceStub
_metadata: Mapping[str, str]
_channel: Channel

async def get_network_list(self) -> List[NetworkInfo]:
"""Returns list of networks that are visible to the Smart Machine."""
request = GetNetworkListRequest()
resp: GetNetworkListResponse = await self._provisioning_client.GetNetworkList(request, metadata=self._metadata)
return list(resp.networks)

async def get_smart_machine_status(self) -> GetSmartMachineStatusResponse:
"""Returns the status of the smart machine."""
request = GetSmartMachineStatusRequest()
return await self._provisioning_client.GetSmartMachineStatus(request, metadata=self._metadata)

async def set_network_credentials(self, network_type: str, ssid: str, psk: str) -> None:
"""Sets the network credentials of the Smart Machine.

Args:
network_type (str): The type of the network.
ssid (str): The SSID of the network.
psk (str): The network's passkey.
"""

request = SetNetworkCredentialsRequest(type=network_type, ssid=ssid, psk=psk)
await self._provisioning_client.SetNetworkCredentials(request, metadata=self._metadata)

async def set_smart_machine_credentials(self, cloud_config: Optional[CloudConfig] = None) -> None:
request = SetSmartMachineCredentialsRequest(cloud=cloud_config)
await self._provisioning_client.SetSmartMachineCredentials(request, metadata=self._metadata)
22 changes: 22 additions & 0 deletions src/viam/app/viam_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from viam.app.billing_client import BillingClient
from viam.app.data_client import DataClient
from viam.app.ml_training_client import MLTrainingClient
from viam.app.provisioning_client import ProvisioningClient
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,6 +150,27 @@ async def main():

return BillingClient(self._channel, self._metadata)

@property
def provisioning_client(self) -> ProvisioningClient:
"""Instantiate and return a `ProvisioningClient` used to make `provisioning` method calls.
To use the `ProvisioningClient`, you must first instantiate a `ViamClient`.

::

async def connect() -> ViamClient:
# Replace "<API-KEY>" (including brackets) with your API key and "<API-KEY-ID>" with your API key ID
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
return await ViamClient.create_from_dial_options(dial_options)


async def main():
viam_client = await connect()

# Instantiate a ProvisioningClient to run provisioning API methods on
provisioning_client = viam_client.provisioning_client
"""
return ProvisioningClient(self._channel, self._metadata)

def close(self):
"""Close opened channels used for the various service stubs initialized."""
if self._closed:
Expand Down
49 changes: 49 additions & 0 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@
FlatTensorDataUInt64,
FlatTensors,
)
from viam.proto.provisioning import (
NetworkInfo,
ProvisioningServiceBase,
GetNetworkListRequest,
GetNetworkListResponse,
GetSmartMachineStatusRequest,
GetSmartMachineStatusResponse,
SetNetworkCredentialsRequest,
SetNetworkCredentialsResponse,
SetSmartMachineCredentialsRequest,
SetSmartMachineCredentialsResponse,
)
from viam.proto.service.motion import (
Constraints,
GetPlanRequest,
Expand Down Expand Up @@ -698,6 +710,43 @@ async def do_command(self, command: Mapping[str, ValueTypes], *, timeout: Option
return {"command": command}


class MockProvisioning(ProvisioningServiceBase):
def __init__(
self,
smart_machine_status: GetSmartMachineStatusResponse,
network_info: List[NetworkInfo],
):
self.smart_machine_status = smart_machine_status
self.network_info = network_info

async def GetNetworkList(self, stream: Stream[GetNetworkListRequest, GetNetworkListResponse]) -> None:
request = await stream.recv_message()
assert request is not None
await stream.send_message(GetNetworkListResponse(networks=self.network_info))

async def GetSmartMachineStatus(self, stream: Stream[GetSmartMachineStatusRequest, GetSmartMachineStatusResponse]) -> None:
request = await stream.recv_message()
assert request is not None
await stream.send_message(self.smart_machine_status)

async def SetNetworkCredentials(self, stream: Stream[SetNetworkCredentialsRequest, SetNetworkCredentialsResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.network_type = request.type
self.ssid = request.ssid
self.psk = request.psk
await stream.send_message(SetNetworkCredentialsResponse())

async def SetSmartMachineCredentials(
self,
stream: Stream[SetSmartMachineCredentialsRequest, SetSmartMachineCredentialsResponse],
) -> None:
request = await stream.recv_message()
assert request is not None
self.cloud_config = request.cloud
await stream.send_message(SetSmartMachineCredentialsResponse())


class MockData(DataServiceBase):
def __init__(
self,
Expand Down
80 changes: 80 additions & 0 deletions tests/test_provisioning_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest

from grpclib.testing import ChannelFor

from viam.app.provisioning_client import ProvisioningClient

from viam.proto.provisioning import GetSmartMachineStatusResponse, NetworkInfo, ProvisioningInfo, CloudConfig

from .mocks.services import MockProvisioning

ID = "id"
MODEL = "model"
MANUFACTURER = "acme"
PROVISIONING_INFO = ProvisioningInfo(fragment_id=ID, model=MODEL, manufacturer=MANUFACTURER)
HAS_CREDENTIALS = True
IS_ONLINE = True
NETWORK_TYPE = "type"
SSID = "ssid"
ERROR = "error"
ERRORS = [ERROR]
PSK = "psk"
SECRET = "secret"
APP_ADDRESS = "address"
NETWORK_INFO_LATEST = NetworkInfo(
type=NETWORK_TYPE,
ssid=SSID,
security="security",
signal=12,
connected=IS_ONLINE,
last_error=ERROR,
)
NETWORK_INFO = [NETWORK_INFO_LATEST]
SMART_MACHINE_STATUS_RESPONSE = GetSmartMachineStatusResponse(
provisioning_info=PROVISIONING_INFO,
has_smart_machine_credentials=HAS_CREDENTIALS,
is_online=IS_ONLINE,
latest_connection_attempt=NETWORK_INFO_LATEST,
errors=ERRORS
)
CLOUD_CONFIG = CloudConfig(id=ID, secret=SECRET, app_address=APP_ADDRESS)

AUTH_TOKEN = "auth_token"
PROVISIONING_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"}


@pytest.fixture(scope="function")
def service() -> MockProvisioning:
return MockProvisioning(smart_machine_status=SMART_MACHINE_STATUS_RESPONSE, network_info=NETWORK_INFO)


class TestClient:
@pytest.mark.asyncio
async def test_get_network_list(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
network_info = await client.get_network_list()
assert network_info == NETWORK_INFO

@pytest.mark.asyncio
async def test_get_smart_machine_status(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
smart_machine_status = await client.get_smart_machine_status()
assert smart_machine_status == SMART_MACHINE_STATUS_RESPONSE

@pytest.mark.asyncio
async def test_set_network_credentials(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
await client.set_network_credentials(network_type=NETWORK_TYPE, ssid=SSID, psk=PSK)
assert service.network_type == NETWORK_TYPE
assert service.ssid == SSID
assert service.psk == PSK

@pytest.mark.asyncio
async def test_set_smart_machine_credentials(self, service: MockProvisioning):
async with ChannelFor([service]) as channel:
client = ProvisioningClient(channel, PROVISIONING_SERVICE_METADATA)
await client.set_smart_machine_credentials(cloud_config=CLOUD_CONFIG)
assert service.cloud_config == CLOUD_CONFIG
Loading