Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[WIP]: Idea exploring easier layer init. #883

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Sources/TensorFlow/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ extension Layer {
}
}

// TODO: clean up the Input & Output requirements.
public protocol ShapedLayer: Layer where Input == Tensor<Float>, Output == Tensor<Float> {
/// The hyper parameters required to initialize `Self`.
associatedtype HyperParameters

/// Initializes `self` with the given hyper parameters to process inputs shaped `inputShape`.
init(hparams: HyperParameters, inputShape: TensorShape)

init(hparams: HyperParameters, _ shapeTrackingTensor: inout Tensor<Float>)
}

extension ShapedLayer {
public init(hparams: HyperParameters, _ shapeTrackingTensor: inout Tensor<Float>) {
self.init(hparams: hparams, inputShape: shapeTrackingTensor.shape)
shapeTrackingTensor = self(shapeTrackingTensor)
}
}

/// An empty struct representing empty `TangentVector`s for parameterless layers.
public struct EmptyTangentVector: EuclideanDifferentiable, VectorProtocol, ElementaryFunctions,
PointwiseMultiplicative, KeyPathIterable
Expand Down
30 changes: 30 additions & 0 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,36 @@ extension Conv2D {
}
}

extension Conv2D: ShapedLayer where Scalar == Float {
public struct HyperParameters {
// TODO: FIX ME! (Add strides, padding, etc.)
let filterHeight: Int
let filterWidth: Int
let outputChannels: Int

public init(filterHeight: Int, filterWidth: Int, outputChannels: Int) {
self.filterHeight = filterHeight
self.filterWidth = filterWidth
self.outputChannels = outputChannels
}

public init(_ height: Int, _ width: Int? = nil, channels: Int) {
self.filterHeight = height
self.filterWidth = width ?? height // Default to square.
self.outputChannels = channels
}
}

public init(hparams: HyperParameters, inputShape: TensorShape) {
precondition(inputShape.rank == 4, "Unexpected input shape: \(inputShape).") // THIS REQUIRES WORKING VMAP!
let inputChannels = inputShape[3] // Assuming channels last.
let filterShape = (
hparams.filterHeight, hparams.filterWidth, inputChannels, hparams.outputChannels
)
self.init(filterShape: filterShape)
}
}

/// A 3-D convolution layer for spatial/spatio-temporal convolution over images.
///
/// This layer creates a convolution filter that is convolved with the layer input to produce a
Expand Down
7 changes: 7 additions & 0 deletions Sources/TensorFlow/Layers/Core.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ public struct Flatten<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
}
}

extension Flatten: ShapedLayer where Scalar == Float {
public init(hparams: (), inputShape: TensorShape) {
precondition(inputShape.rank > 1, "Unexpected shape: \(inputShape); must have rank > 1.")
self.init()
}
}

/// A reshape layer.
@frozen
public struct Reshape<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
Expand Down
23 changes: 23 additions & 0 deletions Sources/TensorFlow/Layers/Dense.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,26 @@ extension Dense {
activation: activation)
}
}

extension Dense: ShapedLayer where Scalar == Float {
public struct HyperParameters {
let outputSize: Int
let useBias: Bool
let activation: Activation
// TODO: figure out how to handle random init of params from a given seed / etc.

public init(outputSize: Int, useBias: Bool = false, activation: @escaping Activation = identity)
{
self.outputSize = outputSize
self.useBias = useBias
self.activation = activation
}
}

public init(hparams: HyperParameters, inputShape: TensorShape) {
precondition(inputShape.rank == 2, "Wrong input shape; got \(inputShape)")
self.init(
inputSize: inputShape[1], outputSize: hparams.outputSize, activation: hparams.activation,
useBias: hparams.useBias)
}
}
49 changes: 49 additions & 0 deletions Tests/TensorFlowTests/ShapedLayerExampleTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest

@testable import TensorFlow

final class ShapedLayerExampleTests: XCTestCase {
func testSimple() {
let input = Tensor<Float>(zeros: [12, 28, 28, 1]) // MNIST-size; batch size 12.
let model = SimpleModel(forExample: input)
let output = model(input) // Should run without shape errors.
XCTAssertEqual([12, 10], output.shape)
}

static var allTests = [
("testSimple", testSimple)
]
}

extension ShapedLayerExampleTests {
struct SimpleModel: Layer {
var conv: Conv2D<Float>
var flatten: Flatten<Float>
var dense: Dense<Float>

init(forExample sampleInput: Tensor<Float>) {
var se = sampleInput // se == shapeExample
conv = Conv2D(hparams: .init(3, channels: 5), &se)
flatten = Flatten(hparams: (), &se)
dense = Dense(hparams: .init(outputSize: 10), &se)
}

func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
input.sequenced(through: conv, flatten, dense)
}
}
}
1 change: 1 addition & 0 deletions Tests/TensorFlowTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import XCTest
testCase(RuntimeTests.allTests),
testCase(SequencedTests.allTests),
testCase(SequentialTests.allTests),
testCase(ShapedLayerExampleTests.allTests),
testCase(TensorAutoDiffTests.allTests),
testCase(TensorGroupTests.allTests),
testCase(TensorAutoDiffTests.allTests),
Expand Down