This repository was archived by the owner on Apr 23, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 149
Introduce new layer initialization APIs with automatic shape computation #584
Closed
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
a11892f
Initial prototype for layer init with automatic shapes
shadaj ebb4e3e
Don't expose output shape type in buildModel
shadaj fc31111
Add all layers needed for LeNet-5
shadaj db7e19a
Port VGG16 to use new layer init API
shadaj 3d3bdc5
Update CMake config
shadaj dc7a595
Support ResNet-style models with skip connections
shadaj 0a24d9b
Update CMake config
shadaj e13728a
Rename AutoSequencedDefinition to AutoSequenced for naming concistency
shadaj 630c952
Explore LayerModule protocol to enable better type inference
shadaj f1628ed
Implement initial API for accessing instance layers
shadaj edbe93c
Simplify API for getting layers by key
shadaj 97a4eac
Simplify ResNet block definition
shadaj 741bdc2
Add support for models that reuse layers
shadaj 7f69b64
Update CMake config
shadaj 82d9d83
Dynamically unpack tuples where input shape dimension is unknown
shadaj fff5176
Use tuple unpacking to check shapes when building AutoReuseLayer
shadaj 26ebd4e
Add initial docs
shadaj 402b4f2
Use triple-slash style for docs
shadaj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
add_subdirectory(ImageClassification) | ||
add_subdirectory(LayerInit) | ||
add_subdirectory(Recommendation) | ||
add_subdirectory(Text) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import TensorFlow | ||
|
||
public struct AutoBatchNorm<Shape, Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
let axis: Int | ||
let momentum: Scalar | ||
let epsilon: Scalar | ||
|
||
public typealias InstanceType = BatchNorm<Scalar> | ||
public typealias InputShape = Shape | ||
public typealias OutputShape = Shape | ||
|
||
public init( | ||
axis: Int = -1, | ||
momentum: Scalar = 0.99, | ||
epsilon: Scalar = 0.001 | ||
) { | ||
self.axis = axis | ||
self.momentum = momentum | ||
self.epsilon = epsilon | ||
} | ||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: Shape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Shape) { | ||
let inputShapeArray: [Int] = intTupleToArray(tuple: inputShape) | ||
|
||
let featureCount = inputShapeArray[(inputShapeArray.count + axis) % inputShapeArray.count] | ||
return (BatchNorm<Scalar>(featureCount: featureCount, axis: axis, momentum: momentum, epsilon: epsilon), inputShape) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import TensorFlow | ||
|
||
public struct AutoConv2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
let filterShape: (Int, Int) | ||
let outputChannels: Int | ||
let strides: (Int, Int) | ||
let padding: Padding | ||
let dilations: (Int, Int) | ||
let activation: Conv2D<Scalar>.Activation | ||
let useBias: Bool | ||
let filterInitializer: ParameterInitializer<Scalar> | ||
let biasInitializer: ParameterInitializer<Scalar> | ||
|
||
public typealias InstanceType = Conv2D<Scalar> | ||
public typealias InputShape = (Int, Int, Int) | ||
public typealias OutputShape = (Int, Int, Int) | ||
|
||
public init( | ||
filterShape: (Int, Int), | ||
outputChannels: Int, | ||
strides: (Int, Int) = (1, 1), | ||
padding: Padding = .valid, | ||
dilations: (Int, Int) = (1, 1), | ||
activation: @escaping Conv2D<Scalar>.Activation = identity, | ||
useBias: Bool = true, | ||
filterInitializer: @escaping ParameterInitializer<Scalar> = glorotUniform(), | ||
biasInitializer: @escaping ParameterInitializer<Scalar> = zeros() | ||
) { | ||
self.filterShape = filterShape | ||
self.outputChannels = outputChannels | ||
self.strides = strides | ||
self.padding = padding | ||
self.dilations = dilations | ||
self.activation = activation | ||
self.useBias = useBias | ||
self.filterInitializer = filterInitializer | ||
self.biasInitializer = biasInitializer | ||
} | ||
Comment on lines
+4
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also just looks like a closure capture. |
||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, (Int, Int, Int)) { | ||
let outputShape: (Int, Int, Int) | ||
if (padding == .valid) { | ||
outputShape = ( | ||
Int(ceil(Float(inputShape.0 - filterShape.0 + 1) / Float(strides.0))), | ||
Int(ceil(Float(inputShape.1 - filterShape.1 + 1) / Float(strides.1))), | ||
outputChannels | ||
) | ||
} else { | ||
outputShape = ( | ||
Int(ceil(Float(inputShape.0) / Float(strides.0))), | ||
Int(ceil(Float(inputShape.1) / Float(strides.1))), | ||
outputChannels | ||
) | ||
} | ||
|
||
return (Conv2D<Scalar>( | ||
filterShape: (filterShape.0, filterShape.1, inputShape.2, outputChannels), | ||
strides: strides, padding: padding, dilations: dilations, | ||
activation: activation, useBias: useBias, | ||
filterInitializer: filterInitializer, biasInitializer: biasInitializer | ||
), outputShape) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import TensorFlow | ||
|
||
public struct AutoDense<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
let outputSize: Int; | ||
let activation: Dense<Scalar>.Activation | ||
|
||
public typealias InstanceType = Dense<Scalar> | ||
public typealias InputShape = Int | ||
public typealias OutputShape = Int | ||
|
||
public init(outputSize: Int, activation: @escaping Dense<Scalar>.Activation = identity) { | ||
self.outputSize = outputSize | ||
self.activation = activation | ||
} | ||
Comment on lines
+4
to
+14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also looks like a closure capture. |
||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: Int, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Int) { | ||
return (Dense<Scalar>(inputSize: inputShape, outputSize: self.outputSize, activation: self.activation), self.outputSize) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import TensorFlow | ||
|
||
public struct AutoFlatten<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
public typealias InstanceType = Flatten<Scalar> | ||
public typealias InputShape = Any | ||
public typealias OutputShape = Int | ||
|
||
public init() {} | ||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: Any, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Int) { | ||
return (Flatten<Scalar>(), intTupleToArray(tuple: inputShape).reduce(1, { $0 * $1 })) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import TensorFlow | ||
|
||
/// A layer that applies a user-defined function to the input data | ||
public struct AutoFunction<Input: Differentiable, Output: Differentiable, InputShape, OutputShape>: AutoLayer { | ||
let fnShape: (InputShape) -> OutputShape | ||
let fn: @differentiable (Input) -> Output | ||
|
||
public typealias InstanceType = Function<Input, Output> | ||
public typealias InputShape = InputShape | ||
public typealias OutputShape = OutputShape | ||
|
||
/// Constructs a function layer instance. | ||
/// Parameters: | ||
/// - fnShape: a function that computes the output shape of the function given the input shape | ||
/// - fn: a function that computes the output data of the function given the input data | ||
public init(fnShape: @escaping (InputShape) -> OutputShape, fn: @escaping @differentiable (Input) -> Output) { | ||
self.fnShape = fnShape | ||
self.fn = fn | ||
} | ||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) { | ||
return (Function(fn), fnShape(inputShape)) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import TensorFlow | ||
|
||
/// A layer "blueprint", which defines elements that can be constructed into a Layer instance for training | ||
public protocol AutoLayer { | ||
/// The type of the layer instance that can be built with this prototype | ||
associatedtype InstanceType: Layer | ||
|
||
/// The specific tuple of `Int`s that define the input shape of the layer | ||
associatedtype InputShape | ||
|
||
/// The specific tuple of `Int`s that define the output shape of the layer | ||
associatedtype OutputShape | ||
|
||
/// Initializes a new instance of the layer defined by this blueprint. | ||
/// Parameters: | ||
/// - inputShape: the shape of a single input instance (no batch) to this layer | ||
/// - keyPathSoFar: a `KeyPath` that tracks the path from the root layer to the current layer instance | ||
/// - keyDict: a dictionary tracking the mapping from `AutoLayerKey`s to the key path to the layer instance | ||
/// Returns: | ||
/// - $0: the instance of the layer with the given input shape | ||
/// - $1: the output shape of the layer instance computed based on the input shape | ||
func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) | ||
} | ||
|
||
extension AutoLayer { | ||
/// Builds an instance of the model with the given input shape | ||
public func buildModel(inputShape: InputShape) -> BuiltAutoLayer<InstanceType> { | ||
var keyDict: [AnyAutoLayerKey: Any] = [:] | ||
let (layerInstance, _) = self.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: \InstanceType.self, keyDict: &keyDict) | ||
return BuiltAutoLayer(layer: layerInstance, keyMapping: keyDict) | ||
} | ||
} | ||
|
||
/// A layer instance containing a model built with the AutoLayer API. Offers keyed access to layers with `AutoLayerKey`. | ||
public struct BuiltAutoLayer<InstanceType: Layer>: Layer { | ||
public var layer: InstanceType | ||
@noDerivative let keyMapping: [AnyAutoLayerKey: Any] | ||
|
||
public init(layer: InstanceType, keyMapping: [AnyAutoLayerKey: Any]) { | ||
self.layer = layer | ||
self.keyMapping = keyMapping | ||
} | ||
|
||
@differentiable | ||
public func callAsFunction(_ input: InstanceType.Input) -> InstanceType.Output { | ||
return layer(input) | ||
} | ||
|
||
/// Grab a specific layer by the given `AutoLayerKey`. | ||
public subscript<T>(index: AutoLayerKey<T>) -> T { | ||
return self.layer[keyPath: self.keyMapping[index] as! KeyPath<InstanceType, T>] | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import TensorFlow | ||
|
||
public class AnyAutoLayerKey: Hashable { | ||
public func hash(into hasher: inout Hasher) { | ||
hasher.combine(ObjectIdentifier(self)) | ||
} | ||
|
||
public static func == (lhs: AnyAutoLayerKey, rhs: AnyAutoLayerKey) -> Bool { | ||
return lhs === rhs | ||
} | ||
} | ||
|
||
/// A key that can be associated with an `AutoLayer` to access it after it has been built as part of a larger model. | ||
public class AutoLayerKey<T: Layer>: AnyAutoLayerKey { | ||
public override init() {} | ||
} | ||
|
||
/// A layer blueprint that associates an underlying blueprint with an `AutoLayerKey` so that the underlying instance can be accessed from the built model. | ||
public struct KeyedAutoLayer<Underlying: AutoLayer>: AutoLayer { | ||
let underlying: Underlying | ||
let key: AutoLayerKey<InstanceType> | ||
|
||
public typealias InstanceType = Underlying.InstanceType | ||
public typealias InputShape = Underlying.InputShape | ||
public typealias OutputShape = Underlying.OutputShape | ||
|
||
public init(_ underlying: Underlying, key: AutoLayerKey<InstanceType>) { | ||
self.underlying = underlying | ||
self.key = key | ||
} | ||
|
||
Comment on lines
+19
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also looks a bit like a closure capture. |
||
public func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) { | ||
let (layer, outputShape) = underlying.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar, keyDict: &keyDict) | ||
keyDict[self.key] = keyPathSoFar | ||
return (layer, outputShape) | ||
} | ||
} | ||
|
||
extension AutoLayer { | ||
/// Attaches a key to an existing layer blueprint | ||
public func withKey(_ key: AutoLayerKey<InstanceType>) -> KeyedAutoLayer<Self> { | ||
return KeyedAutoLayer(self, key: key) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import TensorFlow | ||
|
||
public protocol AutoModule: AutoLayer { | ||
associatedtype LayerType: AutoLayer | ||
|
||
var initializeLayer: LayerType { mutating get } | ||
} | ||
|
||
extension AutoModule { | ||
public typealias InstanceType = LayerType.InstanceType | ||
public typealias InputShape = LayerType.InputShape | ||
public typealias OutputShape = LayerType.OutputShape | ||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: InputShape, keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, OutputShape) { | ||
var selfCopy = self | ||
return selfCopy.initializeLayer.buildModelWithOutputShape(inputShape: inputShape, keyPathSoFar: keyPathSoFar, keyDict: &keyDict) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import TensorFlow | ||
|
||
public struct AutoAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
let poolSize: (Int, Int) | ||
let strides: (Int, Int) | ||
let padding: Padding | ||
|
||
public typealias InstanceType = AvgPool2D<Scalar> | ||
public typealias InputShape = (Int, Int, Int) | ||
public typealias OutputShape = (Int, Int, Int) | ||
|
||
public init( | ||
poolSize: (Int, Int), | ||
strides: (Int, Int) = (1, 1), | ||
padding: Padding = .valid | ||
) { | ||
self.poolSize = poolSize | ||
self.strides = strides | ||
self.padding = padding | ||
} | ||
|
||
Comment on lines
+4
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also looks like a closure capture. (Also below.) |
||
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, (Int, Int, Int)) { | ||
let outputShape: (Int, Int, Int) | ||
if (padding == .valid) { | ||
outputShape = ( | ||
Int(ceil(Float(inputShape.0 - poolSize.0 + 1) / Float(strides.0))), | ||
Int(ceil(Float(inputShape.1 - poolSize.1 + 1) / Float(strides.1))), | ||
inputShape.2 | ||
) | ||
} else { | ||
outputShape = ( | ||
Int(ceil(Float(inputShape.0) / Float(strides.0))), | ||
Int(ceil(Float(inputShape.1) / Float(strides.1))), | ||
inputShape.2 | ||
) | ||
} | ||
|
||
return (AvgPool2D<Scalar>( | ||
poolSize: poolSize, | ||
strides: strides, | ||
padding: padding | ||
), outputShape) | ||
} | ||
} | ||
|
||
public struct AutoGlobalAvgPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
public typealias InstanceType = GlobalAvgPool2D<Scalar> | ||
public typealias InputShape = (Int, Int, Int) | ||
public typealias OutputShape = Int | ||
|
||
public init() { | ||
} | ||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, Int) { | ||
return (GlobalAvgPool2D<Scalar>(), inputShape.2) | ||
} | ||
} | ||
|
||
public struct AutoMaxPool2D<Scalar>: AutoLayer where Scalar: TensorFlowFloatingPoint { | ||
let poolSize: (Int, Int) | ||
let strides: (Int, Int) | ||
let padding: Padding | ||
|
||
public typealias InstanceType = MaxPool2D<Scalar> | ||
public typealias InputShape = (Int, Int, Int) | ||
public typealias OutputShape = (Int, Int, Int) | ||
|
||
public init( | ||
poolSize: (Int, Int), | ||
strides: (Int, Int) = (1, 1), | ||
padding: Padding = .valid | ||
) { | ||
self.poolSize = poolSize | ||
self.strides = strides | ||
self.padding = padding | ||
} | ||
|
||
public func buildModelWithOutputShape<Prefix>(inputShape: (Int, Int, Int), keyPathSoFar: KeyPath<Prefix, InstanceType>, keyDict: inout [AnyAutoLayerKey: Any]) -> (InstanceType, (Int, Int, Int)) { | ||
let outputShape: (Int, Int, Int) | ||
if (padding == .valid) { | ||
outputShape = ( | ||
Int(ceil(Float(inputShape.0 - poolSize.0 + 1) / Float(strides.0))), | ||
Int(ceil(Float(inputShape.1 - poolSize.1 + 1) / Float(strides.1))), | ||
inputShape.2 | ||
) | ||
} else { | ||
outputShape = ( | ||
Int(ceil(Float(inputShape.0) / Float(strides.0))), | ||
Int(ceil(Float(inputShape.1) / Float(strides.1))), | ||
inputShape.2 | ||
) | ||
} | ||
|
||
return (MaxPool2D<Scalar>( | ||
poolSize: poolSize, | ||
strides: strides, | ||
padding: padding | ||
), outputShape) | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This to me looks just like a closure capture.