Skip to content

Commit 70f4cba

Browse files
committed
routing+migration: make payment failure msg optional
1 parent f6f81be commit 70f4cba

File tree

3 files changed

+60
-15
lines changed

3 files changed

+60
-15
lines changed

channeldb/migration32/mission_control_store.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,14 @@ func newPaymentFailure(sourceIdx *int,
163163
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
164164
uint8(*sourceIdx),
165165
),
166-
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
166+
}
167+
168+
if failureMsg != nil {
169+
info.msg = tlv.SomeRecordT(
170+
tlv.NewRecordT[tlv.TlvType1](
171+
failureMessage{failureMsg},
172+
),
173+
)
167174
}
168175

169176
return &paymentFailure{
@@ -237,7 +244,7 @@ func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
237244
// paymentFailureInfo holds additional information about a payment failure.
238245
type paymentFailureInfo struct {
239246
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
240-
msg tlv.RecordT[tlv.TlvType1, failureMessage]
247+
msg tlv.OptionalRecordT[tlv.TlvType1, failureMessage]
241248
}
242249

243250
// Record returns a TLV record that can be used to encode/decode a
@@ -263,10 +270,17 @@ func (r *paymentFailureInfo) Record() tlv.Record {
263270

264271
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
265272
if v, ok := val.(*paymentFailureInfo); ok {
273+
recordProducers := []tlv.RecordProducer{
274+
&v.sourceIdx,
275+
}
276+
v.msg.WhenSome(
277+
func(r tlv.RecordT[tlv.TlvType1, failureMessage]) {
278+
recordProducers = append(recordProducers, &r)
279+
},
280+
)
281+
266282
return lnwire.EncodeRecordsTo(
267-
w, lnwire.ProduceRecordsSorted(
268-
&v.sourceIdx, &v.msg,
269-
),
283+
w, lnwire.ProduceRecordsSorted(recordProducers...),
270284
)
271285
}
272286

@@ -279,14 +293,19 @@ func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
279293
if v, ok := val.(*paymentFailureInfo); ok {
280294
var h paymentFailureInfo
281295

282-
_, err := lnwire.DecodeRecords(
296+
msg := tlv.ZeroRecordT[tlv.TlvType1, failureMessage]()
297+
typeMap, err := lnwire.DecodeRecords(
283298
r,
284-
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
299+
lnwire.ProduceRecordsSorted(&h.sourceIdx, &msg)...,
285300
)
286301
if err != nil {
287302
return err
288303
}
289304

305+
if _, ok := typeMap[h.msg.TlvType()]; ok {
306+
h.msg = tlv.SomeRecordT(msg)
307+
}
308+
290309
*v = h
291310

292311
return nil

routing/missioncontrol.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,14 @@ func newPaymentFailure(sourceIdx *int,
848848
sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0](
849849
uint8(*sourceIdx),
850850
),
851-
msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}),
851+
}
852+
853+
if failureMsg != nil {
854+
info.msg = tlv.SomeRecordT(
855+
tlv.NewRecordT[tlv.TlvType1](
856+
failureMessage{failureMsg},
857+
),
858+
)
852859
}
853860

854861
return &paymentFailure{
@@ -922,7 +929,7 @@ func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte,
922929
// paymentFailureInfo holds additional information about a payment failure.
923930
type paymentFailureInfo struct {
924931
sourceIdx tlv.RecordT[tlv.TlvType0, uint8]
925-
msg tlv.RecordT[tlv.TlvType1, failureMessage]
932+
msg tlv.OptionalRecordT[tlv.TlvType1, failureMessage]
926933
}
927934

928935
// Record returns a TLV record that can be used to encode/decode a
@@ -948,10 +955,17 @@ func (r *paymentFailureInfo) Record() tlv.Record {
948955

949956
func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error {
950957
if v, ok := val.(*paymentFailureInfo); ok {
958+
recordProducers := []tlv.RecordProducer{
959+
&v.sourceIdx,
960+
}
961+
v.msg.WhenSome(
962+
func(r tlv.RecordT[tlv.TlvType1, failureMessage]) {
963+
recordProducers = append(recordProducers, &r)
964+
},
965+
)
966+
951967
return lnwire.EncodeRecordsTo(
952-
w, lnwire.ProduceRecordsSorted(
953-
&v.sourceIdx, &v.msg,
954-
),
968+
w, lnwire.ProduceRecordsSorted(recordProducers...),
955969
)
956970
}
957971

@@ -964,14 +978,19 @@ func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte,
964978
if v, ok := val.(*paymentFailureInfo); ok {
965979
var h paymentFailureInfo
966980

967-
_, err := lnwire.DecodeRecords(
981+
msg := tlv.ZeroRecordT[tlv.TlvType1, failureMessage]()
982+
typeMap, err := lnwire.DecodeRecords(
968983
r,
969-
lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)...,
984+
lnwire.ProduceRecordsSorted(&h.sourceIdx, &msg)...,
970985
)
971986
if err != nil {
972987
return err
973988
}
974989

990+
if _, ok := typeMap[h.msg.TlvType()]; ok {
991+
h.msg = tlv.SomeRecordT(msg)
992+
}
993+
975994
*v = h
976995

977996
return nil

routing/result_interpretation.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,14 @@ func (i *interpretedResult) processFail(rt *mcRoute, failure paymentFailure) {
139139
failure.info.WhenSome(
140140
func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) {
141141
idx = int(r.Val.sourceIdx.Val)
142-
failMsg = r.Val.msg.Val.FailureMessage
142+
143+
r.Val.msg.WhenSome(
144+
func(msg tlv.RecordT[tlv.TlvType1,
145+
failureMessage]) {
146+
147+
failMsg = msg.Val.FailureMessage
148+
},
149+
)
143150
},
144151
)
145152

0 commit comments

Comments
 (0)