Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Structural generic layers 2 #650

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .vscode/tasks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
// See https://go.microsoft.com/fwlink/?LinkId=733558
// for the documentation about the tasks.json format
"version": "2.0.0",
"tasks": [
{
"label": "swift-build",
"type": "shell",
"command": "/usr/local/google/home/saeta/tmp/toolchains/1165/usr/bin/swift build",
"problemMatcher": [],
"group": {
"kind": "build",
"isDefault": true
}
}
]
}
3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ let package = Package(
.package(url: "https://github.com/apple/swift-protobuf.git", from: "1.10.0"),
.package(url: "https://github.com/apple/swift-argument-parser", .upToNextMinor(from: "0.2.0")),
.package(url: "https://github.com/google/swift-benchmark", .branch("master")),
.package(url: "https://github.com/google/swift-structural.git", .branch("master")),
.package(url: "https://github.com/saeta/penguin.git", .branch("master")),
],
targets: [
.target(name: "StructuralModelBuilding", dependencies: ["StructuralCore", "PenguinStructures"], path: "StructuralModelBuilding"),
.target(
name: "Checkpoints", dependencies: ["SwiftProtobuf", "ModelSupport"],
path: "Checkpoints"),
Expand Down
120 changes: 120 additions & 0 deletions StructuralModelBuilding/HParamInitExample.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import TensorFlow
import PenguinStructures


public struct MyInitModel {
var conv: Conv2D<Float>
var flatten: Flatten<Float>
var dense: Dense<Float>
}
// See below for explicit conformances to `DifferentiableStructural` and `StaticKeyPathIterable`

// Thanks to `DifferentiableStructural` conformances, we can derive these protocols automagically!
extension MyInitModel: HParamInitLayer, Layer, SequentialLayer {
// Must specify typealiases because they are not inferred automatically. :-(
public typealias Input = Tensor<Float>
public typealias Output = Tensor<Float>
public typealias SequentialInput = Input
public typealias SequentialOutput = Output

public typealias HParam = StaticStructuralRepresentation.HParam
public init(hparam: HParam, inputExample: Input) {fatalError()} // TODO: Infer automatically from.
}

func sampleModelUsage() -> MyInitModel {
fatalError("TODO: WRITE ME OUT!")
}

public struct MyInitModelExplicit: HParamInitLayer, StaticKeyPathIterable {
var conv: Conv2D<Float>
var flatten: Flatten<Float>
var dense: Dense<Float>

public typealias StaticKeyPaths = [PartialKeyPath<Self>]
public static var staticKeyPaths = [
\Self.conv,
\Self.flatten,
\Self.dense,
]

public typealias HParam = StructuralHParams<Self,
HParamCons<Self, HParamHolder<Self, Conv2D<Float>>,
HParamCons<Self, HParamHolder<Self, Flatten<Float>>,
HParamHolder<Self, Dense<Float>>>>>

public init(hparam: HParam, inputExample: Tensor<Float>) {
var tmp = inputExample
conv = .init(hparam: hparam.conv!, inputExample: tmp)
tmp = conv(tmp) // Move forward.
flatten = .init(hparam: hparam.flatten!, inputExample: tmp)
tmp = flatten(tmp) // Move forward.
dense = .init(hparam: hparam.dense!, inputExample: tmp)
}

@differentiable
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
dense(flatten(conv(input)))
}
}

func makeExplicitModel() -> MyInitModelExplicit {
var hparams = MyInitModelExplicit.HParam()
hparams.conv = .init(height: 3, width: 3, channels: 10) // Fully typesafe!
hparams.dense = .init(size: 10)

return hparams.build(for: Tensor<Float>(zeros: [5, 28, 28, 1]))
}

// TODO: Figure out how StaticKeyPathIterable's capabilities can be constructed just from `Structural`.
extension MyInitModel: StaticKeyPathIterable {
public typealias StaticKeyPaths = [PartialKeyPath<Self>]
public static var staticKeyPaths = [
\Self.conv,
\Self.flatten,
\Self.dense,
]
}

