Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit adaaf94

Browse files
committed
Port VGG16 to use new layer init API
1 parent a08917a commit adaaf94

File tree

5 files changed

+117
-30
lines changed

5 files changed

+117
-30
lines changed

Examples/VGG-Imagewoof/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import TensorFlow
1919
let batchSize = 32
2020

2121
let dataset = Imagewoof(batchSize: batchSize, inputSize: .full, outputSize: 224)
22-
var model = VGG16(classCount: 10)
22+
var model = makeVGG16(classCount: 10).buildModel(inputShape: (224, 224, 3))
2323
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9, decay: 0.0005)
2424

2525
print("Starting training...")

Models/ImageClassification/VGG.swift

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,50 @@
1313
// limitations under the License.
1414

1515
import TensorFlow
16+
import LayerInit
1617

1718
// Original Paper:
1819
// "Very Deep Convolutional Networks for Large-Scale Image Recognition"
1920
// Karen Simonyan, Andrew Zisserman
2021
// https://arxiv.org/abs/1409.1556
2122

23+
public typealias AutoVGGBlock = AutoSequencedDefinition<AutoSequencedMany<AutoConv2D<Float>>, AutoMaxPool2D<Float>>
24+
func makeVGGBlock(featureCounts: (Int, Int, Int, Int), blockCount: Int) -> AutoVGGBlock {
25+
var blocks: [AutoConv2D<Float>] = [
26+
AutoConv2D<Float>(filterShape: (3, 3), outputChannels: featureCounts.1,
27+
padding: .same,
28+
activation: relu)]
29+
for _ in 1..<blockCount {
30+
blocks += [AutoConv2D(filterShape: (3, 3), outputChannels: featureCounts.3,
31+
padding: .same,
32+
activation: relu)]
33+
}
34+
35+
return AutoSequencedMany(layers: blocks)
36+
.then(AutoMaxPool2D(poolSize: (2, 2), strides: (2, 2)))
37+
}
38+
39+
// TODO(shadaj): oh no
40+
public typealias AutoVGG16Backbone = AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoVGGBlock, AutoVGGBlock>, AutoVGGBlock>, AutoVGGBlock>, AutoVGGBlock>
41+
public typealias AutoVGG16 = AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoSequencedDefinition<AutoVGG16Backbone, AutoFlatten<Float>>, AutoDense<Float>>, AutoDense<Float>>, AutoDense<Float>>
42+
43+
public func makeVGG16(classCount: Int = 1000) -> AutoVGG16 {
44+
let layer1 = makeVGGBlock(featureCounts: (3, 64, 64, 64), blockCount: 2)
45+
let layer2 = makeVGGBlock(featureCounts: (64, 128, 128, 128), blockCount: 2)
46+
let layer3 = makeVGGBlock(featureCounts: (128, 256, 256, 256), blockCount: 3)
47+
let layer4 = makeVGGBlock(featureCounts: (256, 512, 512, 512), blockCount: 3)
48+
let layer5 = makeVGGBlock(featureCounts: (512, 512, 512, 512), blockCount: 3)
49+
50+
let flatten = AutoFlatten<Float>()
51+
let dense1 = AutoDense<Float>(outputSize: 4096, activation: relu)
52+
let dense2 = AutoDense<Float>(outputSize: 4096, activation: relu)
53+
let output = AutoDense<Float>(outputSize: classCount)
54+
55+
let backbone = layer1.then(layer2).then(layer3).then(layer4).then(layer5)
56+
let fullModel = backbone.then(flatten).then(dense1).then(dense2).then(output)
57+
return fullModel
58+
}
59+
2260
public struct VGGBlock: Layer {
2361
var blocks: [Conv2D<Float>] = []
2462
var maxpool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
@@ -40,34 +78,6 @@ public struct VGGBlock: Layer {
4078
}
4179
}
4280

