Skip to content

Commit dd1aab8

Browse files
committed
pyln.proto.message: expose array types, add set_field for Message class.
Exposing the array types is required for our dummyrunner in the lnprototest suite, since it wants to be able to generate fake fields. The set_field is similarly useful. Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
1 parent 2ead207 commit dd1aab8

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

contrib/pyln-proto/pyln/proto/message/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType
12
from .message import MessageNamespace, MessageType, Message, SubtypeType
23
from .fundamental_types import split_field, FieldType
34

@@ -10,6 +11,9 @@
1011
"SubtypeType",
1112
"FieldType",
1213
"split_field",
14+
"SizedArrayType",
15+
"DynamicArrayType",
16+
"EllipsisArrayType",
1317

1418
# fundamental_types
1519
'byte',

contrib/pyln-proto/pyln/proto/message/message.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -545,22 +545,24 @@ def __init__(self, messagetype: MessageType, **kwargs):
545545

546546
# Convert arguments from strings to values if necessary.
547547
for field in kwargs:
548-
f = self.messagetype.find_field(field)
549-
if f is None:
550-
raise ValueError("Unknown field {}".format(field))
551-
552-
v = kwargs[field]
553-
if isinstance(v, str):
554-
v, remainder = f.fieldtype.val_from_str(v)
555-
if remainder != '':
556-
raise ValueError('Unexpected {} at end of initializer for {}'.format(remainder, field))
557-
self.fields[field] = v
548+
self.set_field(field, kwargs[field])
558549

559550
bad_lens = self.messagetype.len_fields_bad(self.messagetype.name,
560551
self.fields)
561552
if bad_lens:
562553
raise ValueError("Inconsistent length fields: {}".format(bad_lens))
563554

555+
def set_field(self, field: str, val: Any) -> None:
556+
f = self.messagetype.find_field(field)
557+
if f is None:
558+
raise ValueError("Unknown field {}".format(field))
559+
560+
if isinstance(val, str):
561+
val, remainder = f.fieldtype.val_from_str(val)
562+
if remainder != '':
563+
raise ValueError('Unexpected {} at end of initializer for {}'.format(remainder, field))
564+
self.fields[field] = val
565+
564566
def missing_fields(self) -> List[str]:
565567
"""Are any required fields missing?"""
566568
missing: List[str] = []

0 commit comments

Comments
 (0)