extension MyInitModel: DifferentiableStructural {

public typealias StaticStructuralRepresentation =
StaticStructuralStruct<MyInitModel,
StructuralCons<StaticStructuralProperty<MyInitModel, Conv2D<Float>>,
StructuralCons<StaticStructuralProperty<MyInitModel, Flatten<Float>>,
StaticStructuralProperty<MyInitModel, Dense<Float>>
>>>

public static var staticStructuralRepresentation: StaticStructuralRepresentation { fatalError() }

public typealias StructuralRepresentation =
StructuralStruct<MyInitModel,
StructuralCons<StructuralProperty<MyInitModel, Conv2D<Float>>,
StructuralCons<StructuralProperty<MyInitModel, Flatten<Float>>,
StructuralProperty<MyInitModel, Dense<Float>>
>>>

@differentiable
public init(differentiableStructuralRepresentation: StructuralRepresentation) {
fatalError()
}

@derivative(of: init(differentiableStructuralRepresentation:))
public static func _vjp_init(differentiableStructuralRepresentation: StructuralRepresentation)
-> (value: Self, pullback: (TangentVector) -> StructuralRepresentation.TangentVector)
{
fatalError()
}

@differentiable
public var differentiableStructuralRepresentation: StructuralRepresentation {
get { fatalError() }
set { fatalError() }
}

@derivative(of: differentiableStructuralRepresentation)
public func _vjp_differentiableStructuralRepresentation()
-> (value: StructuralRepresentation, pullback: (StructuralRepresentation.TangentVector) -> TangentVector)
{
fatalError()
}
}
221 changes: 221 additions & 0 deletions StructuralModelBuilding/HParamInitLayer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import TensorFlow
import PenguinStructures

/// A layer that can be initialized with some hyperparameters and a representative input.
///
/// The HParam represents hyperparameters to the layer, and the input example encodes shape
/// information / etc.
public protocol HParamInitLayer: Layer {
/// Hyper parameters for `self`.
associatedtype HParam

init(hparam: HParam, inputExample: Input)
}

/// Like KeyPathIterable, but available statically!
public protocol StaticKeyPathIterable {
associatedtype StaticKeyPaths: Collection where StaticKeyPaths.Element == PartialKeyPath<Self>
static var staticKeyPaths: StaticKeyPaths { get }
}

// Retroactive conformances.

extension Conv2D: HParamInitLayer {
public struct HParam {
public init(height: Int, width: Int, channels: Int) {
self.height = height
self.width = width
self.channels = channels
}

/// The height of the filter.
var height: Int
/// The width of the filter.
var width: Int
/// The number of output channels.
var channels: Int

// The number of input channels is automatically inferred from the provided input tensor.

/// The strides of the sliding window for spatial dimensions.
var strides = (1, 1)

var padding: Padding = .valid

// TODO: Add others (e.g. initializers, usebias, dialation, etc).
}

public init(hparam: HParam, inputExample: Tensor<Scalar>) {
precondition(inputExample.shape.count == 4)
let inputChannelCount = inputExample.shape[3] // Assumes channels last.

self.init(
filterShape: (hparam.height, hparam.width, inputChannelCount, hparam.channels),
strides: hparam.strides,
padding: hparam.padding)
}
}

extension Flatten: HParamInitLayer {
public typealias HParam = Empty

public init(hparam: HParam, inputExample: Tensor<Scalar>) {
self.init()
}
}

extension Dense: HParamInitLayer {
public struct HParam {
public init(size: Int) {
self.size = size
}

/// The output size of the Dense layer.
var size: Int

// Input size comes from the inputExample.
// TODO: Add extra hparams here!
}

public init(hparam: HParam, inputExample: Tensor<Scalar>) {
precondition(inputExample.shape.count == 2, "input example must be a matrix, got: \(inputExample.shape)")
let inputSize = inputExample.shape[1] // [batch, inputSize]

self.init(inputSize: inputSize, outputSize: hparam.size)
}
}

public protocol _HParamHolderProtocol {
associatedtype Model: StaticKeyPathIterable

init(keyPaths: Model.StaticKeyPaths, index: Model.StaticKeyPaths.Index)

func getHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>) -> V.HParam?
mutating func storeHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>, value: V.HParam)
}

/// A struct that holds an optionally initialized hyperparameter for a constitutent layer.
public struct HParamHolder<Model: StaticKeyPathIterable, Layer: HParamInitLayer>: _HParamHolderProtocol {
/// The keyPath we use
var keyPath: KeyPath<Model, Layer>
var value: Layer.HParam?

public init(_ keyPath: KeyPath<Model, Layer>) {
self.keyPath = keyPath
}

public init(keyPaths: Model.StaticKeyPaths, index: Model.StaticKeyPaths.Index) {
let pkp = keyPaths[index]
guard let kp = pkp as? KeyPath<Model, Layer> else {
preconditionFailure("Key path \(pkp) at index \(index) in keyPaths \(keyPaths) not of expected type: \(KeyPath<Model, Layer>.self).")
}
self.keyPath = kp
}

public func getHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>) -> V.HParam? {
if keyPath == self.keyPath {
if let value = value {
return (value as! V.HParam)
} else {
// TODO: if V.HParam: DefaultInitializable, return that instead of crashing!
preconditionFailure("HParam for \(keyPath) has not yet been initialized!")
}
} else {
return nil
}
}

