Skip to content

[SR-13698] Conditionally conform Optional to Differentiable #53072

Closed
@dan-zheng

Description

@dan-zheng
Previous ID SR-13698
Radar None
Original Reporter @dan-zheng
Type Sub-task
Status Closed
Resolution Done
Additional Detail from JIRA
Votes 1
Component/s
Labels Sub-task
Assignee @dan-zheng
Priority Medium

md5: 13b8627652d2f04d11be34db0a4c54ae

Parent-Task:

blocks:

  • SR-13700 Differentiation transform: support optional-related operations

is duplicated by:

  • TF-365 [AD] Make Optional conditionally conform to Differentiable

Issue Description:

Optional can conditionally conform to Differentiable when the Wrapped type does.

extension Optional : Differentiable where Wrapped : Differentaible {
  ...
}

Changes should be made in Optional.swift.


Prototype:

extension Optional : Differentiable where Wrapped : Differentiable {
  public enum DifferentiableView : Equatable, AdditiveArithmetic, Differentiable {
    case none
    case some(Wrapped.TangentVector)

    public static var zero: Self { .some(Wrapped.TangentVector.zero) }
    public static func + (lhs: Self, rhs: Self) -> Self {
      switch (lhs, rhs) {
        case (.none, .none): return .none
        case let (x, .none): return x
        case let (.none, y): return y
        case let (.some(x), .some(y)): return .some(x + y)
      }
    }

    public static func - (lhs: Self, rhs: Self) -> Self {
      switch (lhs, rhs) {
        case (.none, .none): return .none
        case let (x, .none): return x
        case let (.none, .some(y)): return .some(.zero - y)
        case let (.some(x), .some(y)): return .some(x - y)
      }
    }

    public typealias TangentVector = DifferentiableView
    public typealias AllDifferentiableVariables = Self
    public var allDifferentiableVariables: AllDifferentiableVariables {
      get { self }
      set { self = newValue }
    }
    public mutating func move(along direction: TangentVector) {
      switch (self, direction) {
      case (_, .none): return
      case let (.none, y): self = y
      case let (.some(x), .some(y)):
        var wrapped = x
        wrapped.move(along: y)
        self = .some(wrapped)
      }
    }
  }

  public typealias TangentVector = DifferentiableView
  public typealias AllDifferentiableVariables = Self
  public var allDifferentiableVariables: AllDifferentiableVariables {
    get { self }
    set { self = newValue }
  }
  public mutating func move(along direction: TangentVector) {
    switch (self, direction) {
    case (_, .none): return
    case (.none, _): fatalError("Move to move `.none`?")
    case let (.some(x), .some(y)):
      var wrapped = x
      wrapped.move(along: y)
      self = .some(wrapped)
    }
  }
}

Optional has a separate TangentVector type called DifferentiableView type, similar to Array.DifferentiableView. This is important to avoid conforming Optional to AdditiveArithmetic.


Note that differentiation of active Optional values is blocked by TF-583 (support for active enum values).

func loop_array(_ array: [Float]) -> Float {
  var result: Float = 1
  // for-in loop generates an active `Float?` value.
  for x in array {
    result = result * x
  }
  return result
}

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions