Skip to content

Commit 9c749a9

Browse files
cache credential exchange by thread ID & initiator
Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
1 parent b1c0838 commit 9c749a9

File tree

5 files changed

+69
-63
lines changed

5 files changed

+69
-63
lines changed

aries_cloudagent/holder/tests/test_indy.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pytest
1414

1515
from ...messaging.issue_credential.v1_0.messages.inner.credential_preview import (
16-
CredentialPreview
16+
CredentialPreview,
1717
)
1818

1919

@@ -73,7 +73,7 @@ async def test_get_credential_attrs_mime_types(self, mock_nonsec_get_wallet_reco
7373
"type": IndyHolder.RECORD_TYPE_MIME_TYPES,
7474
"id": cred_id,
7575
"value": "value",
76-
"tags": dummy_tags
76+
"tags": dummy_tags,
7777
}
7878
mock_nonsec_get_wallet_record.return_value = json.dumps(dummy_rec)
7979

@@ -88,12 +88,8 @@ async def test_get_credential_attrs_mime_types(self, mock_nonsec_get_wallet_reco
8888
dummy_rec["type"],
8989
f"{IndyHolder.RECORD_TYPE_MIME_TYPES}::{dummy_rec['id']}",
9090
json.dumps(
91-
{
92-
"retrieveType": True,
93-
"retrieveValue": True,
94-
"retrieveTags": True
95-
}
96-
)
91+
{"retrieveType": False, "retrieveValue": True, "retrieveTags": True}
92+
),
9793
)
9894

9995
assert mime_types == dummy_tags
@@ -106,7 +102,7 @@ async def test_get_credential_attr_mime_type(self, mock_nonsec_get_wallet_record
106102
"type": IndyHolder.RECORD_TYPE_MIME_TYPES,
107103
"id": cred_id,
108104
"value": "value",
109-
"tags": dummy_tags
105+
"tags": dummy_tags,
110106
}
111107
mock_nonsec_get_wallet_record.return_value = json.dumps(dummy_rec)
112108

@@ -121,12 +117,8 @@ async def test_get_credential_attr_mime_type(self, mock_nonsec_get_wallet_record
121117
dummy_rec["type"],
122118
f"{IndyHolder.RECORD_TYPE_MIME_TYPES}::{dummy_rec['id']}",
123119
json.dumps(
124-
{
125-
"retrieveType": True,
126-
"retrieveValue": True,
127-
"retrieveTags": True
128-
}
129-
)
120+
{"retrieveType": False, "retrieveValue": True, "retrieveTags": True}
121+
),
130122
)
131123

132124
assert a_mime_type == dummy_tags["a"]
@@ -233,7 +225,7 @@ async def test_delete_credential(
233225
self,
234226
mock_nonsec_del_wallet_record,
235227
mock_nonsec_get_wallet_record,
236-
mock_prover_del_cred
228+
mock_prover_del_cred,
237229
):
238230
mock_wallet = async_mock.MagicMock()
239231
holder = IndyHolder(mock_wallet)
@@ -242,18 +234,14 @@ async def test_delete_credential(
242234
"type": "typ",
243235
"id": "ident",
244236
"value": "value",
245-
"tags": {
246-
"a": json.dumps("1"),
247-
"b": json.dumps("2")
248-
}
237+
"tags": {"a": json.dumps("1"), "b": json.dumps("2")},
249238
}
250239
)
251240

252241
credential = await holder.delete_credential("credential_id")
253242

254243
mock_prover_del_cred.assert_called_once_with(
255-
mock_wallet.handle,
256-
"credential_id"
244+
mock_wallet.handle, "credential_id"
257245
)
258246

259247
@async_mock.patch("indy.anoncreds.prover_create_proof")

aries_cloudagent/messaging/credentials/manager.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,10 @@ async def receive_request(self, credential_request_message: CredentialRequest):
340340

341341
credential_request = json.loads(credential_request_message.request)
342342