public mutating func storeHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>, value: V.HParam) {
if keyPath == self.keyPath {
self.value = (value as! Layer.HParam) // Force downcast to catch errors.
}
}
}

/// The Cons-list of `HParamHolder`s.
public struct HParamCons<Model, Value: _HParamHolderProtocol, Next: _HParamHolderProtocol>: _HParamHolderProtocol where Value.Model == Model, Next.Model == Model {
var value: Value
var next: Next

public init(keyPaths: Model.StaticKeyPaths, index: Model.StaticKeyPaths.Index) {
value = .init(keyPaths: keyPaths, index: index)
let nextIndex = keyPaths.index(after: index)
next = .init(keyPaths: keyPaths, index: nextIndex)
}

public mutating func storeHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>, value: V.HParam) {
self.value.storeHParam(keyPath: keyPath, value: value)
next.storeHParam(keyPath: keyPath, value: value)
}

public func getHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>) -> V.HParam? {
value.getHParam(keyPath: keyPath) ?? next.getHParam(keyPath: keyPath)
}
}

@dynamicMemberLookup
public struct StructuralHParams<Layer: HParamInitLayer /*& StaticKeyPathIterable*/, Values: _HParamHolderProtocol>: DefaultInitializable where Values.Model == Layer {
public init() {
let keyPaths = Layer.staticKeyPaths
values = .init(keyPaths: keyPaths, index: keyPaths.startIndex)
}

var values: Values

subscript<T: HParamInitLayer>(dynamicMember keyPath: KeyPath<Layer, T>) -> T.HParam? {
get { values.getHParam(keyPath: keyPath) }
set {
guard let newValue = newValue else { fatalError("Cannot unset hparam values! (Attempted to unset keyPath: \(keyPath)") }
values.storeHParam(keyPath: keyPath, value: newValue)
}
}
}

extension StructuralHParams where Layer.HParam == Self {
public func build(for exampleInput: Layer.Input) -> Layer {
return .init(hparam: self, inputExample: exampleInput)
}
}


// Build type using StaticStructural!

extension StructuralCons: _HParamHolderProtocol where Value: _HParamHolderProtocol, Next: _HParamHolderProtocol, Value.Model == Next.Model {
public typealias Model = Value.Model

public init(keyPaths: Model.StaticKeyPaths, index: Model.StaticKeyPaths.Index) {
let nextIndex = keyPaths.index(after: index)
self.init(
.init(keyPaths: keyPaths, index: index),
.init(keyPaths: keyPaths, index: nextIndex))
}

public mutating func storeHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>, value: V.HParam) {
self.value.storeHParam(keyPath: keyPath, value: value)
next.storeHParam(keyPath: keyPath, value: value)
}

public func getHParam<V: HParamInitLayer>(keyPath: KeyPath<Model, V>) -> V.HParam? {
value.getHParam(keyPath: keyPath) ?? next.getHParam(keyPath: keyPath)
}
}

public protocol _HParamHolderStructural: BaseTypeProtocol {
associatedtype HParam: _HParamHolderProtocol where HParam.Model == BaseType
}

extension StaticStructuralStruct where Properties: _HParamHolderStructural, BaseType: HParamInitLayer, Properties.HParam.Model == BaseType {
public typealias HParam = StructuralHParams<BaseType, Properties.HParam>
}

extension StaticStructuralProperty: _HParamHolderStructural where BaseType: StaticKeyPathIterable, Value: HParamInitLayer {
public typealias HParam = HParamHolder<BaseType, Value>
}

extension StructuralCons: _HParamHolderStructural where Value: _HParamHolderStructural, Next: _HParamHolderStructural, Value.BaseType == Next.BaseType {
public typealias HParam = HParamCons<BaseType, Value.HParam, Next.HParam>
}

extension HParamInitLayer where Self: Structural, Self.StaticStructuralRepresentation: _HParamHolderStructural {
// public typealias HParam = Self.StaticStructuralRepresentation.HParam // Note: this causes problems! :-(
}
Loading