Skip to content

Commit 3b02cbd

Browse files
committed
taprootassets: add generic Copy fn test
1 parent 65e2ee1 commit 3b02cbd

File tree

2 files changed

+350
-1
lines changed

2 files changed

+350
-1
lines changed

copy_test.go

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
package taprootassets
2+
3+
import (
4+
"fmt"
5+
"math/rand"
6+
"reflect"
7+
"strings"
8+
"testing"
9+
10+
"github.com/davecgh/go-spew/spew"
11+
"github.com/lightninglabs/taproot-assets/fn"
12+
"github.com/lightninglabs/taproot-assets/tapfreighter"
13+
"github.com/pmezard/go-difflib/difflib"
14+
)
15+
16+
// FillFakeData recursively fills a struct with dummy values.
17+
func FillFakeData[T any](t *testing.T, debug bool, maxDepth int, v T) {
18+
if t != nil {
19+
t.Helper()
20+
}
21+
22+
val := reflect.ValueOf(v)
23+
name := val.Type().Elem().Name()
24+
fillFakeData(t, debug, 0, maxDepth, val, name)
25+
}
26+
27+
// fillFakeData is the recursive helper to fill a value with fake data.
28+
func fillFakeData(t *testing.T, debug bool, depth, maxDepth int,
29+
v reflect.Value, path string) {
30+
31+
if t != nil {
32+
t.Helper()
33+
}
34+
35+
if depth > maxDepth || !v.IsValid() {
36+
return
37+
}
38+
39+
indent := strings.Repeat(" ", depth)
40+
41+
log := func(format string, args ...any) {
42+
if debug {
43+
if t != nil {
44+
t.Logf(indent+format, args...)
45+
} else {
46+
fmt.Printf(indent+format+"\n", args...)
47+
}
48+
}
49+
}
50+
switch v.Kind() {
51+
case reflect.Ptr:
52+
if v.IsNil() {
53+
ptr := reflect.New(v.Type().Elem())
54+
v.Set(ptr)
55+
56+
log("ptr: %s (%s)", path, v.Type())
57+
}
58+
59+
fillFakeData(t, debug, depth+1, maxDepth, v.Elem(), path)
60+
61+
case reflect.Struct:
62+
typ := v.Type()
63+
for i := range v.NumField() {
64+
field := v.Field(i)
65+
fieldType := typ.Field(i)
66+
67+
if !field.CanSet() {
68+
continue
69+
}
70+
71+
fieldPath := fmt.Sprintf("%s.%s", path, fieldType.Name)
72+
fillFakeData(
73+
t, debug, depth+1, maxDepth, field, fieldPath,
74+
)
75+
}
76+
77+
case reflect.Slice:
78+
if v.Type().Elem().Kind() == reflect.Uint8 {
79+
// Special case: []byte.
80+
b := make([]byte, randomLen())
81+
for i := range b {
82+
b[i] = byte(rand.Intn(256))
83+
}
84+
85+
v.SetBytes(b)
86+
log("[]byte: %s = %v", path, b)
87+
88+
return
89+
}
90+
91+
elemType := v.Type().Elem()
92+
length := randomLen()
93+
slice := reflect.MakeSlice(v.Type(), length, length)
94+
95+
for i := range length {
96+
elemPath := fmt.Sprintf("%s[%d]", path, i)
97+
98+
var elem reflect.Value
99+
if elemType.Kind() == reflect.Ptr {
100+
elem = reflect.New(elemType.Elem())
101+
102+
fillFakeData(
103+
t, debug, depth+1, maxDepth,
104+
elem.Elem(), elemPath,
105+
)
106+
} else {
107+
elem = reflect.New(elemType).Elem()
108+
109+
fillFakeData(
110+
t, debug, depth+1, maxDepth, elem,
111+
elemPath,
112+
)
113+
}
114+
115+
slice.Index(i).Set(elem)
116+
}
117+
118+
v.Set(slice)
119+
log("slice: %s (len=%d)", path, length)
120+
121+
case reflect.Array:
122+
for i := range v.Len() {
123+
fillFakeData(
124+
t, debug, depth+1, maxDepth, v.Index(i),
125+
fmt.Sprintf("%s[%d]", path, i),
126+
)
127+
}
128+
129+
log("array: %s (len=%d)", path, v.Len())
130+
131+
case reflect.Map:
132+
keyType := v.Type().Key()
133+
valType := v.Type().Elem()
134+
m := reflect.MakeMap(v.Type())
135+
length := randomLen()
136+
137+
for i := range length {
138+
key := reflect.New(keyType).Elem()
139+
140+
fillFakeData(
141+
t, debug, depth+1, maxDepth, key,
142+
fmt.Sprintf("%s[key%d]", path, i),
143+
)
144+
145+
val := reflect.New(valType).Elem()
146+
147+
fillFakeData(
148+
t, debug, depth+1, maxDepth, val,
149+
fmt.Sprintf("%s[val%d]", path, i),
150+
)
151+
152+
m.SetMapIndex(key, val)
153+
}
154+
155+
v.Set(m)
156+
log("map: %s (len=%d)", path, length)
157+
158+
default:
159+
assignDummyPrimitive(t, debug, indent, v, path)
160+
}
161+
}
162+
163+
// assignDummyPrimitive assigns dummy values to primitive type values.
164+
func assignDummyPrimitive(t *testing.T, debug bool, indent string,
165+
v reflect.Value, path string) {
166+
167+
log := func(format string, args ...any) {
168+
if debug {
169+
if t != nil {
170+
t.Logf(indent+format, args...)
171+
} else {
172+
fmt.Printf(indent+format+"\n", args...)
173+
}
174+
}
175+
}
176+
177+
switch v.Kind() {
178+
case reflect.String:
179+
s := randomString()
180+
v.SetString(s)
181+
log("string: %s = %q", path, s)
182+
183+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
184+
reflect.Int64:
185+
186+
i := rand.Int63n(1_000_000)
187+
v.SetInt(i)
188+
log("int: %s = %d", path, i)
189+
190+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
191+
reflect.Uint64:
192+
193+
u := uint64(rand.Intn(1_000_000))
194+
v.SetUint(u)
195+
log("uint: %s = %d", path, u)
196+
197+
case reflect.Bool:
198+
b := rand.Intn(2) == 0
199+
v.SetBool(b)
200+
log("bool: %s = %v", path, b)
201+
202+
case reflect.Float32, reflect.Float64:
203+
f := rand.Float64() * 1_000
204+
v.SetFloat(f)
205+
log("float: %s = %f", path, f)
206+
}
207+
}
208+
209+
func randomString() string {
210+
return fmt.Sprintf("val_%d", rand.Intn(100_000))
211+
}
212+
213+
func randomLen() int {
214+
return rand.Intn(3)
215+
}
216+
217+
// checkAliasing walks the fields and check for shared references.
218+
func checkAliasing(t *testing.T, debug, strict bool, f1, f2 reflect.Value,
219+
path string) {
220+
221+
t.Helper()
222+
223+
if !f1.IsValid() || !f2.IsValid() {
224+
return
225+
}
226+
227+
switch f1.Kind() {
228+
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Func,
229+
reflect.Chan:
230+
231+
if f1.IsNil() || f2.IsNil() {
232+
return
233+
}
234+
235+
if f1.Pointer() == f2.Pointer() {
236+
msg := fmt.Sprintf("Aliasing detected at path: %s "+
237+
"(shared %s)", path, f1.Kind())
238+
239+
if strict {
240+
t.Fatalf(msg)
241+
}
242+
243+
if debug {
244+
t.Logf("WARNING %s", msg)
245+
}
246+
}
247+
248+
// Recurse into slice/map values
249+
switch f1.Kind() {
250+
case reflect.Slice:
251+
for i := 0; i < f1.Len() && i < f2.Len(); i++ {
252+
checkAliasing(
253+
t, debug, strict,
254+
f1.Index(i), f2.Index(i),
255+
fmt.Sprintf("%s[%d]", path, i),
256+
)
257+
}
258+
case reflect.Map:
259+
for _, key := range f1.MapKeys() {
260+
v1 := f1.MapIndex(key)
261+
v2 := f2.MapIndex(key)
262+
checkAliasing(
263+
t, debug, strict,
264+
v1, v2, fmt.Sprintf("%s[%v]", path,
265+
key.Interface()),
266+
)
267+
}
268+
}
269+
270+
case reflect.Struct:
271+
for i := range f1.NumField() {
272+
field := f1.Type().Field(i)
273+
274+
// Skip unexported fields.
275+
if !f1.Field(i).CanInterface() {
276+
continue
277+
}
278+
279+
childPath := fmt.Sprintf("%s.%s", path, field.Name)
280+
checkAliasing(
281+
t, debug, strict,
282+
f1.Field(i), f2.Field(i), childPath,
283+
)
284+
}
285+
}
286+
}
287+
288+
// CopyFnTest checks that the Copy method returns a value that:
289+
// 1) is deeply equal
290+
// 2) does not alias mutable fields (pointers, slices, maps)
291+
func CopyFnTest[T fn.Copyable[T]](t *testing.T, debug, strict bool,
292+
original T) {
293+
294+
originalVal := reflect.ValueOf(original)
295+
copied := original.Copy()
296+
copiedVal := reflect.ValueOf(copied)
297+
298+
if !reflect.DeepEqual(original, copied) {
299+
diff := difflib.UnifiedDiff{
300+
A: difflib.SplitLines(
301+
spew.Sdump(original),
302+
),
303+
B: difflib.SplitLines(
304+
spew.Sdump(copied),
305+
),
306+
FromFile: "Original",
307+
FromDate: "",
308+
ToFile: "Copied",
309+
ToDate: "",
310+
Context: 3,
311+
}
312+
diffText, _ := difflib.GetUnifiedDiffString(diff)
313+
314+
t.Fatalf("Copied value is not deeply equal to the orginal:\n%v",
315+
diffText)
316+
}
317+
318+
if originalVal.Kind() == reflect.Ptr {
319+
originalVal = originalVal.Elem()
320+
copiedVal = copiedVal.Elem()
321+
}
322+
323+
for i := range originalVal.NumField() {
324+
f1 := originalVal.Field(i)
325+
f2 := copiedVal.Field(i)
326+
name := originalVal.Type().Field(i).Name
327+
328+
checkAliasing(t, debug, strict, f1, f2, name)
329+
}
330+
}
331+
332+
// TestCopy tests known Copy() functions.
333+
func TestCopy(t *testing.T) {
334+
// Set to true to debug print.
335+
debug := false
336+
337+
// Please set depth values carefully. Sometimes our copy functions are
338+
// deeply nested in other packages and do not need changes. Sometimes
339+
// types are recursive and too deep copy may end up in stack-overlow.
340+
t.Run("tapfreighter.OutboundParcel", func(t *testing.T) {
341+
const maxDepth = 5
342+
p := &tapfreighter.OutboundParcel{}
343+
FillFakeData(t, debug, maxDepth, p)
344+
345+
// We allow aliasing here deep down (for now).
346+
strict := false
347+
CopyFnTest(t, debug, strict, p)
348+
})
349+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ require (
3737
github.com/lightningnetwork/lnd/tlv v1.3.1
3838
github.com/lightningnetwork/lnd/tor v1.1.6
3939
github.com/ory/dockertest/v3 v3.10.0
40+
github.com/pmezard/go-difflib v1.0.0
4041
github.com/prometheus/client_golang v1.14.0
4142
github.com/stretchr/testify v1.10.0
4243
github.com/urfave/cli v1.22.14
@@ -148,7 +149,6 @@ require (
148149
github.com/opencontainers/image-spec v1.1.0 // indirect
149150
github.com/opencontainers/runc v1.2.0 // indirect
150151
github.com/pkg/errors v0.9.1 // indirect
151-
github.com/pmezard/go-difflib v1.0.0 // indirect
152152
github.com/prometheus/client_model v0.3.0 // indirect
153153
github.com/prometheus/common v0.37.0 // indirect
154154
github.com/prometheus/procfs v0.8.0 // indirect

0 commit comments

Comments
 (0)