343-
credential_exchange_record = await CredentialExchange.retrieve_by_tag_filter(
344-
self.context,
345-
{
346-
"thread_id": credential_request_message._thread_id,
347-
},
348-
{
349-
"initiator": "self",
350-
},
343+
(
344+
credential_exchange_record
345+
) = await CredentialExchange.retrieve_by_thread_and_initiator(
346+
self.context, credential_request_message._thread_id, "self"
351347
)
352348
credential_exchange_record.credential_request = credential_request
353349
credential_exchange_record.state = CredentialExchange.STATE_REQUEST_RECEIVED
@@ -443,14 +439,8 @@ async def receive_credential(self, credential_message: CredentialIssue):
443439
try:
444440
(
445441
credential_exchange_record
446-
) = await CredentialExchange.retrieve_by_tag_filter(
447-
self.context,
448-
{
449-
"thread_id": credential_message._thread_id,
450-
},
451-
{
452-
"initiator": "external",
453-
},
442+
) = await CredentialExchange.retrieve_by_thread_and_initiator(
443+
self.context, credential_message._thread_id, "external"
454444
)
455445
except StorageNotFoundError:
456446

@@ -464,14 +454,8 @@ async def receive_credential(self, credential_message: CredentialIssue):
464454
# credential_request_metadata
465455
(
466456
credential_exchange_record
467-
) = await CredentialExchange.retrieve_by_tag_filter(
468-
self.context,
469-
{
470-
"thread_id": credential_message._thread.pthid,
471-
},
472-
{
473-
"initiator": "external",
474-
},
457+
) = await CredentialExchange.retrieve_by_thread_and_initiator(
458+
self.context, credential_message._thread.pthid, "external"
475459
)
476460

477461
# Copy values from parent but create new record on save (no id)
@@ -587,14 +571,10 @@ async def credential_stored(self, credential_stored_message: CredentialStored):
587571
"""
588572

589573
# Get current exchange record by thread id
590-
credential_exchange_record = await CredentialExchange.retrieve_by_tag_filter(
591-
self.context,
592-
{
593-
"thread_id": credential_stored_message._thread_id,
594-
},
595-
{
596-
"initiator": "self",
597-
},
574+
(
575+
credential_exchange_record
576+
) = await CredentialExchange.retrieve_by_thread_and_initiator(
577+
self.context, credential_stored_message._thread_id, "self"
598578
)
599579

600580
credential_exchange_record.state = CredentialExchange.STATE_STORED

aries_cloudagent/messaging/credentials/models/credential_exchange.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from marshmallow import fields
44

5+
from ....config.injection_context import InjectionContext
6+
57
from ...models.base_record import BaseRecord, BaseRecordSchema
68

79

@@ -100,6 +102,44 @@ def record_value(self) -> dict:
100102
)
101103
}
102104

105+
@classmethod
106+
async def retrieve_by_thread_and_initiator(
107+
cls, context: InjectionContext, thread_id: str, initiator: str
108+
) -> "CredentialExchange":
109+
"""Retrieve a credential exchange record by thread ID and inititator."""
110+
cache_key = f"credential_exchange_tidx::{thread_id}::{initiator}"
111+
record_id = await cls.get_cached_key(context, cache_key)
112+
if record_id:
113+
record = await cls.retrieve_by_id(context, record_id)
114+
else:
115+
record = await cls.retrieve_by_tag_filter(
116+
context, {"thread_id": thread_id}, {"initiator": initiator}
117+
)
118+
await cls.set_cached_key(context, cache_key, record.credential_exchange_id)
119+
return record
120+
121+
async def post_save(
122+
self,
123+
context: InjectionContext,
124+
new_record: bool,
125+
last_state: str,
126+
webhook: bool = None,
127+
):
128+
"""Perform post-save actions.
129+
130+
Args:
131+
context: The injection context to use
132+
new_record: Flag indicating if the record was just created
133+
last_state: The previous state value
134+
webhook: Adjust whether the webhook is called
135+
"""
136+
await super(CredentialExchange, self).post_save(
137+
context, new_record, last_state, webhook
138+
)
139+
if self.thread_id and self.initiator:
140+
cache_key = f"credential_exchange_tidx::{self.thread_id}::{self.initiator}"
141+
await self.set_cached_key(context, cache_key, self.credential_exchange_id)
142+
103143

104144
class CredentialExchangeSchema(BaseRecordSchema):
105145
"""Schema to allow serialization/deserialization of credential exchange records."""

aries_cloudagent/messaging/models/base_record.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,14 @@ async def retrieve_by_tag_filter(
222222
post_filter: Additional value filters to apply after retrieval
223223
"""
224224
storage: BaseStorage = await context.inject(BaseStorage)
225-
result = storage.search_records(
225+
query = storage.search_records(
226226
cls.RECORD_TYPE,
227227
cls.prefix_tag_filter(tag_filter),
228228
None,
229229
{"retrieveTags": False},
230230
)
231231
found = None
232-
async for record in result:
232+
async for record in query:
233233
vals = json.loads(record.value)
234234
if not post_filter or match_post_filter(vals, post_filter):
235235
if found:
@@ -254,14 +254,14 @@ async def query(
254254
post_filter: Additional value filters to apply
255255
"""
256256
storage: BaseStorage = await context.inject(BaseStorage)
257-
result = storage.search_records(
257+
query = storage.search_records(
258258
cls.RECORD_TYPE,
259259
cls.prefix_tag_filter(tag_filter),
260260
None,
261261
{"retrieveTags": False},
262262
)
263263
result = []
264-
async for record in result:
264+
async for record in query:
265265
vals = json.loads(record.value)
266266
if not post_filter or match_post_filter(vals, post_filter):
267267
result.append(cls.from_storage(record.id, vals))

aries_cloudagent/messaging/models/tests/test_base_record.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def test_retrieve_uncached_id(self):
126126
mock_storage.get_record.return_value = stored
127127
result = await BaseRecordImpl.retrieve_by_id(context, record_id, False)
128128
mock_storage.get_record.assert_awaited_once_with(
129-
BaseRecordImpl.RECORD_TYPE, record_id
129+
BaseRecordImpl.RECORD_TYPE, record_id, {"retrieveTags": False}
130130
)
131131
set_cached_key.assert_awaited_once_with(
132132
context, cache_key.return_value, record_value
@@ -146,12 +146,10 @@ async def test_query(self):
146146
BaseRecordImpl.RECORD_TYPE, json.dumps(record_value), {}, record_id
147147
)
148148

149-
mock_storage.search_records.return_value.fetch_all = async_mock.CoroutineMock(
150-
return_value=[stored]
151-
)
149+
mock_storage.search_records.return_value.__aiter__.return_value = [stored]
152150
result = await BaseRecordImpl.query(context, tag_filter)
153151
mock_storage.search_records.assert_called_once_with(
154-
BaseRecordImpl.RECORD_TYPE, tag_filter
152+
BaseRecordImpl.RECORD_TYPE, tag_filter, None, {"retrieveTags": False}
155153
)
156154
assert result and isinstance(result[0], BaseRecordImpl)
157155
assert result[0]._id == record_id

0 commit comments

Comments
 (0)