|
| 1 | +import TensorFlow |
| 2 | + |
| 3 | +public struct AutoReuseLayerInstance<OuterLayer: Layer, MiddleLayer: Layer>: Layer |
| 4 | +where |
| 5 | + OuterLayer.Output == MiddleLayer.Input, OuterLayer.Input == MiddleLayer.Output, |
| 6 | + OuterLayer.TangentVector.VectorSpaceScalar == MiddleLayer.TangentVector.VectorSpaceScalar { |
| 7 | + public var outerLayer: OuterLayer |
| 8 | + public var middleLayer: MiddleLayer |
| 9 | + |
| 10 | + @differentiable |
| 11 | + public func callAsFunction(_ input: OuterLayer.Input) -> OuterLayer.Output { |
| 12 | + return outerLayer(middleLayer(outerLayer(input))) |
| 13 | + } |
| 14 | +} |
| 15 | + |
| 16 | +public struct AutoReuseLayer<OuterLayer: AutoLayer, MiddleLayer: AutoLayer>: AutoLayer |
| 17 | +where |
| 18 | + OuterLayer.OutputShape == MiddleLayer.InputShape, |
| 19 | + MiddleLayer.OutputShape == OuterLayer.InputShape, |
| 20 | + OuterLayer.InstanceType.Output == MiddleLayer.InstanceType.Input, |
| 21 | + MiddleLayer.InstanceType.Output == OuterLayer.InstanceType.Input, |
| 22 | + OuterLayer.InstanceType.TangentVector.VectorSpaceScalar == MiddleLayer.InstanceType.TangentVector.VectorSpaceScalar { |
| 23 | + let outer: OuterLayer |
| 24 | + let middle: MiddleLayer |
| 25 | + |
| 26 | + public typealias InstanceType = AutoReuseLayerInstance<OuterLayer.InstanceType, MiddleLayer.InstanceType> |
| 27 | + |
| 28 | + public init(outer: OuterLayer, middle: MiddleLayer) { |
| 29 | + self.outer = outer |
| 30 | + self.middle = middle |
| 31 | + } |
| 32 | + |
| 33 | + public func buildModelWithOutputShape<Prefix>(inputShape: OuterLayer.InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OuterLayer.OutputShape) { |
| 34 | + let (outerLayer, firstOuterOutputShape) = outer.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.outerLayer), keyDict: &keyDict) |
| 35 | + let (middleLayer, middleOutputShape) = middle.buildModelWithOutputShape(inputShape: firstOuterOutputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.middleLayer), keyDict: &keyDict) |
| 36 | + |
| 37 | + var tempDictionary: [AnyAutoLayerKey: Any] = [:] |
| 38 | + |
| 39 | + // TODO(shadaj): tuples are not equatable |
| 40 | + // let shapeMismatchError: String = "Cannot reuse outer layer because original input size \(inputShape) does not match output size of middle layer \(middleOutputShape)" |
| 41 | + // precondition(inputShape == middleOutputShape, shapeMismatchError) |
| 42 | + let (_, finalOutputShape) = outer.buildModelWithOutputShape(inputShape: middleOutputShape, keyPathSoFar: keyPathSoFar.appending(path: \InstanceType.outerLayer), keyDict: &tempDictionary) |
| 43 | + |
| 44 | + return (AutoReuseLayerInstance(outerLayer: outerLayer, middleLayer: middleLayer), finalOutputShape) |
| 45 | + } |
| 46 | +} |
0 commit comments