43-
public struct VGG16: Layer {
44-
var layer1: VGGBlock
45-
var layer2: VGGBlock
46-
var layer3: VGGBlock
47-
var layer4: VGGBlock
48-
var layer5: VGGBlock
49-
50-
var flatten = Flatten<Float>()
51-
var dense1 = Dense<Float>(inputSize: 512 * 7 * 7, outputSize: 4096, activation: relu)
52-
var dense2 = Dense<Float>(inputSize: 4096, outputSize: 4096, activation: relu)
53-
var output: Dense<Float>
54-
55-
public init(classCount: Int = 1000) {
56-
layer1 = VGGBlock(featureCounts: (3, 64, 64, 64), blockCount: 2)
57-
layer2 = VGGBlock(featureCounts: (64, 128, 128, 128), blockCount: 2)
58-
layer3 = VGGBlock(featureCounts: (128, 256, 256, 256), blockCount: 3)
59-
layer4 = VGGBlock(featureCounts: (256, 512, 512, 512), blockCount: 3)
60-
layer5 = VGGBlock(featureCounts: (512, 512, 512, 512), blockCount: 3)
61-
output = Dense(inputSize: 4096, outputSize: classCount)
62-
}
63-
64-
@differentiable
65-
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
66-
let backbone = input.sequenced(through: layer1, layer2, layer3, layer4, layer5)
67-
return backbone.sequenced(through: flatten, dense1, dense2, output)
68-
}
69-
}
70-
7181
public struct VGG19: Layer {
7282
var layer1: VGGBlock
7383
var layer2: VGGBlock

Models/LayerInit/AutoPool.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,46 @@ public struct AutoAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingP
4242
), outputShape)
4343
}
4444
}
45+
46+
public struct AutoMaxPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint {
47+
let poolSize: (Int, Int)
48+
let strides: (Int, Int)
49+
let padding: Padding
50+
51+
public typealias InstanceType = MaxPool2D<Scalar>
52+
public typealias InputShape = (Int, Int, Int)
53+
public typealias OutputShape = (Int, Int, Int)
54+
55+
public init(
56+
poolSize: (Int, Int),
57+
strides: (Int, Int) = (1, 1),
58+
padding: Padding = .valid
59+
) {
60+
self.poolSize = poolSize
61+
self.strides = strides
62+
self.padding = padding
63+
}
64+
65+
public func buildModelWithOutputShape(inputShape: (Int, Int, Int)) -> (InstanceType, (Int, Int, Int)) {
66+
let outputShape: (Int, Int, Int)
67+
if (padding == .valid) {
68+
outputShape = (
69+
Int(ceil(Float(inputShape.0 - poolSize.0 + 1) / Float(strides.0))),
70+
Int(ceil(Float(inputShape.1 - poolSize.1 + 1) / Float(strides.1))),
71+
inputShape.2
72+
)
73+
} else {
74+
outputShape = (
75+
Int(ceil(Float(inputShape.0) / Float(strides.0))),
76+
Int(ceil(Float(inputShape.1) / Float(strides.1))),
77+
inputShape.2
78+
)
79+
}
80+
81+
return (MaxPool2D<Scalar>(
82+
poolSize: poolSize,
83+
strides: strides,
84+
padding: padding
85+
), outputShape)
86+
}
87+
}

Models/LayerInit/AutoSequenced.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,37 @@ extension AutoLayer {
2727
return AutoSequencedDefinition<Self, T>(first: self, second: other)
2828
}
2929
}
30+
31+
public struct AutoSequencedManyInstance<LayerType: Layer>: Layer
32+
where LayerType.Input == LayerType.Output {
33+
var layers: [LayerType]
34+
35+
@differentiable
36+
public func callAsFunction(_ input: LayerType.Input) -> LayerType.Output {
37+
return layers.differentiableReduce(input) { $1($0) }
38+
}
39+
}
40+
41+
public struct AutoSequencedMany<LayerType: AutoLayer>: AutoLayer
42+
where
43+
LayerType.OutputShape == LayerType.InputShape,
44+
LayerType.InstanceType.Input == LayerType.InstanceType.Output {
45+
let layers: [LayerType]
46+
47+
public typealias InstanceType = AutoSequencedManyInstance<LayerType.InstanceType>
48+
49+
public init(layers: [LayerType]) {
50+
self.layers = layers
51+
}
52+
53+
public func buildModelWithOutputShape(inputShape: LayerType.InputShape) -> (InstanceType, LayerType.OutputShape) {
54+
var lastOutputShape = inputShape
55+
let builtInstances = self.layers.map({ autoLayer -> LayerType.InstanceType in
56+
let (instance, outputShape) = autoLayer.buildModelWithOutputShape(inputShape: lastOutputShape)
57+
lastOutputShape = outputShape
58+
return instance
59+
})
60+
61+
return (AutoSequencedManyInstance(layers: builtInstances), lastOutputShape)
62+
}
63+
}

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ let package = Package(
3030
.target(
3131
name: "ModelSupport", dependencies: ["SwiftProtobuf", "STBImage"], path: "Support",
3232
exclude: ["STBImage"]),
33-
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
33+
.target(name: "ImageClassificationModels", dependencies: ["LayerInit"], path: "Models/ImageClassification"),
3434
.target(name: "VideoClassificationModels", path: "Models/Spatiotemporal"),
3535
.target(name: "TextModels", dependencies: ["Datasets"], path: "Models/Text"),
3636
.target(name: "RecommendationModels", path: "Models/Recommendation"),

0 commit comments

Comments
 (0)