Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 741bdc2

Browse files
committed
Add support for models that reuse layers
1 parent 97a4eac commit 741bdc2

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

Models/LayerInit/AutoReuseLayer.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import TensorFlow
2+
3+
public struct AutoReuseLayerInstance<OuterLayer: Layer, MiddleLayer: Layer>: Layer
4+
where
5+
OuterLayer.Output == MiddleLayer.Input, OuterLayer.Input == MiddleLayer.Output,
6+
OuterLayer.TangentVector.VectorSpaceScalar == MiddleLayer.TangentVector.VectorSpaceScalar {
7+
public var outerLayer: OuterLayer
8+
public var middleLayer: MiddleLayer
9+
10+
@differentiable
11+
public func callAsFunction(_ input: OuterLayer.Input) -> OuterLayer.Output {
12+
return outerLayer(middleLayer(outerLayer(input)))
13+
}
14+
}
15+
16+
public struct AutoReuseLayer<OuterLayer: AutoLayer, MiddleLayer: AutoLayer>: AutoLayer
17+
where
18+
OuterLayer.OutputShape == MiddleLayer.InputShape,
19+
MiddleLayer.OutputShape == OuterLayer.InputShape,
20+
OuterLayer.InstanceType.Output == MiddleLayer.InstanceType.Input,
21+
MiddleLayer.InstanceType.Output == OuterLayer.InstanceType.Input,
22+
OuterLayer.InstanceType.TangentVector.VectorSpaceScalar == MiddleLayer.InstanceType.TangentVector.VectorSpaceScalar {
23+
let outer: OuterLayer
24+
let middle: MiddleLayer
25+
26+
public typealias InstanceType = AutoReuseLayerInstance<OuterLayer.InstanceType, MiddleLayer.InstanceType>
27+
28+
public init(outer: OuterLayer, middle: MiddleLayer) {
29+
self.outer = outer
30+
self.middle = middle
31+
}
32+
33+
public func buildModelWithOutputShape<Prefix>(inputShape: OuterLayer.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OuterLayer.OutputShape) {
34+
let (outerLayer, firstOuterOutputShape) = outer.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.outerLayer), keyDict: &keyDict)
35+
let (middleLayer, middleOutputShape) = middle.buildModelWithOutputShape(inputShape: firstOuterOutputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.middleLayer), keyDict: &keyDict)
36+
37+
var tempDictionary: [AnyAutoLayerKey: Any] = [:]
38+
39+
// TODO(shadaj): tuples are not equatable
40+
// let shapeMismatchError: String = "Cannot reuse outer layer because original input size \(inputShape) does not match output size of middle layer \(middleOutputShape)"
41+
// precondition(inputShape == middleOutputShape, shapeMismatchError)
42+
let (_, finalOutputShape) = outer.buildModelWithOutputShape(inputShape: middleOutputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.outerLayer), keyDict: &tempDictionary)
43+
44+
return (AutoReuseLayerInstance(outerLayer: outerLayer, middleLayer: middleLayer), finalOutputShape)
45+
}
46+
}

0 commit comments

Comments
 (0)