|
1 | 1 | package flattener
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
4 | 5 | "fmt"
|
5 | 6 | "io"
|
6 | 7 |
|
7 | 8 | "github.com/onflow/flow-go/ledger"
|
8 | 9 | "github.com/onflow/flow-go/ledger/common/encoding"
|
| 10 | + "github.com/onflow/flow-go/ledger/common/hash" |
9 | 11 | "github.com/onflow/flow-go/ledger/common/utils"
|
10 | 12 | "github.com/onflow/flow-go/ledger/complete/mtrie/node"
|
11 | 13 | "github.com/onflow/flow-go/ledger/complete/mtrie/trie"
|
@@ -59,82 +61,127 @@ func EncodeNode(n *node.Node, lchildIndex uint64, rchildIndex uint64) []byte {
|
59 | 61 | return buf
|
60 | 62 | }
|
61 | 63 |
|
62 |
| -// ReadStorableNode reads a storable node from io |
63 |
| -func ReadStorableNode(reader io.Reader) (*StorableNode, error) { |
| 64 | +// ReadNode reconstructs a node from data read from reader. |
| 65 | +// TODO: reuse read buffer |
| 66 | +func ReadNode(reader io.Reader, getNode func(nodeIndex uint64) (*node.Node, error)) (*node.Node, error) { |
64 | 67 |
|
65 | 68 | // reading version
|
66 | 69 | buf := make([]byte, 2)
|
67 | 70 | read, err := io.ReadFull(reader, buf)
|
68 | 71 | if err != nil {
|
69 |
| - return nil, fmt.Errorf("error reading storable node, cannot read version part: %w", err) |
| 72 | + return nil, fmt.Errorf("failed to read serialized node, cannot read version part: %w", err) |
70 | 73 | }
|
71 | 74 | if read != len(buf) {
|
72 |
| - return nil, fmt.Errorf("not enough bytes read %d expected %d", read, len(buf)) |
| 75 | + return nil, fmt.Errorf("failed to read serialized node: not enough bytes read %d expected %d", read, len(buf)) |
73 | 76 | }
|
74 | 77 |
|
75 | 78 | version, _, err := utils.ReadUint16(buf)
|
76 | 79 | if err != nil {
|
77 |
| - return nil, fmt.Errorf("error reading storable node: %w", err) |
| 80 | + return nil, fmt.Errorf("failed to read serialized node: %w", err) |
78 | 81 | }
|
79 | 82 |
|
80 | 83 | if version > encodingDecodingVersion {
|
81 |
| - return nil, fmt.Errorf("error reading storable node: unsuported version %d > %d", version, encodingDecodingVersion) |
| 84 | + return nil, fmt.Errorf("failed to read serialized node: unsuported version %d > %d", version, encodingDecodingVersion) |
82 | 85 | }
|
83 | 86 |
|
84 | 87 | // reading fixed-length part
|
85 | 88 | buf = make([]byte, 2+8+8+2+8)
|
86 | 89 |
|
87 | 90 | read, err = io.ReadFull(reader, buf)
|
88 | 91 | if err != nil {
|
89 |
| - return nil, fmt.Errorf("error reading storable node, cannot read fixed-length part: %w", err) |
| 92 | + return nil, fmt.Errorf("failed to read serialized node, cannot read fixed-length part: %w", err) |
90 | 93 | }
|
91 | 94 | if read != len(buf) {
|
92 |
| - return nil, fmt.Errorf("not enough bytes read %d expected %d", read, len(buf)) |
| 95 | + return nil, fmt.Errorf("failed to read serialized node: not enough bytes read %d expected %d", read, len(buf)) |
93 | 96 | }
|
94 | 97 |
|
95 |
| - storableNode := &StorableNode{} |
| 98 | + var height, maxDepth uint16 |
| 99 | + var lchildIndex, rchildIndex, regCount uint64 |
| 100 | + var path, hashValue, encPayload []byte |
96 | 101 |
|
97 |
| - storableNode.Height, buf, err = utils.ReadUint16(buf) |
| 102 | + height, buf, err = utils.ReadUint16(buf) |
98 | 103 | if err != nil {
|
99 |
| - return nil, fmt.Errorf("error reading storable node: %w", err) |
| 104 | + return nil, fmt.Errorf("failed to read serialized node: %w", err) |
100 | 105 | }
|
101 | 106 |
|
102 |
| - storableNode.LIndex, buf, err = utils.ReadUint64(buf) |
| 107 | + lchildIndex, buf, err = utils.ReadUint64(buf) |
103 | 108 | if err != nil {
|
104 |
| - return nil, fmt.Errorf("error reading storable node: %w", err) |
| 109 | + return nil, fmt.Errorf("failed to read serialized node: %w", err) |
105 | 110 | }
|
106 | 111 |
|
107 |
| - storableNode.RIndex, buf, err = utils.ReadUint64(buf) |
| 112 | + rchildIndex, buf, err = utils.ReadUint64(buf) |
108 | 113 | if err != nil {
|
109 |
| - return nil, fmt.Errorf("error reading storable node: %w", err) |
| 114 | + return nil, fmt.Errorf("failed to read serialized node: %w", err) |
110 | 115 | }
|
111 | 116 |
|
112 |
| - storableNode.MaxDepth, buf, err = utils.ReadUint16(buf) |
| 117 | + maxDepth, buf, err = utils.ReadUint16(buf) |
113 | 118 | if err != nil {
|
114 |
| - return nil, fmt.Errorf("error reading storable node: %w", err) |
| 119 | + return nil, fmt.Errorf("failed to read serialized node: %w", err) |
115 | 120 | }
|
116 | 121 |
|
117 |
| - storableNode.RegCount, _, err = utils.ReadUint64(buf) |
| 122 | + regCount, _, err = utils.ReadUint64(buf) |
118 | 123 | if err != nil {
|
119 |
| - return nil, fmt.Errorf("error reading storable node: %w", err) |
| 124 | + return nil, fmt.Errorf("failed to read serialized node: %w", err) |
120 | 125 | }
|
121 | 126 |
|
122 |
| - storableNode.Path, err = utils.ReadShortDataFromReader(reader) |
| 127 | + path, err = utils.ReadShortDataFromReader(reader) |
123 | 128 | if err != nil {
|
124 | 129 | return nil, fmt.Errorf("cannot read key data: %w", err)
|
125 | 130 | }
|
126 | 131 |
|
127 |
| - storableNode.EncPayload, err = utils.ReadLongDataFromReader(reader) |
| 132 | + encPayload, err = utils.ReadLongDataFromReader(reader) |
128 | 133 | if err != nil {
|
129 | 134 | return nil, fmt.Errorf("cannot read value data: %w", err)
|
130 | 135 | }
|
131 | 136 |
|
132 |
| - storableNode.HashValue, err = utils.ReadShortDataFromReader(reader) |
| 137 | + hashValue, err = utils.ReadShortDataFromReader(reader) |
133 | 138 | if err != nil {
|
134 | 139 | return nil, fmt.Errorf("cannot read hashValue data: %w", err)
|
135 | 140 | }
|
136 | 141 |
|
137 |
| - return storableNode, nil |
| 142 | + // Create (and copy) hash from raw data. |
| 143 | + nodeHash, err := hash.ToHash(hashValue) |
| 144 | + if err != nil { |
| 145 | + return nil, fmt.Errorf("failed to decode hash from checkpoint: %w", err) |
| 146 | + } |
| 147 | + |
| 148 | + if len(path) > 0 { |
| 149 | + // Create (and copy) path from raw data. |
| 150 | + path, err := ledger.ToPath(path) |
| 151 | + if err != nil { |
| 152 | + return nil, fmt.Errorf("failed to decode path from checkpoint: %w", err) |
| 153 | + } |
| 154 | + |
| 155 | + // Decode payload (payload data isn't copied). |
| 156 | + payload, err := encoding.DecodePayload(encPayload) |
| 157 | + if err != nil { |
| 158 | + return nil, fmt.Errorf("failed to decode payload from checkpoint: %w", err) |
| 159 | + } |
| 160 | + |
| 161 | + // make a copy of payload |
| 162 | + var pl *ledger.Payload |
| 163 | + if payload != nil { |
| 164 | + pl = payload.DeepCopy() |
| 165 | + } |
| 166 | + |
| 167 | + n := node.NewNode(int(height), nil, nil, path, pl, nodeHash, maxDepth, regCount) |
| 168 | + return n, nil |
| 169 | + } |
| 170 | + |
| 171 | + // Get left child node by node index |
| 172 | + lchild, err := getNode(lchildIndex) |
| 173 | + if err != nil { |
| 174 | + return nil, fmt.Errorf("failed to find left child node: %w", err) |
| 175 | + } |
| 176 | + |
| 177 | + // Get right child node by node index |
| 178 | + rchild, err := getNode(rchildIndex) |
| 179 | + if err != nil { |
| 180 | + return nil, fmt.Errorf("failed to find right child node: %w", err) |
| 181 | + } |
| 182 | + |
| 183 | + n := node.NewNode(int(height), lchild, rchild, ledger.DummyPath, nil, nodeHash, maxDepth, regCount) |
| 184 | + return n, nil |
138 | 185 | }
|
139 | 186 |
|
140 | 187 | // EncodeTrie encodes trie root node
|
@@ -162,9 +209,8 @@ func EncodeTrie(rootNode *node.Node, rootIndex uint64) []byte {
|
162 | 209 | return buf
|
163 | 210 | }
|
164 | 211 |
|
165 |
| -// ReadStorableTrie reads a storable trie from io |
166 |
| -func ReadStorableTrie(reader io.Reader) (*StorableTrie, error) { |
167 |
| - storableTrie := &StorableTrie{} |
| 212 | +// ReadTrie reconstructs a trie from data read from reader. |
| 213 | +func ReadTrie(reader io.Reader, getNode func(nodeIndex uint64) (*node.Node, error)) (*trie.MTrie, error) { |
168 | 214 |
|
169 | 215 | // reading version
|
170 | 216 | buf := make([]byte, 2)
|
@@ -199,13 +245,26 @@ func ReadStorableTrie(reader io.Reader) (*StorableTrie, error) {
|
199 | 245 | if err != nil {
|
200 | 246 | return nil, fmt.Errorf("cannot read root index data: %w", err)
|
201 | 247 | }
|
202 |
| - storableTrie.RootIndex = rootIndex |
203 | 248 |
|
204 |
| - roothash, err := utils.ReadShortDataFromReader(reader) |
| 249 | + readRootHash, err := utils.ReadShortDataFromReader(reader) |
205 | 250 | if err != nil {
|
206 | 251 | return nil, fmt.Errorf("cannot read roothash data: %w", err)
|
207 | 252 | }
|
208 |
| - storableTrie.RootHash = roothash |
209 | 253 |
|
210 |
| - return storableTrie, nil |
| 254 | + rootNode, err := getNode(rootIndex) |
| 255 | + if err != nil { |
| 256 | + return nil, fmt.Errorf("cannot find root node: %w", err) |
| 257 | + } |
| 258 | + |
| 259 | + mtrie, err := trie.NewMTrie(rootNode) |
| 260 | + if err != nil { |
| 261 | + return nil, fmt.Errorf("restoring trie failed: %w", err) |
| 262 | + } |
| 263 | + |
| 264 | + rootHash := mtrie.RootHash() |
| 265 | + if !bytes.Equal(readRootHash, rootHash[:]) { |
| 266 | + return nil, fmt.Errorf("restoring trie failed: roothash doesn't match") |
| 267 | + } |
| 268 | + |
| 269 | + return mtrie, nil |
211 | 270 | }
|
0 commit comments