Skip to content

Commit

Permalink
refactor: orchestrator client return info object (#454)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre-P <alexandre.picosson@owkin.com>
  • Loading branch information
AlexandrePicosson authored Sep 13, 2022
1 parent bf7c7bc commit e66a8d9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
3 changes: 2 additions & 1 deletion backend/api/tests/views/test_views_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from api.tests.common import AuthenticatedClient
from orchestrator.client import OrchestratorClient
from orchestrator.resources import OrchestratorVersion


@override_settings(LEDGER_CHANNELS={"mychannel": {"chaincode": {"name": "mycc"}, "model_export_enabled": True}})
Expand All @@ -32,7 +33,7 @@ def test_authenticated(self):
client = AuthenticatedClient()

with mock.patch.object(
OrchestratorClient, "query_version", return_value={"orchestrator": "foo", "chaincode": "bar"}
OrchestratorClient, "query_version", return_value=OrchestratorVersion(server="foo", chaincode="bar")
):
response = client.get(self.url, **self.extra)

Expand Down
8 changes: 4 additions & 4 deletions backend/backend/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,23 @@ def get(self, request, *args, **kwargs):
channel_name = get_channel_name(request)
channel = settings.LEDGER_CHANNELS[channel_name]

orchestrator_versions = {}
orchestrator_versions = None

if not settings.ISOLATED:
with get_orchestrator_client(channel_name) as client:
orchestrator_versions = client.query_version()

res["channel"] = channel_name
res["version"] = settings.BACKEND_VERSION
res["orchestrator_version"] = orchestrator_versions.get("orchestrator")
res["orchestrator_version"] = orchestrator_versions.server if orchestrator_versions is not None else None
res["config"]["model_export_enabled"] = channel["model_export_enabled"]

res["user"] = request.user.get_username()
if hasattr(request.user, "channel"):
res["user_role"] = request.user.channel.role

if orchestrator_versions.get("chaincode"):
res["chaincode_version"] = orchestrator_versions.get("chaincode")
if orchestrator_versions and orchestrator_versions.chaincode:
res["chaincode_version"] = orchestrator_versions.chaincode

return ApiResponse(res)

Expand Down
8 changes: 3 additions & 5 deletions backend/orchestrator/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from orchestrator.resources import ComputePlanStatus
from orchestrator.resources import ComputeTask
from orchestrator.resources import ComputeTaskInputAsset
from orchestrator.resources import OrchestratorVersion

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -530,15 +531,12 @@ def subscribe_to_events(self, channel_name=None, start_event_id=""):

return (MessageToDict(event, **CONVERT_SETTINGS) for event in events_stream)

def query_version(
self,
):
def query_version(self) -> OrchestratorVersion:
data = self._info_client.QueryVersion(
info_pb2.QueryVersionParam(),
metadata=self._metadata,
)
data = MessageToDict(data, **CONVERT_SETTINGS)
return data
return OrchestratorVersion.from_grpc(data)

def __enter__(self):
return self
Expand Down
10 changes: 10 additions & 0 deletions backend/orchestrator/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from orchestrator import computetask_pb2
from orchestrator import datamanager_pb2
from orchestrator import datasample_pb2
from orchestrator import info_pb2
from orchestrator import model_pb2


Expand Down Expand Up @@ -344,3 +345,12 @@ class InvalidInputAsset(Exception):
def __init__(self, actual: AssetKind, expected: AssetKind):
message = f"Invalid asset kind, expected {expected} but have {actual}"
super().__init__(message)


class OrchestratorVersion(pydantic.BaseModel):
server: str
chaincode: str

@classmethod
def from_grpc(cls, orc_version: info_pb2.QueryVersionResponse) -> OrchestratorVersion:
return cls(server=orc_version.orchestrator, chaincode=orc_version.chaincode)

0 comments on commit e66a8d9

Please sign in to comment.