13
13
// limitations under the License.
14
14
15
15
import TensorFlow
16
+ import LayerInit
16
17
17
18
// Original Paper:
18
19
// "Very Deep Convolutional Networks for Large-Scale Image Recognition"
19
20
// Karen Simonyan, Andrew Zisserman
20
21
// https://arxiv.org/abs/1409.1556
21
22
23
+ public typealias AutoVGGBlock = AutoSequencedDefinition < AutoSequencedMany < AutoConv2D < Float > > , AutoMaxPool2D < Float > >
24
+ func makeVGGBlock( featureCounts: ( Int , Int , Int , Int ) , blockCount: Int ) -> AutoVGGBlock {
25
+ var blocks : [ AutoConv2D < Float > ] = [
26
+ AutoConv2D < Float > ( filterShape: ( 3 , 3 ) , outputChannels: featureCounts. 1 ,
27
+ padding: . same,
28
+ activation: relu) ]
29
+ for _ in 1 ..< blockCount {
30
+ blocks += [ AutoConv2D ( filterShape: ( 3 , 3 ) , outputChannels: featureCounts. 3 ,
31
+ padding: . same,
32
+ activation: relu) ]
33
+ }
34
+
35
+ return AutoSequencedMany ( layers: blocks)
36
+ . then ( AutoMaxPool2D ( poolSize: ( 2 , 2 ) , strides: ( 2 , 2 ) ) )
37
+ }
38
+
39
+ // TODO(shadaj): oh no
40
+ public typealias AutoVGG16Backbone = AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoVGGBlock , AutoVGGBlock > , AutoVGGBlock > , AutoVGGBlock > , AutoVGGBlock >
41
+ public typealias AutoVGG16 = AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoSequencedDefinition < AutoVGG16Backbone , AutoFlatten < Float > > , AutoDense < Float > > , AutoDense < Float > > , AutoDense < Float > >
42
+
43
+ public func makeVGG16( classCount: Int = 1000 ) -> AutoVGG16 {
44
+ let layer1 = makeVGGBlock ( featureCounts: ( 3 , 64 , 64 , 64 ) , blockCount: 2 )
45
+ let layer2 = makeVGGBlock ( featureCounts: ( 64 , 128 , 128 , 128 ) , blockCount: 2 )
46
+ let layer3 = makeVGGBlock ( featureCounts: ( 128 , 256 , 256 , 256 ) , blockCount: 3 )
47
+ let layer4 = makeVGGBlock ( featureCounts: ( 256 , 512 , 512 , 512 ) , blockCount: 3 )
48
+ let layer5 = makeVGGBlock ( featureCounts: ( 512 , 512 , 512 , 512 ) , blockCount: 3 )
49
+
50
+ let flatten = AutoFlatten < Float > ( )
51
+ let dense1 = AutoDense < Float > ( outputSize: 4096 , activation: relu)
52
+ let dense2 = AutoDense < Float > ( outputSize: 4096 , activation: relu)
53
+ let output = AutoDense < Float > ( outputSize: classCount)
54
+
55
+ let backbone = layer1. then ( layer2) . then ( layer3) . then ( layer4) . then ( layer5)
56
+ let fullModel = backbone. then ( flatten) . then ( dense1) . then ( dense2) . then ( output)
57
+ return fullModel
58
+ }
59
+
22
60
public struct VGGBlock : Layer {
23
61
var blocks : [ Conv2D < Float > ] = [ ]
24
62
var maxpool = MaxPool2D < Float > ( poolSize: ( 2 , 2 ) , strides: ( 2 , 2 ) )
@@ -40,34 +78,6 @@ public struct VGGBlock: Layer {
40
78
}
41
79
}
42
80
43
- public struct VGG16 : Layer {
44
- var layer1 : VGGBlock
45
- var layer2 : VGGBlock
46
- var layer3 : VGGBlock
47
- var layer4 : VGGBlock
48
- var layer5 : VGGBlock
49
-
50
- var flatten = Flatten < Float > ( )
51
- var dense1 = Dense < Float > ( inputSize: 512 * 7 * 7 , outputSize: 4096 , activation: relu)
52
- var dense2 = Dense < Float > ( inputSize: 4096 , outputSize: 4096 , activation: relu)
53
- var output : Dense < Float >
54
-
55
- public init ( classCount: Int = 1000 ) {
56
- layer1 = VGGBlock ( featureCounts: ( 3 , 64 , 64 , 64 ) , blockCount: 2 )
57
- layer2 = VGGBlock ( featureCounts: ( 64 , 128 , 128 , 128 ) , blockCount: 2 )
58
- layer3 = VGGBlock ( featureCounts: ( 128 , 256 , 256 , 256 ) , blockCount: 3 )
59
- layer4 = VGGBlock ( featureCounts: ( 256 , 512 , 512 , 512 ) , blockCount: 3 )
60
- layer5 = VGGBlock ( featureCounts: ( 512 , 512 , 512 , 512 ) , blockCount: 3 )
61
- output = Dense ( inputSize: 4096 , outputSize: classCount)
62
- }
63
-
64
- @differentiable
65
- public func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
66
- let backbone = input. sequenced ( through: layer1, layer2, layer3, layer4, layer5)
67
- return backbone. sequenced ( through: flatten, dense1, dense2, output)
68
- }
69
- }
70
-
71
81
public struct VGG19 : Layer {
72
82
var layer1 : VGGBlock
73
83
var layer2 : VGGBlock
0 commit comments