Skip to content

Commit 97c5026

Browse files
committed
taprootassets: add generic Copy fn test
1 parent 65e2ee1 commit 97c5026

File tree

2 files changed

+356
-1
lines changed

2 files changed

+356
-1
lines changed

copy_test.go

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
package taprootassets
2+
3+
import (
4+
"fmt"
5+
"math/rand"
6+
"reflect"
7+
"strings"
8+
"testing"
9+
10+
"github.com/davecgh/go-spew/spew"
11+
"github.com/lightninglabs/taproot-assets/fn"
12+
"github.com/lightninglabs/taproot-assets/tapfreighter"
13+
"github.com/pmezard/go-difflib/difflib"
14+
)
15+
16+
// FillFakeData recursively fills a struct with dummy values.
17+
func FillFakeData[T any](t *testing.T, debug bool, maxDepth int, v T) {
18+
if t != nil {
19+
t.Helper()
20+
}
21+
22+
val := reflect.ValueOf(v)
23+
name := val.Type().Elem().Name()
24+
fillFakeData(t, debug, 0, maxDepth, val, name)
25+
}
26+
27+
// fillFakeData is the recursive helper to fill a value with fake data.
28+
func fillFakeData(t *testing.T, debug bool, depth, maxDepth int,
29+
v reflect.Value, path string) {
30+
31+
if t != nil {
32+
t.Helper()
33+
}
34+
35+
if depth > maxDepth || !v.IsValid() {
36+
return
37+
}
38+
39+
indent := strings.Repeat(" ", depth)
40+
41+
log := func(format string, args ...any) {
42+
if debug {
43+
if t != nil {
44+
t.Logf(indent+format, args...)
45+
} else {
46+
fmt.Printf(indent+format+"\n", args...)
47+
}
48+
}
49+
}
50+
switch v.Kind() {
51+
case reflect.Ptr:
52+
if v.IsNil() {
53+
ptr := reflect.New(v.Type().Elem())
54+
v.Set(ptr)
55+
56+
log("ptr: %s (%s)", path, v.Type())
57+
}
58+
59+
fillFakeData(t, debug, depth+1, maxDepth, v.Elem(), path)
60+
61+
case reflect.Struct:
62+
typ := v.Type()
63+
for i := range v.NumField() {
64+
field := v.Field(i)
65+
fieldType := typ.Field(i)
66+
67+
if !field.CanSet() {
68+
continue
69+
}
70+
71+
fieldPath := fmt.Sprintf("%s.%s", path, fieldType.Name)
72+
fillFakeData(
73+
t, debug, depth+1, maxDepth, field, fieldPath,
74+
)
75+
}
76+
77+
case reflect.Slice:
78+
if v.Type().Elem().Kind() == reflect.Uint8 {
79+
// Special case: []byte.
80+
b := make([]byte, randomLen())
81+
for i := range b {
82+
b[i] = byte(rand.Intn(256))
83+
}
84+
85+
v.SetBytes(b)
86+
log("[]byte: %s = %v", path, b)
87+
88+
return
89+
}
90+
91+
elemType := v.Type().Elem()
92+
length := randomLen()
93+
slice := reflect.MakeSlice(v.Type(), length, length)
94+
95+
for i := range length {
96+
elemPath := fmt.Sprintf("%s[%d]", path, i)
97+
98+
var elem reflect.Value
99+
if elemType.Kind() == reflect.Ptr {
100+
elem = reflect.New(elemType.Elem())
101+
102+
fillFakeData(
103+
t, debug, depth+1, maxDepth,
104+
elem.Elem(), elemPath,
105+
)
106+
} else {
107+
elem = reflect.New(elemType).Elem()
108+
109+
fillFakeData(
110+
t, debug, depth+1, maxDepth, elem,
111+
elemPath,
112+
)
113+
}
114+
115+
slice.Index(i).Set(elem)
116+
}
117+
118+
v.Set(slice)
119+
log("slice: %s (len=%d)", path, length)
120+
121+
case reflect.Array:
122+
for i := range v.Len() {
123+
fillFakeData(
124+
t, debug, depth+1, maxDepth, v.Index(i),
125+
fmt.Sprintf("%s[%d]", path, i),
126+
)
127+
}
128+
129+
log("array: %s (len=%d)", path, v.Len())
130+
131+
case reflect.Map:
132+
keyType := v.Type().Key()
133+
valType := v.Type().Elem()
134+
m := reflect.MakeMap(v.Type())
135+
length := randomLen()
136+
137+
for i := range length {
138+
key := reflect.New(keyType).Elem()
139+
140+
fillFakeData(
141+
t, debug, depth+1, maxDepth, key,
142+
fmt.Sprintf("%s[key%d]", path, i),
143+
)
144+
145+
val := reflect.New(valType).Elem()
146+
147+
fillFakeData(
148+
t, debug, depth+1, maxDepth, val,
149+
fmt.Sprintf("%s[val%d]", path, i),
150+
)
151+
152+
m.SetMapIndex(key, val)
153+
}
154+
155+
v.Set(m)
156+
log("map: %s (len=%d)", path, length)
157+
158+
default:
159+
assignDummyPrimitive(t, debug, indent, v, path)
160+
}
161+
}
162+
163+
// assignDummyPrimitive assigns dummy values to primitive type values.
164+
func assignDummyPrimitive(t *testing.T, debug bool, indent string,
165+
v reflect.Value, path string) {
166+
167+
log := func(format string, args ...any) {
168+
if debug {
169+
if t != nil {
170+
t.Logf(indent+format, args...)
171+
} else {
172+
fmt.Printf(indent+format+"\n", args...)
173+
}
174+
}
175+
}
176+
177+
switch v.Kind() {
178+
case reflect.String:
179+
s := randomString()
180+
v.SetString(s)
181+
log("string: %s = %q", path, s)
182+
183+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
184+
reflect.Int64:
185+
186+
i := rand.Int63n(1_000_000)
187+
v.SetInt(i)
188+
log("int: %s = %d", path, i)
189+
190+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
191+
reflect.Uint64:
192+
193+
u := uint64(rand.Intn(1_000_000))
194+
v.SetUint(u)
195+
log("uint: %s = %d", path, u)
196+
197+
case reflect.Bool:
198+
b := rand.Intn(2) == 0
199+
v.SetBool(b)
200+
log("bool: %s = %v", path, b)
201+
202+
case reflect.Float32, reflect.Float64:
203+
f := rand.Float64() * 1_000
204+
v.SetFloat(f)
205+
log("float: %s = %f", path, f)
206+
207+
default:
208+
}
209+
}
210+
211+
func randomString() string {
212+
return fmt.Sprintf("val_%d", rand.Intn(100_000))
213+
}
214+
215+
func randomLen() int {
216+
return rand.Intn(3)
217+
}
218+
219+
// checkAliasing walks the fields and check for shared references.
220+
func checkAliasing(t *testing.T, debug, strict bool, f1, f2 reflect.Value,
221+
path string) {
222+
223+
t.Helper()
224+
225+
if !f1.IsValid() || !f2.IsValid() {
226+
return
227+
}
228+
229+
switch f1.Kind() {
230+
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func,
231+
reflect.Chan:
232+
233+
if f1.IsNil() || f2.IsNil() {
234+
return
235+
}
236+
237+
if f1.Pointer() == f2.Pointer() {
238+
msg := fmt.Sprintf("Aliasing detected at path: %s "+
239+
"(shared %s)", path, f1.Kind())
240+
241+
if strict {
242+
t.Fatalf(msg)
243+
}
244+
245+
if debug {
246+
t.Logf("WARNING %s", msg)
247+
}
248+
}
249+
250+
// Recurse into slice/map values.
251+
switch f1.Kind() {
252+
case reflect.Slice:
253+
for i := 0; i < f1.Len() && i < f2.Len(); i++ {
254+
checkAliasing(
255+
t, debug, strict,
256+
f1.Index(i), f2.Index(i),
257+
fmt.Sprintf("%s[%d]", path, i),
258+
)
259+
}
260+
case reflect.Map:
261+
for _, key := range f1.MapKeys() {
262+
v1 := f1.MapIndex(key)
263+
v2 := f2.MapIndex(key)
264+
checkAliasing(
265+
t, debug, strict,
266+
v1, v2, fmt.Sprintf("%s[%v]", path,
267+
key.Interface()),
268+
)
269+
}
270+
271+
default:
272+
}
273+
274+
case reflect.Struct:
275+
for i := range f1.NumField() {
276+
field := f1.Type().Field(i)
277+
278+
// Skip unexported fields.
279+
if !f1.Field(i).CanInterface() {
280+
continue
281+
}
282+
283+
childPath := fmt.Sprintf("%s.%s", path, field.Name)
284+
checkAliasing(
285+
t, debug, strict,
286+
f1.Field(i), f2.Field(i), childPath,
287+
)
288+
}
289+
290+
default:
291+
}
292+
}
293+
294+
// CopyFnTest checks that the Copy method returns a value that:
295+
// 1) is deeply equal
296+
// 2) does not alias mutable fields (pointers, slices, maps)
297+
func CopyFnTest[T fn.Copyable[T]](t *testing.T, debug, strict bool,
298+
original T) {
299+
300+
originalVal := reflect.ValueOf(original)
301+
copied := original.Copy()
302+
copiedVal := reflect.ValueOf(copied)
303+
304+
if !reflect.DeepEqual(original, copied) {
305+
diff := difflib.UnifiedDiff{
306+
A: difflib.SplitLines(
307+
spew.Sdump(original),
308+
),
309+
B: difflib.SplitLines(
310+
spew.Sdump(copied),
311+
),
312+
FromFile: "Original",
313+
FromDate: "",
314+
ToFile: "Copied",
315+
ToDate: "",
316+
Context: 3,
317+
}
318+
diffText, _ := difflib.GetUnifiedDiffString(diff)
319+
320+
t.Fatalf("Copied value is not deeply equal to the orginal:\n%v",
321+
diffText)
322+
}
323+
324+
if originalVal.Kind() == reflect.Ptr {
325+
originalVal = originalVal.Elem()
326+
copiedVal = copiedVal.Elem()
327+
}
328+
329+
for i := range originalVal.NumField() {
330+
f1 := originalVal.Field(i)
331+
f2 := copiedVal.Field(i)
332+
name := originalVal.Type().Field(i).Name
333+
334+
checkAliasing(t, debug, strict, f1, f2, name)
335+
}
336+
}
337+
338+
// TestCopy tests known Copy() functions.
339+
func TestCopy(t *testing.T) {
340+
// Set to true to debug print.
341+
debug := false
342+
343+
// Please set depth values carefully. Sometimes our copy functions are
344+
// deeply nested in other packages and do not need changes. Sometimes
345+
// types are recursive and too deep copy may end up in stack-overlow.
346+
t.Run("tapfreighter.OutboundParcel", func(t *testing.T) {
347+
const maxDepth = 5
348+
p := &tapfreighter.OutboundParcel{}
349+
FillFakeData(t, debug, maxDepth, p)
350+
351+
// We allow aliasing here deep down (for now).
352+
strict := false
353+
CopyFnTest(t, debug, strict, p)
354+
})
355+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ require (
3737
github.com/lightningnetwork/lnd/tlv v1.3.1
3838
github.com/lightningnetwork/lnd/tor v1.1.6
3939
github.com/ory/dockertest/v3 v3.10.0
40+
github.com/pmezard/go-difflib v1.0.0
4041
github.com/prometheus/client_golang v1.14.0
4142
github.com/stretchr/testify v1.10.0
4243
github.com/urfave/cli v1.22.14
@@ -148,7 +149,6 @@ require (
148149
github.com/opencontainers/image-spec v1.1.0 // indirect
149150
github.com/opencontainers/runc v1.2.0 // indirect
150151
github.com/pkg/errors v0.9.1 // indirect
151-
github.com/pmezard/go-difflib v1.0.0 // indirect
152152
github.com/prometheus/client_model v0.3.0 // indirect
153153
github.com/prometheus/common v0.37.0 // indirect
154154
github.com/prometheus/procfs v0.8.0 // indirect

0 commit comments

Comments
 (0)