-
Notifications
You must be signed in to change notification settings - Fork 6
/
export_network.go
94 lines (87 loc) · 2.54 KB
/
export_network.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package cnns
import (
"encoding/json"
"fmt"
"io/ioutil"
"github.com/pkg/errors"
)
// ExportToFile Save network structure and its weights to JSON file
func (wh *WholeNet) ExportToFile(fname string, saveWeights bool) error {
save := NetJSON{
Network: &NetworkJSON{},
Parameters: &LearningParams{},
}
for i := 0; i < len(wh.Layers); i++ {
switch wh.Layers[i].GetType() {
case "conv":
layer := wh.Layers[i].(*ConvLayer)
kernels := wh.Layers[i].GetWeights()
newLayer := &NetLayerJSON{
LayerType: "conv",
InputSize: wh.Layers[i].GetInputSize(),
Parameters: &LayerParamsJSON{
Stride: wh.Layers[i].GetStride(),
KernelSize: layer.KernelSize,
},
Weights: make([]*NestedData, len(kernels)),
}
if saveWeights {
for k := range kernels {
newLayer.Weights[k] = &NestedData{Data: kernels[k].RawMatrix().Data}
}
}
save.Network.Layers = append(save.Network.Layers, newLayer)
break
case "relu":
newLayer := &NetLayerJSON{
LayerType: "relu",
InputSize: wh.Layers[i].GetInputSize(),
}
save.Network.Layers = append(save.Network.Layers, newLayer)
break
case "pool":
layer := wh.Layers[i].(*PoolingLayer)
newLayer := &NetLayerJSON{
LayerType: "pool",
InputSize: wh.Layers[i].GetInputSize(),
Parameters: &LayerParamsJSON{
Stride: wh.Layers[i].GetStride(),
KernelSize: layer.ExtendFilter,
PoolingType: layer.PoolingType.String(),
ZeroPaddingType: layer.ZeroPadding.String(),
},
}
save.Network.Layers = append(save.Network.Layers, newLayer)
break
case "fc":
newLayer := &NetLayerJSON{
LayerType: "fc",
InputSize: wh.Layers[i].GetInputSize(),
OutputSize: wh.Layers[i].GetOutputSize(),
Weights: make([]*NestedData, 1),
}
if saveWeights {
weights := wh.Layers[i].GetWeights()
if len(weights) != 1 {
return fmt.Errorf("Fully connected layer can have only 1 array for weights")
}
newLayer.Weights[0] = &NestedData{Data: weights[0].RawMatrix().Data}
}
save.Network.Layers = append(save.Network.Layers, newLayer)
break
default:
return fmt.Errorf("Unrecognized layer type: %v", wh.Layers[i].GetType())
}
}
save.Parameters.LearningRate = wh.LP.LearningRate
save.Parameters.Momentum = wh.LP.Momentum
saveJSON, err := json.Marshal(save)
if err != nil {
return errors.Wrap(err, "Can't marshal network to JSON")
}
err = ioutil.WriteFile(fname, saveJSON, 0644)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Can't write data to file '%s'", fname))
}
return err
}