-
Notifications
You must be signed in to change notification settings - Fork 149
Exploring structural generic programming and layer APIs #613
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
{ | ||
// See https://go.microsoft.com/fwlink/?LinkId=733558 | ||
// for the documentation about the tasks.json format | ||
"version": "2.0.0", | ||
"tasks": [ | ||
{ | ||
"label": "swift-build", | ||
"type": "shell", | ||
"command": "/usr/local/google/home/saeta/tmp/toolchains/1165/usr/bin/swift build", | ||
"problemMatcher": [], | ||
"group": { | ||
"kind": "build", | ||
"isDefault": true | ||
} | ||
} | ||
] | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import TensorFlow | ||
import StructuralCore | ||
import PenguinStructures | ||
|
||
/// A simple model, where we don't have to write `callAsFunction`, thanks to `SequentialLayer`. | ||
public struct MyModel: Module, Layer, SequentialLayer { | ||
public var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)) | ||
public var pool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2)) | ||
public var flatten = Flatten<Float>() | ||
public var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10) | ||
} | ||
|
||
public struct MyModelSkipping: Module, Layer, SequentialLayer { | ||
public var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)) | ||
public var pool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2)) | ||
public var flatten = Flatten<Float>() | ||
@SequentialSkip(passing: Type<Tensor<Float>>()) var denseSkipped = Dense<Float>(inputSize: 1, outputSize: 2) | ||
public var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10) | ||
} | ||
|
||
public struct MyResidualModel: Module, Layer, SequentialLayer { | ||
public var conv = Conv2D<Float>(filterShape: (5, 5, 3, 6)) | ||
public var pool = MaxPool2D<Float>(poolSize: (2, 2), strides: (2, 2)) | ||
public var flatten = Flatten<Float>() | ||
@ResidualConnection var denseSkipped = Dense<Float>(inputSize: 36 * 6, outputSize: 36 * 6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would we express models where two different transformations need to be applied to the same data? As far as I understand, It also feels a bit awkward to write this as part of a sequential layer since it's more something parallel. I wonder if there is some way we can use properties for the vertical axis and something else for a horizontal axis of parallel layers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, there's definitely other ways to spell this. One thing we could do is instead of using property wrappers, we could instead have something like: struct ParallelLayers<Lhs: Layer, Rhs: Layer>: Layer where Lhs.Input == Rhs.Input, ... {
var lhs: Lhs
var rhs: Rhs
// TODO: make merge func configurable.
@differentiable
public func callAsFunction(_ input: Lhs.Input) -> Lhs.Output {
return lhs(input) + rhs(input)
}
} which could represent the parallelism. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could work really well! It's interesting to see this combination of using both structural and regular functions/structs to compose layers. Another (slightly crazier) idea is to offer a structural There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, that's definitely another potentially interesting point in design space. Perhaps you just require users to define a pair-wise reduction function. Alternatively, perhaps a |
||
public var dense = Dense<Float>(inputSize: 36 * 6, outputSize: 10) | ||
} | ||
|
||
// Below should be (eventually) auto-generated by the Swift compiler. | ||
|
||
extension MyModel: DifferentiableStructural { | ||
// TODO: figure out why these didn't get automatically inferred. | ||
public typealias Input = Tensor<Float> | ||
public typealias Output = Tensor<Float> | ||
public typealias SequentialInput = Input | ||
public typealias SequentialOutput = Output | ||
|
||
public typealias StructuralRepresentation = | ||
StructuralStruct< | ||
StructuralCons<StructuralProperty<Conv2D<Float>>, | ||
StructuralCons<StructuralProperty<MaxPool2D<Float>>, | ||
StructuralCons<StructuralProperty<Flatten<Float>>, | ||
StructuralCons<StructuralProperty<Dense<Float>>, | ||
StructuralEmpty>>>>> | ||
|
||
@differentiable | ||
public init(differentiableStructuralRepresentation: StructuralRepresentation) { | ||
fatalError() | ||
} | ||
|
||
@derivative(of: init(differentiableStructuralRepresentation:)) | ||
public static func _vjp_init(differentiableStructuralRepresentation: StructuralRepresentation) | ||
-> (value: Self, pullback: (TangentVector) -> StructuralRepresentation.TangentVector) | ||
{ | ||
fatalError() | ||
} | ||
|
||
@differentiable | ||
public var differentiableStructuralRepresentation: StructuralRepresentation { | ||
fatalError() | ||
} | ||
|
||
@derivative(of: differentiableStructuralRepresentation) | ||
public func _vjp_differentiableStructuralRepresentation() | ||
-> (value: StructuralRepresentation, pullback: (StructuralRepresentation.TangentVector) -> TangentVector) | ||
{ | ||
fatalError() | ||
} | ||
} | ||
|
||
extension MyModelSkipping: DifferentiableStructural { | ||
// TODO: figure out why these didn't get automatically inferred. | ||
public typealias Input = Tensor<Float> | ||
public typealias Output = Tensor<Float> | ||
public typealias SequentialInput = Input | ||
public typealias SequentialOutput = Output | ||
|
||
public typealias StructuralRepresentation = | ||
StructuralStruct< | ||
StructuralCons<StructuralProperty<Conv2D<Float>>, | ||
StructuralCons<StructuralProperty<MaxPool2D<Float>>, | ||
StructuralCons<StructuralProperty<Flatten<Float>>, | ||
StructuralCons<StructuralProperty<SequentialSkip<Dense<Float>, Tensor<Float>>>, | ||
StructuralCons<StructuralProperty<Dense<Float>>, | ||
StructuralEmpty>>>>>> | ||
|
||
@differentiable | ||
public init(differentiableStructuralRepresentation: StructuralRepresentation) { | ||
fatalError() | ||
} | ||
|
||
@derivative(of: init(differentiableStructuralRepresentation:)) | ||
public static func _vjp_init(differentiableStructuralRepresentation: StructuralRepresentation) | ||
-> (value: Self, pullback: (TangentVector) -> StructuralRepresentation.TangentVector) | ||
{ | ||
fatalError() | ||
} | ||
|
||
@differentiable | ||
public var differentiableStructuralRepresentation: StructuralRepresentation { | ||
fatalError() | ||
} | ||
|
||
@derivative(of: differentiableStructuralRepresentation) | ||
public func _vjp_differentiableStructuralRepresentation() | ||
-> (value: StructuralRepresentation, pullback: (StructuralRepresentation.TangentVector) -> TangentVector) | ||
{ | ||
fatalError() | ||
} | ||
} | ||
|
||
extension MyResidualModel: DifferentiableStructural { | ||
// TODO: figure out why these didn't get automatically inferred. | ||
public typealias Input = Tensor<Float> | ||
public typealias Output = Tensor<Float> | ||
public typealias SequentialInput = Input | ||
public typealias SequentialOutput = Output | ||
|
||
public typealias StructuralRepresentation = | ||
StructuralStruct< | ||
StructuralCons<StructuralProperty<Conv2D<Float>>, | ||
StructuralCons<StructuralProperty<MaxPool2D<Float>>, | ||
StructuralCons<StructuralProperty<Flatten<Float>>, | ||
StructuralCons<StructuralProperty<ResidualConnection<Dense<Float>>>, | ||
StructuralCons<StructuralProperty<Dense<Float>>, | ||
StructuralEmpty>>>>>> | ||
|
||
@differentiable | ||
public init(differentiableStructuralRepresentation: StructuralRepresentation) { | ||
fatalError() | ||
} | ||
|
||
@derivative(of: init(differentiableStructuralRepresentation:)) | ||
public static func _vjp_init(differentiableStructuralRepresentation: StructuralRepresentation) | ||
-> (value: Self, pullback: (TangentVector) -> StructuralRepresentation.TangentVector) | ||
{ | ||
fatalError() | ||
} | ||
|
||
@differentiable | ||
public var differentiableStructuralRepresentation: StructuralRepresentation { | ||
fatalError() | ||
} | ||
|
||
@derivative(of: differentiableStructuralRepresentation) | ||
public func _vjp_differentiableStructuralRepresentation() | ||
-> (value: StructuralRepresentation, pullback: (StructuralRepresentation.TangentVector) -> TangentVector) | ||
{ | ||
fatalError() | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import TensorFlow | ||
import StructuralCore | ||
import PenguinStructures | ||
|
||
// TODO: Pick a better name. | ||
// TODO: Consider splitting the inductive cases into a separate protocol. | ||
/// A layer composed of a sequential application of its constituent field layers. | ||
public protocol SequentialLayer: Differentiable { | ||
associatedtype SequentialInput: Differentiable // TODO: support embedding layers. | ||
associatedtype SequentialOutput: Differentiable | ||
|
||
@differentiable(wrt: (self, input)) | ||
func sequentialApply(_ input: SequentialInput) -> SequentialOutput | ||
} | ||
|
||
extension SequentialLayer | ||
where | ||
Self: DifferentiableStructural & Layer, | ||
Self.StructuralRepresentation: SequentialLayer, | ||
SequentialInput == StructuralRepresentation.SequentialInput, | ||
SequentialOutput == StructuralRepresentation.SequentialOutput, | ||
SequentialInput == Input, | ||
SequentialOutput == Output | ||
{ | ||
@differentiable | ||
public func sequentialApply(_ input: SequentialInput) -> SequentialOutput { | ||
self.differentiableStructuralRepresentation.sequentialApply(input) | ||
} | ||
@differentiable | ||
public func callAsFunction(_ input: Input) -> Output { | ||
self.sequentialApply(input) | ||
} | ||
} | ||
|
||
extension StructuralCons: SequentialLayer where Value: SequentialLayer, Next: SequentialLayer, Next.SequentialInput == Value.SequentialOutput { | ||
|
||
public typealias SequentialInput = Value.SequentialInput | ||
public typealias SequentialOutput = Next.SequentialOutput | ||
|
||
@differentiable | ||
public func sequentialApply(_ input: SequentialInput) -> SequentialOutput { | ||
let tmp = value.sequentialApply(input) | ||
return next.sequentialApply(tmp) | ||
} | ||
} | ||
|
||
extension StructuralProperty: SequentialLayer where Value: Layer { | ||
public typealias SequentialInput = Value.Input | ||
public typealias SequentialOutput = Value.Output | ||
|
||
@differentiable | ||
public func sequentialApply(_ input: SequentialInput) -> SequentialOutput { value(input) } | ||
} | ||
|
||
extension StructuralStruct: SequentialLayer where Properties: SequentialLayer { | ||
public typealias SequentialInput = Properties.SequentialInput | ||
public typealias SequentialOutput = Properties.SequentialOutput | ||
@differentiable | ||
public func sequentialApply(_ input: SequentialInput) -> SequentialOutput { properties.sequentialApply(input) } | ||
} | ||
|
||
extension StructuralEmpty: SequentialLayer { | ||
public typealias SequentialInput = Tensor<Float> // BAD! | ||
public typealias SequentialOutput = Tensor<Float> // BAD! | ||
|
||
@differentiable | ||
public func sequentialApply(_ input: SequentialInput) -> SequentialOutput { input } | ||
} | ||
|
||
/// Allows skipping a field in a SequentialLayer | ||
/// | ||
/// - SeeAlso: `SequentialLayer`. | ||
@propertyWrapper | ||
public struct SequentialSkip<Underlying, PassingType: Differentiable>: KeyPathIterable { | ||
public var wrappedValue: Underlying | ||
|
||
public init(wrappedValue: Underlying, passing passingType: Type<PassingType>) { | ||
self.wrappedValue = wrappedValue | ||
} | ||
} | ||
|
||
extension SequentialSkip: Differentiable { | ||
// public typealias TangentVector = Empty | ||
public typealias TangentVector = Empty2 | ||
public mutating func move(along direction: TangentVector) {} | ||
public var zeroTangentVectorInitializer: () -> TangentVector { { () in TangentVector() } } | ||
} | ||
|
||
extension SequentialSkip: EuclideanDifferentiable { | ||
public var differentiableVectorView: Self.TangentVector { TangentVector() } | ||
} | ||
|
||
// See error in `StructuralDifferentiability.swift` regarding `ElementaryFunctions` conformance for `Empty`. | ||
extension SequentialSkip: Layer { | ||
public typealias SequentialInput = PassingType | ||
public typealias SequentialOutput = PassingType | ||
|
||
@differentiable | ||
public func callAsFunction(_ input: SequentialInput) -> SequentialOutput { input } | ||
} | ||
|
||
// Work around compiler error regarding retroactive conformance to `ElementaryFunctions` (See StructuralDifferentability.) | ||
public struct Empty2: Differentiable, EuclideanDifferentiable, KeyPathIterable, PointwiseMultiplicative, ElementaryFunctions, VectorProtocol { | ||
public typealias VectorSpaceScalar = Float | ||
public func adding(_ x: Self.VectorSpaceScalar) -> Self { Self() } | ||
public func subtracting(_ x: Self.VectorSpaceScalar) -> Self { Self() } | ||
public func scaled(by scalar: Self.VectorSpaceScalar) -> Self { Self() } | ||
} | ||
|
||
// TODO: Make merge function configurable? | ||
@propertyWrapper | ||
public struct ResidualConnection<Underlying: Layer>: Layer where Underlying.Input == Underlying.Output, Underlying.Input == Tensor<Float> { // TODO: generalize beyond Tensor<Float>! | ||
public var wrappedValue: Underlying | ||
|
||
public init(wrappedValue: Underlying) { | ||
self.wrappedValue = wrappedValue | ||
} | ||
|
||
@differentiable | ||
public func callAsFunction(_ input: Underlying.Input) -> Underlying.Output { | ||
let tmp = wrappedValue(input) | ||
return input + tmp | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import PenguinStructures | ||
import StructuralCore | ||
import TensorFlow | ||
|
||
// TODO: Pick a better name. | ||
/// A layer that propagates shapes during initialization. | ||
public protocol ShapePropagatingLayer: SequentialLayer { | ||
// associatedtype ShapeTracker: FixedSizeArray // TODO: generalize shape tracking? | ||
|
||
/// Initializes `self`, updating `shapeTracker` to reflect the new output shape. | ||
init(shapeTracker: inout Tensor<Float>) // TODO: different shape tracking type? | ||
} | ||
|
||
extension ShapePropagatingLayer | ||
where | ||
Self: DifferentiableStructural & Layer, | ||
Self.StructuralRepresentation: ShapePropagatingLayer, | ||
SequentialInput == StructuralRepresentation.SequentialInput, | ||
SequentialOutput == StructuralRepresentation.SequentialOutput, | ||
SequentialInput == Input, | ||
SequentialOutput == Output | ||
{ | ||
public init(shapeTracker: inout Tensor<Float>) { | ||
self.init(structuralRepresentation: StructuralRepresentation(shapeTracker: &shapeTracker)) | ||
} | ||
} | ||
|
||
// Inductive cases | ||
|
||
extension StructuralCons: ShapePropagatingLayer | ||
where | ||
Value: ShapePropagatingLayer, | ||
Next: ShapePropagatingLayer, | ||
Next.SequentialInput == Value.SequentialOutput | ||
{ | ||
public init(shapeTracker: inout Tensor<Float>) { | ||
let value = Value(shapeTracker: &shapeTracker) | ||
let next = Next(shapeTracker: &shapeTracker) | ||
self.init(value, next) | ||
} | ||
} | ||
|
||
extension StructuralProperty: ShapePropagatingLayer where Value: ShapePropagatingLayer & Layer { | ||
public init(shapeTracker: inout Tensor<Float>) { | ||
self.init(Value(shapeTracker: &shapeTracker)) | ||
} | ||
} | ||
|
||
extension StructuralStruct: ShapePropagatingLayer where Properties: ShapePropagatingLayer { | ||
public init(shapeTracker: inout Tensor<Float>) { | ||
self.init(Properties(shapeTracker: &shapeTracker)) | ||
} | ||
} | ||
|
||
extension StructuralEmpty: ShapePropagatingLayer { | ||
public init(shapeTracker: inout Tensor<Float>) { self.init() } | ||
} | ||
|
||
// TODO: HParam property wrapper? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so
passing
is the type of the previous and following layer? I guess it can't be inferred since we're using properties.