diff --git a/backend/api/tests/views/test_views_info.py b/backend/api/tests/views/test_views_info.py index 3f74f9be7..58b7efeef 100644 --- a/backend/api/tests/views/test_views_info.py +++ b/backend/api/tests/views/test_views_info.py @@ -7,6 +7,7 @@ from api.tests.common import AuthenticatedClient from orchestrator.client import OrchestratorClient +from orchestrator.resources import OrchestratorVersion @override_settings(LEDGER_CHANNELS={"mychannel": {"chaincode": {"name": "mycc"}, "model_export_enabled": True}}) @@ -32,7 +33,7 @@ def test_authenticated(self): client = AuthenticatedClient() with mock.patch.object( - OrchestratorClient, "query_version", return_value={"orchestrator": "foo", "chaincode": "bar"} + OrchestratorClient, "query_version", return_value=OrchestratorVersion(server="foo", chaincode="bar") ): response = client.get(self.url, **self.extra) diff --git a/backend/backend/views.py b/backend/backend/views.py index cb4d4dc92..db42630e8 100644 --- a/backend/backend/views.py +++ b/backend/backend/views.py @@ -52,7 +52,7 @@ def get(self, request, *args, **kwargs): channel_name = get_channel_name(request) channel = settings.LEDGER_CHANNELS[channel_name] - orchestrator_versions = {} + orchestrator_versions = None if not settings.ISOLATED: with get_orchestrator_client(channel_name) as client: @@ -60,15 +60,15 @@ def get(self, request, *args, **kwargs): res["channel"] = channel_name res["version"] = settings.BACKEND_VERSION - res["orchestrator_version"] = orchestrator_versions.get("orchestrator") + res["orchestrator_version"] = orchestrator_versions.server if orchestrator_versions is not None else None res["config"]["model_export_enabled"] = channel["model_export_enabled"] res["user"] = request.user.get_username() if hasattr(request.user, "channel"): res["user_role"] = request.user.channel.role - if orchestrator_versions.get("chaincode"): - res["chaincode_version"] = orchestrator_versions.get("chaincode") + if orchestrator_versions and orchestrator_versions.chaincode: + res["chaincode_version"] = orchestrator_versions.chaincode return ApiResponse(res) diff --git a/backend/orchestrator/client.py b/backend/orchestrator/client.py index fa9497e45..78f28da6d 100644 --- a/backend/orchestrator/client.py +++ b/backend/orchestrator/client.py @@ -39,6 +39,7 @@ from orchestrator.resources import ComputePlanStatus from orchestrator.resources import ComputeTask from orchestrator.resources import ComputeTaskInputAsset +from orchestrator.resources import OrchestratorVersion logger = structlog.get_logger(__name__) @@ -530,15 +531,12 @@ def subscribe_to_events(self, channel_name=None, start_event_id=""): return (MessageToDict(event, **CONVERT_SETTINGS) for event in events_stream) - def query_version( - self, - ): + def query_version(self) -> OrchestratorVersion: data = self._info_client.QueryVersion( info_pb2.QueryVersionParam(), metadata=self._metadata, ) - data = MessageToDict(data, **CONVERT_SETTINGS) - return data + return OrchestratorVersion.from_grpc(data) def __enter__(self): return self diff --git a/backend/orchestrator/resources.py b/backend/orchestrator/resources.py index ecf1dfbcc..558b621d4 100644 --- a/backend/orchestrator/resources.py +++ b/backend/orchestrator/resources.py @@ -12,6 +12,7 @@ from orchestrator import computetask_pb2 from orchestrator import datamanager_pb2 from orchestrator import datasample_pb2 +from orchestrator import info_pb2 from orchestrator import model_pb2 @@ -344,3 +345,12 @@ class InvalidInputAsset(Exception): def __init__(self, actual: AssetKind, expected: AssetKind): message = f"Invalid asset kind, expected {expected} but have {actual}" super().__init__(message) + + +class OrchestratorVersion(pydantic.BaseModel): + server: str + chaincode: str + + @classmethod + def from_grpc(cls, orc_version: info_pb2.QueryVersionResponse) -> OrchestratorVersion: + return cls(server=orc_version.orchestrator, chaincode=orc_version.chaincode)